fal 1.3.3__py3-none-any.whl → 1.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fal might be problematic. Click here for more details.
- fal/_fal_version.py +2 -2
- fal/api.py +46 -14
- fal/app.py +157 -17
- fal/apps.py +138 -3
- fal/auth/__init__.py +50 -2
- fal/cli/_utils.py +8 -2
- fal/cli/apps.py +1 -1
- fal/cli/deploy.py +34 -8
- fal/cli/main.py +2 -2
- fal/cli/run.py +1 -1
- fal/cli/runners.py +44 -0
- fal/config.py +23 -0
- fal/container.py +1 -1
- fal/sdk.py +34 -9
- fal/toolkit/file/file.py +92 -19
- fal/toolkit/file/providers/fal.py +571 -83
- fal/toolkit/file/providers/gcp.py +8 -1
- fal/toolkit/file/providers/r2.py +8 -1
- fal/toolkit/file/providers/s3.py +80 -0
- fal/toolkit/file/types.py +11 -4
- fal/toolkit/image/__init__.py +3 -3
- fal/toolkit/image/image.py +25 -2
- fal/toolkit/types.py +140 -0
- fal/toolkit/utils/download_utils.py +4 -0
- fal/toolkit/utils/retry.py +45 -0
- fal/workflows.py +10 -4
- {fal-1.3.3.dist-info → fal-1.7.3.dist-info}/METADATA +14 -9
- {fal-1.3.3.dist-info → fal-1.7.3.dist-info}/RECORD +31 -26
- {fal-1.3.3.dist-info → fal-1.7.3.dist-info}/WHEEL +1 -1
- {fal-1.3.3.dist-info → fal-1.7.3.dist-info}/entry_points.txt +0 -0
- {fal-1.3.3.dist-info → fal-1.7.3.dist-info}/top_level.txt +0 -0
fal/apps.py
CHANGED
|
@@ -4,15 +4,19 @@ import json
|
|
|
4
4
|
import time
|
|
5
5
|
from contextlib import contextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
|
-
from typing import Any, Iterator
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Iterator
|
|
8
8
|
|
|
9
9
|
import httpx
|
|
10
10
|
|
|
11
11
|
from fal import flags
|
|
12
12
|
from fal.sdk import Credentials, get_default_credentials
|
|
13
13
|
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from websockets.sync.connection import Connection
|
|
16
|
+
|
|
14
17
|
_QUEUE_URL_FORMAT = f"https://queue.{flags.FAL_RUN_HOST}/{{app_id}}"
|
|
15
18
|
_REALTIME_URL_FORMAT = f"wss://{flags.FAL_RUN_HOST}/{{app_id}}"
|
|
19
|
+
_WS_URL_FORMAT = f"wss://ws.{flags.FAL_RUN_HOST}/{{app_id}}"
|
|
16
20
|
|
|
17
21
|
|
|
18
22
|
def _backwards_compatible_app_id(app_id: str) -> str:
|
|
@@ -173,7 +177,8 @@ def submit(app_id: str, arguments: dict[str, Any], *, path: str = "") -> Request
|
|
|
173
177
|
app_id = _backwards_compatible_app_id(app_id)
|
|
174
178
|
url = _QUEUE_URL_FORMAT.format(app_id=app_id)
|
|
175
179
|
if path:
|
|
176
|
-
|
|
180
|
+
_path = path[len("/") :] if path.startswith("/") else path
|
|
181
|
+
url += "/" + _path
|
|
177
182
|
|
|
178
183
|
creds = get_default_credentials()
|
|
179
184
|
|
|
@@ -235,7 +240,8 @@ def _connect(app_id: str, *, path: str = "/realtime") -> Iterator[_RealtimeConne
|
|
|
235
240
|
app_id = _backwards_compatible_app_id(app_id)
|
|
236
241
|
url = _REALTIME_URL_FORMAT.format(app_id=app_id)
|
|
237
242
|
if path:
|
|
238
|
-
|
|
243
|
+
_path = path[len("/") :] if path.startswith("/") else path
|
|
244
|
+
url += "/" + _path
|
|
239
245
|
|
|
240
246
|
creds = get_default_credentials()
|
|
241
247
|
|
|
@@ -243,3 +249,132 @@ def _connect(app_id: str, *, path: str = "/realtime") -> Iterator[_RealtimeConne
|
|
|
243
249
|
url, additional_headers=creds.to_headers(), open_timeout=90
|
|
244
250
|
) as ws:
|
|
245
251
|
yield _RealtimeConnection(ws)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class _MetaMessageFound(Exception): ...
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@dataclass
|
|
258
|
+
class _WSConnection:
|
|
259
|
+
"""A WS connection to an HTTP Fal app."""
|
|
260
|
+
|
|
261
|
+
_ws: Connection
|
|
262
|
+
_buffer: str | bytes | None = None
|
|
263
|
+
|
|
264
|
+
def run(self, arguments: dict[str, Any]) -> bytes:
|
|
265
|
+
"""Run an inference task on the app and return the result."""
|
|
266
|
+
self.send(arguments)
|
|
267
|
+
return self.recv()
|
|
268
|
+
|
|
269
|
+
def send(self, arguments: dict[str, Any]) -> None:
|
|
270
|
+
import json
|
|
271
|
+
|
|
272
|
+
payload = json.dumps(arguments)
|
|
273
|
+
self._ws.send(payload)
|
|
274
|
+
|
|
275
|
+
def _peek(self) -> bytes | str:
|
|
276
|
+
if self._buffer is None:
|
|
277
|
+
self._buffer = self._ws.recv()
|
|
278
|
+
|
|
279
|
+
return self._buffer
|
|
280
|
+
|
|
281
|
+
def _consume(self) -> None:
|
|
282
|
+
if self._buffer is None:
|
|
283
|
+
raise ValueError("No data to consume")
|
|
284
|
+
|
|
285
|
+
self._buffer = None
|
|
286
|
+
|
|
287
|
+
@contextmanager
|
|
288
|
+
def _recv(self) -> Iterator[str | bytes]:
|
|
289
|
+
res = self._peek()
|
|
290
|
+
|
|
291
|
+
yield res
|
|
292
|
+
|
|
293
|
+
# Only consume if it went through the context manager without raising
|
|
294
|
+
self._consume()
|
|
295
|
+
|
|
296
|
+
def _is_meta(self, res: str | bytes) -> bool:
|
|
297
|
+
if not isinstance(res, str):
|
|
298
|
+
return False
|
|
299
|
+
|
|
300
|
+
try:
|
|
301
|
+
json_payload: Any = json.loads(res)
|
|
302
|
+
except json.JSONDecodeError:
|
|
303
|
+
return False
|
|
304
|
+
|
|
305
|
+
if not isinstance(json_payload, dict):
|
|
306
|
+
return False
|
|
307
|
+
|
|
308
|
+
return "type" in json_payload and "request_id" in json_payload
|
|
309
|
+
|
|
310
|
+
def _recv_meta(self, type: str) -> dict[str, Any]:
|
|
311
|
+
with self._recv() as res:
|
|
312
|
+
if not self._is_meta(res):
|
|
313
|
+
raise ValueError(f"Expected a {type} message")
|
|
314
|
+
|
|
315
|
+
json_payload: dict = json.loads(res)
|
|
316
|
+
if json_payload.get("type") != type:
|
|
317
|
+
raise ValueError(f"Expected a {type} message")
|
|
318
|
+
|
|
319
|
+
return json_payload
|
|
320
|
+
|
|
321
|
+
def _recv_response(self) -> Iterator[str | bytes]:
|
|
322
|
+
while True:
|
|
323
|
+
try:
|
|
324
|
+
with self._recv() as res:
|
|
325
|
+
if self._is_meta(res):
|
|
326
|
+
# Raise so we dont consume the message
|
|
327
|
+
raise _MetaMessageFound()
|
|
328
|
+
|
|
329
|
+
yield res
|
|
330
|
+
except _MetaMessageFound:
|
|
331
|
+
break
|
|
332
|
+
|
|
333
|
+
def recv(self) -> bytes:
|
|
334
|
+
start = self._recv_meta("start")
|
|
335
|
+
request_id = start["request_id"]
|
|
336
|
+
|
|
337
|
+
response = b""
|
|
338
|
+
for part in self._recv_response():
|
|
339
|
+
if isinstance(part, str):
|
|
340
|
+
response += part.encode()
|
|
341
|
+
else:
|
|
342
|
+
response += part
|
|
343
|
+
|
|
344
|
+
end = self._recv_meta("end")
|
|
345
|
+
if end["request_id"] != request_id:
|
|
346
|
+
raise ValueError("Mismatched request_id in end message")
|
|
347
|
+
|
|
348
|
+
return response
|
|
349
|
+
|
|
350
|
+
def stream(self) -> Iterator[str | bytes]:
|
|
351
|
+
start = self._recv_meta("start")
|
|
352
|
+
request_id = start["request_id"]
|
|
353
|
+
|
|
354
|
+
yield from self._recv_response()
|
|
355
|
+
|
|
356
|
+
# Make sure we consume the end message
|
|
357
|
+
end = self._recv_meta("end")
|
|
358
|
+
if end["request_id"] != request_id:
|
|
359
|
+
raise ValueError("Mismatched request_id in end message")
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
@contextmanager
|
|
363
|
+
def ws(app_id: str, *, path: str = "") -> Iterator[_WSConnection]:
|
|
364
|
+
"""Connect to a HTTP endpoint but with websocket protocol. This is an internal and
|
|
365
|
+
experimental API, use it at your own risk."""
|
|
366
|
+
|
|
367
|
+
from websockets.sync import client
|
|
368
|
+
|
|
369
|
+
app_id = _backwards_compatible_app_id(app_id)
|
|
370
|
+
url = _WS_URL_FORMAT.format(app_id=app_id)
|
|
371
|
+
if path:
|
|
372
|
+
_path = path[len("/") :] if path.startswith("/") else path
|
|
373
|
+
url += "/" + _path
|
|
374
|
+
|
|
375
|
+
creds = get_default_credentials()
|
|
376
|
+
|
|
377
|
+
with client.connect(
|
|
378
|
+
url, additional_headers=creds.to_headers(), open_timeout=90
|
|
379
|
+
) as ws:
|
|
380
|
+
yield _WSConnection(ws)
|
fal/auth/__init__.py
CHANGED
|
@@ -2,22 +2,70 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
|
+
from threading import Lock
|
|
6
|
+
from typing import Optional
|
|
5
7
|
|
|
6
8
|
import click
|
|
7
9
|
|
|
8
10
|
from fal.auth import auth0, local
|
|
11
|
+
from fal.config import Config
|
|
9
12
|
from fal.console import console
|
|
10
13
|
from fal.console.icons import CHECK_ICON
|
|
11
14
|
from fal.exceptions.auth import UnauthenticatedException
|
|
12
15
|
|
|
13
16
|
|
|
17
|
+
class GoogleColabState:
|
|
18
|
+
def __init__(self):
|
|
19
|
+
self.is_checked = False
|
|
20
|
+
self.lock = Lock()
|
|
21
|
+
self.secret: Optional[str] = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
_colab_state = GoogleColabState()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def is_google_colab() -> bool:
|
|
28
|
+
try:
|
|
29
|
+
from IPython import get_ipython
|
|
30
|
+
|
|
31
|
+
return "google.colab" in str(get_ipython())
|
|
32
|
+
except ModuleNotFoundError:
|
|
33
|
+
return False
|
|
34
|
+
except NameError:
|
|
35
|
+
return False
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_colab_token() -> Optional[str]:
|
|
39
|
+
if not is_google_colab():
|
|
40
|
+
return None
|
|
41
|
+
with _colab_state.lock:
|
|
42
|
+
if _colab_state.is_checked: # request access only once
|
|
43
|
+
return _colab_state.secret
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
from google.colab import userdata # noqa: I001
|
|
47
|
+
except ImportError:
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
token = userdata.get("FAL_KEY")
|
|
52
|
+
_colab_state.secret = token.strip()
|
|
53
|
+
except Exception:
|
|
54
|
+
_colab_state.secret = None
|
|
55
|
+
|
|
56
|
+
_colab_state.is_checked = True
|
|
57
|
+
return _colab_state.secret
|
|
58
|
+
|
|
59
|
+
|
|
14
60
|
def key_credentials() -> tuple[str, str] | None:
|
|
15
61
|
# Ignore key credentials when the user forces auth by user.
|
|
16
62
|
if os.environ.get("FAL_FORCE_AUTH_BY_USER") == "1":
|
|
17
63
|
return None
|
|
18
64
|
|
|
19
|
-
|
|
20
|
-
|
|
65
|
+
config = Config()
|
|
66
|
+
|
|
67
|
+
key = os.environ.get("FAL_KEY") or config.get("key") or get_colab_token()
|
|
68
|
+
if key:
|
|
21
69
|
key_id, key_secret = key.split(":", 1)
|
|
22
70
|
return (key_id, key_secret)
|
|
23
71
|
elif "FAL_KEY_ID" in os.environ and "FAL_KEY_SECRET" in os.environ:
|
fal/cli/_utils.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from fal.files import find_pyproject_toml, parse_pyproject_toml
|
|
3
|
+
from fal.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def is_app_name(app_ref: tuple[str, str | None]) -> bool:
|
|
@@ -29,6 +29,12 @@ def get_app_data_from_toml(app_name):
|
|
|
29
29
|
except KeyError:
|
|
30
30
|
raise ValueError(f"App {app_name} does not have a ref key in pyproject.toml")
|
|
31
31
|
|
|
32
|
+
# Convert the app_ref to a path relative to the project root
|
|
33
|
+
project_root, _ = find_project_root(None)
|
|
34
|
+
app_ref = str(project_root / app_ref)
|
|
35
|
+
|
|
32
36
|
app_auth = app_data.get("auth", "private")
|
|
37
|
+
app_deployment_strategy = app_data.get("deployment_strategy", "recreate")
|
|
38
|
+
app_no_scale = app_data.get("no_scale", False)
|
|
33
39
|
|
|
34
|
-
return app_ref, app_auth
|
|
40
|
+
return app_ref, app_auth, app_deployment_strategy, app_no_scale
|
fal/cli/apps.py
CHANGED
|
@@ -221,7 +221,7 @@ def _runners(args):
|
|
|
221
221
|
str(runner.in_flight_requests),
|
|
222
222
|
(
|
|
223
223
|
"N/A (active)"
|
|
224
|
-
if
|
|
224
|
+
if runner.expiration_countdown is None
|
|
225
225
|
else f"{runner.expiration_countdown}s"
|
|
226
226
|
),
|
|
227
227
|
f"{runner.uptime} ({runner.uptime.total_seconds()}s)",
|
fal/cli/deploy.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
from collections import namedtuple
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Optional, Union
|
|
4
|
+
from typing import Literal, Optional, Tuple, Union
|
|
5
5
|
|
|
6
6
|
from ._utils import get_app_data_from_toml, is_app_name
|
|
7
7
|
from .parser import FalClientParser, RefAction
|
|
@@ -63,7 +63,12 @@ def _get_user() -> User:
|
|
|
63
63
|
|
|
64
64
|
|
|
65
65
|
def _deploy_from_reference(
|
|
66
|
-
app_ref:
|
|
66
|
+
app_ref: Tuple[Optional[Union[Path, str]], ...],
|
|
67
|
+
app_name: str,
|
|
68
|
+
args,
|
|
69
|
+
auth: Optional[Literal["public", "shared", "private"]] = None,
|
|
70
|
+
deployment_strategy: Optional[Literal["recreate", "rolling"]] = None,
|
|
71
|
+
no_scale: bool = False,
|
|
67
72
|
):
|
|
68
73
|
from fal.api import FalServerlessError, FalServerlessHost
|
|
69
74
|
from fal.utils import load_function_from
|
|
@@ -93,7 +98,7 @@ def _deploy_from_reference(
|
|
|
93
98
|
isolated_function = loaded.function
|
|
94
99
|
app_name = app_name or loaded.app_name # type: ignore
|
|
95
100
|
app_auth = auth or loaded.app_auth or "private"
|
|
96
|
-
deployment_strategy =
|
|
101
|
+
deployment_strategy = deployment_strategy or "recreate"
|
|
97
102
|
|
|
98
103
|
app_id = host.register(
|
|
99
104
|
func=isolated_function.func,
|
|
@@ -102,6 +107,7 @@ def _deploy_from_reference(
|
|
|
102
107
|
application_auth_mode=app_auth,
|
|
103
108
|
metadata=isolated_function.options.host.get("metadata", {}),
|
|
104
109
|
deployment_strategy=deployment_strategy,
|
|
110
|
+
scale=not no_scale,
|
|
105
111
|
)
|
|
106
112
|
|
|
107
113
|
if app_id:
|
|
@@ -134,7 +140,9 @@ def _deploy(args):
|
|
|
134
140
|
raise ValueError("Cannot use --app-name or --auth with app name reference.")
|
|
135
141
|
|
|
136
142
|
app_name = args.app_ref[0]
|
|
137
|
-
app_ref, app_auth =
|
|
143
|
+
app_ref, app_auth, app_deployment_strategy, app_no_scale = (
|
|
144
|
+
get_app_data_from_toml(app_name)
|
|
145
|
+
)
|
|
138
146
|
file_path, func_name = RefAction.split_ref(app_ref)
|
|
139
147
|
|
|
140
148
|
# path/to/myfile.py::MyApp
|
|
@@ -142,8 +150,17 @@ def _deploy(args):
|
|
|
142
150
|
file_path, func_name = args.app_ref
|
|
143
151
|
app_name = args.app_name
|
|
144
152
|
app_auth = args.auth
|
|
145
|
-
|
|
146
|
-
|
|
153
|
+
app_deployment_strategy = args.strategy
|
|
154
|
+
app_no_scale = args.no_scale
|
|
155
|
+
|
|
156
|
+
_deploy_from_reference(
|
|
157
|
+
(file_path, func_name),
|
|
158
|
+
app_name,
|
|
159
|
+
args,
|
|
160
|
+
app_auth,
|
|
161
|
+
app_deployment_strategy,
|
|
162
|
+
app_no_scale,
|
|
163
|
+
)
|
|
147
164
|
|
|
148
165
|
|
|
149
166
|
def add_parser(main_subparsers, parents):
|
|
@@ -204,9 +221,18 @@ def add_parser(main_subparsers, parents):
|
|
|
204
221
|
)
|
|
205
222
|
parser.add_argument(
|
|
206
223
|
"--strategy",
|
|
207
|
-
choices=["
|
|
224
|
+
choices=["recreate", "rolling"],
|
|
208
225
|
help="Deployment strategy.",
|
|
209
|
-
default="
|
|
226
|
+
default="recreate",
|
|
227
|
+
)
|
|
228
|
+
parser.add_argument(
|
|
229
|
+
"--no-scale",
|
|
230
|
+
action="store_true",
|
|
231
|
+
help=(
|
|
232
|
+
"Use min_concurrency/max_concurrency/max_multiplexing from previous "
|
|
233
|
+
"deployment of application with this name, if exists. Otherwise will "
|
|
234
|
+
"use the values from the application code."
|
|
235
|
+
),
|
|
210
236
|
)
|
|
211
237
|
|
|
212
238
|
parser.set_defaults(func=_deploy)
|
fal/cli/main.py
CHANGED
|
@@ -6,7 +6,7 @@ from fal import __version__
|
|
|
6
6
|
from fal.console import console
|
|
7
7
|
from fal.console.icons import CROSS_ICON
|
|
8
8
|
|
|
9
|
-
from . import apps, auth, create, deploy, doctor, keys, run, secrets
|
|
9
|
+
from . import apps, auth, create, deploy, doctor, keys, run, runners, secrets
|
|
10
10
|
from .debug import debugtools, get_debug_parser
|
|
11
11
|
from .parser import FalParser, FalParserExit
|
|
12
12
|
|
|
@@ -31,7 +31,7 @@ def _get_main_parser() -> argparse.ArgumentParser:
|
|
|
31
31
|
required=True,
|
|
32
32
|
)
|
|
33
33
|
|
|
34
|
-
for cmd in [auth, apps, deploy, run, keys, secrets, doctor, create]:
|
|
34
|
+
for cmd in [auth, apps, deploy, run, keys, secrets, doctor, create, runners]:
|
|
35
35
|
cmd.add_parser(subparsers, parents)
|
|
36
36
|
|
|
37
37
|
return parser
|
fal/cli/run.py
CHANGED
|
@@ -10,7 +10,7 @@ def _run(args):
|
|
|
10
10
|
|
|
11
11
|
if is_app_name(args.func_ref):
|
|
12
12
|
app_name = args.func_ref[0]
|
|
13
|
-
app_ref, _ = get_app_data_from_toml(app_name)
|
|
13
|
+
app_ref, *_ = get_app_data_from_toml(app_name)
|
|
14
14
|
file_path, func_name = RefAction.split_ref(app_ref)
|
|
15
15
|
else:
|
|
16
16
|
file_path, func_name = args.func_ref
|
fal/cli/runners.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from .parser import FalClientParser
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def _kill(args):
|
|
5
|
+
from fal.sdk import FalServerlessClient
|
|
6
|
+
|
|
7
|
+
client = FalServerlessClient(args.host)
|
|
8
|
+
with client.connect() as connection:
|
|
9
|
+
connection.kill_runner(args.id)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _add_kill_parser(subparsers, parents):
|
|
13
|
+
kill_help = "Kill a runner."
|
|
14
|
+
parser = subparsers.add_parser(
|
|
15
|
+
"kill",
|
|
16
|
+
description=kill_help,
|
|
17
|
+
help=kill_help,
|
|
18
|
+
parents=parents,
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"id",
|
|
22
|
+
help="Runner ID.",
|
|
23
|
+
)
|
|
24
|
+
parser.set_defaults(func=_kill)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def add_parser(main_subparsers, parents):
|
|
28
|
+
runners_help = "Manage fal runners."
|
|
29
|
+
parser = main_subparsers.add_parser(
|
|
30
|
+
"runners",
|
|
31
|
+
description=runners_help,
|
|
32
|
+
help=runners_help,
|
|
33
|
+
parents=parents,
|
|
34
|
+
aliases=["machine"], # backwards compatibility
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
subparsers = parser.add_subparsers(
|
|
38
|
+
title="Commands",
|
|
39
|
+
metavar="command",
|
|
40
|
+
required=True,
|
|
41
|
+
parser_class=FalClientParser,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
_add_kill_parser(subparsers, parents)
|
fal/config.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import tomli
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Config:
|
|
7
|
+
DEFAULT_CONFIG_PATH = "~/.fal/config.toml"
|
|
8
|
+
DEFAULT_PROFILE = "default"
|
|
9
|
+
|
|
10
|
+
def __init__(self):
|
|
11
|
+
self.config_path = os.path.expanduser(
|
|
12
|
+
os.getenv("FAL_CONFIG_PATH", self.DEFAULT_CONFIG_PATH)
|
|
13
|
+
)
|
|
14
|
+
self.profile = os.getenv("FAL_PROFILE", self.DEFAULT_PROFILE)
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
with open(self.config_path, "rb") as file:
|
|
18
|
+
self.config = tomli.load(file)
|
|
19
|
+
except FileNotFoundError:
|
|
20
|
+
self.config = {}
|
|
21
|
+
|
|
22
|
+
def get(self, key):
|
|
23
|
+
return self.config.get(self.profile, {}).get(key)
|
fal/container.py
CHANGED
fal/sdk.py
CHANGED
|
@@ -5,7 +5,7 @@ from contextlib import ExitStack
|
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timedelta
|
|
7
7
|
from enum import Enum
|
|
8
|
-
from typing import Any, Callable, Generic, Iterator, Literal, TypeVar
|
|
8
|
+
from typing import Any, Callable, Generic, Iterator, Literal, Optional, TypeVar
|
|
9
9
|
|
|
10
10
|
import grpc
|
|
11
11
|
import isolate_proto
|
|
@@ -214,7 +214,7 @@ class AliasInfo:
|
|
|
214
214
|
class RunnerInfo:
|
|
215
215
|
runner_id: str
|
|
216
216
|
in_flight_requests: int
|
|
217
|
-
expiration_countdown: int
|
|
217
|
+
expiration_countdown: Optional[int]
|
|
218
218
|
uptime: timedelta
|
|
219
219
|
|
|
220
220
|
|
|
@@ -344,7 +344,9 @@ def _from_grpc_runner_info(message: isolate_proto.RunnerInfo) -> RunnerInfo:
|
|
|
344
344
|
return RunnerInfo(
|
|
345
345
|
runner_id=message.runner_id,
|
|
346
346
|
in_flight_requests=message.in_flight_requests,
|
|
347
|
-
expiration_countdown=message.expiration_countdown
|
|
347
|
+
expiration_countdown=message.expiration_countdown
|
|
348
|
+
if message.HasField("expiration_countdown")
|
|
349
|
+
else None,
|
|
348
350
|
uptime=timedelta(seconds=message.uptime),
|
|
349
351
|
)
|
|
350
352
|
|
|
@@ -389,7 +391,8 @@ def _from_grpc_hosted_run_result(
|
|
|
389
391
|
|
|
390
392
|
@dataclass
|
|
391
393
|
class MachineRequirements:
|
|
392
|
-
|
|
394
|
+
machine_types: list[str]
|
|
395
|
+
num_gpus: int | None = field(default=None)
|
|
393
396
|
keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE
|
|
394
397
|
base_image: str | None = None
|
|
395
398
|
exposed_port: int | None = None
|
|
@@ -398,6 +401,17 @@ class MachineRequirements:
|
|
|
398
401
|
max_concurrency: int | None = None
|
|
399
402
|
max_multiplexing: int | None = None
|
|
400
403
|
min_concurrency: int | None = None
|
|
404
|
+
request_timeout: int | None = None
|
|
405
|
+
|
|
406
|
+
def __post_init__(self):
|
|
407
|
+
if isinstance(self.machine_types, str):
|
|
408
|
+
self.machine_types = [self.machine_types]
|
|
409
|
+
|
|
410
|
+
if not isinstance(self.machine_types, list):
|
|
411
|
+
raise ValueError("machine_types must be a list of strings.")
|
|
412
|
+
|
|
413
|
+
if not self.machine_types:
|
|
414
|
+
raise ValueError("No machine type provided.")
|
|
401
415
|
|
|
402
416
|
|
|
403
417
|
@dataclass
|
|
@@ -485,11 +499,15 @@ class FalServerlessConnection:
|
|
|
485
499
|
machine_requirements: MachineRequirements | None = None,
|
|
486
500
|
metadata: dict[str, Any] | None = None,
|
|
487
501
|
deployment_strategy: Literal["recreate", "rolling"] = "recreate",
|
|
502
|
+
scale: bool = True,
|
|
488
503
|
) -> Iterator[isolate_proto.RegisterApplicationResult]:
|
|
489
504
|
wrapped_function = to_serialized_object(function, serialization_method)
|
|
490
505
|
if machine_requirements:
|
|
491
506
|
wrapped_requirements = isolate_proto.MachineRequirements(
|
|
492
|
-
|
|
507
|
+
# NOTE: backwards compatibility with old API
|
|
508
|
+
machine_type=machine_requirements.machine_types[0],
|
|
509
|
+
machine_types=machine_requirements.machine_types,
|
|
510
|
+
num_gpus=machine_requirements.num_gpus,
|
|
493
511
|
keep_alive=machine_requirements.keep_alive,
|
|
494
512
|
base_image=machine_requirements.base_image,
|
|
495
513
|
exposed_port=machine_requirements.exposed_port,
|
|
@@ -500,6 +518,7 @@ class FalServerlessConnection:
|
|
|
500
518
|
max_concurrency=machine_requirements.max_concurrency,
|
|
501
519
|
min_concurrency=machine_requirements.min_concurrency,
|
|
502
520
|
max_multiplexing=machine_requirements.max_multiplexing,
|
|
521
|
+
request_timeout=machine_requirements.request_timeout,
|
|
503
522
|
)
|
|
504
523
|
else:
|
|
505
524
|
wrapped_requirements = None
|
|
@@ -516,9 +535,6 @@ class FalServerlessConnection:
|
|
|
516
535
|
struct_metadata = isolate_proto.Struct()
|
|
517
536
|
struct_metadata.update(metadata)
|
|
518
537
|
|
|
519
|
-
if deployment_strategy == "default":
|
|
520
|
-
deployment_strategy = "recreate"
|
|
521
|
-
|
|
522
538
|
deployment_strategy_proto = DeploymentStrategy[
|
|
523
539
|
deployment_strategy.upper()
|
|
524
540
|
].to_proto()
|
|
@@ -531,6 +547,7 @@ class FalServerlessConnection:
|
|
|
531
547
|
auth_mode=auth_mode,
|
|
532
548
|
metadata=struct_metadata,
|
|
533
549
|
deployment_strategy=deployment_strategy_proto,
|
|
550
|
+
scale=scale,
|
|
534
551
|
)
|
|
535
552
|
for partial_result in self.stub.RegisterApplication(request):
|
|
536
553
|
yield from_grpc(partial_result)
|
|
@@ -582,7 +599,10 @@ class FalServerlessConnection:
|
|
|
582
599
|
wrapped_function = to_serialized_object(function, serialization_method)
|
|
583
600
|
if machine_requirements:
|
|
584
601
|
wrapped_requirements = isolate_proto.MachineRequirements(
|
|
585
|
-
|
|
602
|
+
# NOTE: backwards compatibility with old API
|
|
603
|
+
machine_type=machine_requirements.machine_types[0],
|
|
604
|
+
machine_types=machine_requirements.machine_types,
|
|
605
|
+
num_gpus=machine_requirements.num_gpus,
|
|
586
606
|
keep_alive=machine_requirements.keep_alive,
|
|
587
607
|
base_image=machine_requirements.base_image,
|
|
588
608
|
exposed_port=machine_requirements.exposed_port,
|
|
@@ -593,6 +613,7 @@ class FalServerlessConnection:
|
|
|
593
613
|
max_concurrency=machine_requirements.max_concurrency,
|
|
594
614
|
max_multiplexing=machine_requirements.max_multiplexing,
|
|
595
615
|
min_concurrency=machine_requirements.min_concurrency,
|
|
616
|
+
request_timeout=machine_requirements.request_timeout,
|
|
596
617
|
)
|
|
597
618
|
else:
|
|
598
619
|
wrapped_requirements = None
|
|
@@ -665,3 +686,7 @@ class FalServerlessConnection:
|
|
|
665
686
|
)
|
|
666
687
|
for secret in response.secrets
|
|
667
688
|
]
|
|
689
|
+
|
|
690
|
+
def kill_runner(self, runner_id: str) -> None:
|
|
691
|
+
request = isolate_proto.KillRunnerRequest(runner_id=runner_id)
|
|
692
|
+
self.stub.KillRunner(request)
|