fal 1.49.1__py3-none-any.whl → 1.57.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fal/sdk.py CHANGED
@@ -246,6 +246,7 @@ class ApplicationInfo:
246
246
  min_concurrency: int
247
247
  concurrency_buffer: int
248
248
  concurrency_buffer_perc: int
249
+ scaling_delay: int
249
250
  machine_types: list[str]
250
251
  request_timeout: int
251
252
  startup_timeout: int
@@ -265,6 +266,7 @@ class AliasInfo:
265
266
  min_concurrency: int
266
267
  concurrency_buffer: int
267
268
  concurrency_buffer_perc: int
269
+ scaling_delay: int
268
270
  machine_types: list[str]
269
271
  request_timeout: int
270
272
  startup_timeout: int
@@ -272,27 +274,14 @@ class AliasInfo:
272
274
 
273
275
 
274
276
  class RunnerState(Enum):
275
- RUNNING = "running"
276
- PENDING = "pending"
277
- SETUP = "setup"
278
- DOCKER_PULL = "docker_pull"
279
- DEAD = "dead"
280
- UNKNOWN = "unknown"
281
-
282
- @staticmethod
283
- def from_proto(proto: isolate_proto.RunnerInfo.State) -> RunnerState:
284
- if proto is isolate_proto.RunnerInfo.State.RUNNING:
285
- return RunnerState.RUNNING
286
- elif proto is isolate_proto.RunnerInfo.State.PENDING:
287
- return RunnerState.PENDING
288
- elif proto is isolate_proto.RunnerInfo.State.SETUP:
289
- return RunnerState.SETUP
290
- elif proto is isolate_proto.RunnerInfo.State.DEAD:
291
- return RunnerState.DEAD
292
- elif proto is isolate_proto.RunnerInfo.State.DOCKER_PULL:
293
- return RunnerState.DOCKER_PULL
294
- else:
295
- return RunnerState.UNKNOWN
277
+ RUNNING = "RUNNING"
278
+ PENDING = "PENDING"
279
+ SETUP = "SETUP"
280
+ DOCKER_PULL = "DOCKER_PULL"
281
+ DEAD = "DEAD"
282
+ DRAINING = "DRAINING"
283
+ TERMINATING = "TERMINATING"
284
+ TERMINATED = "TERMINATED"
296
285
 
297
286
 
298
287
  @dataclass
@@ -414,6 +403,7 @@ def _from_grpc_application_info(
414
403
  min_concurrency=message.min_concurrency,
415
404
  concurrency_buffer=message.concurrency_buffer,
416
405
  concurrency_buffer_perc=message.concurrency_buffer_perc,
406
+ scaling_delay=message.scaling_delay_seconds,
417
407
  machine_types=list(message.machine_types),
418
408
  request_timeout=message.request_timeout,
419
409
  startup_timeout=message.startup_timeout,
@@ -444,6 +434,7 @@ def _from_grpc_alias_info(message: isolate_proto.AliasInfo) -> AliasInfo:
444
434
  min_concurrency=message.min_concurrency,
445
435
  concurrency_buffer=message.concurrency_buffer,
446
436
  concurrency_buffer_perc=message.concurrency_buffer_perc,
437
+ scaling_delay=message.scaling_delay_seconds,
447
438
  machine_types=list(message.machine_types),
448
439
  request_timeout=message.request_timeout,
449
440
  startup_timeout=message.startup_timeout,
@@ -468,7 +459,7 @@ def _from_grpc_runner_info(message: isolate_proto.RunnerInfo) -> RunnerInfo:
468
459
  external_metadata=external_metadata,
469
460
  revision=message.revision,
470
461
  alias=message.alias,
471
- state=RunnerState.from_proto(message.state),
462
+ state=RunnerState(isolate_proto.RunnerInfo.State.Name(message.state)),
472
463
  )
473
464
 
474
465
 
