fal 1.3.3__py3-none-any.whl → 1.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

fal/apps.py CHANGED
@@ -4,15 +4,19 @@ import json
4
4
  import time
5
5
  from contextlib import contextmanager
6
6
  from dataclasses import dataclass, field
7
- from typing import Any, Iterator
7
+ from typing import TYPE_CHECKING, Any, Iterator
8
8
 
9
9
  import httpx
10
10
 
11
11
  from fal import flags
12
12
  from fal.sdk import Credentials, get_default_credentials
13
13
 
14
+ if TYPE_CHECKING:
15
+ from websockets.sync.connection import Connection
16
+
14
17
  _QUEUE_URL_FORMAT = f"https://queue.{flags.FAL_RUN_HOST}/{{app_id}}"
15
18
  _REALTIME_URL_FORMAT = f"wss://{flags.FAL_RUN_HOST}/{{app_id}}"
19
+ _WS_URL_FORMAT = f"wss://ws.{flags.FAL_RUN_HOST}/{{app_id}}"
16
20
 
17
21
 
18
22
  def _backwards_compatible_app_id(app_id: str) -> str:
@@ -173,7 +177,8 @@ def submit(app_id: str, arguments: dict[str, Any], *, path: str = "") -> Request
173
177
  app_id = _backwards_compatible_app_id(app_id)
174
178
  url = _QUEUE_URL_FORMAT.format(app_id=app_id)
175
179
  if path:
176
- url += "/" + path.removeprefix("/")
180
+ _path = path[len("/") :] if path.startswith("/") else path
181
+ url += "/" + _path
177
182
 
178
183
  creds = get_default_credentials()
179
184
 
