fal 1.2.1__py3-none-any.whl → 1.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

Files changed (45) hide show
  1. fal/__main__.py +3 -1
  2. fal/_fal_version.py +2 -2
  3. fal/api.py +88 -20
  4. fal/app.py +221 -27
  5. fal/apps.py +147 -3
  6. fal/auth/__init__.py +50 -2
  7. fal/cli/_utils.py +40 -0
  8. fal/cli/apps.py +5 -3
  9. fal/cli/create.py +26 -0
  10. fal/cli/deploy.py +97 -16
  11. fal/cli/main.py +2 -2
  12. fal/cli/parser.py +11 -7
  13. fal/cli/run.py +12 -1
  14. fal/cli/runners.py +44 -0
  15. fal/config.py +23 -0
  16. fal/container.py +1 -1
  17. fal/exceptions/__init__.py +7 -1
  18. fal/exceptions/_base.py +51 -0
  19. fal/exceptions/_cuda.py +44 -0
  20. fal/files.py +81 -0
  21. fal/sdk.py +67 -6
  22. fal/toolkit/file/file.py +103 -13
  23. fal/toolkit/file/providers/fal.py +572 -24
  24. fal/toolkit/file/providers/gcp.py +8 -1
  25. fal/toolkit/file/providers/r2.py +8 -1
  26. fal/toolkit/file/providers/s3.py +80 -0
  27. fal/toolkit/file/types.py +28 -3
  28. fal/toolkit/image/__init__.py +71 -0
  29. fal/toolkit/image/image.py +25 -2
  30. fal/toolkit/image/nsfw_filter/__init__.py +11 -0
  31. fal/toolkit/image/nsfw_filter/env.py +9 -0
  32. fal/toolkit/image/nsfw_filter/inference.py +77 -0
  33. fal/toolkit/image/nsfw_filter/model.py +18 -0
  34. fal/toolkit/image/nsfw_filter/requirements.txt +4 -0
  35. fal/toolkit/image/safety_checker.py +107 -0
  36. fal/toolkit/types.py +140 -0
  37. fal/toolkit/utils/download_utils.py +4 -0
  38. fal/toolkit/utils/retry.py +45 -0
  39. fal/utils.py +20 -4
  40. fal/workflows.py +10 -4
  41. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/METADATA +47 -40
  42. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/RECORD +45 -30
  43. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/WHEEL +1 -1
  44. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/entry_points.txt +0 -0
  45. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/top_level.txt +0 -0
fal/sdk.py CHANGED
@@ -5,7 +5,7 @@ from contextlib import ExitStack
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timedelta
7
7
  from enum import Enum
8
- from typing import Any, Callable, Generic, Iterator, Literal, TypeVar
8
+ from typing import Any, Callable, Generic, Iterator, Literal, Optional, TypeVar
9
9
 
10
10
  import grpc
11
11
  import isolate_proto
@@ -214,7 +214,7 @@ class AliasInfo:
214
214
  class RunnerInfo:
215
215
  runner_id: str
216
216
  in_flight_requests: int
217
- expiration_countdown: int
217
+ expiration_countdown: Optional[int]
218
218
  uptime: timedelta
219
219
 
220
220
 
@@ -275,6 +275,33 @@ class KeyScope(enum.Enum):
275
275
  raise ValueError(f"Unknown KeyScope: {proto}")
276
276
 
277
277
 
278
+ class DeploymentStrategy(enum.Enum):
279
+ RECREATE = "recreate"
280
+ ROLLING = "rolling"
281
+
282
+ @staticmethod
283
+ def from_proto(
284
+ proto: isolate_proto.DeploymentStrategy.ValueType | None,
285
+ ) -> DeploymentStrategy:
286
+ if proto is None:
287
+ return DeploymentStrategy.RECREATE
288
+
289
+ if proto is isolate_proto.DeploymentStrategy.RECREATE:
290
+ return DeploymentStrategy.RECREATE
291
+ elif proto is isolate_proto.DeploymentStrategy.ROLLING:
292
+ return DeploymentStrategy.ROLLING
293
+ else:
294
+ raise ValueError(f"Unknown DeploymentStrategy: {proto}")
295
+
296
+ def to_proto(self) -> isolate_proto.DeploymentStrategy.ValueType:
297
+ if self is DeploymentStrategy.RECREATE:
298
+ return isolate_proto.DeploymentStrategy.RECREATE
299
+ elif self is DeploymentStrategy.ROLLING:
300
+ return isolate_proto.DeploymentStrategy.ROLLING
301
+ else:
302
+ raise ValueError(f"Unknown DeploymentStrategy: {self}")
303
+
304
+
278
305
  @from_grpc.register(isolate_proto.ApplicationInfo)
