fal 1.50.1__py3-none-any.whl → 1.57.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fal/cli/runners.py CHANGED
@@ -15,7 +15,10 @@ from datetime import datetime, timedelta, timezone
15
15
  from http import HTTPStatus
16
16
  from queue import Empty, Queue
17
17
  from threading import Thread
18
- from typing import Iterator, List
18
+ from typing import TYPE_CHECKING, Iterator, List
19
+
20
+ if TYPE_CHECKING:
21
+ from openapi_fal_rest.client import Client
19
22
 
20
23
  import grpc
21
24
  import httpx
@@ -24,7 +27,6 @@ from rich.console import Console
24
27
  from structlog.typing import EventDict
25
28
 
26
29
  from fal.api.client import SyncServerlessClient
27
- from fal.rest_client import REST_CLIENT
28
30
  from fal.sdk import RunnerInfo, RunnerState
29
31
 
30
32
  from .parser import FalClientParser, SinceAction, get_output_parser
@@ -154,8 +156,11 @@ def _shell(args):
154
156
 
155
157
  def stream_inputs():
156
158
  """Generate input stream for gRPC."""
157
- # Send initial message with runner_id and terminal size
158
- msg = isolate_proto.ShellRunnerInput(runner_id=runner_id)
159
+ # Send initial message with runner_id
160
+ yield isolate_proto.ShellRunnerInput(runner_id=runner_id)
161
+
162
+ # Send terminal size
163
+ msg = isolate_proto.ShellRunnerInput()
159
164
  h, w = _get_tty_size()
160
165
  msg.tty_size.height = h
161
166
  msg.tty_size.width = w
@@ -238,13 +243,24 @@ def _list(args):
238
243
  if args.state:
239
244
  states = set(args.state)
240
245
  if "all" not in states:
241
- runners = [r for r in runners if r.state.value in states]
246
+ runners = [
247
+ r
248
+ for r in runners
249
+ if r.state.value.lower() in states
250
+ or (
251
+ "terminated" in states and r.state.value.lower() == "dead"
252
+ ) # TODO for backwards compatibility. remove later
253
+ ]
242
254
 
243
255
  pending_runners = [
244
256
  runner for runner in runners if runner.state == RunnerState.PENDING
245
257
  ]
246
258
  setup_runners = [runner for runner in runners if runner.state == RunnerState.SETUP]
247
- dead_runners = [runner for runner in runners if runner.state == RunnerState.DEAD]
259
+ terminated_runners = [
260
+ runner
261
+ for runner in runners
262
+ if runner.state == RunnerState.DEAD or runner.state == RunnerState.TERMINATED
263
+ ]
248
264
  if args.output == "pretty":
249
265
  args.console.print(
250
266
  "Runners: "
@@ -252,7 +268,7 @@ def _list(args):
252
268
  len(runners)
253
269
  - len(pending_runners)
254
270
  - len(setup_runners)
255
- - len(dead_runners)
271
+ - len(terminated_runners)
256
272
  )
257
273
  )
258
274
  args.console.print(f"Runners Pending: {len(pending_runners)}")
@@ -312,14 +328,14 @@ def _add_list_parser(subparsers, parents):
312
328
  action=SinceAction,
313
329
  limit="1 day",
314
330
  help=(
315
- "Show dead runners since the given time. "
331
+ "Show terminated runners since the given time. "
316
332
  "Accepts 'now', relative like '30m', '1h', '1d', "
317
333
  "or an ISO timestamp. Max 24 hours."
318
334
  ),
319
335
  )