@@ -537,8 +528,10 @@ class MachineRequirements:
537
528
  min_concurrency: int | None = None
538
529
  concurrency_buffer: int | None = None
539
530
  concurrency_buffer_perc: int | None = None
531
+ scaling_delay: int | None = None
540
532
  request_timeout: int | None = None
541
533
  startup_timeout: int | None = None
534
+ valid_regions: list[str] | None = None
542
535
 
543
536
  def __post_init__(self):
544
537
  if isinstance(self.machine_types, str):
@@ -633,6 +626,7 @@ class FalServerlessConnection:
633
626
  auth_mode: Optional[AuthModeLiteral] = None,
634
627
  *,
635
628
  source_code: str | None = None,
629
+ health_check_path: str | None = None,
636
630
  serialization_method: str = _DEFAULT_SERIALIZATION_METHOD,
637
631
  machine_requirements: MachineRequirements | None = None,
638
632
  metadata: dict[str, Any] | None = None,
@@ -640,7 +634,7 @@ class FalServerlessConnection:
640
634
  scale: bool = True,
641
635
  private_logs: bool = False,
642
636
  files: list[File] | None = None,
643
- ) -> Iterator[isolate_proto.RegisterApplicationResult]:
637
+ ) -> Iterator[RegisterApplicationResult]:
644
638
  wrapped_function = to_serialized_object(function, serialization_method)
645
639
  if machine_requirements:
646
640
  wrapped_requirements = isolate_proto.MachineRequirements(
@@ -659,9 +653,11 @@ class FalServerlessConnection:
659
653
  min_concurrency=machine_requirements.min_concurrency,
660
654
  concurrency_buffer=machine_requirements.concurrency_buffer,
661
655
  concurrency_buffer_perc=machine_requirements.concurrency_buffer_perc,
656
+ scaling_delay_seconds=machine_requirements.scaling_delay,
662
657
  max_multiplexing=machine_requirements.max_multiplexing,
663
658
  request_timeout=machine_requirements.request_timeout,
664
659
  startup_timeout=machine_requirements.startup_timeout,
660
+ valid_regions=machine_requirements.valid_regions,
665
661
  )
666
662
  else:
667
663
  wrapped_requirements = None
@@ -702,6 +698,7 @@ class FalServerlessConnection:
702
698
  private_logs=private_logs,
703
699
  files=files,
704
700
  source_code=source_code,
701
+ health_check_path=health_check_path,
705
702
  )
706
703
  for partial_result in self.stub.RegisterApplication(request):
707
704
  yield from_grpc(partial_result)
@@ -718,6 +715,7 @@ class FalServerlessConnection:
718
715
  min_concurrency: int | None = None,
719
716
  concurrency_buffer: int | None = None,
720
717
  concurrency_buffer_perc: int | None = None,
718
+ scaling_delay: int | None = None,
721
719
  request_timeout: int | None = None,
722
720
  startup_timeout: int | None = None,
723
721
  valid_regions: list[str] | None = None,
@@ -731,6 +729,7 @@ class FalServerlessConnection:
731
729
  min_concurrency=min_concurrency,
732
730
  concurrency_buffer=concurrency_buffer,
733
731
  concurrency_buffer_perc=concurrency_buffer_perc,
732
+ scaling_delay_seconds=scaling_delay,
734
733
  request_timeout=request_timeout,
735
734
  startup_timeout=startup_timeout,
736
735
  valid_regions=valid_regions,
@@ -757,6 +756,17 @@ class FalServerlessConnection:
757
756
  request = isolate_proto.DeleteApplicationRequest(application_id=application_id)
758
757
  self.stub.DeleteApplication(request)
759
758
 