279
306
  def _from_grpc_application_info(
280
307
  message: isolate_proto.ApplicationInfo,
@@ -317,7 +344,9 @@ def _from_grpc_runner_info(message: isolate_proto.RunnerInfo) -> RunnerInfo:
317
344
  return RunnerInfo(
318
345
  runner_id=message.runner_id,
319
346
  in_flight_requests=message.in_flight_requests,
320
- expiration_countdown=message.expiration_countdown,
347
+ expiration_countdown=message.expiration_countdown
348
+ if message.HasField("expiration_countdown")
349
+ else None,
321
350
  uptime=timedelta(seconds=message.uptime),
322
351
  )
323
352
 
@@ -362,7 +391,8 @@ def _from_grpc_hosted_run_result(
362
391
 
363
392
  @dataclass
364
393
  class MachineRequirements:
365
- machine_type: str
394
+ machine_types: list[str]
395
+ num_gpus: int | None = field(default=None)
366
396
  keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE
367
397
  base_image: str | None = None
368
398
  exposed_port: int | None = None
@@ -371,6 +401,17 @@ class MachineRequirements:
371
401
  max_concurrency: int | None = None
372
402
  max_multiplexing: int | None = None
373
403
  min_concurrency: int | None = None
404
+ request_timeout: int | None = None
405
+
406
+ def __post_init__(self):
407
+ if isinstance(self.machine_types, str):
408
+ self.machine_types = [self.machine_types]
409
+
410
+ if not isinstance(self.machine_types, list):
411
+ raise ValueError("machine_types must be a list of strings.")
412
+
413
+ if not self.machine_types:
414
+ raise ValueError("No machine type provided.")
374
415
 
375
416
 
376
417
  @dataclass
@@ -457,11 +498,16 @@ class FalServerlessConnection:
457
498
  serialization_method: str = _DEFAULT_SERIALIZATION_METHOD,
458
499
  machine_requirements: MachineRequirements | None = None,
459
500
  metadata: dict[str, Any] | None = None,
501
+ deployment_strategy: Literal["recreate", "rolling"] = "recreate",
502
+ scale: bool = True,
460
503
  ) -> Iterator[isolate_proto.RegisterApplicationResult]:
461
504
  wrapped_function = to_serialized_object(function, serialization_method)
462
505
  if machine_requirements:
463
506
  wrapped_requirements = isolate_proto.MachineRequirements(
464
- machine_type=machine_requirements.machine_type,
507
+ # NOTE: backwards compatibility with old API
508
+ machine_type=machine_requirements.machine_types[0],
509
+ machine_types=machine_requirements.machine_types,
510
+ num_gpus=machine_requirements.num_gpus,
465
511
  keep_alive=machine_requirements.keep_alive,
466
512
  base_image=machine_requirements.base_image,
467
513
  exposed_port=machine_requirements.exposed_port,
@@ -472,6 +518,7 @@ class FalServerlessConnection:
472
518
  max_concurrency=machine_requirements.max_concurrency,
473
519
  min_concurrency=machine_requirements.min_concurrency,
474
520
  max_multiplexing=machine_requirements.max_multiplexing,
521
+ request_timeout=machine_requirements.request_timeout,
475
522
  )
476
523
  else:
477
524
  wrapped_requirements = None
@@ -488,6 +535,10 @@ class FalServerlessConnection:
488
535
  struct_metadata = isolate_proto.Struct()
489
536
  struct_metadata.update(metadata)
490
537
 
538
+ deployment_strategy_proto = DeploymentStrategy[
539
+ deployment_strategy.upper()
540
+ ].to_proto()
541
+
491
542
  request = isolate_proto.RegisterApplicationRequest(
492
543
  function=wrapped_function,
493
544
  environments=environments,
@@ -495,6 +546,8 @@ class FalServerlessConnection:
495
546
  application_name=application_name,
496
547
  auth_mode=auth_mode,
497
548
  metadata=struct_metadata,
549
+ deployment_strategy=deployment_strategy_proto,
550
+ scale=scale,
498
551
  )
499
552
  for partial_result in self.stub.RegisterApplication(request):
500
553
  yield from_grpc(partial_result)
@@ -546,7 +599,10 @@ class FalServerlessConnection:
546
599
  wrapped_function = to_serialized_object(function, serialization_method)
547
600
  if machine_requirements:
548
601
  wrapped_requirements = isolate_proto.MachineRequirements(
549
- machine_type=machine_requirements.machine_type,
602
+ # NOTE: backwards compatibility with old API
603
+ machine_type=machine_requirements.machine_types[0],
604
+ machine_types=machine_requirements.machine_types,
605
+ num_gpus=machine_requirements.num_gpus,
550
606
  keep_alive=machine_requirements.keep_alive,
551
607
  base_image=machine_requirements.base_image,
552
608
  exposed_port=machine_requirements.exposed_port,
@@ -557,6 +613,7 @@ class FalServerlessConnection:
557
613
  max_concurrency=machine_requirements.max_concurrency,
558
614
  max_multiplexing=machine_requirements.max_multiplexing,
559
615
  min_concurrency=machine_requirements.min_concurrency,
616
+ request_timeout=machine_requirements.request_timeout,
560
617
  )
561
618
  else:
562
619
  wrapped_requirements = None
@@ -629,3 +686,7 @@ class FalServerlessConnection:
629
686
  )