320
336
  parser.add_argument(
321
337
  "--state",
322
- choices=["all", "running", "pending", "setup", "dead"],
338
+ choices=["all", "running", "pending", "setup", "terminated"],
323
339
  nargs="+",
324
340
  default=None,
325
341
  help=("Filter by runner state(s). Choose one or more, or 'all'(default)."),
@@ -378,10 +394,10 @@ class RestRunnerInfo:
378
394
  ended_at: datetime | None
379
395
 
380
396
 
381
- def _get_runner_info(runner_id: str) -> RestRunnerInfo:
382
- headers = REST_CLIENT.get_headers()
397
+ def _get_runner_info(rest_client: Client, runner_id: str) -> RestRunnerInfo:
398
+ headers = rest_client.get_headers()
383
399
  with httpx.Client(
384
- base_url=REST_CLIENT.base_url, headers=headers, timeout=30
400
+ base_url=rest_client.base_url, headers=headers, timeout=30
385
401
  ) as client:
386
402
  resp = client.get(f"/runners/{runner_id}")
387
403
  if resp.status_code == HTTPStatus.NOT_FOUND:
@@ -415,16 +431,19 @@ def _get_runner_info(runner_id: str) -> RestRunnerInfo:
415
431
 
416
432
 
417
433
  def _stream_logs(
418
- base_params: dict[str, str], since: datetime | None, until: datetime | None
434
+ rest_client: Client,
435
+ base_params: dict[str, str],
436
+ since: datetime | None,
437
+ until: datetime | None,
419
438
  ) -> Iterator[dict]:
420
- headers = REST_CLIENT.get_headers()
439
+ headers = rest_client.get_headers()
421
440
  params: dict[str, str] = base_params.copy()
422
441
  if since is not None:
423
442
  params["since"] = _to_iso_naive(since)
424
443
  if until is not None:
425
444
  params["until"] = _to_iso_naive(until)
426
445
  with httpx.Client(
427
- base_url=REST_CLIENT.base_url,
446
+ base_url=rest_client.base_url,
428
447
  headers=headers,
429
448
  timeout=None,
430
449
  follow_redirects=True,
@@ -451,11 +470,14 @@ DEFAULT_PAGE_SIZE = 1000
451
470
 
452
471
 
453
472
  def _iter_logs(
454
- base_params: dict[str, str], start: datetime | None, end: datetime | None
473
+ rest_client: Client,
474
+ base_params: dict[str, str],
475
+ start: datetime | None,
476
+ end: datetime | None,
455
477
  ) -> Iterator[dict]:
456
- headers = REST_CLIENT.get_headers()
478
+ headers = rest_client.get_headers()
457
479
  with httpx.Client(
458
- base_url=REST_CLIENT.base_url,
480
+ base_url=rest_client.base_url,
459
481
  headers=headers,
460
482
  timeout=300,
461
483
  follow_redirects=True,
@@ -478,6 +500,7 @@ def _iter_logs(
478
500
 
479
501
 
480
502
  def _get_logs(
503
+ rest_client: Client,
481
504
  params: dict[str, str],
482
505
  since: datetime | None,
483
506
  until: datetime | None,
@@ -486,12 +509,12 @@ def _get_logs(
486
509
  oldest: bool = False,
487
510
  ) -> Iterator[dict]:
488
511
  if lines_count is None:
489
- yield from _iter_logs(params, since, until)
512
+ yield from _iter_logs(rest_client, params, since, until)
490
513
  return
491
514
 
492
515
  if oldest:
493
516
  produced = 0
494
- for log in _iter_logs(params, since, until):
517
+ for log in _iter_logs(rest_client, params, since, until):
495
518
  if produced >= lines_count:
496
519
  break
497
520
  produced += 1
@@ -500,7 +523,7 @@ def _get_logs(
500
523
 
501
524
  # newest tail: collect into a fixed-size deque, then yield
502
525
  tail: deque[dict] = deque(maxlen=lines_count)
503
- for log in _iter_logs(params, since, until):
526
+ for log in _iter_logs(rest_client, params, since, until):
504
527
  tail.append(log)
505
528
  for log in tail:
506
529
  yield log
@@ -543,7 +566,9 @@ def _logs(args):
543
566
  if args.search is not None:
544
567
  params["search"] = args.search
545
568
 
546
- runner_info = _get_runner_info(args.id)
569
+ client = SyncServerlessClient(host=args.host, team=args.team)
570
+ rest_client = client._create_rest_client()
571
+ runner_info = _get_runner_info(rest_client, args.id)
547
572
  follow: bool = args.follow
548
573
  since = args.since
549
574
  if follow:
@@ -589,9 +614,11 @@ def _logs(args):
589
614
  args.parser.error("Invalid -n|--lines value. Use an integer or +integer.")
590
615
 
591
616
  if follow:
592
- logs_gen = _stream_logs(params, since, until)
617
+ logs_gen = _stream_logs(rest_client, params, since, until)
593
618
  else:
594
- logs_gen = _get_logs(params, since, until, lines_count, oldest=lines_oldest)
619
+ logs_gen = _get_logs(
620
+ rest_client, params, since, until, lines_count, oldest=lines_oldest
621
+ )
595
622
 
596
623
  printer = LogPrinter(args.console)
597
624
 
fal/cli/secrets.py CHANGED
@@ -1,12 +1,12 @@
1
- from ._utils import get_client
1
+ from fal.api.client import SyncServerlessClient
2
+
2
3
  from .parser import DictAction, FalClientParser
3
4
 
4
5
 
5
6
  def _set(args):
6
- client = get_client(args.host, args.team)
7
- with client.connect() as connection:
8
- for name, value in args.secrets.items():
9
- connection.set_secret(name, value)
7
+ client = SyncServerlessClient(host=args.host, team=args.team)
8
+ for name, value in args.secrets.items():
9
+ client.secrets.set(name, value)
10
10
 
11
11
 
12
12
  def _add_set_parser(subparsers, parents):
@@ -33,32 +33,31 @@ def _add_set_parser(subparsers, parents):
33
33
  def _list(args):
34
34
  import json
35
35
 
36
- client = get_client(args.host, args.team)
37
- with client.connect() as connection:
38
- secrets = list(connection.list_secrets())
36
+ client = SyncServerlessClient(host=args.host, team=args.team)
37
+ secrets = client.secrets.list()
39
38
 
40
- if args.output == "json":
41
- json_secrets = [
42
- {
43
- "name": secret.name,
44
- "created_at": str(secret.created_at),
45
- }
46
- for secret in secrets
47
- ]
48
- args.console.print(json.dumps({"secrets": json_secrets}))
49
- elif args.output == "pretty":
50
- from rich.table import Table
39
+ if args.output == "json":
40
+ json_secrets = [
41
+ {
42
+ "name": secret.name,
43
+ "created_at": str(secret.created_at),
44
+ }
45
+ for secret in secrets
46
+ ]
47
+ args.console.print(json.dumps({"secrets": json_secrets}))
48
+ elif args.output == "pretty":
49
+ from rich.table import Table
51
50
 
52
- table = Table()
53
- table.add_column("Secret Name")
54
- table.add_column("Created At")
51
+ table = Table()
52
+ table.add_column("Secret Name")
53
+ table.add_column("Created At")
55
54
 
56
- for secret in secrets:
57
- table.add_row(secret.name, str(secret.created_at))
55
+ for secret in secrets:
56
+ table.add_row(secret.name, str(secret.created_at))
58
57
 
59
- args.console.print(table)
60
- else:
61
- raise AssertionError(f"Invalid output format: {args.output}")
58
+ args.console.print(table)
59
+ else:
60
+ raise AssertionError(f"Invalid output format: {args.output}")
62
61
 
63
62
 
64
63
  def _add_list_parser(subparsers, parents):
@@ -75,9 +74,8 @@ def _add_list_parser(subparsers, parents):
75
74
 
76
75
 
77
76
  def _unset(args):
78
- client = get_client(args.host, args.team)
79
- with client.connect() as connection:
80
- connection.delete_secret(args.secret)
77
+ client = SyncServerlessClient(host=args.host, team=args.team)
78
+ client.secrets.unset(args.secret)
81
79
 
82
80
 
83
81
  def _add_unset_parser(subparsers, parents):
fal/files.py CHANGED
@@ -4,7 +4,7 @@ import os
4
4
  import posixpath
5
5
  from concurrent.futures import ThreadPoolExecutor
6
6
  from functools import cached_property
7
- from typing import TYPE_CHECKING
7
+ from typing import TYPE_CHECKING, Optional
8
8
 
9
9
  from fsspec import AbstractFileSystem
10
10
 
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
14
14
  USER_AGENT = "fal-sdk/1.14.0 (python)"
15
15
  MULTIPART_THRESHOLD = 10 * 1024 * 1024 # 10MB
16
16
  MULTIPART_CHUNK_SIZE = 10 * 1024 * 1024 # 10MB
17
- MULTIPART_WORKERS = 2 # only 2 because our REST is currently struggling with more
17
+ MULTIPART_WORKERS = 10
18
18
 
19
19
 
20
20
  def _compute_md5(lpath, chunk_size=8192):
@@ -28,20 +28,43 @@ def _compute_md5(lpath, chunk_size=8192):
28
28
 
29
29
 
30
30
  class FalFileSystem(AbstractFileSystem):
31
+ def __init__(
32
+ self,
33
+ *,
34
+ host: Optional[str] = None,
35
+ team: Optional[str] = None,
36
+ profile: Optional[str] = None,
37
+ **kwargs,
38
+ ):
39
+ self.host = host
40
+ self.team = team
41
+ self.profile = profile
42
+ super().__init__(**kwargs)
43
+
31
44
  @cached_property
32
45
  def _client(self) -> "httpx.Client":
33
- from httpx import Client
46
+ from httpx import Client, Timeout
47
+
48
+ from fal.api.client import SyncServerlessClient
34
49
 
35
- from fal.flags import REST_URL
36
- from fal.sdk import get_default_credentials
50
+ client = SyncServerlessClient(
51
+ host=self.host,
52
+ team=self.team,
53
+ profile=self.profile,
54
+ )
37
55
 
38
- creds = get_default_credentials()
39
56
  return Client(
40
- base_url=REST_URL,
57
+ base_url=client._rest_url,
41
58
  headers={
42
- **creds.to_headers(),
59
+ **client._credentials.to_headers(),
43
60
  "User-Agent": USER_AGENT,
44
61
  },
62
+ timeout=Timeout(
63
+ connect=30,
64
+ read=4 * 60, # multipart complete can take time
65
+ write=5 * 60, # we could be uploading slowly
66
+ pool=30,
67
+ ),
45
68
  )
46
69
 
47
70
  def _request(self, method, path, **kwargs):
@@ -224,6 +247,7 @@ class FalFileSystem(AbstractFileSystem):
224
247
  "POST",
225
248
  f"/files/file/url/{abs_rpath}",
226
249
  json={"url": url},
250
+ timeout=10 * 60, # 10 minutes in seconds
227
251
  )
228
252
  self.dircache.clear()
229
253
 
fal/sdk.py CHANGED
@@ -246,6 +246,7 @@ class ApplicationInfo:
246
246
  min_concurrency: int
247
247
  concurrency_buffer: int
248
248
  concurrency_buffer_perc: int
249
+ scaling_delay: int
249
250
  machine_types: list[str]
250
251
  request_timeout: int
251
252
  startup_timeout: int
@@ -265,6 +266,7 @@ class AliasInfo:
265
266
  min_concurrency: int
266
267
  concurrency_buffer: int
267
268
  concurrency_buffer_perc: int
269
+ scaling_delay: int
268
270
  machine_types: list[str]
269
271
  request_timeout: int
270
272
  startup_timeout: int
@@ -272,27 +274,14 @@ class AliasInfo:
272
274
 
273
275
 
274
276
  class RunnerState(Enum):
275
- RUNNING = "running"
276
- PENDING = "pending"
277
- SETUP = "setup"
278
- DOCKER_PULL = "docker_pull"
279
- DEAD = "dead"
280
- UNKNOWN = "unknown"
281
-
282
- @staticmethod
283
- def from_proto(proto: isolate_proto.RunnerInfo.State) -> RunnerState:
284
- if proto is isolate_proto.RunnerInfo.State.RUNNING:
285
- return RunnerState.RUNNING
286
- elif proto is isolate_proto.RunnerInfo.State.PENDING:
287
- return RunnerState.PENDING
288
- elif proto is isolate_proto.RunnerInfo.State.SETUP:
289
- return RunnerState.SETUP
290
- elif proto is isolate_proto.RunnerInfo.State.DEAD:
291
- return RunnerState.DEAD
292
- elif proto is isolate_proto.RunnerInfo.State.DOCKER_PULL:
293
- return RunnerState.DOCKER_PULL
294
- else:
295
- return RunnerState.UNKNOWN
277
+ RUNNING = "RUNNING"
278
+ PENDING = "PENDING"
279
+ SETUP = "SETUP"
280
+ DOCKER_PULL = "DOCKER_PULL"
281
+ DEAD = "DEAD"
282
+ DRAINING = "DRAINING"
283
+ TERMINATING = "TERMINATING"
284
+ TERMINATED = "TERMINATED"
296
285
 
297
286
 
298
287
  @dataclass
@@ -414,6 +403,7 @@ def _from_grpc_application_info(
414
403
  min_concurrency=message.min_concurrency,
415
404
  concurrency_buffer=message.concurrency_buffer,
416
405
  concurrency_buffer_perc=message.concurrency_buffer_perc,
406
+ scaling_delay=message.scaling_delay_seconds,
417
407
  machine_types=list(message.machine_types),
418
408
  request_timeout=message.request_timeout,
419
409
  startup_timeout=message.startup_timeout,
@@ -444,6 +434,7 @@ def _from_grpc_alias_info(message: isolate_proto.AliasInfo) -> AliasInfo:
444
434
  min_concurrency=message.min_concurrency,
445
435
  concurrency_buffer=message.concurrency_buffer,
446
436
  concurrency_buffer_perc=message.concurrency_buffer_perc,
437
+ scaling_delay=message.scaling_delay_seconds,
447
438
  machine_types=list(message.machine_types),
448
439
  request_timeout=message.request_timeout,
449
440
  startup_timeout=message.startup_timeout,
@@ -468,7 +459,7 @@ def _from_grpc_runner_info(message: isolate_proto.RunnerInfo) -> RunnerInfo:
468
459
  external_metadata=external_metadata,
469
460
  revision=message.revision,
470
461
  alias=message.alias,
471
- state=RunnerState.from_proto(message.state),
462
+ state=RunnerState(isolate_proto.RunnerInfo.State.Name(message.state)),
472
463
  )
473
464
 
474
465
 
@@ -537,8 +528,10 @@ class MachineRequirements:
537
528
  min_concurrency: int | None = None
538
529
  concurrency_buffer: int | None = None
539
530
  concurrency_buffer_perc: int | None = None
531
+ scaling_delay: int | None = None
540
532
  request_timeout: int | None = None
541
533
  startup_timeout: int | None = None
534
+ valid_regions: list[str] | None = None
542
535
 
543
536
  def __post_init__(self):
544
537
  if isinstance(self.machine_types, str):
@@ -633,6 +626,7 @@ class FalServerlessConnection:
633
626
  auth_mode: Optional[AuthModeLiteral] = None,
634
627
  *,
635
628
  source_code: str | None = None,
629
+ health_check_path: str | None = None,
636
630
  serialization_method: str = _DEFAULT_SERIALIZATION_METHOD,
637
631
  machine_requirements: MachineRequirements | None = None,
638
632
  metadata: dict[str, Any] | None = None,
@@ -640,7 +634,7 @@ class FalServerlessConnection:
640
634
  scale: bool = True,
641
635
  private_logs: bool = False,
642
636
  files: list[File] | None = None,
643
- ) -> Iterator[isolate_proto.RegisterApplicationResult]:
637
+ ) -> Iterator[RegisterApplicationResult]:
644
638
  wrapped_function = to_serialized_object(function, serialization_method)
645
639
  if machine_requirements:
646
640
  wrapped_requirements = isolate_proto.MachineRequirements(
@@ -659,9 +653,11 @@ class FalServerlessConnection:
659
653
  min_concurrency=machine_requirements.min_concurrency,
660
654
  concurrency_buffer=machine_requirements.concurrency_buffer,
661
655
  concurrency_buffer_perc=machine_requirements.concurrency_buffer_perc,
656
+ scaling_delay_seconds=machine_requirements.scaling_delay,
662
657
  max_multiplexing=machine_requirements.max_multiplexing,
663
658
  request_timeout=machine_requirements.request_timeout,
664
659
  startup_timeout=machine_requirements.startup_timeout,
660
+ valid_regions=machine_requirements.valid_regions,
665
661
  )
666
662
  else:
667
663
  wrapped_requirements = None
@@ -702,6 +698,7 @@ class FalServerlessConnection:
702
698
  private_logs=private_logs,
703
699
  files=files,
704
700
  source_code=source_code,
701
+ health_check_path=health_check_path,
705
702
  )
706
703
  for partial_result in self.stub.RegisterApplication(request):
707
704
  yield from_grpc(partial_result)
@@ -718,6 +715,7 @@ class FalServerlessConnection:
718
715
  min_concurrency: int | None = None,
719
716
  concurrency_buffer: int | None = None,
720
717
  concurrency_buffer_perc: int | None = None,
718
+ scaling_delay: int | None = None,
721
719
  request_timeout: int | None = None,
722
720
  startup_timeout: int | None = None,
723
721
  valid_regions: list[str] | None = None,
@@ -731,6 +729,7 @@ class FalServerlessConnection:
731
729
  min_concurrency=min_concurrency,
732
730
  concurrency_buffer=concurrency_buffer,
733
731
  concurrency_buffer_perc=concurrency_buffer_perc,
732
+ scaling_delay_seconds=scaling_delay,
734
733
  request_timeout=request_timeout,
735
734
  startup_timeout=startup_timeout,
736
735
  valid_regions=valid_regions,
@@ -757,6 +756,17 @@ class FalServerlessConnection:
757
756
  request = isolate_proto.DeleteApplicationRequest(application_id=application_id)
758
757
  self.stub.DeleteApplication(request)
759
758
 
759
+ def rollout_application(
760
+ self,
761
+ application_name: str,
762
+ force: bool = False,
763
+ ) -> None:
764
+ request = isolate_proto.RolloutApplicationRequest(
765
+ application_name=application_name,
766
+ force=force,
767
+ )
768
+ self.stub.RolloutApplication(request)
769
+
760
770
  def run(
761
771
  self,
762
772
  function: Callable[..., ResultT],
@@ -786,8 +796,10 @@ class FalServerlessConnection:
786
796
  min_concurrency=machine_requirements.min_concurrency,
787
797
  concurrency_buffer=machine_requirements.concurrency_buffer,
788
798
  concurrency_buffer_perc=machine_requirements.concurrency_buffer_perc,
799
+ scaling_delay_seconds=machine_requirements.scaling_delay,
789
800
  request_timeout=machine_requirements.request_timeout,
790
801
  startup_timeout=machine_requirements.startup_timeout,
802
+ valid_regions=machine_requirements.valid_regions,
791
803
  )
792
804
  else:
793
805
  wrapped_requirements = None
fal/sync.py CHANGED
@@ -4,21 +4,21 @@ import hashlib
4
4
  import os
5
5
  import zipfile
6
6
  from pathlib import Path
7
+ from typing import TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ from openapi_fal_rest.client import Client
7
11
 
8
- import openapi_fal_rest.api.files.check_dir_hash as check_dir_hash_api
9
- import openapi_fal_rest.api.files.upload_local_file as upload_local_file_api
10
- import openapi_fal_rest.models.body_upload_local_file as upload_file_model
11
- import openapi_fal_rest.models.hash_check as hash_check_model
12
- import openapi_fal_rest.types as rest_types
13
12
  from pathspec import PathSpec
14
13
 
15
- from fal.rest_client import REST_CLIENT
16
14
 
15
+ def _check_hash(client: Client, target_path: str, hash_string: str) -> bool:
16
+ import openapi_fal_rest.api.files.check_dir_hash as check_dir_hash_api
17
+ import openapi_fal_rest.models.hash_check as hash_check_model
17
18
 
18
- def _check_hash(target_path: str, hash_string: str) -> bool:
19
19
  response = check_dir_hash_api.sync_detailed(
20
20
  target_path,
21
- client=REST_CLIENT,
21
+ client=client,
22
22
  json_body=hash_check_model.HashCheck(hash_string),
23
23
  )
24
24
 
@@ -26,7 +26,13 @@ def _check_hash(target_path: str, hash_string: str) -> bool:
26
26
  return response.status_code == 200 and res
27
27
 
28
28
 
29
- def _upload_file(source_path: str, target_path: str, unzip: bool = False):
29
+ def _upload_file(
30
+ client: Client, source_path: str, target_path: str, unzip: bool = False
31
+ ):
32
+ import openapi_fal_rest.api.files.upload_local_file as upload_local_file_api
33
+ import openapi_fal_rest.models.body_upload_local_file as upload_file_model
34
+ import openapi_fal_rest.types as rest_types
35
+
30
36
  with open(source_path, "rb") as file_to_upload:
31
37
  body = upload_file_model.BodyUploadLocalFile(
32
38
  rest_types.File(
@@ -39,7 +45,7 @@ def _upload_file(source_path: str, target_path: str, unzip: bool = False):
39
45
 
40
46
  response = upload_local_file_api.sync_detailed(
41
47
  target_path,
42
- client=REST_CLIENT,
48
+ client=client,
43
49
  unzip=unzip,
44
50
  multipart_data=body,
45
51
  )
@@ -94,6 +100,8 @@ def _zip_directory(dir_path: str, zip_path: str) -> None:
94
100
 
95
101
 
96
102
  def sync_dir(local_dir: str | Path, remote_dir: str, force_upload=False) -> str:
103
+ from fal.api.client import SyncServerlessClient
104
+
97
105
  local_dir_abs = os.path.expanduser(local_dir)
98
106
  if not os.path.isabs(remote_dir) or not remote_dir.startswith("/data"):
99
107
  raise ValueError(
@@ -106,9 +114,11 @@ def sync_dir(local_dir: str | Path, remote_dir: str, force_upload=False) -> str:
106
114
  # Compute the local directory hash
107
115
  local_hash = _compute_directory_hash(local_dir_abs)
108
116
 
117
+ client = SyncServerlessClient()._create_rest_client()
118
+
109
119
  print(f"Syncing {local_dir} with {remote_dir}...")
110
120
 
111
- if _check_hash(remote_dir, local_hash) and not force_upload:
121
+ if _check_hash(client, remote_dir, local_hash) and not force_upload:
112
122
  print(f"{remote_dir} already uploaded and matches {local_dir}")
113
123
  return remote_dir
114
124
 
@@ -121,7 +131,7 @@ def sync_dir(local_dir: str | Path, remote_dir: str, force_upload=False) -> str:
121
131
  _zip_directory(local_dir_abs, zip_path)
122
132
 
123
133
  # Upload the zipped directory to the serverless environment
124
- _upload_file(zip_path, remote_dir, unzip=True)
134
+ _upload_file(client, zip_path, remote_dir, unzip=True)
125
135
 
126
136
  os.remove(zip_path)
127
137
 
fal/toolkit/__init__.py CHANGED
@@ -1,6 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from fal.toolkit.audio.audio import Audio, AudioField
4
+ from fal.toolkit.compilation import (
5
+ get_gpu_type,
6
+ load_inductor_cache,
7
+ sync_inductor_cache,
8
+ synchronized_inductor_cache,
9
+ )
4
10
  from fal.toolkit.file import CompressedFile, File, FileField
5
11
  from fal.toolkit.image.image import Image, ImageField, ImageSizeInput, get_image_size
6
12
  from fal.toolkit.optimize import optimize
@@ -33,4 +39,8 @@ __all__ = [
33
39
  "clone_repository",
34
40
  "download_file",
35
41
  "download_model_weights",
42
+ "get_gpu_type",
43
+ "load_inductor_cache",
44
+ "sync_inductor_cache",
45
+ "synchronized_inductor_cache",
36
46
  ]