fal 1.2.1__py3-none-any.whl → 1.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

Files changed (45) hide show
  1. fal/__main__.py +3 -1
  2. fal/_fal_version.py +2 -2
  3. fal/api.py +88 -20
  4. fal/app.py +221 -27
  5. fal/apps.py +147 -3
  6. fal/auth/__init__.py +50 -2
  7. fal/cli/_utils.py +40 -0
  8. fal/cli/apps.py +5 -3
  9. fal/cli/create.py +26 -0
  10. fal/cli/deploy.py +97 -16
  11. fal/cli/main.py +2 -2
  12. fal/cli/parser.py +11 -7
  13. fal/cli/run.py +12 -1
  14. fal/cli/runners.py +44 -0
  15. fal/config.py +23 -0
  16. fal/container.py +1 -1
  17. fal/exceptions/__init__.py +7 -1
  18. fal/exceptions/_base.py +51 -0
  19. fal/exceptions/_cuda.py +44 -0
  20. fal/files.py +81 -0
  21. fal/sdk.py +67 -6
  22. fal/toolkit/file/file.py +103 -13
  23. fal/toolkit/file/providers/fal.py +572 -24
  24. fal/toolkit/file/providers/gcp.py +8 -1
  25. fal/toolkit/file/providers/r2.py +8 -1
  26. fal/toolkit/file/providers/s3.py +80 -0
  27. fal/toolkit/file/types.py +28 -3
  28. fal/toolkit/image/__init__.py +71 -0
  29. fal/toolkit/image/image.py +25 -2
  30. fal/toolkit/image/nsfw_filter/__init__.py +11 -0
  31. fal/toolkit/image/nsfw_filter/env.py +9 -0
  32. fal/toolkit/image/nsfw_filter/inference.py +77 -0
  33. fal/toolkit/image/nsfw_filter/model.py +18 -0
  34. fal/toolkit/image/nsfw_filter/requirements.txt +4 -0
  35. fal/toolkit/image/safety_checker.py +107 -0
  36. fal/toolkit/types.py +140 -0
  37. fal/toolkit/utils/download_utils.py +4 -0
  38. fal/toolkit/utils/retry.py +45 -0
  39. fal/utils.py +20 -4
  40. fal/workflows.py +10 -4
  41. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/METADATA +47 -40
  42. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/RECORD +45 -30
  43. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/WHEEL +1 -1
  44. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/entry_points.txt +0 -0
  45. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/top_level.txt +0 -0
fal/__main__.py CHANGED
@@ -1,4 +1,6 @@
1
+ import sys
2
+
1
3
  from .cli import main
2
4
 
3
5
  if __name__ == "__main__":
4
- main()
6
+ sys.exit(main())
fal/_fal_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '1.2.1'
16
- __version_tuple__ = version_tuple = (1, 2, 1)
15
+ __version__ = version = '1.7.2'
16
+ __version_tuple__ = version_tuple = (1, 7, 2)
fal/api.py CHANGED
@@ -44,7 +44,13 @@ from typing_extensions import Concatenate, ParamSpec
44
44
  import fal.flags as flags
45
45
  from fal._serialization import include_modules_from, patch_pickle
46
46
  from fal.container import ContainerImage
47
- from fal.exceptions import FalServerlessException
47
+ from fal.exceptions import (
48
+ AppException,
49
+ CUDAOutOfMemoryException,
50
+ FalServerlessException,
51
+ FieldException,
52
+ )
53
+ from fal.exceptions._cuda import _is_cuda_oom_exception
48
54
  from fal.logging.isolate import IsolateLogPrinter
49
55
  from fal.sdk import (
50
56
  FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
@@ -70,6 +76,8 @@ SERVE_REQUIREMENTS = [
70
76
  f"pydantic=={pydantic_version}",
71
77
  "uvicorn",
72
78
  "starlette_exporter",
79
+ "structlog",
80
+ "tomli",
73
81
  ]
74
82
 
75
83
 
@@ -164,6 +172,7 @@ class Host(Generic[ArgsT, ReturnT]):
164
172
  application_name: str | None = None,
165
173
  application_auth_mode: Literal["public", "shared", "private"] | None = None,
166
174
  metadata: dict[str, Any] | None = None,
175
+ scale: bool = True,
167
176
  ) -> str | None:
168
177
  """Register the given function on the host for API call execution."""
169
178
  raise NotImplementedError
@@ -383,12 +392,15 @@ class FalServerlessHost(Host):
383
392
  _SUPPORTED_KEYS = frozenset(
384
393
  {
385
394
  "machine_type",
395
+ "machine_types",
396
+ "num_gpus",
386
397
  "keep_alive",
387
398
  "max_concurrency",
388
399
  "min_concurrency",
389
400
  "max_multiplexing",
390
401
  "setup_function",
391
402
  "metadata",
403
+ "request_timeout",
392
404
  "_base_image",
393
405
  "_scheduler",
394
406
  "_scheduler_options",
@@ -419,25 +431,28 @@ class FalServerlessHost(Host):
419
431
  application_name: str | None = None,
420
432
  application_auth_mode: Literal["public", "shared", "private"] | None = None,
421
433
  metadata: dict[str, Any] | None = None,
434
+ deployment_strategy: Literal["recreate", "rolling"] = "recreate",
435
+ scale: bool = True,
422
436
  ) -> str | None:
423
437
  environment_options = options.environment.copy()
424
438
  environment_options.setdefault("python_version", active_python())
425
439
  environments = [self._connection.define_environment(**environment_options)]
426
440
 
427
- machine_type = options.host.get(
441
+ machine_type: list[str] | str = options.host.get(
428
442
  "machine_type", FAL_SERVERLESS_DEFAULT_MACHINE_TYPE
429
443
  )
430
444
  keep_alive = options.host.get("keep_alive", FAL_SERVERLESS_DEFAULT_KEEP_ALIVE)
431
- max_concurrency = options.host.get("max_concurrency")
432
- min_concurrency = options.host.get("min_concurrency")
433
- max_multiplexing = options.host.get("max_multiplexing")
434
445
  base_image = options.host.get("_base_image", None)
435
446
  scheduler = options.host.get("_scheduler", None)
436
447
  scheduler_options = options.host.get("_scheduler_options", None)
448
+ max_concurrency = options.host.get("max_concurrency")
449
+ min_concurrency = options.host.get("min_concurrency")
450
+ max_multiplexing = options.host.get("max_multiplexing")
437
451
  exposed_port = options.get_exposed_port()
438
-
452
+ request_timeout = options.host.get("request_timeout")
439
453
  machine_requirements = MachineRequirements(
440
- machine_type=machine_type,
454
+ machine_types=machine_type, # type: ignore
455
+ num_gpus=options.host.get("num_gpus"),
441
456
  keep_alive=keep_alive,
442
457
  base_image=base_image,
443
458
  exposed_port=exposed_port,
@@ -446,6 +461,7 @@ class FalServerlessHost(Host):
446
461
  max_multiplexing=max_multiplexing,
447
462
  max_concurrency=max_concurrency,
448
463
  min_concurrency=min_concurrency,
464
+ request_timeout=request_timeout,
449
465
  )
450
466
 
451
467
  partial_func = _prepare_partial_func(func)
@@ -471,6 +487,8 @@ class FalServerlessHost(Host):
471
487
  application_auth_mode=application_auth_mode,
472
488
  machine_requirements=machine_requirements,
473
489
  metadata=metadata,
490
+ deployment_strategy=deployment_strategy,
491
+ scale=scale,
474
492
  ):
475
493
  for log in partial_result.logs:
476
494
  self._log_printer.print(log)
@@ -493,7 +511,7 @@ class FalServerlessHost(Host):
493
511
  environment_options.setdefault("python_version", active_python())
494
512
  environments = [self._connection.define_environment(**environment_options)]
495
513
 
496
- machine_type = options.host.get(
514
+ machine_type: list[str] | str = options.host.get(
497
515
  "machine_type", FAL_SERVERLESS_DEFAULT_MACHINE_TYPE
498
516
  )
499
517
  keep_alive = options.host.get("keep_alive", FAL_SERVERLESS_DEFAULT_KEEP_ALIVE)
@@ -505,9 +523,11 @@ class FalServerlessHost(Host):
505
523
  scheduler_options = options.host.get("_scheduler_options", None)
506
524
  exposed_port = options.get_exposed_port()
507
525
  setup_function = options.host.get("setup_function", None)
526
+ request_timeout = options.host.get("request_timeout")
508
527
 
509
528
  machine_requirements = MachineRequirements(
510
- machine_type=machine_type,
529
+ machine_types=machine_type, # type: ignore
530
+ num_gpus=options.host.get("num_gpus"),
511
531
  keep_alive=keep_alive,
512
532
  base_image=base_image,
513
533
  exposed_port=exposed_port,
@@ -516,6 +536,7 @@ class FalServerlessHost(Host):
516
536
  max_multiplexing=max_multiplexing,
517
537
  max_concurrency=max_concurrency,
518
538
  min_concurrency=min_concurrency,
539
+ request_timeout=request_timeout,
519
540
  )
520
541
 
521
542
  return_value = _UNSET
@@ -676,10 +697,12 @@ def function(
676
697
  max_concurrency: int | None = None,
677
698
  # FalServerlessHost options
678
699
  metadata: dict[str, Any] | None = None,
679
- machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
700
+ machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
701
+ num_gpus: int | None = None,
680
702
  keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
681
703
  max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
682
704
  min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
705
+ request_timeout: int | None = None,
683
706
  setup_function: Callable[..., None] | None = None,
684
707
  _base_image: str | None = None,
685
708
  _scheduler: str | None = None,
@@ -701,10 +724,12 @@ def function(
701
724
  max_concurrency: int | None = None,
702
725
  # FalServerlessHost options
703
726
  metadata: dict[str, Any] | None = None,
704
- machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
727
+ machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
728
+ num_gpus: int | None = None,
705
729
  keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
706
730
  max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
707
731
  min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
732
+ request_timeout: int | None = None,
708
733
  setup_function: Callable[..., None] | None = None,
709
734
  _base_image: str | None = None,
710
735
  _scheduler: str | None = None,
@@ -776,10 +801,12 @@ def function(
776
801
  max_concurrency: int | None = None,
777
802
  # FalServerlessHost options
778
803
  metadata: dict[str, Any] | None = None,
779
- machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
804
+ machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
805
+ num_gpus: int | None = None,
780
806
  keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
781
807
  max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
782
808
  min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
809
+ request_timeout: int | None = None,
783
810
  setup_function: Callable[..., None] | None = None,
784
811
  _base_image: str | None = None,
785
812
  _scheduler: str | None = None,
@@ -806,10 +833,12 @@ def function(
806
833
  max_concurrency: int | None = None,
807
834
  # FalServerlessHost options
808
835
  metadata: dict[str, Any] | None = None,
809
- machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
836
+ machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
837
+ num_gpus: int | None = None,
810
838
  keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
811
839
  max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
812
840
  min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
841
+ request_timeout: int | None = None,
813
842
  setup_function: Callable[..., None] | None = None,
814
843
  _base_image: str | None = None,
815
844
  _scheduler: str | None = None,
@@ -830,10 +859,12 @@ def function(
830
859
  max_concurrency: int | None = None,
831
860
  # FalServerlessHost options
832
861
  metadata: dict[str, Any] | None = None,
833
- machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
862
+ machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
863
+ num_gpus: int | None = None,
834
864
  keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
835
865
  max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
836
866
  min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
867
+ request_timeout: int | None = None,
837
868
  setup_function: Callable[..., None] | None = None,
838
869
  _base_image: str | None = None,
839
870
  _scheduler: str | None = None,
@@ -854,10 +885,12 @@ def function(
854
885
  max_concurrency: int | None = None,
855
886
  # FalServerlessHost options
856
887
  metadata: dict[str, Any] | None = None,
857
- machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
888
+ machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
889
+ num_gpus: int | None = None,
858
890
  keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
859
891
  max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
860
892
  min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
893
+ request_timeout: int | None = None,
861
894
  setup_function: Callable[..., None] | None = None,
862
895
  _base_image: str | None = None,
863
896
  _scheduler: str | None = None,
@@ -942,6 +975,8 @@ class RouteSignature(NamedTuple):
942
975
 
943
976
 
944
977
  class BaseServable:
978
+ version: ClassVar[str] = "unknown"
979
+
945
980
  def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
946
981
  raise NotImplementedError
947
982
 
@@ -1002,13 +1037,39 @@ class BaseServable:
1002
1037
  # If it's not a generic 404, just return the original message.
1003
1038
  return JSONResponse({"detail": exc.detail}, 404)
1004
1039
 
1040
+ @_app.exception_handler(AppException)
1041
+ async def app_exception_handler(request: Request, exc: AppException):
1042
+ return JSONResponse({"detail": exc.message}, exc.status_code)
1043
+
1044
+ @_app.exception_handler(FieldException)
1045
+ async def field_exception_handler(request: Request, exc: FieldException):
1046
+ return JSONResponse(exc.to_pydantic_format(), exc.status_code)
1047
+
1048
+ @_app.exception_handler(CUDAOutOfMemoryException)
1049
+ async def cuda_out_of_memory_exception_handler(
1050
+ request: Request, exc: CUDAOutOfMemoryException
1051
+ ):
1052
+ return JSONResponse({"detail": exc.message}, exc.status_code)
1053
+
1005
1054
  @_app.exception_handler(Exception)
1006
1055
  async def traceback_logging_exception_handler(request: Request, exc: Exception):
1007
- print(
1008
- json.dumps(
1009
- {"traceback": "".join(traceback.format_exception(exc)[::-1])} # type: ignore
1056
+ _, MINOR, *_ = sys.version_info
1057
+
1058
+ # traceback.format_exception() has a different signature in Python >=3.10
1059
+ if MINOR >= 10:
1060
+ formatted_exception = traceback.format_exception(exc) # type: ignore
1061
+ else:
1062
+ formatted_exception = traceback.format_exception(
1063
+ type(exc), exc, exc.__traceback__
1010
1064
  )
1011
- )
1065
+
1066
+ print(json.dumps({"traceback": "".join(formatted_exception[::-1])}))
1067
+
1068
+ if _is_cuda_oom_exception(exc):
1069
+ return await cuda_out_of_memory_exception_handler(
1070
+ request, CUDAOutOfMemoryException()
1071
+ )
1072
+
1012
1073
  return JSONResponse({"detail": "Internal Server Error"}, 500)
1013
1074
 
1014
1075
  routes = self.collect_routes()
@@ -1044,11 +1105,18 @@ class BaseServable:
1044
1105
  def serve(self) -> None:
1045
1106
  import asyncio
1046
1107
 
1108
+ from prometheus_client import Gauge
1047
1109
  from starlette_exporter import handle_metrics
1048
1110
  from uvicorn import Config
1049
1111
 
1112
+ # NOTE: this uses the global prometheus registry
1113
+ app_info = Gauge("fal_app_info", "Fal application information", ["version"])
1114
+ app_info.labels(version=self.version).set(1)
1115
+
1050
1116
  app = self._build_app()
1051
- server = Server(config=Config(app, host="0.0.0.0", port=8080))
1117
+ server = Server(
1118
+ config=Config(app, host="0.0.0.0", port=8080, timeout_keep_alive=300)
1119
+ )
1052
1120
  metrics_app = FastAPI()
1053
1121
  metrics_app.add_route("/metrics", handle_metrics)
1054
1122
  metrics_server = Server(config=Config(metrics_app, host="0.0.0.0", port=9090))