fal 1.49.1__py3-none-any.whl → 1.57.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fal/_fal_version.py +2 -2
- fal/_serialization.py +1 -0
- fal/api/__init__.py +1 -0
- fal/api/api.py +32 -2
- fal/api/apps.py +23 -1
- fal/api/client.py +72 -1
- fal/api/deploy.py +16 -28
- fal/api/keys.py +31 -0
- fal/api/runners.py +10 -0
- fal/api/secrets.py +29 -0
- fal/app.py +50 -14
- fal/cli/_utils.py +11 -3
- fal/cli/api.py +4 -2
- fal/cli/apps.py +56 -2
- fal/cli/deploy.py +17 -3
- fal/cli/files.py +16 -24
- fal/cli/keys.py +47 -50
- fal/cli/queue.py +12 -10
- fal/cli/run.py +11 -7
- fal/cli/runners.py +189 -27
- fal/cli/secrets.py +28 -30
- fal/files.py +32 -8
- fal/logging/__init__.py +0 -5
- fal/sdk.py +39 -23
- fal/sync.py +22 -12
- fal/toolkit/__init__.py +10 -0
- fal/toolkit/compilation.py +220 -0
- fal/toolkit/file/file.py +10 -9
- fal/utils.py +65 -31
- fal/workflows.py +6 -2
- {fal-1.49.1.dist-info → fal-1.57.2.dist-info}/METADATA +6 -6
- {fal-1.49.1.dist-info → fal-1.57.2.dist-info}/RECORD +35 -33
- fal/rest_client.py +0 -25
- {fal-1.49.1.dist-info → fal-1.57.2.dist-info}/WHEEL +0 -0
- {fal-1.49.1.dist-info → fal-1.57.2.dist-info}/entry_points.txt +0 -0
- {fal-1.49.1.dist-info → fal-1.57.2.dist-info}/top_level.txt +0 -0
fal/sdk.py
CHANGED
|
@@ -246,6 +246,7 @@ class ApplicationInfo:
|
|
|
246
246
|
min_concurrency: int
|
|
247
247
|
concurrency_buffer: int
|
|
248
248
|
concurrency_buffer_perc: int
|
|
249
|
+
scaling_delay: int
|
|
249
250
|
machine_types: list[str]
|
|
250
251
|
request_timeout: int
|
|
251
252
|
startup_timeout: int
|
|
@@ -265,6 +266,7 @@ class AliasInfo:
|
|
|
265
266
|
min_concurrency: int
|
|
266
267
|
concurrency_buffer: int
|
|
267
268
|
concurrency_buffer_perc: int
|
|
269
|
+
scaling_delay: int
|
|
268
270
|
machine_types: list[str]
|
|
269
271
|
request_timeout: int
|
|
270
272
|
startup_timeout: int
|
|
@@ -272,27 +274,14 @@ class AliasInfo:
|
|
|
272
274
|
|
|
273
275
|
|
|
274
276
|
class RunnerState(Enum):
|
|
275
|
-
RUNNING = "
|
|
276
|
-
PENDING = "
|
|
277
|
-
SETUP = "
|
|
278
|
-
DOCKER_PULL = "
|
|
279
|
-
DEAD = "
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
def from_proto(proto: isolate_proto.RunnerInfo.State) -> RunnerState:
|
|
284
|
-
if proto is isolate_proto.RunnerInfo.State.RUNNING:
|
|
285
|
-
return RunnerState.RUNNING
|
|
286
|
-
elif proto is isolate_proto.RunnerInfo.State.PENDING:
|
|
287
|
-
return RunnerState.PENDING
|
|
288
|
-
elif proto is isolate_proto.RunnerInfo.State.SETUP:
|
|
289
|
-
return RunnerState.SETUP
|
|
290
|
-
elif proto is isolate_proto.RunnerInfo.State.DEAD:
|
|
291
|
-
return RunnerState.DEAD
|
|
292
|
-
elif proto is isolate_proto.RunnerInfo.State.DOCKER_PULL:
|
|
293
|
-
return RunnerState.DOCKER_PULL
|
|
294
|
-
else:
|
|
295
|
-
return RunnerState.UNKNOWN
|
|
277
|
+
RUNNING = "RUNNING"
|
|
278
|
+
PENDING = "PENDING"
|
|
279
|
+
SETUP = "SETUP"
|
|
280
|
+
DOCKER_PULL = "DOCKER_PULL"
|
|
281
|
+
DEAD = "DEAD"
|
|
282
|
+
DRAINING = "DRAINING"
|
|
283
|
+
TERMINATING = "TERMINATING"
|
|
284
|
+
TERMINATED = "TERMINATED"
|
|
296
285
|
|
|
297
286
|
|
|
298
287
|
@dataclass
|
|
@@ -414,6 +403,7 @@ def _from_grpc_application_info(
|
|
|
414
403
|
min_concurrency=message.min_concurrency,
|
|
415
404
|
concurrency_buffer=message.concurrency_buffer,
|
|
416
405
|
concurrency_buffer_perc=message.concurrency_buffer_perc,
|
|
406
|
+
scaling_delay=message.scaling_delay_seconds,
|
|
417
407
|
machine_types=list(message.machine_types),
|
|
418
408
|
request_timeout=message.request_timeout,
|
|
419
409
|
startup_timeout=message.startup_timeout,
|
|
@@ -444,6 +434,7 @@ def _from_grpc_alias_info(message: isolate_proto.AliasInfo) -> AliasInfo:
|
|
|
444
434
|
min_concurrency=message.min_concurrency,
|
|
445
435
|
concurrency_buffer=message.concurrency_buffer,
|
|
446
436
|
concurrency_buffer_perc=message.concurrency_buffer_perc,
|
|
437
|
+
scaling_delay=message.scaling_delay_seconds,
|
|
447
438
|
machine_types=list(message.machine_types),
|
|
448
439
|
request_timeout=message.request_timeout,
|
|
449
440
|
startup_timeout=message.startup_timeout,
|
|
@@ -468,7 +459,7 @@ def _from_grpc_runner_info(message: isolate_proto.RunnerInfo) -> RunnerInfo:
|
|
|
468
459
|
external_metadata=external_metadata,
|
|
469
460
|
revision=message.revision,
|
|
470
461
|
alias=message.alias,
|
|
471
|
-
state=RunnerState.
|
|
462
|
+
state=RunnerState(isolate_proto.RunnerInfo.State.Name(message.state)),
|
|
472
463
|
)
|
|
473
464
|
|
|
474
465
|
|
|
@@ -537,8 +528,10 @@ class MachineRequirements:
|
|
|
537
528
|
min_concurrency: int | None = None
|
|
538
529
|
concurrency_buffer: int | None = None
|
|
539
530
|
concurrency_buffer_perc: int | None = None
|
|
531
|
+
scaling_delay: int | None = None
|
|
540
532
|
request_timeout: int | None = None
|
|
541
533
|
startup_timeout: int | None = None
|
|
534
|
+
valid_regions: list[str] | None = None
|
|
542
535
|
|
|
543
536
|
def __post_init__(self):
|
|
544
537
|
if isinstance(self.machine_types, str):
|
|
@@ -633,6 +626,7 @@ class FalServerlessConnection:
|
|
|
633
626
|
auth_mode: Optional[AuthModeLiteral] = None,
|
|
634
627
|
*,
|
|
635
628
|
source_code: str | None = None,
|
|
629
|
+
health_check_path: str | None = None,
|
|
636
630
|
serialization_method: str = _DEFAULT_SERIALIZATION_METHOD,
|
|
637
631
|
machine_requirements: MachineRequirements | None = None,
|
|
638
632
|
metadata: dict[str, Any] | None = None,
|
|
@@ -640,7 +634,7 @@ class FalServerlessConnection:
|
|
|
640
634
|
scale: bool = True,
|
|
641
635
|
private_logs: bool = False,
|
|
642
636
|
files: list[File] | None = None,
|
|
643
|
-
) -> Iterator[
|
|
637
|
+
) -> Iterator[RegisterApplicationResult]:
|
|
644
638
|
wrapped_function = to_serialized_object(function, serialization_method)
|
|
645
639
|
if machine_requirements:
|
|
646
640
|
wrapped_requirements = isolate_proto.MachineRequirements(
|
|
@@ -659,9 +653,11 @@ class FalServerlessConnection:
|
|
|
659
653
|
min_concurrency=machine_requirements.min_concurrency,
|
|
660
654
|
concurrency_buffer=machine_requirements.concurrency_buffer,
|
|
661
655
|
concurrency_buffer_perc=machine_requirements.concurrency_buffer_perc,
|
|
656
|
+
scaling_delay_seconds=machine_requirements.scaling_delay,
|
|
662
657
|
max_multiplexing=machine_requirements.max_multiplexing,
|
|
663
658
|
request_timeout=machine_requirements.request_timeout,
|
|
664
659
|
startup_timeout=machine_requirements.startup_timeout,
|
|
660
|
+
valid_regions=machine_requirements.valid_regions,
|
|
665
661
|
)
|
|
666
662
|
else:
|
|
667
663
|
wrapped_requirements = None
|
|
@@ -702,6 +698,7 @@ class FalServerlessConnection:
|
|
|
702
698
|
private_logs=private_logs,
|
|
703
699
|
files=files,
|
|
704
700
|
source_code=source_code,
|
|
701
|
+
health_check_path=health_check_path,
|
|
705
702
|
)
|
|
706
703
|
for partial_result in self.stub.RegisterApplication(request):
|
|
707
704
|
yield from_grpc(partial_result)
|
|
@@ -718,6 +715,7 @@ class FalServerlessConnection:
|
|
|
718
715
|
min_concurrency: int | None = None,
|
|
719
716
|
concurrency_buffer: int | None = None,
|
|
720
717
|
concurrency_buffer_perc: int | None = None,
|
|
718
|
+
scaling_delay: int | None = None,
|
|
721
719
|
request_timeout: int | None = None,
|
|
722
720
|
startup_timeout: int | None = None,
|
|
723
721
|
valid_regions: list[str] | None = None,
|
|
@@ -731,6 +729,7 @@ class FalServerlessConnection:
|
|
|
731
729
|
min_concurrency=min_concurrency,
|
|
732
730
|
concurrency_buffer=concurrency_buffer,
|
|
733
731
|
concurrency_buffer_perc=concurrency_buffer_perc,
|
|
732
|
+
scaling_delay_seconds=scaling_delay,
|
|
734
733
|
request_timeout=request_timeout,
|
|
735
734
|
startup_timeout=startup_timeout,
|
|
736
735
|
valid_regions=valid_regions,
|
|
@@ -757,6 +756,17 @@ class FalServerlessConnection:
|
|
|
757
756
|
request = isolate_proto.DeleteApplicationRequest(application_id=application_id)
|
|
758
757
|
self.stub.DeleteApplication(request)
|
|
759
758
|
|
|
759
|
+
def rollout_application(
|
|
760
|
+
self,
|
|
761
|
+
application_name: str,
|
|
762
|
+
force: bool = False,
|
|
763
|
+
) -> None:
|
|
764
|
+
request = isolate_proto.RolloutApplicationRequest(
|
|
765
|
+
application_name=application_name,
|
|
766
|
+
force=force,
|
|
767
|
+
)
|
|
768
|
+
self.stub.RolloutApplication(request)
|
|
769
|
+
|
|
760
770
|
def run(
|
|
761
771
|
self,
|
|
762
772
|
function: Callable[..., ResultT],
|
|
@@ -786,8 +796,10 @@ class FalServerlessConnection:
|
|
|
786
796
|
min_concurrency=machine_requirements.min_concurrency,
|
|
787
797
|
concurrency_buffer=machine_requirements.concurrency_buffer,
|
|
788
798
|
concurrency_buffer_perc=machine_requirements.concurrency_buffer_perc,
|
|
799
|
+
scaling_delay_seconds=machine_requirements.scaling_delay,
|
|
789
800
|
request_timeout=machine_requirements.request_timeout,
|
|
790
801
|
startup_timeout=machine_requirements.startup_timeout,
|
|
802
|
+
valid_regions=machine_requirements.valid_regions,
|
|
791
803
|
)
|
|
792
804
|
else:
|
|
793
805
|
wrapped_requirements = None
|
|
@@ -884,6 +896,10 @@ class FalServerlessConnection:
|
|
|
884
896
|
for secret in response.secrets
|
|
885
897
|
]
|
|
886
898
|
|
|
899
|
+
def stop_runner(self, runner_id: str) -> None:
|
|
900
|
+
request = isolate_proto.StopRunnerRequest(runner_id=runner_id)
|
|
901
|
+
self.stub.StopRunner(request)
|
|
902
|
+
|
|
887
903
|
def kill_runner(self, runner_id: str) -> None:
|
|
888
904
|
request = isolate_proto.KillRunnerRequest(runner_id=runner_id)
|
|
889
905
|
self.stub.KillRunner(request)
|
fal/sync.py
CHANGED
|
@@ -4,21 +4,21 @@ import hashlib
|
|
|
4
4
|
import os
|
|
5
5
|
import zipfile
|
|
6
6
|
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from openapi_fal_rest.client import Client
|
|
7
11
|
|
|
8
|
-
import openapi_fal_rest.api.files.check_dir_hash as check_dir_hash_api
|
|
9
|
-
import openapi_fal_rest.api.files.upload_local_file as upload_local_file_api
|
|
10
|
-
import openapi_fal_rest.models.body_upload_local_file as upload_file_model
|
|
11
|
-
import openapi_fal_rest.models.hash_check as hash_check_model
|
|
12
|
-
import openapi_fal_rest.types as rest_types
|
|
13
12
|
from pathspec import PathSpec
|
|
14
13
|
|
|
15
|
-
from fal.rest_client import REST_CLIENT
|
|
16
14
|
|
|
15
|
+
def _check_hash(client: Client, target_path: str, hash_string: str) -> bool:
|
|
16
|
+
import openapi_fal_rest.api.files.check_dir_hash as check_dir_hash_api
|
|
17
|
+
import openapi_fal_rest.models.hash_check as hash_check_model
|
|
17
18
|
|
|
18
|
-
def _check_hash(target_path: str, hash_string: str) -> bool:
|
|
19
19
|
response = check_dir_hash_api.sync_detailed(
|
|
20
20
|
target_path,
|
|
21
|
-
client=
|
|
21
|
+
client=client,
|
|
22
22
|
json_body=hash_check_model.HashCheck(hash_string),
|
|
23
23
|
)
|
|
24
24
|
|
|
@@ -26,7 +26,13 @@ def _check_hash(target_path: str, hash_string: str) -> bool:
|
|
|
26
26
|
return response.status_code == 200 and res
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
def _upload_file(
|
|
29
|
+
def _upload_file(
|
|
30
|
+
client: Client, source_path: str, target_path: str, unzip: bool = False
|
|
31
|
+
):
|
|
32
|
+
import openapi_fal_rest.api.files.upload_local_file as upload_local_file_api
|
|
33
|
+
import openapi_fal_rest.models.body_upload_local_file as upload_file_model
|
|
34
|
+
import openapi_fal_rest.types as rest_types
|
|
35
|
+
|
|
30
36
|
with open(source_path, "rb") as file_to_upload:
|
|
31
37
|
body = upload_file_model.BodyUploadLocalFile(
|
|
32
38
|
rest_types.File(
|
|
@@ -39,7 +45,7 @@ def _upload_file(source_path: str, target_path: str, unzip: bool = False):
|
|
|
39
45
|
|
|
40
46
|
response = upload_local_file_api.sync_detailed(
|
|
41
47
|
target_path,
|
|
42
|
-
client=
|
|
48
|
+
client=client,
|
|
43
49
|
unzip=unzip,
|
|
44
50
|
multipart_data=body,
|
|
45
51
|
)
|
|
@@ -94,6 +100,8 @@ def _zip_directory(dir_path: str, zip_path: str) -> None:
|
|
|
94
100
|
|
|
95
101
|
|
|
96
102
|
def sync_dir(local_dir: str | Path, remote_dir: str, force_upload=False) -> str:
|
|
103
|
+
from fal.api.client import SyncServerlessClient
|
|
104
|
+
|
|
97
105
|
local_dir_abs = os.path.expanduser(local_dir)
|
|
98
106
|
if not os.path.isabs(remote_dir) or not remote_dir.startswith("/data"):
|
|
99
107
|
raise ValueError(
|
|
@@ -106,9 +114,11 @@ def sync_dir(local_dir: str | Path, remote_dir: str, force_upload=False) -> str:
|
|
|
106
114
|
# Compute the local directory hash
|
|
107
115
|
local_hash = _compute_directory_hash(local_dir_abs)
|
|
108
116
|
|
|
117
|
+
client = SyncServerlessClient()._create_rest_client()
|
|
118
|
+
|
|
109
119
|
print(f"Syncing {local_dir} with {remote_dir}...")
|
|
110
120
|
|
|
111
|
-
if _check_hash(remote_dir, local_hash) and not force_upload:
|
|
121
|
+
if _check_hash(client, remote_dir, local_hash) and not force_upload:
|
|
112
122
|
print(f"{remote_dir} already uploaded and matches {local_dir}")
|
|
113
123
|
return remote_dir
|
|
114
124
|
|
|
@@ -121,7 +131,7 @@ def sync_dir(local_dir: str | Path, remote_dir: str, force_upload=False) -> str:
|
|
|
121
131
|
_zip_directory(local_dir_abs, zip_path)
|
|
122
132
|
|
|
123
133
|
# Upload the zipped directory to the serverless environment
|
|
124
|
-
_upload_file(zip_path, remote_dir, unzip=True)
|
|
134
|
+
_upload_file(client, zip_path, remote_dir, unzip=True)
|
|
125
135
|
|
|
126
136
|
os.remove(zip_path)
|
|
127
137
|
|
fal/toolkit/__init__.py
CHANGED
|
@@ -1,6 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from fal.toolkit.audio.audio import Audio, AudioField
|
|
4
|
+
from fal.toolkit.compilation import (
|
|
5
|
+
get_gpu_type,
|
|
6
|
+
load_inductor_cache,
|
|
7
|
+
sync_inductor_cache,
|
|
8
|
+
synchronized_inductor_cache,
|
|
9
|
+
)
|
|
4
10
|
from fal.toolkit.file import CompressedFile, File, FileField
|
|
5
11
|
from fal.toolkit.image.image import Image, ImageField, ImageSizeInput, get_image_size
|
|
6
12
|
from fal.toolkit.optimize import optimize
|
|
@@ -33,4 +39,8 @@ __all__ = [
|
|
|
33
39
|
"clone_repository",
|
|
34
40
|
"download_file",
|
|
35
41
|
"download_model_weights",
|
|
42
|
+
"get_gpu_type",
|
|
43
|
+
"load_inductor_cache",
|
|
44
|
+
"sync_inductor_cache",
|
|
45
|
+
"synchronized_inductor_cache",
|
|
36
46
|
]
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
"""PyTorch compilation cache management utilities.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for managing PyTorch Inductor compilation caches
|
|
4
|
+
across workers. When using torch.compile(), PyTorch generates optimized CUDA kernels
|
|
5
|
+
on first run, which can take 20-30 seconds. By sharing these compiled kernels across
|
|
6
|
+
workers, subsequent workers can load pre-compiled kernels in ~2 seconds instead of
|
|
7
|
+
recompiling.
|
|
8
|
+
|
|
9
|
+
Typical usage in a model setup:
|
|
10
|
+
|
|
11
|
+
Manual cache management:
|
|
12
|
+
dir_hash = load_inductor_cache("mymodel/v1")
|
|
13
|
+
self.model = torch.compile(self.model)
|
|
14
|
+
self.warmup() # Triggers compilation
|
|
15
|
+
sync_inductor_cache("mymodel/v1", dir_hash)
|
|
16
|
+
|
|
17
|
+
Context manager (automatic):
|
|
18
|
+
with synchronized_inductor_cache("mymodel/v1"):
|
|
19
|
+
self.model = torch.compile(self.model)
|
|
20
|
+
self.warmup() # Compilation is automatically synced after
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import hashlib
|
|
26
|
+
import os
|
|
27
|
+
import re
|
|
28
|
+
import shutil
|
|
29
|
+
import subprocess
|
|
30
|
+
import tempfile
|
|
31
|
+
from collections.abc import Iterator
|
|
32
|
+
from contextlib import contextmanager
|
|
33
|
+
from pathlib import Path
|
|
34
|
+
|
|
35
|
+
LOCAL_INDUCTOR_CACHE_DIR = Path("/tmp/inductor-cache/")
|
|
36
|
+
GLOBAL_INDUCTOR_CACHES_DIR = Path("/data/inductor-caches/")
|
|
37
|
+
PERSISTENT_TMP_DIR = Path("/data/tmp/")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_gpu_type() -> str:
|
|
41
|
+
"""Detect the GPU type using nvidia-smi.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
The GPU model name (e.g., "H100", "A100", "H200") or "UNKNOWN"
|
|
45
|
+
if detection fails.
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
>>> gpu_type = get_gpu_type()
|
|
49
|
+
>>> print(f"Running on: {gpu_type}")
|
|
50
|
+
Running on: H100
|
|
51
|
+
"""
|
|
52
|
+
try:
|
|
53
|
+
gpu_type_string = subprocess.run(
|
|
54
|
+
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
|
|
55
|
+
capture_output=True,
|
|
56
|
+
text=True,
|
|
57
|
+
check=False,
|
|
58
|
+
).stdout
|
|
59
|
+
matches = re.search(r"NVIDIA [a-zA-Z0-9]*", gpu_type_string)
|
|
60
|
+
# check for matches - if there are none, return "UNKNOWN"
|
|
61
|
+
if matches:
|
|
62
|
+
gpu_type = matches.group(0)
|
|
63
|
+
return gpu_type[7:] # remove `NVIDIA `
|
|
64
|
+
else:
|
|
65
|
+
return "UNKNOWN"
|
|
66
|
+
except Exception:
|
|
67
|
+
return "UNKNOWN"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _dir_hash(path: Path) -> str:
|
|
71
|
+
"""Compute a hash of all filenames in a directory (recursively).
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
path: Directory to hash.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
SHA256 hex digest of sorted filenames.
|
|
78
|
+
"""
|
|
79
|
+
# Hash of all the filenames in the directory, recursively, sorted
|
|
80
|
+
filenames = {str(file) for file in path.rglob("*") if file.is_file()}
|
|
81
|
+
return hashlib.sha256("".join(sorted(filenames)).encode()).hexdigest()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def load_inductor_cache(cache_key: str) -> str:
|
|
85
|
+
"""Load PyTorch Inductor compilation cache from global storage.
|
|
86
|
+
|
|
87
|
+
This function:
|
|
88
|
+
1. Sets TORCHINDUCTOR_CACHE_DIR environment variable
|
|
89
|
+
2. Looks for cached compiled kernels in GPU-specific global storage
|
|
90
|
+
3. Unpacks the cache to local temporary directory
|
|
91
|
+
4. Returns a hash of the unpacked directory for change detection
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
cache_key: Unique identifier for this cache (e.g., "flux/2", "mymodel/v1")
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Hash of the unpacked cache directory, or empty string if cache not found.
|
|
98
|
+
|
|
99
|
+
Example:
|
|
100
|
+
>>> dir_hash = load_inductor_cache("flux/2")
|
|
101
|
+
Found compilation cache at /data/inductor-caches/H100/flux/2.zip, unpacking...
|
|
102
|
+
Cache unpacked successfully.
|
|
103
|
+
"""
|
|
104
|
+
gpu_type = get_gpu_type()
|
|
105
|
+
|
|
106
|
+
os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(LOCAL_INDUCTOR_CACHE_DIR)
|
|
107
|
+
|
|
108
|
+
cache_source_path = GLOBAL_INDUCTOR_CACHES_DIR / gpu_type / f"{cache_key}.zip"
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
next(cache_source_path.parent.iterdir(), None)
|
|
112
|
+
except Exception as e:
|
|
113
|
+
# Check for cache without gpu_type in the path
|
|
114
|
+
try:
|
|
115
|
+
old_source_path = GLOBAL_INDUCTOR_CACHES_DIR / f"{cache_key}.zip"
|
|
116
|
+
# Since old source exists, copy it over to global caches
|
|
117
|
+
os.makedirs(cache_source_path.parent, exist_ok=True)
|
|
118
|
+
shutil.copy(old_source_path, cache_source_path)
|
|
119
|
+
except Exception:
|
|
120
|
+
print(f"Failed to list: {e}")
|
|
121
|
+
|
|
122
|
+
if not cache_source_path.exists():
|
|
123
|
+
print(f"Couldn't find compilation cache at {cache_source_path}")
|
|
124
|
+
return ""
|
|
125
|
+
|
|
126
|
+
print(f"Found compilation cache at {cache_source_path}, unpacking...")
|
|
127
|
+
try:
|
|
128
|
+
shutil.unpack_archive(cache_source_path, LOCAL_INDUCTOR_CACHE_DIR)
|
|
129
|
+
except Exception as e:
|
|
130
|
+
print(f"Failed to unpack cache: {e}")
|
|
131
|
+
return ""
|
|
132
|
+
|
|
133
|
+
print("Cache unpacked successfully.")
|
|
134
|
+
return _dir_hash(LOCAL_INDUCTOR_CACHE_DIR)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def sync_inductor_cache(cache_key: str, unpacked_dir_hash: str) -> None:
|
|
138
|
+
"""Sync updated PyTorch Inductor cache back to global storage.
|
|
139
|
+
|
|
140
|
+
This function:
|
|
141
|
+
1. Checks if the local cache has changed (by comparing hashes)
|
|
142
|
+
2. If changed, creates a zip archive of the new cache
|
|
143
|
+
3. Saves it to GPU-specific global storage
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
cache_key: Unique identifier for this cache (same as used in
|
|
147
|
+
load_inductor_cache)
|
|
148
|
+
unpacked_dir_hash: Hash returned from load_inductor_cache
|
|
149
|
+
(for change detection)
|
|
150
|
+
|
|
151
|
+
Example:
|
|
152
|
+
>>> sync_inductor_cache("flux/2", dir_hash)
|
|
153
|
+
No changes in the cache dir, skipping sync.
|
|
154
|
+
# or
|
|
155
|
+
Changes detected in the cache dir, syncing...
|
|
156
|
+
"""
|
|
157
|
+
gpu_type = get_gpu_type()
|
|
158
|
+
if not LOCAL_INDUCTOR_CACHE_DIR.exists():
|
|
159
|
+
print(f"No cache to sync, {LOCAL_INDUCTOR_CACHE_DIR} doesn't exist.")
|
|
160
|
+
return
|
|
161
|
+
|
|
162
|
+
if not GLOBAL_INDUCTOR_CACHES_DIR.exists():
|
|
163
|
+
GLOBAL_INDUCTOR_CACHES_DIR.mkdir(parents=True)
|
|
164
|
+
|
|
165
|
+
# If we updated the cache (the hashes of LOCAL_INDUCTOR_CACHE_DIR and
|
|
166
|
+
# unpacked_dir_hash differ), we pack the cache and move it to the
|
|
167
|
+
# global cache directory.
|
|
168
|
+
new_dir_hash = _dir_hash(LOCAL_INDUCTOR_CACHE_DIR)
|
|
169
|
+
if new_dir_hash == unpacked_dir_hash:
|
|
170
|
+
print("No changes in the cache dir, skipping sync.")
|
|
171
|
+
return
|
|
172
|
+
|
|
173
|
+
print("Changes detected in the cache dir, syncing...")
|
|
174
|
+
os.makedirs(
|
|
175
|
+
PERSISTENT_TMP_DIR, exist_ok=True
|
|
176
|
+
) # Non fal-ai users do not have this directory
|
|
177
|
+
with tempfile.TemporaryDirectory(dir=PERSISTENT_TMP_DIR) as temp_dir:
|
|
178
|
+
temp_dir_path = Path(temp_dir)
|
|
179
|
+
cache_path = GLOBAL_INDUCTOR_CACHES_DIR / gpu_type / f"{cache_key}.zip"
|
|
180
|
+
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
|
181
|
+
|
|
182
|
+
try:
|
|
183
|
+
zip_name = shutil.make_archive(
|
|
184
|
+
str(temp_dir_path / "inductor_cache"),
|
|
185
|
+
"zip",
|
|
186
|
+
LOCAL_INDUCTOR_CACHE_DIR,
|
|
187
|
+
)
|
|
188
|
+
os.rename(
|
|
189
|
+
zip_name,
|
|
190
|
+
cache_path,
|
|
191
|
+
)
|
|
192
|
+
except Exception as e:
|
|
193
|
+
print(f"Failed to sync cache: {e}")
|
|
194
|
+
return
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@contextmanager
|
|
198
|
+
def synchronized_inductor_cache(cache_key: str) -> Iterator[None]:
|
|
199
|
+
"""Context manager to automatically load and sync PyTorch Inductor cache.
|
|
200
|
+
|
|
201
|
+
This wraps load_inductor_cache and sync_inductor_cache for convenience.
|
|
202
|
+
The cache is loaded on entry and synced on exit (even if an exception occurs).
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
cache_key: Unique identifier for this cache (e.g., "flux/2", "mymodel/v1")
|
|
206
|
+
|
|
207
|
+
Yields:
|
|
208
|
+
None
|
|
209
|
+
|
|
210
|
+
Example:
|
|
211
|
+
>>> with synchronized_inductor_cache("mymodel/v1"):
|
|
212
|
+
... self.model = torch.compile(self.model)
|
|
213
|
+
... self.warmup() # Compilation happens here
|
|
214
|
+
# Cache is automatically synced after the with block
|
|
215
|
+
"""
|
|
216
|
+
unpacked_dir_hash = load_inductor_cache(cache_key)
|
|
217
|
+
try:
|
|
218
|
+
yield
|
|
219
|
+
finally:
|
|
220
|
+
sync_inductor_cache(cache_key, unpacked_dir_hash)
|
fal/toolkit/file/file.py
CHANGED
|
@@ -16,8 +16,7 @@ from fastapi import Request
|
|
|
16
16
|
if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."):
|
|
17
17
|
IS_PYDANTIC_V2 = False
|
|
18
18
|
else:
|
|
19
|
-
from pydantic import
|
|
20
|
-
from pydantic_core import CoreSchema, core_schema
|
|
19
|
+
from pydantic import model_validator
|
|
21
20
|
|
|
22
21
|
IS_PYDANTIC_V2 = True
|
|
23
22
|
|
|
@@ -137,14 +136,16 @@ class File(BaseModel):
|
|
|
137
136
|
# Pydantic custom validator for input type conversion
|
|
138
137
|
if IS_PYDANTIC_V2:
|
|
139
138
|
|
|
139
|
+
@model_validator(mode="before")
|
|
140
140
|
@classmethod
|
|
141
|
-
def
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
141
|
+
def __convert_from_str_v2(cls, value: Any):
|
|
142
|
+
if isinstance(value, str):
|
|
143
|
+
parsed_url = urlparse(value)
|
|
144
|
+
if parsed_url.scheme not in ["http", "https", "data"]:
|
|
145
|
+
raise ValueError("value must be a valid URL")
|
|
146
|
+
# Return a mapping so the model can be constructed normally
|
|
147
|
+
return {"url": parsed_url.geturl()}
|
|
148
|
+
return value
|
|
148
149
|
|
|
149
150
|
else:
|
|
150
151
|
|
fal/utils.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
4
5
|
|
|
5
|
-
|
|
6
|
-
from
|
|
7
|
-
|
|
8
|
-
from .api import FalServerlessError, FalServerlessHost, IsolatedFunction
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from .api import FalServerlessHost, IsolatedFunction
|
|
9
8
|
|
|
10
9
|
|
|
11
10
|
@dataclass
|
|
@@ -17,6 +16,62 @@ class LoadedFunction:
|
|
|
17
16
|
source_code: str | None
|
|
18
17
|
|
|
19
18
|
|
|
19
|
+
def _find_target(
|
|
20
|
+
module: dict[str, object], function_name: str | None = None
|
|
21
|
+
) -> tuple[object, str | None, str | None]:
|
|
22
|
+
import fal
|
|
23
|
+
from fal.api import FalServerlessError, IsolatedFunction
|
|
24
|
+
|
|
25
|
+
if function_name is not None:
|
|
26
|
+
if function_name not in module:
|
|
27
|
+
raise FalServerlessError(f"Function '{function_name}' not found in module")
|
|
28
|
+
|
|
29
|
+
target = module[function_name]
|
|
30
|
+
|
|
31
|
+
if isinstance(target, type) and issubclass(target, fal.App):
|
|
32
|
+
return target, target.app_name, target.app_auth
|
|
33
|
+
|
|
34
|
+
if isinstance(target, IsolatedFunction):
|
|
35
|
+
return target, function_name, None
|
|
36
|
+
|
|
37
|
+
raise FalServerlessError(
|
|
38
|
+
f"Function '{function_name}' is not a fal.App or a fal.function"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
fal_apps = {
|
|
42
|
+
obj_name: obj
|
|
43
|
+
for obj_name, obj in module.items()
|
|
44
|
+
if isinstance(obj, type) and issubclass(obj, fal.App) and obj is not fal.App
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
if len(fal_apps) == 1:
|
|
48
|
+
[(function_name, target)] = fal_apps.items()
|
|
49
|
+
return target, target.app_name, target.app_auth
|
|
50
|
+
elif len(fal_apps) > 1:
|
|
51
|
+
raise FalServerlessError(
|
|
52
|
+
f"Multiple fal.Apps found in the module: {list(fal_apps.keys())}. "
|
|
53
|
+
"Please specify the name of the app."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
fal_functions = {
|
|
57
|
+
obj_name: obj
|
|
58
|
+
for obj_name, obj in module.items()
|
|
59
|
+
if isinstance(obj, IsolatedFunction)
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
if len(fal_functions) == 0:
|
|
63
|
+
raise FalServerlessError("No fal.App or fal.function found in the module.")
|
|
64
|
+
elif len(fal_functions) > 1:
|
|
65
|
+
raise FalServerlessError(
|
|
66
|
+
"Multiple fal.functions found in the module: "
|
|
67
|
+
f"{list(fal_functions.keys())}. "
|
|
68
|
+
"Please specify the name of the function."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
[(function_name, target)] = fal_functions.items()
|
|
72
|
+
return target, function_name, None
|
|
73
|
+
|
|
74
|
+
|
|
20
75
|
def load_function_from(
|
|
21
76
|
host: FalServerlessHost,
|
|
22
77
|
file_path: str,
|
|
@@ -26,45 +81,24 @@ def load_function_from(
|
|
|
26
81
|
import runpy
|
|
27
82
|
import sys
|
|
28
83
|
|
|
84
|
+
import fal._serialization
|
|
85
|
+
from fal import App, wrap_app
|
|
86
|
+
|
|
87
|
+
from .api import FalServerlessError, IsolatedFunction
|
|
88
|
+
|
|
29
89
|
sys.path.append(os.getcwd())
|
|
30
90
|
module = runpy.run_path(file_path)
|
|
31
|
-
|
|
32
|
-
fal_objects = {
|
|
33
|
-
obj_name: obj
|
|
34
|
-
for obj_name, obj in module.items()
|
|
35
|
-
if isinstance(obj, type) and issubclass(obj, fal.App) and obj is not fal.App
|
|
36
|
-
}
|
|
37
|
-
if len(fal_objects) == 0:
|
|
38
|
-
raise FalServerlessError("No fal.App found in the module.")
|
|
39
|
-
elif len(fal_objects) > 1:
|
|
40
|
-
raise FalServerlessError(
|
|
41
|
-
"Multiple fal.Apps found in the module. "
|
|
42
|
-
"Please specify the name of the app."
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
[(function_name, obj)] = fal_objects.items()
|
|
46
|
-
app_name = obj.app_name
|
|
47
|
-
app_auth = obj.app_auth
|
|
48
|
-
else:
|
|
49
|
-
app_name = None
|
|
50
|
-
app_auth = None
|
|
51
|
-
|
|
52
|
-
if function_name not in module:
|
|
53
|
-
raise FalServerlessError(f"Function '{function_name}' not found in module")
|
|
91
|
+
target, app_name, app_auth = _find_target(module, function_name)
|
|
54
92
|
|
|
55
93
|
# The module for the function is set to <run_path> when runpy is used, in which
|
|
56
94
|
# case we want to manually include the package it is defined in.
|
|
57
95
|
fal._serialization.include_package_from_path(file_path)
|
|
58
96
|
|
|
59
|
-
target = module[function_name]
|
|
60
|
-
|
|
61
97
|
with open(file_path) as f:
|
|
62
98
|
source_code = f.read()
|
|
63
99
|
|
|
64
100
|
endpoints = ["/"]
|
|
65
101
|
if isinstance(target, type) and issubclass(target, App):
|
|
66
|
-
app_name = target.app_name
|
|
67
|
-
app_auth = target.app_auth
|
|
68
102
|
endpoints = target.get_endpoints() or ["/"]
|
|
69
103
|
target = wrap_app(target, host=host)
|
|
70
104
|
|
fal/workflows.py
CHANGED
|
@@ -19,7 +19,6 @@ from rich.syntax import Syntax
|
|
|
19
19
|
import fal
|
|
20
20
|
from fal import flags
|
|
21
21
|
from fal.exceptions import FalServerlessException
|
|
22
|
-
from fal.rest_client import REST_CLIENT
|
|
23
22
|
|
|
24
23
|
JSONType = Union[Dict[str, Any], List[Any], str, int, float, bool, None, "Leaf"]
|
|
25
24
|
SchemaType = Dict[str, Any]
|
|
@@ -372,6 +371,11 @@ class Workflow:
|
|
|
372
371
|
to_dict = to_json
|
|
373
372
|
|
|
374
373
|
def publish(self, title: str, *, is_public: bool = True):
|
|
374
|
+
from fal.api.client import SyncServerlessClient
|
|
375
|
+
|
|
376
|
+
client = SyncServerlessClient()
|
|
377
|
+
rest_client = client._create_rest_client()
|
|
378
|
+
|
|
375
379
|
workflow_contents = publish_workflow.TypedWorkflow(
|
|
376
380
|
name=self.name,
|
|
377
381
|
title=title,
|
|
@@ -379,7 +383,7 @@ class Workflow:
|
|
|
379
383
|
is_public=is_public,
|
|
380
384
|
)
|
|
381
385
|
published_workflow = publish_workflow.sync(
|
|
382
|
-
client=
|
|
386
|
+
client=rest_client,
|
|
383
387
|
json_body=workflow_contents,
|
|
384
388
|
)
|
|
385
389
|
if isinstance(published_workflow, Exception):
|