759
+ def rollout_application(
760
+ self,
761
+ application_name: str,
762
+ force: bool = False,
763
+ ) -> None:
764
+ request = isolate_proto.RolloutApplicationRequest(
765
+ application_name=application_name,
766
+ force=force,
767
+ )
768
+ self.stub.RolloutApplication(request)
769
+
760
770
  def run(
761
771
  self,
762
772
  function: Callable[..., ResultT],
@@ -786,8 +796,10 @@ class FalServerlessConnection:
786
796
  min_concurrency=machine_requirements.min_concurrency,
787
797
  concurrency_buffer=machine_requirements.concurrency_buffer,
788
798
  concurrency_buffer_perc=machine_requirements.concurrency_buffer_perc,
799
+ scaling_delay_seconds=machine_requirements.scaling_delay,
789
800
  request_timeout=machine_requirements.request_timeout,
790
801
  startup_timeout=machine_requirements.startup_timeout,
802
+ valid_regions=machine_requirements.valid_regions,
791
803
  )
792
804
  else:
793
805
  wrapped_requirements = None
@@ -884,6 +896,10 @@ class FalServerlessConnection:
884
896
  for secret in response.secrets
885
897
  ]
886
898
 
899
+ def stop_runner(self, runner_id: str) -> None:
900
+ request = isolate_proto.StopRunnerRequest(runner_id=runner_id)
901
+ self.stub.StopRunner(request)
902
+
887
903
  def kill_runner(self, runner_id: str) -> None:
888
904
  request = isolate_proto.KillRunnerRequest(runner_id=runner_id)
889
905
  self.stub.KillRunner(request)
fal/sync.py CHANGED
@@ -4,21 +4,21 @@ import hashlib
4
4
  import os
5
5
  import zipfile
6
6
  from pathlib import Path
7
+ from typing import TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ from openapi_fal_rest.client import Client
7
11
 
8
- import openapi_fal_rest.api.files.check_dir_hash as check_dir_hash_api
9
- import openapi_fal_rest.api.files.upload_local_file as upload_local_file_api
10
- import openapi_fal_rest.models.body_upload_local_file as upload_file_model
11
- import openapi_fal_rest.models.hash_check as hash_check_model
12
- import openapi_fal_rest.types as rest_types
13
12
  from pathspec import PathSpec
14
13
 
15
- from fal.rest_client import REST_CLIENT
16
14
 
15
+ def _check_hash(client: Client, target_path: str, hash_string: str) -> bool:
16
+ import openapi_fal_rest.api.files.check_dir_hash as check_dir_hash_api
17
+ import openapi_fal_rest.models.hash_check as hash_check_model
17
18
 
18
- def _check_hash(target_path: str, hash_string: str) -> bool:
19
19
  response = check_dir_hash_api.sync_detailed(
20
20
  target_path,
21
- client=REST_CLIENT,
21
+ client=client,
22
22
  json_body=hash_check_model.HashCheck(hash_string),
23
23
  )
24
24
 
@@ -26,7 +26,13 @@ def _check_hash(target_path: str, hash_string: str) -> bool:
26
26
  return response.status_code == 200 and res
27
27
 
28
28
 
29
- def _upload_file(source_path: str, target_path: str, unzip: bool = False):
29
+ def _upload_file(
30
+ client: Client, source_path: str, target_path: str, unzip: bool = False
31
+ ):
32
+ import openapi_fal_rest.api.files.upload_local_file as upload_local_file_api
33
+ import openapi_fal_rest.models.body_upload_local_file as upload_file_model
34
+ import openapi_fal_rest.types as rest_types
35
+
30
36
  with open(source_path, "rb") as file_to_upload:
31
37
  body = upload_file_model.BodyUploadLocalFile(
32
38
  rest_types.File(
@@ -39,7 +45,7 @@ def _upload_file(source_path: str, target_path: str, unzip: bool = False):
39
45
 
40
46
  response = upload_local_file_api.sync_detailed(
41
47
  target_path,
42
- client=REST_CLIENT,
48
+ client=client,
43
49
  unzip=unzip,
44
50
  multipart_data=body,
45
51
  )
@@ -94,6 +100,8 @@ def _zip_directory(dir_path: str, zip_path: str) -> None:
94
100
 
95
101
 
96
102
  def sync_dir(local_dir: str | Path, remote_dir: str, force_upload=False) -> str:
103
+ from fal.api.client import SyncServerlessClient
104
+
97
105
  local_dir_abs = os.path.expanduser(local_dir)
98
106
  if not os.path.isabs(remote_dir) or not remote_dir.startswith("/data"):
99
107
  raise ValueError(
@@ -106,9 +114,11 @@ def sync_dir(local_dir: str | Path, remote_dir: str, force_upload=False) -> str:
106
114
  # Compute the local directory hash
107
115
  local_hash = _compute_directory_hash(local_dir_abs)
108
116
 
117
+ client = SyncServerlessClient()._create_rest_client()
118
+
109
119
  print(f"Syncing {local_dir} with {remote_dir}...")
110
120
 
111
- if _check_hash(remote_dir, local_hash) and not force_upload:
121
+ if _check_hash(client, remote_dir, local_hash) and not force_upload:
112
122
  print(f"{remote_dir} already uploaded and matches {local_dir}")
113
123
  return remote_dir
114
124
 
@@ -121,7 +131,7 @@ def sync_dir(local_dir: str | Path, remote_dir: str, force_upload=False) -> str:
121
131
  _zip_directory(local_dir_abs, zip_path)
122
132
 
123
133
  # Upload the zipped directory to the serverless environment
124
- _upload_file(zip_path, remote_dir, unzip=True)
134
+ _upload_file(client, zip_path, remote_dir, unzip=True)
125
135
 
126
136
  os.remove(zip_path)
127
137
 
fal/toolkit/__init__.py CHANGED
@@ -1,6 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from fal.toolkit.audio.audio import Audio, AudioField
4
+ from fal.toolkit.compilation import (
5
+ get_gpu_type,
6
+ load_inductor_cache,
7
+ sync_inductor_cache,
8
+ synchronized_inductor_cache,
9
+ )
4
10
  from fal.toolkit.file import CompressedFile, File, FileField
5
11
  from fal.toolkit.image.image import Image, ImageField, ImageSizeInput, get_image_size
6
12
  from fal.toolkit.optimize import optimize
@@ -33,4 +39,8 @@ __all__ = [
33
39
  "clone_repository",
34
40
  "download_file",
35
41
  "download_model_weights",
42
+ "get_gpu_type",
43
+ "load_inductor_cache",
44
+ "sync_inductor_cache",
45
+ "synchronized_inductor_cache",
36
46
  ]
@@ -0,0 +1,220 @@
1
+ """PyTorch compilation cache management utilities.
2
+
3
+ This module provides utilities for managing PyTorch Inductor compilation caches
4
+ across workers. When using torch.compile(), PyTorch generates optimized CUDA kernels
5
+ on first run, which can take 20-30 seconds. By sharing these compiled kernels across
6
+ workers, subsequent workers can load pre-compiled kernels in ~2 seconds instead of
7
+ recompiling.
8
+
9
+ Typical usage in a model setup:
10
+
11
+ Manual cache management:
12
+ dir_hash = load_inductor_cache("mymodel/v1")
13
+ self.model = torch.compile(self.model)
14
+ self.warmup() # Triggers compilation
15
+ sync_inductor_cache("mymodel/v1", dir_hash)
16
+
17
+ Context manager (automatic):
18
+ with synchronized_inductor_cache("mymodel/v1"):
19
+ self.model = torch.compile(self.model)
20
+ self.warmup() # Compilation is automatically synced after
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import hashlib
26
+ import os
27
+ import re
28
+ import shutil
29
+ import subprocess
30
+ import tempfile
31
+ from collections.abc import Iterator
32
+ from contextlib import contextmanager
33
+ from pathlib import Path
34
+
35
+ LOCAL_INDUCTOR_CACHE_DIR = Path("/tmp/inductor-cache/")
36
+ GLOBAL_INDUCTOR_CACHES_DIR = Path("/data/inductor-caches/")
37
+ PERSISTENT_TMP_DIR = Path("/data/tmp/")
38
+
39
+
40
+ def get_gpu_type() -> str:
41
+ """Detect the GPU type using nvidia-smi.
42
+
43
+ Returns:
44
+ The GPU model name (e.g., "H100", "A100", "H200") or "UNKNOWN"
45
+ if detection fails.
46
+
47
+ Example:
48
+ >>> gpu_type = get_gpu_type()
49
+ >>> print(f"Running on: {gpu_type}")
50
+ Running on: H100
51
+ """
52
+ try:
53
+ gpu_type_string = subprocess.run(
54
+ ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
55
+ capture_output=True,
56
+ text=True,
57
+ check=False,
58
+ ).stdout
59
+ matches = re.search(r"NVIDIA [a-zA-Z0-9]*", gpu_type_string)
60
+ # check for matches - if there are none, return "UNKNOWN"
61
+ if matches:
62
+ gpu_type = matches.group(0)
63
+ return gpu_type[7:] # remove `NVIDIA `
64
+ else:
65
+ return "UNKNOWN"
66
+ except Exception:
67
+ return "UNKNOWN"
68
+
69
+
70
+ def _dir_hash(path: Path) -> str:
71
+ """Compute a hash of all filenames in a directory (recursively).
72
+
73
+ Args:
74
+ path: Directory to hash.
75
+
76
+ Returns:
77
+ SHA256 hex digest of sorted filenames.
78
+ """
79
+ # Hash of all the filenames in the directory, recursively, sorted
80
+ filenames = {str(file) for file in path.rglob("*") if file.is_file()}
81
+ return hashlib.sha256("".join(sorted(filenames)).encode()).hexdigest()
82
+
83
+
84
+ def load_inductor_cache(cache_key: str) -> str:
85
+ """Load PyTorch Inductor compilation cache from global storage.
86
+
87
+ This function:
88
+ 1. Sets TORCHINDUCTOR_CACHE_DIR environment variable
89
+ 2. Looks for cached compiled kernels in GPU-specific global storage
90
+ 3. Unpacks the cache to local temporary directory
91
+ 4. Returns a hash of the unpacked directory for change detection
92
+
93
+ Args:
94
+ cache_key: Unique identifier for this cache (e.g., "flux/2", "mymodel/v1")
95
+
96
+ Returns:
97
+ Hash of the unpacked cache directory, or empty string if cache not found.
98
+
99
+ Example:
100
+ >>> dir_hash = load_inductor_cache("flux/2")
101
+ Found compilation cache at /data/inductor-caches/H100/flux/2.zip, unpacking...
102
+ Cache unpacked successfully.
103
+ """
104
+ gpu_type = get_gpu_type()
105
+
106
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(LOCAL_INDUCTOR_CACHE_DIR)
107
+
108
+ cache_source_path = GLOBAL_INDUCTOR_CACHES_DIR / gpu_type / f"{cache_key}.zip"
109
+
110
+ try:
111
+ next(cache_source_path.parent.iterdir(), None)
112
+ except Exception as e:
113
+ # Check for cache without gpu_type in the path
114
+ try:
115
+ old_source_path = GLOBAL_INDUCTOR_CACHES_DIR / f"{cache_key}.zip"
116
+ # Since old source exists, copy it over to global caches
117
+ os.makedirs(cache_source_path.parent, exist_ok=True)
118
+ shutil.copy(old_source_path, cache_source_path)
119
+ except Exception:
120
+ print(f"Failed to list: {e}")
121
+
122
+ if not cache_source_path.exists():
123
+ print(f"Couldn't find compilation cache at {cache_source_path}")
124
+ return ""
125
+
126
+ print(f"Found compilation cache at {cache_source_path}, unpacking...")
127
+ try:
128
+ shutil.unpack_archive(cache_source_path, LOCAL_INDUCTOR_CACHE_DIR)
129
+ except Exception as e:
130
+ print(f"Failed to unpack cache: {e}")
131
+ return ""
132
+
133
+ print("Cache unpacked successfully.")
134
+ return _dir_hash(LOCAL_INDUCTOR_CACHE_DIR)
135
+
136
+
137
+ def sync_inductor_cache(cache_key: str, unpacked_dir_hash: str) -> None:
138
+ """Sync updated PyTorch Inductor cache back to global storage.
139
+
140
+ This function:
141
+ 1. Checks if the local cache has changed (by comparing hashes)
142
+ 2. If changed, creates a zip archive of the new cache
143
+ 3. Saves it to GPU-specific global storage
144
+
145
+ Args:
146
+ cache_key: Unique identifier for this cache (same as used in
147
+ load_inductor_cache)
148
+ unpacked_dir_hash: Hash returned from load_inductor_cache
149
+ (for change detection)
150
+
151
+ Example:
152
+ >>> sync_inductor_cache("flux/2", dir_hash)
153
+ No changes in the cache dir, skipping sync.
154
+ # or
155
+ Changes detected in the cache dir, syncing...
156
+ """
157
+ gpu_type = get_gpu_type()
158
+ if not LOCAL_INDUCTOR_CACHE_DIR.exists():
159
+ print(f"No cache to sync, {LOCAL_INDUCTOR_CACHE_DIR} doesn't exist.")
160
+ return
161
+
162
+ if not GLOBAL_INDUCTOR_CACHES_DIR.exists():
163
+ GLOBAL_INDUCTOR_CACHES_DIR.mkdir(parents=True)
164
+
165
+ # If we updated the cache (the hashes of LOCAL_INDUCTOR_CACHE_DIR and
166
+ # unpacked_dir_hash differ), we pack the cache and move it to the
167
+ # global cache directory.
168
+ new_dir_hash = _dir_hash(LOCAL_INDUCTOR_CACHE_DIR)
169
+ if new_dir_hash == unpacked_dir_hash:
170
+ print("No changes in the cache dir, skipping sync.")
171
+ return
172
+
173
+ print("Changes detected in the cache dir, syncing...")
174
+ os.makedirs(
175
+ PERSISTENT_TMP_DIR, exist_ok=True
176
+ ) # Non fal-ai users do not have this directory
177
+ with tempfile.TemporaryDirectory(dir=PERSISTENT_TMP_DIR) as temp_dir:
178
+ temp_dir_path = Path(temp_dir)
179
+ cache_path = GLOBAL_INDUCTOR_CACHES_DIR / gpu_type / f"{cache_key}.zip"
180
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
181
+
182
+ try:
183
+ zip_name = shutil.make_archive(
184
+ str(temp_dir_path / "inductor_cache"),
185
+ "zip",
186
+ LOCAL_INDUCTOR_CACHE_DIR,
187
+ )
188
+ os.rename(
189
+ zip_name,
190
+ cache_path,
191
+ )
192
+ except Exception as e:
193
+ print(f"Failed to sync cache: {e}")
194
+ return
195
+
196
+
197
+ @contextmanager
198
+ def synchronized_inductor_cache(cache_key: str) -> Iterator[None]:
199
+ """Context manager to automatically load and sync PyTorch Inductor cache.
200
+
201
+ This wraps load_inductor_cache and sync_inductor_cache for convenience.
202
+ The cache is loaded on entry and synced on exit (even if an exception occurs).
203
+
204
+ Args:
205
+ cache_key: Unique identifier for this cache (e.g., "flux/2", "mymodel/v1")
206
+
207
+ Yields:
208
+ None
209
+
210
+ Example:
211
+ >>> with synchronized_inductor_cache("mymodel/v1"):
212
+ ... self.model = torch.compile(self.model)
213
+ ... self.warmup() # Compilation happens here
214
+ # Cache is automatically synced after the with block
215
+ """
216
+ unpacked_dir_hash = load_inductor_cache(cache_key)
217
+ try:
218
+ yield
219
+ finally:
220
+ sync_inductor_cache(cache_key, unpacked_dir_hash)
fal/toolkit/file/file.py CHANGED
@@ -16,8 +16,7 @@ from fastapi import Request
16
16
  if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."):
17
17
  IS_PYDANTIC_V2 = False
18
18
  else:
19
- from pydantic import GetCoreSchemaHandler
20
- from pydantic_core import CoreSchema, core_schema
19
+ from pydantic import model_validator
21
20
 
22
21
  IS_PYDANTIC_V2 = True
23
22
 
@@ -137,14 +136,16 @@ class File(BaseModel):
137
136
  # Pydantic custom validator for input type conversion
138
137
  if IS_PYDANTIC_V2:
139
138
 
139
+ @model_validator(mode="before")
140
140
  @classmethod
141
- def __get_pydantic_core_schema__(
142
- cls, source_type: Any, handler: GetCoreSchemaHandler
143
- ) -> CoreSchema:
144
- return core_schema.no_info_before_validator_function(
145
- cls.__convert_from_str,
146
- handler(source_type),
147
- )
141
+ def __convert_from_str_v2(cls, value: Any):
142
+ if isinstance(value, str):
143
+ parsed_url = urlparse(value)
144
+ if parsed_url.scheme not in ["http", "https", "data"]:
145
+ raise ValueError("value must be a valid URL")
146
+ # Return a mapping so the model can be constructed normally
147
+ return {"url": parsed_url.geturl()}
148
+ return value
148
149
 
149
150
  else:
150
151
 
fal/utils.py CHANGED
@@ -1,11 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING
4
5
 
5
- import fal._serialization
6
- from fal import App, wrap_app
7
-
8
- from .api import FalServerlessError, FalServerlessHost, IsolatedFunction
6
+ if TYPE_CHECKING:
7
+ from .api import FalServerlessHost, IsolatedFunction
9
8
 
10
9
 
11
10
  @dataclass
@@ -17,6 +16,62 @@ class LoadedFunction:
17
16
  source_code: str | None
18
17
 
19
18
 
19
+ def _find_target(
20
+ module: dict[str, object], function_name: str | None = None
21
+ ) -> tuple[object, str | None, str | None]:
22
+ import fal
23
+ from fal.api import FalServerlessError, IsolatedFunction
24
+
25
+ if function_name is not None:
26
+ if function_name not in module:
27
+ raise FalServerlessError(f"Function '{function_name}' not found in module")
28
+
29
+ target = module[function_name]
30
+
31
+ if isinstance(target, type) and issubclass(target, fal.App):
32
+ return target, target.app_name, target.app_auth
33
+
34
+ if isinstance(target, IsolatedFunction):
35
+ return target, function_name, None
36
+
37
+ raise FalServerlessError(
38
+ f"Function '{function_name}' is not a fal.App or a fal.function"
39
+ )
40
+
41
+ fal_apps = {
42
+ obj_name: obj
43
+ for obj_name, obj in module.items()
44
+ if isinstance(obj, type) and issubclass(obj, fal.App) and obj is not fal.App
45
+ }
46
+
47
+ if len(fal_apps) == 1:
48
+ [(function_name, target)] = fal_apps.items()
49
+ return target, target.app_name, target.app_auth
50
+ elif len(fal_apps) > 1:
51
+ raise FalServerlessError(
52
+ f"Multiple fal.Apps found in the module: {list(fal_apps.keys())}. "
53
+ "Please specify the name of the app."
54
+ )
55
+
56
+ fal_functions = {
57
+ obj_name: obj
58
+ for obj_name, obj in module.items()
59
+ if isinstance(obj, IsolatedFunction)
60
+ }
61
+
62
+ if len(fal_functions) == 0:
63
+ raise FalServerlessError("No fal.App or fal.function found in the module.")
64
+ elif len(fal_functions) > 1:
65
+ raise FalServerlessError(
66
+ "Multiple fal.functions found in the module: "
67
+ f"{list(fal_functions.keys())}. "
68
+ "Please specify the name of the function."
69
+ )
70
+
71
+ [(function_name, target)] = fal_functions.items()
72
+ return target, function_name, None
73
+
74
+
20
75
  def load_function_from(
21
76
  host: FalServerlessHost,
22
77
  file_path: str,
@@ -26,45 +81,24 @@ def load_function_from(
26
81
  import runpy
27
82
  import sys
28
83
 
84
+ import fal._serialization
85
+ from fal import App, wrap_app
86
+
87
+ from .api import FalServerlessError, IsolatedFunction
88
+
29
89
  sys.path.append(os.getcwd())
30
90
  module = runpy.run_path(file_path)
31
- if function_name is None:
32
- fal_objects = {
33
- obj_name: obj
34
- for obj_name, obj in module.items()
35
- if isinstance(obj, type) and issubclass(obj, fal.App) and obj is not fal.App
36
- }
37
- if len(fal_objects) == 0:
38
- raise FalServerlessError("No fal.App found in the module.")
39
- elif len(fal_objects) > 1:
40
- raise FalServerlessError(
41
- "Multiple fal.Apps found in the module. "
42
- "Please specify the name of the app."
43
- )
44
-
45
- [(function_name, obj)] = fal_objects.items()
46
- app_name = obj.app_name
47
- app_auth = obj.app_auth
48
- else:
49
- app_name = None
50
- app_auth = None
51
-
52
- if function_name not in module:
53
- raise FalServerlessError(f"Function '{function_name}' not found in module")
91
+ target, app_name, app_auth = _find_target(module, function_name)
54
92
 
55
93
  # The module for the function is set to <run_path> when runpy is used, in which
56
94
  # case we want to manually include the package it is defined in.
57
95
  fal._serialization.include_package_from_path(file_path)
58
96
 
59
- target = module[function_name]
60
-
61
97
  with open(file_path) as f:
62
98
  source_code = f.read()
63
99
 
64
100
  endpoints = ["/"]
65
101
  if isinstance(target, type) and issubclass(target, App):
66
- app_name = target.app_name
67
- app_auth = target.app_auth
68
102
  endpoints = target.get_endpoints() or ["/"]
69
103
  target = wrap_app(target, host=host)
70
104
 
fal/workflows.py CHANGED
@@ -19,7 +19,6 @@ from rich.syntax import Syntax
19
19
  import fal
20
20
  from fal import flags
21
21
  from fal.exceptions import FalServerlessException
22
- from fal.rest_client import REST_CLIENT
23
22
 
24
23
  JSONType = Union[Dict[str, Any], List[Any], str, int, float, bool, None, "Leaf"]
25
24
  SchemaType = Dict[str, Any]
@@ -372,6 +371,11 @@ class Workflow:
372
371
  to_dict = to_json
373
372
 
374
373
  def publish(self, title: str, *, is_public: bool = True):
374
+ from fal.api.client import SyncServerlessClient
375
+
376
+ client = SyncServerlessClient()
377
+ rest_client = client._create_rest_client()
378
+
375
379
  workflow_contents = publish_workflow.TypedWorkflow(
376
380
  name=self.name,
377
381
  title=title,
@@ -379,7 +383,7 @@ class Workflow:
379
383
  is_public=is_public,
380
384
  )
381
385
  published_workflow = publish_workflow.sync(
382
- client=REST_CLIENT,
386
+ client=rest_client,
383
387
  json_body=workflow_contents,
384
388
  )
385
389
  if isinstance(published_workflow, Exception):