@@ -235,7 +240,8 @@ def _connect(app_id: str, *, path: str = "/realtime") -> Iterator[_RealtimeConne
235
240
  app_id = _backwards_compatible_app_id(app_id)
236
241
  url = _REALTIME_URL_FORMAT.format(app_id=app_id)
237
242
  if path:
238
- url += "/" + path.removeprefix("/")
243
+ _path = path[len("/") :] if path.startswith("/") else path
244
+ url += "/" + _path
239
245
 
240
246
  creds = get_default_credentials()
241
247
 
@@ -243,3 +249,132 @@ def _connect(app_id: str, *, path: str = "/realtime") -> Iterator[_RealtimeConne
243
249
  url, additional_headers=creds.to_headers(), open_timeout=90
244
250
  ) as ws:
245
251
  yield _RealtimeConnection(ws)
252
+
253
+
254
+ class _MetaMessageFound(Exception): ...
255
+
256
+
257
+ @dataclass
258
+ class _WSConnection:
259
+ """A WS connection to an HTTP Fal app."""
260
+
261
+ _ws: Connection
262
+ _buffer: str | bytes | None = None
263
+
264
+ def run(self, arguments: dict[str, Any]) -> bytes:
265
+ """Run an inference task on the app and return the result."""
266
+ self.send(arguments)
267
+ return self.recv()
268
+
269
+ def send(self, arguments: dict[str, Any]) -> None:
270
+ import json
271
+
272
+ payload = json.dumps(arguments)
273
+ self._ws.send(payload)
274
+
275
+ def _peek(self) -> bytes | str:
276
+ if self._buffer is None:
277
+ self._buffer = self._ws.recv()
278
+
279
+ return self._buffer
280
+
281
+ def _consume(self) -> None:
282
+ if self._buffer is None:
283
+ raise ValueError("No data to consume")
284
+
285
+ self._buffer = None
286
+
287
+ @contextmanager
288
+ def _recv(self) -> Iterator[str | bytes]:
289
+ res = self._peek()
290
+
291
+ yield res
292
+
293
+ # Only consume if it went through the context manager without raising
294
+ self._consume()
295
+
296
+ def _is_meta(self, res: str | bytes) -> bool:
297
+ if not isinstance(res, str):
298
+ return False
299
+
300
+ try:
301
+ json_payload: Any = json.loads(res)
302
+ except json.JSONDecodeError:
303
+ return False
304
+
305
+ if not isinstance(json_payload, dict):
306
+ return False
307
+
308
+ return "type" in json_payload and "request_id" in json_payload
309
+
310
+ def _recv_meta(self, type: str) -> dict[str, Any]:
311
+ with self._recv() as res:
312
+ if not self._is_meta(res):
313
+ raise ValueError(f"Expected a {type} message")
314
+
315
+ json_payload: dict = json.loads(res)
316
+ if json_payload.get("type") != type:
317
+ raise ValueError(f"Expected a {type} message")
318
+
319
+ return json_payload
320
+
321
+ def _recv_response(self) -> Iterator[str | bytes]:
322
+ while True:
323
+ try:
324
+ with self._recv() as res:
325
+ if self._is_meta(res):
326
+ # Raise so we dont consume the message
327
+ raise _MetaMessageFound()
328
+
329
+ yield res
330
+ except _MetaMessageFound:
331
+ break
332
+
333
+ def recv(self) -> bytes:
334
+ start = self._recv_meta("start")
335
+ request_id = start["request_id"]
336
+
337
+ response = b""
338
+ for part in self._recv_response():
339
+ if isinstance(part, str):
340
+ response += part.encode()
341
+ else:
342
+ response += part
343
+
344
+ end = self._recv_meta("end")
345
+ if end["request_id"] != request_id:
346
+ raise ValueError("Mismatched request_id in end message")
347
+
348
+ return response
349
+
350
+ def stream(self) -> Iterator[str | bytes]:
351
+ start = self._recv_meta("start")
352
+ request_id = start["request_id"]
353
+
354
+ yield from self._recv_response()
355
+
356
+ # Make sure we consume the end message
357
+ end = self._recv_meta("end")
358
+ if end["request_id"] != request_id:
359
+ raise ValueError("Mismatched request_id in end message")
360
+
361
+
362
+ @contextmanager
363
+ def ws(app_id: str, *, path: str = "") -> Iterator[_WSConnection]:
364
+ """Connect to a HTTP endpoint but with websocket protocol. This is an internal and
365
+ experimental API, use it at your own risk."""
366
+
367
+ from websockets.sync import client
368
+
369
+ app_id = _backwards_compatible_app_id(app_id)
370
+ url = _WS_URL_FORMAT.format(app_id=app_id)
371
+ if path:
372
+ _path = path[len("/") :] if path.startswith("/") else path
373
+ url += "/" + _path
374
+
375
+ creds = get_default_credentials()
376
+
377
+ with client.connect(
378
+ url, additional_headers=creds.to_headers(), open_timeout=90
379
+ ) as ws:
380
+ yield _WSConnection(ws)
fal/auth/__init__.py CHANGED
@@ -2,22 +2,70 @@ from __future__ import annotations
2
2
 
3
3
  import os
4
4
  from dataclasses import dataclass, field
5
+ from threading import Lock
6
+ from typing import Optional
5
7
 
6
8
  import click
7
9
 
8
10
  from fal.auth import auth0, local
11
+ from fal.config import Config
9
12
  from fal.console import console
10
13
  from fal.console.icons import CHECK_ICON
11
14
  from fal.exceptions.auth import UnauthenticatedException
12
15
 
13
16
 
17
+ class GoogleColabState:
18
+ def __init__(self):
19
+ self.is_checked = False
20
+ self.lock = Lock()
21
+ self.secret: Optional[str] = None
22
+
23
+
24
+ _colab_state = GoogleColabState()
25
+
26
+
27
+ def is_google_colab() -> bool:
28
+ try:
29
+ from IPython import get_ipython
30
+
31
+ return "google.colab" in str(get_ipython())
32
+ except ModuleNotFoundError:
33
+ return False
34
+ except NameError:
35
+ return False
36
+
37
+
38
+ def get_colab_token() -> Optional[str]:
39
+ if not is_google_colab():
40
+ return None
41
+ with _colab_state.lock:
42
+ if _colab_state.is_checked: # request access only once
43
+ return _colab_state.secret
44
+
45
+ try:
46
+ from google.colab import userdata # noqa: I001
47
+ except ImportError:
48
+ return None
49
+
50
+ try:
51
+ token = userdata.get("FAL_KEY")
52
+ _colab_state.secret = token.strip()
53
+ except Exception:
54
+ _colab_state.secret = None
55
+
56
+ _colab_state.is_checked = True
57
+ return _colab_state.secret
58
+
59
+
14
60
  def key_credentials() -> tuple[str, str] | None:
15
61
  # Ignore key credentials when the user forces auth by user.
16
62
  if os.environ.get("FAL_FORCE_AUTH_BY_USER") == "1":
17
63
  return None
18
64
 
19
- if "FAL_KEY" in os.environ:
20
- key = os.environ["FAL_KEY"]
65
+ config = Config()
66
+
67
+ key = os.environ.get("FAL_KEY") or config.get("key") or get_colab_token()
68
+ if key:
21
69
  key_id, key_secret = key.split(":", 1)
22
70
  return (key_id, key_secret)
23
71
  elif "FAL_KEY_ID" in os.environ and "FAL_KEY_SECRET" in os.environ:
fal/cli/_utils.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from fal.files import find_pyproject_toml, parse_pyproject_toml
3
+ from fal.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
4
4
 
5
5
 
6
6
  def is_app_name(app_ref: tuple[str, str | None]) -> bool:
@@ -29,6 +29,12 @@ def get_app_data_from_toml(app_name):
29
29
  except KeyError:
30
30
  raise ValueError(f"App {app_name} does not have a ref key in pyproject.toml")
31
31
 
32
+ # Convert the app_ref to a path relative to the project root
33
+ project_root, _ = find_project_root(None)
34
+ app_ref = str(project_root / app_ref)
35
+
32
36
  app_auth = app_data.get("auth", "private")
37
+ app_deployment_strategy = app_data.get("deployment_strategy", "recreate")
38
+ app_no_scale = app_data.get("no_scale", False)
33
39
 
34
- return app_ref, app_auth
40
+ return app_ref, app_auth, app_deployment_strategy, app_no_scale
fal/cli/apps.py CHANGED
@@ -221,7 +221,7 @@ def _runners(args):
221
221
  str(runner.in_flight_requests),
222
222
  (
223
223
  "N/A (active)"
224
- if not runner.expiration_countdown
224
+ if runner.expiration_countdown is None
225
225
  else f"{runner.expiration_countdown}s"
226
226
  ),
227
227
  f"{runner.uptime} ({runner.uptime.total_seconds()}s)",
fal/cli/deploy.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import argparse
2
2
  from collections import namedtuple
3
3
  from pathlib import Path
4
- from typing import Optional, Union
4
+ from typing import Literal, Optional, Tuple, Union
5
5
 
6
6
  from ._utils import get_app_data_from_toml, is_app_name
7
7
  from .parser import FalClientParser, RefAction
@@ -63,7 +63,12 @@ def _get_user() -> User:
63
63
 
64
64
 
65
65
  def _deploy_from_reference(
66
- app_ref: tuple[Optional[Union[Path, str]], ...], app_name: str, auth: str, args
66
+ app_ref: Tuple[Optional[Union[Path, str]], ...],
67
+ app_name: str,
68
+ args,
69
+ auth: Optional[Literal["public", "shared", "private"]] = None,
70
+ deployment_strategy: Optional[Literal["recreate", "rolling"]] = None,
71
+ no_scale: bool = False,
67
72
  ):
68
73
  from fal.api import FalServerlessError, FalServerlessHost
69
74
  from fal.utils import load_function_from
@@ -93,7 +98,7 @@ def _deploy_from_reference(
93
98
  isolated_function = loaded.function
94
99
  app_name = app_name or loaded.app_name # type: ignore
95
100
  app_auth = auth or loaded.app_auth or "private"
96
- deployment_strategy = args.strategy or "default"
101
+ deployment_strategy = deployment_strategy or "recreate"
97
102
 
98
103
  app_id = host.register(
99
104
  func=isolated_function.func,
@@ -102,6 +107,7 @@ def _deploy_from_reference(
102
107
  application_auth_mode=app_auth,
103
108
  metadata=isolated_function.options.host.get("metadata", {}),
104
109
  deployment_strategy=deployment_strategy,
110
+ scale=not no_scale,
105
111
  )
106
112
 
107
113
  if app_id:
@@ -134,7 +140,9 @@ def _deploy(args):
134
140
  raise ValueError("Cannot use --app-name or --auth with app name reference.")
135
141
 
136
142
  app_name = args.app_ref[0]
137
- app_ref, app_auth = get_app_data_from_toml(app_name)
143
+ app_ref, app_auth, app_deployment_strategy, app_no_scale = (
144
+ get_app_data_from_toml(app_name)
145
+ )
138
146
  file_path, func_name = RefAction.split_ref(app_ref)
139
147
 
140
148
  # path/to/myfile.py::MyApp
@@ -142,8 +150,17 @@ def _deploy(args):
142
150
  file_path, func_name = args.app_ref
143
151
  app_name = args.app_name
144
152
  app_auth = args.auth
145
-
146
- _deploy_from_reference((file_path, func_name), app_name, app_auth, args)
153
+ app_deployment_strategy = args.strategy
154
+ app_no_scale = args.no_scale
155
+
156
+ _deploy_from_reference(
157
+ (file_path, func_name),
158
+ app_name,
159
+ args,
160
+ app_auth,
161
+ app_deployment_strategy,
162
+ app_no_scale,
163
+ )
147
164
 
148
165
 
149
166
  def add_parser(main_subparsers, parents):
@@ -204,9 +221,18 @@ def add_parser(main_subparsers, parents):
204
221
  )
205
222
  parser.add_argument(
206
223
  "--strategy",
207
- choices=["default", "rolling"],
224
+ choices=["recreate", "rolling"],
208
225
  help="Deployment strategy.",
209
- default="default",
226
+ default="recreate",
227
+ )
228
+ parser.add_argument(
229
+ "--no-scale",
230
+ action="store_true",
231
+ help=(
232
+ "Use min_concurrency/max_concurrency/max_multiplexing from previous "
233
+ "deployment of application with this name, if exists. Otherwise will "
234
+ "use the values from the application code."
235
+ ),
210
236
  )
211
237
 
212
238
  parser.set_defaults(func=_deploy)
fal/cli/main.py CHANGED
@@ -6,7 +6,7 @@ from fal import __version__
6
6
  from fal.console import console
7
7
  from fal.console.icons import CROSS_ICON
8
8
 
9
- from . import apps, auth, create, deploy, doctor, keys, run, secrets
9
+ from . import apps, auth, create, deploy, doctor, keys, run, runners, secrets
10
10
  from .debug import debugtools, get_debug_parser
11
11
  from .parser import FalParser, FalParserExit
12
12
 
@@ -31,7 +31,7 @@ def _get_main_parser() -> argparse.ArgumentParser:
31
31
  required=True,
32
32
  )
33
33
 
34
- for cmd in [auth, apps, deploy, run, keys, secrets, doctor, create]:
34
+ for cmd in [auth, apps, deploy, run, keys, secrets, doctor, create, runners]:
35
35
  cmd.add_parser(subparsers, parents)
36
36
 
37
37
  return parser
fal/cli/run.py CHANGED
@@ -10,7 +10,7 @@ def _run(args):
10
10
 
11
11
  if is_app_name(args.func_ref):
12
12
  app_name = args.func_ref[0]
13
- app_ref, _ = get_app_data_from_toml(app_name)
13
+ app_ref, *_ = get_app_data_from_toml(app_name)
14
14
  file_path, func_name = RefAction.split_ref(app_ref)
15
15
  else:
16
16
  file_path, func_name = args.func_ref
fal/cli/runners.py ADDED
@@ -0,0 +1,44 @@
1
+ from .parser import FalClientParser
2
+
3
+
4
+ def _kill(args):
5
+ from fal.sdk import FalServerlessClient
6
+
7
+ client = FalServerlessClient(args.host)
8
+ with client.connect() as connection:
9
+ connection.kill_runner(args.id)
10
+
11
+
12
+ def _add_kill_parser(subparsers, parents):
13
+ kill_help = "Kill a runner."
14
+ parser = subparsers.add_parser(
15
+ "kill",
16
+ description=kill_help,
17
+ help=kill_help,
18
+ parents=parents,
19
+ )
20
+ parser.add_argument(
21
+ "id",
22
+ help="Runner ID.",
23
+ )
24
+ parser.set_defaults(func=_kill)
25
+
26
+
27
+ def add_parser(main_subparsers, parents):
28
+ runners_help = "Manage fal runners."
29
+ parser = main_subparsers.add_parser(
30
+ "runners",
31
+ description=runners_help,
32
+ help=runners_help,
33
+ parents=parents,
34
+ aliases=["machine"], # backwards compatibility
35
+ )
36
+
37
+ subparsers = parser.add_subparsers(
38
+ title="Commands",
39
+ metavar="command",
40
+ required=True,
41
+ parser_class=FalClientParser,
42
+ )
43
+
44
+ _add_kill_parser(subparsers, parents)
fal/config.py ADDED
@@ -0,0 +1,23 @@
1
+ import os
2
+
3
+ import tomli
4
+
5
+
6
+ class Config:
7
+ DEFAULT_CONFIG_PATH = "~/.fal/config.toml"
8
+ DEFAULT_PROFILE = "default"
9
+
10
+ def __init__(self):
11
+ self.config_path = os.path.expanduser(
12
+ os.getenv("FAL_CONFIG_PATH", self.DEFAULT_CONFIG_PATH)
13
+ )
14
+ self.profile = os.getenv("FAL_PROFILE", self.DEFAULT_PROFILE)
15
+
16
+ try:
17
+ with open(self.config_path, "rb") as file:
18
+ self.config = tomli.load(file)
19
+ except FileNotFoundError:
20
+ self.config = {}
21
+
22
+ def get(self, key):
23
+ return self.config.get(self.profile, {}).get(key)
fal/container.py CHANGED
@@ -3,7 +3,7 @@ class ContainerImage:
3
3
  from a Dockerfile.
4
4
  """
5
5
 
6
- _known_keys = {"dockerfile_str", "build_env", "build_args"}
6
+ _known_keys = {"dockerfile_str", "build_args", "registries", "builder"}
7
7
 
8
8
  @classmethod
9
9
  def from_dockerfile_str(cls, text: str, **kwargs):
fal/sdk.py CHANGED
@@ -5,7 +5,7 @@ from contextlib import ExitStack
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime, timedelta
7
7
  from enum import Enum
8
- from typing import Any, Callable, Generic, Iterator, Literal, TypeVar
8
+ from typing import Any, Callable, Generic, Iterator, Literal, Optional, TypeVar
9
9
 
10
10
  import grpc
11
11
  import isolate_proto
@@ -214,7 +214,7 @@ class AliasInfo:
214
214
  class RunnerInfo:
215
215
  runner_id: str
216
216
  in_flight_requests: int
217
- expiration_countdown: int
217
+ expiration_countdown: Optional[int]
218
218
  uptime: timedelta
219
219
 
220
220
 
@@ -344,7 +344,9 @@ def _from_grpc_runner_info(message: isolate_proto.RunnerInfo) -> RunnerInfo:
344
344
  return RunnerInfo(
345
345
  runner_id=message.runner_id,
346
346
  in_flight_requests=message.in_flight_requests,
347
- expiration_countdown=message.expiration_countdown,
347
+ expiration_countdown=message.expiration_countdown
348
+ if message.HasField("expiration_countdown")
349
+ else None,
348
350
  uptime=timedelta(seconds=message.uptime),
349
351
  )
350
352
 
@@ -389,7 +391,8 @@ def _from_grpc_hosted_run_result(
389
391
 
390
392
  @dataclass
391
393
  class MachineRequirements:
392
- machine_type: str
394
+ machine_types: list[str]
395
+ num_gpus: int | None = field(default=None)
393
396
  keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE
394
397
  base_image: str | None = None
395
398
  exposed_port: int | None = None
@@ -398,6 +401,17 @@ class MachineRequirements:
398
401
  max_concurrency: int | None = None
399
402
  max_multiplexing: int | None = None
400
403
  min_concurrency: int | None = None
404
+ request_timeout: int | None = None
405
+
406
+ def __post_init__(self):
407
+ if isinstance(self.machine_types, str):
408
+ self.machine_types = [self.machine_types]
409
+
410
+ if not isinstance(self.machine_types, list):
411
+ raise ValueError("machine_types must be a list of strings.")
412
+
413
+ if not self.machine_types:
414
+ raise ValueError("No machine type provided.")
401
415
 
402
416
 
403
417
  @dataclass
@@ -485,11 +499,15 @@ class FalServerlessConnection:
485
499
  machine_requirements: MachineRequirements | None = None,
486
500
  metadata: dict[str, Any] | None = None,
487
501
  deployment_strategy: Literal["recreate", "rolling"] = "recreate",
502
+ scale: bool = True,
488
503
  ) -> Iterator[isolate_proto.RegisterApplicationResult]:
489
504
  wrapped_function = to_serialized_object(function, serialization_method)
490
505
  if machine_requirements:
491
506
  wrapped_requirements = isolate_proto.MachineRequirements(
492
- machine_type=machine_requirements.machine_type,
507
+ # NOTE: backwards compatibility with old API
508
+ machine_type=machine_requirements.machine_types[0],
509
+ machine_types=machine_requirements.machine_types,
510
+ num_gpus=machine_requirements.num_gpus,
493
511
  keep_alive=machine_requirements.keep_alive,
494
512
  base_image=machine_requirements.base_image,
495
513
  exposed_port=machine_requirements.exposed_port,
@@ -500,6 +518,7 @@ class FalServerlessConnection:
500
518
  max_concurrency=machine_requirements.max_concurrency,
501
519
  min_concurrency=machine_requirements.min_concurrency,
502
520
  max_multiplexing=machine_requirements.max_multiplexing,
521
+ request_timeout=machine_requirements.request_timeout,
503
522
  )
504
523
  else:
505
524
  wrapped_requirements = None
@@ -516,9 +535,6 @@ class FalServerlessConnection:
516
535
  struct_metadata = isolate_proto.Struct()
517
536
  struct_metadata.update(metadata)
518
537
 
519
- if deployment_strategy == "default":
520
- deployment_strategy = "recreate"
521
-
522
538
  deployment_strategy_proto = DeploymentStrategy[
523
539
  deployment_strategy.upper()
524
540
  ].to_proto()
@@ -531,6 +547,7 @@ class FalServerlessConnection:
531
547
  auth_mode=auth_mode,
532
548
  metadata=struct_metadata,
533
549
  deployment_strategy=deployment_strategy_proto,
550
+ scale=scale,
534
551
  )
535
552
  for partial_result in self.stub.RegisterApplication(request):
536
553
  yield from_grpc(partial_result)
@@ -582,7 +599,10 @@ class FalServerlessConnection:
582
599
  wrapped_function = to_serialized_object(function, serialization_method)
583
600
  if machine_requirements:
584
601
  wrapped_requirements = isolate_proto.MachineRequirements(
585
- machine_type=machine_requirements.machine_type,
602
+ # NOTE: backwards compatibility with old API
603
+ machine_type=machine_requirements.machine_types[0],
604
+ machine_types=machine_requirements.machine_types,
605
+ num_gpus=machine_requirements.num_gpus,
586
606
  keep_alive=machine_requirements.keep_alive,
587
607
  base_image=machine_requirements.base_image,
588
608
  exposed_port=machine_requirements.exposed_port,
@@ -593,6 +613,7 @@ class FalServerlessConnection:
593
613
  max_concurrency=machine_requirements.max_concurrency,
594
614
  max_multiplexing=machine_requirements.max_multiplexing,
595
615
  min_concurrency=machine_requirements.min_concurrency,
616
+ request_timeout=machine_requirements.request_timeout,
596
617
  )
597
618
  else:
598
619
  wrapped_requirements = None
@@ -665,3 +686,7 @@ class FalServerlessConnection:
665
686
  )
666
687
  for secret in response.secrets
667
688
  ]
689
+
690
+ def kill_runner(self, runner_id: str) -> None:
691
+ request = isolate_proto.KillRunnerRequest(runner_id=runner_id)
692
+ self.stub.KillRunner(request)