630
687
  for secret in response.secrets
631
688
  ]
689
+
690
+ def kill_runner(self, runner_id: str) -> None:
691
+ request = isolate_proto.KillRunnerRequest(runner_id=runner_id)
692
+ self.stub.KillRunner(request)
fal/toolkit/file/file.py CHANGED
@@ -8,6 +8,7 @@ from urllib.parse import urlparse
8
8
  from zipfile import ZipFile
9
9
 
10
10
  import pydantic
11
+ from fastapi import Request
11
12
 
12
13
  # https://github.com/pydantic/pydantic/pull/2573
13
14
  if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."):
@@ -21,9 +22,11 @@ else:
21
22
  from pydantic import BaseModel, Field
22
23
 
23
24
  from fal.toolkit.file.providers.fal import (
25
+ LIFECYCLE_PREFERENCE,
24
26
  FalCDNFileRepository,
25
27
  FalFileRepository,
26
28
  FalFileRepositoryV2,
29
+ FalFileRepositoryV3,
27
30
  InMemoryRepository,
28
31
  )
29
32
  from fal.toolkit.file.providers.gcp import GoogleStorageRepository
@@ -36,6 +39,7 @@ FileRepositoryFactory = Callable[[], FileRepository]
36
39
  BUILT_IN_REPOSITORIES: dict[RepositoryId, FileRepositoryFactory] = {
37
40
  "fal": lambda: FalFileRepository(),
38
41
  "fal_v2": lambda: FalFileRepositoryV2(),
42
+ "fal_v3": lambda: FalFileRepositoryV3(),
39
43
  "in_memory": lambda: InMemoryRepository(),
40
44
  "gcp_storage": lambda: GoogleStorageRepository(),
41
45
  "r2": lambda: R2Repository(),
@@ -43,7 +47,10 @@ BUILT_IN_REPOSITORIES: dict[RepositoryId, FileRepositoryFactory] = {
43
47
  }
44
48
 
45
49
 
46
- def get_builtin_repository(id: RepositoryId) -> FileRepository:
50
+ def get_builtin_repository(id: RepositoryId | FileRepository) -> FileRepository:
51
+ if isinstance(id, FileRepository):
52
+ return id
53
+
47
54
  if id not in BUILT_IN_REPOSITORIES.keys():
48
55
  raise ValueError(f'"{id}" is not a valid built-in file repository')
49
56
  return BUILT_IN_REPOSITORIES[id]()
@@ -51,7 +58,9 @@ def get_builtin_repository(id: RepositoryId) -> FileRepository:
51
58
 
52
59
  get_builtin_repository.__module__ = "__main__"
53
60
 
54
- DEFAULT_REPOSITORY: FileRepository | RepositoryId = "fal"
61
+ DEFAULT_REPOSITORY: FileRepository | RepositoryId = "fal_v3"
62
+ FALLBACK_REPOSITORY: FileRepository | RepositoryId = "cdn"
63
+ OBJECT_LIFECYCLE_PREFERENCE_KEY = "x-fal-object-lifecycle-preference"
55
64
 
56
65
 
57
66
  class File(BaseModel):
@@ -116,7 +125,8 @@ class File(BaseModel):
116
125
  url=url,
117
126
  content_type=None,
118
127
  file_name=None,
119
- repository=DEFAULT_REPOSITORY,
128
+ file_size=None,
129
+ file_data=None,
120
130
  )
121
131
 
122
132
  @classmethod
@@ -126,17 +136,38 @@ class File(BaseModel):
126
136
  content_type: Optional[str] = None,
127
137
  file_name: Optional[str] = None,
128
138
  repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
139
+ fallback_repository: Optional[
140
+ FileRepository | RepositoryId
141
+ ] = FALLBACK_REPOSITORY,
142
+ request: Optional[Request] = None,
143
+ save_kwargs: Optional[dict] = None,
144
+ fallback_save_kwargs: Optional[dict] = None,
129
145
  ) -> File:
130
- repo = (
131
- repository
132
- if isinstance(repository, FileRepository)
133
- else get_builtin_repository(repository)
134
- )
146
+ repo = get_builtin_repository(repository)
147
+
148
+ save_kwargs = save_kwargs or {}
149
+ fallback_save_kwargs = fallback_save_kwargs or {}
135
150
 
136
151
  fdata = FileData(data, content_type, file_name)
137
152
 
153
+ object_lifecycle_preference = (
154
+ request_lifecycle_preference(request) or LIFECYCLE_PREFERENCE.get()
155
+ )
156
+
157
+ try:
158
+ url = repo.save(fdata, object_lifecycle_preference, **save_kwargs)
159
+ except Exception:
160
+ if not fallback_repository:
161
+ raise
162
+
163
+ fallback_repo = get_builtin_repository(fallback_repository)
164
+
165
+ url = fallback_repo.save(
166
+ fdata, object_lifecycle_preference, **fallback_save_kwargs
167
+ )
168
+
138
169
  return cls(
139
- url=repo.save(fdata),
170
+ url=url,
140
171
  content_type=fdata.content_type,
141
172
  file_name=fdata.file_name,
142
173
  file_size=len(data),
@@ -149,14 +180,56 @@ class File(BaseModel):
149
180
  path: str | Path,
150
181
  content_type: Optional[str] = None,
151
182
  repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
183
+ multipart: bool | None = None,
184
+ fallback_repository: Optional[
185
+ FileRepository | RepositoryId
186
+ ] = FALLBACK_REPOSITORY,
187
+ request: Optional[Request] = None,
188
+ save_kwargs: Optional[dict] = None,
189
+ fallback_save_kwargs: Optional[dict] = None,
152
190
  ) -> File:
153
191
  file_path = Path(path)
154
192
  if not file_path.exists():
155
193
  raise FileNotFoundError(f"File {file_path} does not exist")
156
- with open(file_path, "rb") as f:
157
- data = f.read()
158
- return File.from_bytes(
159
- data, content_type, file_name=file_path.name, repository=repository
194
+
195
+ repo = get_builtin_repository(repository)
196
+
197
+ save_kwargs = save_kwargs or {}
198
+ fallback_save_kwargs = fallback_save_kwargs or {}
199
+
200
+ content_type = content_type or "application/octet-stream"
201
+ object_lifecycle_preference = (
202
+ request_lifecycle_preference(request) or LIFECYCLE_PREFERENCE.get()
203
+ )
204
+
205
+ try:
206
+ url, data = repo.save_file(
207
+ file_path,
208
+ content_type=content_type,
209
+ multipart=multipart,
210
+ object_lifecycle_preference=object_lifecycle_preference,
211
+ **save_kwargs,
212
+ )
213
+ except Exception:
214
+ if not fallback_repository:
215
+ raise
216
+
217
+ fallback_repo = get_builtin_repository(fallback_repository)
218
+
219
+ url, data = fallback_repo.save_file(
220
+ file_path,
221
+ content_type=content_type,
222
+ multipart=multipart,
223
+ object_lifecycle_preference=object_lifecycle_preference,
224
+ **fallback_save_kwargs,
225
+ )
226
+
227
+ return cls(
228
+ url=url,
229
+ file_data=data.data if data else None,
230
+ content_type=content_type,
231
+ file_name=file_path.name,
232
+ file_size=file_path.stat().st_size,
160
233
  )
161
234
 
162
235
  def as_bytes(self) -> bytes:
@@ -206,3 +279,20 @@ class CompressedFile(File):
206
279
  def __del__(self):
207
280
  if self.extract_dir:
208
281
  shutil.rmtree(self.extract_dir)
282
+
283
+
284
+ def request_lifecycle_preference(request: Optional[Request]) -> dict[str, str] | None:
285
+ import json
286
+
287
+ if request is None:
288
+ return None
289
+
290
+ preference_str = request.headers.get(OBJECT_LIFECYCLE_PREFERENCE_KEY)
291
+ if preference_str is None:
292
+ return None
293
+
294
+ try:
295
+ return json.loads(preference_str)
296
+ except Exception as e:
297
+ print(f"Failed to parse object lifecycle preference: {e}")
298
+ return None