fal 1.2.1__py3-none-any.whl → 1.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fal might be problematic. Click here for more details.
- fal/__main__.py +3 -1
- fal/_fal_version.py +2 -2
- fal/api.py +88 -20
- fal/app.py +221 -27
- fal/apps.py +147 -3
- fal/auth/__init__.py +50 -2
- fal/cli/_utils.py +40 -0
- fal/cli/apps.py +5 -3
- fal/cli/create.py +26 -0
- fal/cli/deploy.py +97 -16
- fal/cli/main.py +2 -2
- fal/cli/parser.py +11 -7
- fal/cli/run.py +12 -1
- fal/cli/runners.py +44 -0
- fal/config.py +23 -0
- fal/container.py +1 -1
- fal/exceptions/__init__.py +7 -1
- fal/exceptions/_base.py +51 -0
- fal/exceptions/_cuda.py +44 -0
- fal/files.py +81 -0
- fal/sdk.py +67 -6
- fal/toolkit/file/file.py +103 -13
- fal/toolkit/file/providers/fal.py +572 -24
- fal/toolkit/file/providers/gcp.py +8 -1
- fal/toolkit/file/providers/r2.py +8 -1
- fal/toolkit/file/providers/s3.py +80 -0
- fal/toolkit/file/types.py +28 -3
- fal/toolkit/image/__init__.py +71 -0
- fal/toolkit/image/image.py +25 -2
- fal/toolkit/image/nsfw_filter/__init__.py +11 -0
- fal/toolkit/image/nsfw_filter/env.py +9 -0
- fal/toolkit/image/nsfw_filter/inference.py +77 -0
- fal/toolkit/image/nsfw_filter/model.py +18 -0
- fal/toolkit/image/nsfw_filter/requirements.txt +4 -0
- fal/toolkit/image/safety_checker.py +107 -0
- fal/toolkit/types.py +140 -0
- fal/toolkit/utils/download_utils.py +4 -0
- fal/toolkit/utils/retry.py +45 -0
- fal/utils.py +20 -4
- fal/workflows.py +10 -4
- {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/METADATA +47 -40
- {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/RECORD +45 -30
- {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/WHEEL +1 -1
- {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/entry_points.txt +0 -0
- {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/top_level.txt +0 -0
fal/app.py
CHANGED
|
@@ -1,24 +1,33 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import inspect
|
|
4
5
|
import json
|
|
5
6
|
import os
|
|
7
|
+
import queue
|
|
6
8
|
import re
|
|
9
|
+
import threading
|
|
7
10
|
import time
|
|
8
11
|
import typing
|
|
9
|
-
from contextlib import asynccontextmanager, contextmanager
|
|
10
|
-
from
|
|
12
|
+
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from typing import Any, Callable, ClassVar, Literal, TypeVar
|
|
11
15
|
|
|
16
|
+
import fastapi
|
|
17
|
+
import grpc.aio as async_grpc
|
|
12
18
|
import httpx
|
|
13
|
-
from
|
|
19
|
+
from isolate.server import definitions
|
|
14
20
|
|
|
15
21
|
import fal.api
|
|
16
22
|
from fal._serialization import include_modules_from
|
|
17
23
|
from fal.api import RouteSignature
|
|
24
|
+
from fal.exceptions import FalServerlessException, RequestCancelledException
|
|
18
25
|
from fal.logging import get_logger
|
|
19
|
-
from fal.toolkit.file
|
|
26
|
+
from fal.toolkit.file import request_lifecycle_preference
|
|
27
|
+
from fal.toolkit.file.providers.fal import LIFECYCLE_PREFERENCE
|
|
20
28
|
|
|
21
29
|
REALTIME_APP_REQUIREMENTS = ["websockets", "msgpack"]
|
|
30
|
+
REQUEST_ID_KEY = "x-fal-request-id"
|
|
22
31
|
|
|
23
32
|
EndpointT = TypeVar("EndpointT", bound=Callable[..., Any])
|
|
24
33
|
logger = get_logger(__name__)
|
|
@@ -31,6 +40,56 @@ async def _call_any_fn(fn, *args, **kwargs):
|
|
|
31
40
|
return fn(*args, **kwargs)
|
|
32
41
|
|
|
33
42
|
|
|
43
|
+
async def open_isolate_channel(address: str) -> async_grpc.Channel:
|
|
44
|
+
_stack = AsyncExitStack()
|
|
45
|
+
channel = await _stack.enter_async_context(
|
|
46
|
+
async_grpc.insecure_channel(
|
|
47
|
+
address,
|
|
48
|
+
options=[
|
|
49
|
+
("grpc.max_send_message_length", -1),
|
|
50
|
+
("grpc.max_receive_message_length", -1),
|
|
51
|
+
("grpc.min_reconnect_backoff_ms", 0),
|
|
52
|
+
("grpc.max_reconnect_backoff_ms", 100),
|
|
53
|
+
("grpc.dns_min_time_between_resolutions_ms", 100),
|
|
54
|
+
],
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
channel_status = channel.channel_ready()
|
|
59
|
+
try:
|
|
60
|
+
await asyncio.wait_for(channel_status, timeout=1)
|
|
61
|
+
except asyncio.TimeoutError:
|
|
62
|
+
await _stack.aclose()
|
|
63
|
+
raise Exception("Timed out trying to connect to local isolate")
|
|
64
|
+
|
|
65
|
+
return channel
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
async def _set_logger_labels(
|
|
69
|
+
logger_labels: dict[str, str], channel: async_grpc.Channel
|
|
70
|
+
):
|
|
71
|
+
try:
|
|
72
|
+
import sys
|
|
73
|
+
|
|
74
|
+
# Flush any prints that were buffered before setting the logger labels
|
|
75
|
+
sys.stderr.flush()
|
|
76
|
+
sys.stdout.flush()
|
|
77
|
+
|
|
78
|
+
isolate = definitions.IsolateStub(channel)
|
|
79
|
+
isolate_request = definitions.SetMetadataRequest(
|
|
80
|
+
# TODO: when submit is shipped, get task_id from an env var
|
|
81
|
+
task_id="RUN",
|
|
82
|
+
metadata=definitions.TaskMetadata(logger_labels=logger_labels),
|
|
83
|
+
)
|
|
84
|
+
res = isolate.SetMetadata(isolate_request)
|
|
85
|
+
code = await res.code()
|
|
86
|
+
assert str(code) == "StatusCode.OK", str(code)
|
|
87
|
+
except BaseException:
|
|
88
|
+
# NOTE hiding this for now to not print on every request
|
|
89
|
+
# logger.debug("Failed to set logger labels", exc_info=True)
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
|
|
34
93
|
def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
|
|
35
94
|
include_modules_from(cls)
|
|
36
95
|
|
|
@@ -57,6 +116,7 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
|
|
|
57
116
|
kind,
|
|
58
117
|
requirements=cls.requirements,
|
|
59
118
|
machine_type=cls.machine_type,
|
|
119
|
+
num_gpus=cls.num_gpus,
|
|
60
120
|
**cls.host_kwargs,
|
|
61
121
|
**kwargs,
|
|
62
122
|
metadata=metadata,
|
|
@@ -71,19 +131,37 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
|
|
|
71
131
|
return fn
|
|
72
132
|
|
|
73
133
|
|
|
134
|
+
@dataclass
|
|
135
|
+
class AppClientError(FalServerlessException):
|
|
136
|
+
message: str
|
|
137
|
+
status_code: int
|
|
138
|
+
|
|
139
|
+
|
|
74
140
|
class EndpointClient:
|
|
75
|
-
def __init__(self, url, endpoint, signature):
|
|
141
|
+
def __init__(self, url, endpoint, signature, timeout: int | None = None):
|
|
76
142
|
self.url = url
|
|
77
143
|
self.endpoint = endpoint
|
|
78
144
|
self.signature = signature
|
|
145
|
+
self.timeout = timeout
|
|
79
146
|
|
|
80
147
|
annotations = endpoint.__annotations__ or {}
|
|
81
148
|
self.return_type = annotations.get("return") or None
|
|
82
149
|
|
|
83
150
|
def __call__(self, data):
|
|
84
151
|
with httpx.Client() as client:
|
|
85
|
-
|
|
86
|
-
resp.
|
|
152
|
+
url = self.url + self.signature.path
|
|
153
|
+
resp = client.post(
|
|
154
|
+
self.url + self.signature.path,
|
|
155
|
+
json=data.dict() if hasattr(data, "dict") else dict(data),
|
|
156
|
+
timeout=self.timeout,
|
|
157
|
+
)
|
|
158
|
+
if not resp.is_success:
|
|
159
|
+
# allow logs to be printed before raising the exception
|
|
160
|
+
time.sleep(1)
|
|
161
|
+
raise AppClientError(
|
|
162
|
+
f"Failed to POST {url}: {resp.status_code} {resp.text}",
|
|
163
|
+
status_code=resp.status_code,
|
|
164
|
+
)
|
|
87
165
|
resp_dict = resp.json()
|
|
88
166
|
|
|
89
167
|
if not self.return_type:
|
|
@@ -93,7 +171,12 @@ class EndpointClient:
|
|
|
93
171
|
|
|
94
172
|
|
|
95
173
|
class AppClient:
|
|
96
|
-
def __init__(
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
cls,
|
|
177
|
+
url,
|
|
178
|
+
timeout: int | None = None,
|
|
179
|
+
):
|
|
97
180
|
self.url = url
|
|
98
181
|
self.cls = cls
|
|
99
182
|
|
|
@@ -101,29 +184,54 @@ class AppClient:
|
|
|
101
184
|
signature = getattr(endpoint, "route_signature", None)
|
|
102
185
|
if signature is None:
|
|
103
186
|
continue
|
|
104
|
-
|
|
105
|
-
|
|
187
|
+
endpoint_client = EndpointClient(
|
|
188
|
+
self.url,
|
|
189
|
+
endpoint,
|
|
190
|
+
signature,
|
|
191
|
+
timeout=timeout,
|
|
192
|
+
)
|
|
193
|
+
setattr(self, name, endpoint_client)
|
|
106
194
|
|
|
107
195
|
@classmethod
|
|
108
196
|
@contextmanager
|
|
109
197
|
def connect(cls, app_cls):
|
|
110
198
|
app = wrap_app(app_cls)
|
|
111
199
|
info = app.spawn()
|
|
200
|
+
_shutdown_event = threading.Event()
|
|
201
|
+
|
|
202
|
+
def _print_logs():
|
|
203
|
+
while not _shutdown_event.is_set():
|
|
204
|
+
try:
|
|
205
|
+
log = info.logs.get(timeout=0.1)
|
|
206
|
+
except queue.Empty:
|
|
207
|
+
continue
|
|
208
|
+
print(log)
|
|
209
|
+
|
|
210
|
+
_log_printer = threading.Thread(target=_print_logs, daemon=True)
|
|
211
|
+
_log_printer.start()
|
|
212
|
+
|
|
112
213
|
try:
|
|
113
214
|
with httpx.Client() as client:
|
|
114
215
|
retries = 100
|
|
115
|
-
|
|
116
|
-
|
|
216
|
+
for _ in range(retries):
|
|
217
|
+
url = info.url + "/health"
|
|
218
|
+
resp = client.get(url, timeout=60)
|
|
219
|
+
|
|
117
220
|
if resp.is_success:
|
|
118
221
|
break
|
|
119
|
-
elif resp.status_code
|
|
120
|
-
|
|
222
|
+
elif resp.status_code not in (500, 404):
|
|
223
|
+
raise AppClientError(
|
|
224
|
+
f"Failed to GET {url}: {resp.status_code} {resp.text}",
|
|
225
|
+
status_code=resp.status_code,
|
|
226
|
+
)
|
|
121
227
|
time.sleep(0.1)
|
|
122
|
-
retries -= 1
|
|
123
228
|
|
|
124
|
-
|
|
229
|
+
client = cls(app_cls, info.url)
|
|
230
|
+
yield client
|
|
125
231
|
finally:
|
|
126
232
|
info.stream.cancel()
|
|
233
|
+
_shutdown_event.set()
|
|
234
|
+
_log_printer.join()
|
|
127
235
|
|
|
128
236
|
def health(self):
|
|
129
237
|
with httpx.Client() as client:
|
|
@@ -140,9 +248,18 @@ def _to_fal_app_name(name: str) -> str:
|
|
|
140
248
|
return "-".join(part.lower() for part in PART_FINDER_RE.findall(name))
|
|
141
249
|
|
|
142
250
|
|
|
251
|
+
def _print_python_packages() -> None:
|
|
252
|
+
from importlib.metadata import distributions
|
|
253
|
+
|
|
254
|
+
packages = [f"{dist.metadata['Name']}=={dist.version}" for dist in distributions()]
|
|
255
|
+
|
|
256
|
+
print("[debug] Python packages installed:", ", ".join(packages))
|
|
257
|
+
|
|
258
|
+
|
|
143
259
|
class App(fal.api.BaseServable):
|
|
144
260
|
requirements: ClassVar[list[str]] = []
|
|
145
261
|
machine_type: ClassVar[str] = "S"
|
|
262
|
+
num_gpus: ClassVar[int | None] = None
|
|
146
263
|
host_kwargs: ClassVar[dict[str, Any]] = {
|
|
147
264
|
"_scheduler": "nomad",
|
|
148
265
|
"_scheduler_options": {
|
|
@@ -152,12 +269,20 @@ class App(fal.api.BaseServable):
|
|
|
152
269
|
"keep_alive": 60,
|
|
153
270
|
}
|
|
154
271
|
app_name: ClassVar[str]
|
|
272
|
+
app_auth: ClassVar[Literal["private", "public", "shared"]] = "private"
|
|
273
|
+
request_timeout: ClassVar[int | None] = None
|
|
274
|
+
|
|
275
|
+
isolate_channel: async_grpc.Channel | None = None
|
|
155
276
|
|
|
156
277
|
def __init_subclass__(cls, **kwargs):
|
|
157
278
|
app_name = kwargs.pop("name", None) or _to_fal_app_name(cls.__name__)
|
|
158
279
|
parent_settings = getattr(cls, "host_kwargs", {})
|
|
159
280
|
cls.host_kwargs = {**parent_settings, **kwargs}
|
|
160
|
-
|
|
281
|
+
|
|
282
|
+
if cls.request_timeout is not None:
|
|
283
|
+
cls.host_kwargs["request_timeout"] = cls.request_timeout
|
|
284
|
+
|
|
285
|
+
cls.app_name = getattr(cls, "app_name", app_name)
|
|
161
286
|
|
|
162
287
|
if cls.__init__ is not App.__init__:
|
|
163
288
|
raise ValueError(
|
|
@@ -171,6 +296,14 @@ class App(fal.api.BaseServable):
|
|
|
171
296
|
"Running apps through SDK is not implemented yet."
|
|
172
297
|
)
|
|
173
298
|
|
|
299
|
+
@classmethod
|
|
300
|
+
def get_endpoints(cls) -> list[str]:
|
|
301
|
+
return [
|
|
302
|
+
signature.path
|
|
303
|
+
for _, endpoint in inspect.getmembers(cls, inspect.isfunction)
|
|
304
|
+
if (signature := getattr(endpoint, "route_signature", None))
|
|
305
|
+
]
|
|
306
|
+
|
|
174
307
|
def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
|
|
175
308
|
return {
|
|
176
309
|
signature: endpoint
|
|
@@ -179,7 +312,8 @@ class App(fal.api.BaseServable):
|
|
|
179
312
|
}
|
|
180
313
|
|
|
181
314
|
@asynccontextmanager
|
|
182
|
-
async def lifespan(self, app: FastAPI):
|
|
315
|
+
async def lifespan(self, app: fastapi.FastAPI):
|
|
316
|
+
_print_python_packages()
|
|
183
317
|
await _call_any_fn(self.setup)
|
|
184
318
|
try:
|
|
185
319
|
yield
|
|
@@ -187,7 +321,7 @@ class App(fal.api.BaseServable):
|
|
|
187
321
|
await _call_any_fn(self.teardown)
|
|
188
322
|
|
|
189
323
|
def health(self):
|
|
190
|
-
return {}
|
|
324
|
+
return {"version": self.version}
|
|
191
325
|
|
|
192
326
|
def setup(self):
|
|
193
327
|
"""Setup the application before serving."""
|
|
@@ -195,7 +329,7 @@ class App(fal.api.BaseServable):
|
|
|
195
329
|
def teardown(self):
|
|
196
330
|
"""Teardown the application after serving."""
|
|
197
331
|
|
|
198
|
-
def _add_extra_middlewares(self, app: FastAPI):
|
|
332
|
+
def _add_extra_middlewares(self, app: fastapi.FastAPI):
|
|
199
333
|
@app.middleware("http")
|
|
200
334
|
async def provide_hints_headers(request, call_next):
|
|
201
335
|
response = await call_next(request)
|
|
@@ -216,11 +350,12 @@ class App(fal.api.BaseServable):
|
|
|
216
350
|
|
|
217
351
|
@app.middleware("http")
|
|
218
352
|
async def set_global_object_preference(request, call_next):
|
|
219
|
-
response = await call_next(request)
|
|
220
353
|
try:
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
354
|
+
preference_dict = request_lifecycle_preference(request)
|
|
355
|
+
if preference_dict is not None:
|
|
356
|
+
# This will not work properly for apps with multiplexing enabled
|
|
357
|
+
# we may mix up the preferences between requests
|
|
358
|
+
LIFECYCLE_PREFERENCE.set(preference_dict)
|
|
224
359
|
except Exception:
|
|
225
360
|
from fastapi.logger import logger
|
|
226
361
|
|
|
@@ -228,9 +363,65 @@ class App(fal.api.BaseServable):
|
|
|
228
363
|
"Failed set a global lifecycle preference %s",
|
|
229
364
|
self.__class__.__name__,
|
|
230
365
|
)
|
|
231
|
-
return response
|
|
232
366
|
|
|
233
|
-
|
|
367
|
+
try:
|
|
368
|
+
return await call_next(request)
|
|
369
|
+
finally:
|
|
370
|
+
# We may miss the global preference if there are operations
|
|
371
|
+
# being done in the background that go beyond the request
|
|
372
|
+
LIFECYCLE_PREFERENCE.set(None)
|
|
373
|
+
|
|
374
|
+
@app.middleware("http")
|
|
375
|
+
async def set_request_id(request, call_next):
|
|
376
|
+
# NOTE: Setting request_id is not supported for websocket/realtime endpoints
|
|
377
|
+
|
|
378
|
+
if self.isolate_channel is None:
|
|
379
|
+
grpc_port = os.environ.get("NOMAD_ALLOC_PORT_grpc")
|
|
380
|
+
self.isolate_channel = await open_isolate_channel(
|
|
381
|
+
f"localhost:{grpc_port}"
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
request_id = request.headers.get(REQUEST_ID_KEY)
|
|
385
|
+
if request_id is None:
|
|
386
|
+
# Cut it short
|
|
387
|
+
return await call_next(request)
|
|
388
|
+
|
|
389
|
+
await _set_logger_labels(
|
|
390
|
+
{"fal_request_id": request_id}, channel=self.isolate_channel
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
async def _unset_at_end():
|
|
394
|
+
await _set_logger_labels({}, channel=self.isolate_channel) # type: ignore
|
|
395
|
+
|
|
396
|
+
try:
|
|
397
|
+
response: fastapi.responses.Response = await call_next(request)
|
|
398
|
+
except BaseException:
|
|
399
|
+
await _unset_at_end()
|
|
400
|
+
raise
|
|
401
|
+
else:
|
|
402
|
+
# We need to wait for the entire response to be sent before
|
|
403
|
+
# we can set the logger labels back to the default.
|
|
404
|
+
background_tasks = fastapi.BackgroundTasks()
|
|
405
|
+
background_tasks.add_task(_unset_at_end)
|
|
406
|
+
if response.background:
|
|
407
|
+
# We normally have no background tasks, but we should handle it
|
|
408
|
+
background_tasks.add_task(response.background)
|
|
409
|
+
response.background = background_tasks
|
|
410
|
+
|
|
411
|
+
return response
|
|
412
|
+
|
|
413
|
+
@app.exception_handler(RequestCancelledException)
|
|
414
|
+
async def value_error_exception_handler(
|
|
415
|
+
request, exc: RequestCancelledException
|
|
416
|
+
):
|
|
417
|
+
from fastapi.responses import JSONResponse
|
|
418
|
+
|
|
419
|
+
# A 499 status code is not an officially recognized HTTP status code,
|
|
420
|
+
# but it is sometimes used by servers to indicate that a client has closed
|
|
421
|
+
# the connection without receiving a response
|
|
422
|
+
return JSONResponse({"detail": str(exc)}, 499)
|
|
423
|
+
|
|
424
|
+
def _add_extra_routes(self, app: fastapi.FastAPI):
|
|
234
425
|
@app.get("/health")
|
|
235
426
|
def health():
|
|
236
427
|
return self.health()
|
|
@@ -341,7 +532,10 @@ def _fal_websocket_template(
|
|
|
341
532
|
batch.append(next_input)
|
|
342
533
|
|
|
343
534
|
t0 = loop.time()
|
|
344
|
-
|
|
535
|
+
if inspect.iscoroutinefunction(func):
|
|
536
|
+
output = await func(self, *batch)
|
|
537
|
+
else:
|
|
538
|
+
output = await loop.run_in_executor(None, func, self, *batch) # type: ignore
|
|
345
539
|
total_time = loop.time() - t0
|
|
346
540
|
if not isinstance(output, dict):
|
|
347
541
|
# Handle pydantic output modal
|
fal/apps.py
CHANGED
|
@@ -4,15 +4,19 @@ import json
|
|
|
4
4
|
import time
|
|
5
5
|
from contextlib import contextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
|
-
from typing import Any, Iterator
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Iterator
|
|
8
8
|
|
|
9
9
|
import httpx
|
|
10
10
|
|
|
11
11
|
from fal import flags
|
|
12
12
|
from fal.sdk import Credentials, get_default_credentials
|
|
13
13
|
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from websockets.sync.connection import Connection
|
|
16
|
+
|
|
14
17
|
_QUEUE_URL_FORMAT = f"https://queue.{flags.FAL_RUN_HOST}/{{app_id}}"
|
|
15
18
|
_REALTIME_URL_FORMAT = f"wss://{flags.FAL_RUN_HOST}/{{app_id}}"
|
|
19
|
+
_WS_URL_FORMAT = f"wss://ws.{flags.FAL_RUN_HOST}/{{app_id}}"
|
|
16
20
|
|
|
17
21
|
|
|
18
22
|
def _backwards_compatible_app_id(app_id: str) -> str:
|
|
@@ -97,6 +101,15 @@ class RequestHandle:
|
|
|
97
101
|
else:
|
|
98
102
|
raise ValueError(f"Unknown status: {data['status']}")
|
|
99
103
|
|
|
104
|
+
def cancel(self) -> None:
|
|
105
|
+
"""Cancel an async inference request."""
|
|
106
|
+
url = (
|
|
107
|
+
_QUEUE_URL_FORMAT.format(app_id=self.app_id)
|
|
108
|
+
+ f"/requests/{self.request_id}/cancel"
|
|
109
|
+
)
|
|
110
|
+
response = _HTTP_CLIENT.put(url, headers=self._creds.to_headers())
|
|
111
|
+
response.raise_for_status()
|
|
112
|
+
|
|
100
113
|
def iter_events(
|
|
101
114
|
self,
|
|
102
115
|
*,
|
|
@@ -164,7 +177,8 @@ def submit(app_id: str, arguments: dict[str, Any], *, path: str = "") -> Request
|
|
|
164
177
|
app_id = _backwards_compatible_app_id(app_id)
|
|
165
178
|
url = _QUEUE_URL_FORMAT.format(app_id=app_id)
|
|
166
179
|
if path:
|
|
167
|
-
|
|
180
|
+
_path = path[len("/") :] if path.startswith("/") else path
|
|
181
|
+
url += "/" + _path
|
|
168
182
|
|
|
169
183
|
creds = get_default_credentials()
|
|
170
184
|
|
|
@@ -226,7 +240,8 @@ def _connect(app_id: str, *, path: str = "/realtime") -> Iterator[_RealtimeConne
|
|
|
226
240
|
app_id = _backwards_compatible_app_id(app_id)
|
|
227
241
|
url = _REALTIME_URL_FORMAT.format(app_id=app_id)
|
|
228
242
|
if path:
|
|
229
|
-
|
|
243
|
+
_path = path[len("/") :] if path.startswith("/") else path
|
|
244
|
+
url += "/" + _path
|
|
230
245
|
|
|
231
246
|
creds = get_default_credentials()
|
|
232
247
|
|
|
@@ -234,3 +249,132 @@ def _connect(app_id: str, *, path: str = "/realtime") -> Iterator[_RealtimeConne
|
|
|
234
249
|
url, additional_headers=creds.to_headers(), open_timeout=90
|
|
235
250
|
) as ws:
|
|
236
251
|
yield _RealtimeConnection(ws)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class _MetaMessageFound(Exception): ...
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@dataclass
|
|
258
|
+
class _WSConnection:
|
|
259
|
+
"""A WS connection to an HTTP Fal app."""
|
|
260
|
+
|
|
261
|
+
_ws: Connection
|
|
262
|
+
_buffer: str | bytes | None = None
|
|
263
|
+
|
|
264
|
+
def run(self, arguments: dict[str, Any]) -> bytes:
|
|
265
|
+
"""Run an inference task on the app and return the result."""
|
|
266
|
+
self.send(arguments)
|
|
267
|
+
return self.recv()
|
|
268
|
+
|
|
269
|
+
def send(self, arguments: dict[str, Any]) -> None:
|
|
270
|
+
import json
|
|
271
|
+
|
|
272
|
+
payload = json.dumps(arguments)
|
|
273
|
+
self._ws.send(payload)
|
|
274
|
+
|
|
275
|
+
def _peek(self) -> bytes | str:
|
|
276
|
+
if self._buffer is None:
|
|
277
|
+
self._buffer = self._ws.recv()
|
|
278
|
+
|
|
279
|
+
return self._buffer
|
|
280
|
+
|
|
281
|
+
def _consume(self) -> None:
|
|
282
|
+
if self._buffer is None:
|
|
283
|
+
raise ValueError("No data to consume")
|
|
284
|
+
|
|
285
|
+
self._buffer = None
|
|
286
|
+
|
|
287
|
+
@contextmanager
|
|
288
|
+
def _recv(self) -> Iterator[str | bytes]:
|
|
289
|
+
res = self._peek()
|
|
290
|
+
|
|
291
|
+
yield res
|
|
292
|
+
|
|
293
|
+
# Only consume if it went through the context manager without raising
|
|
294
|
+
self._consume()
|
|
295
|
+
|
|
296
|
+
def _is_meta(self, res: str | bytes) -> bool:
|
|
297
|
+
if not isinstance(res, str):
|
|
298
|
+
return False
|
|
299
|
+
|
|
300
|
+
try:
|
|
301
|
+
json_payload: Any = json.loads(res)
|
|
302
|
+
except json.JSONDecodeError:
|
|
303
|
+
return False
|
|
304
|
+
|
|
305
|
+
if not isinstance(json_payload, dict):
|
|
306
|
+
return False
|
|
307
|
+
|
|
308
|
+
return "type" in json_payload and "request_id" in json_payload
|
|
309
|
+
|
|
310
|
+
def _recv_meta(self, type: str) -> dict[str, Any]:
|
|
311
|
+
with self._recv() as res:
|
|
312
|
+
if not self._is_meta(res):
|
|
313
|
+
raise ValueError(f"Expected a {type} message")
|
|
314
|
+
|
|
315
|
+
json_payload: dict = json.loads(res)
|
|
316
|
+
if json_payload.get("type") != type:
|
|
317
|
+
raise ValueError(f"Expected a {type} message")
|
|
318
|
+
|
|
319
|
+
return json_payload
|
|
320
|
+
|
|
321
|
+
def _recv_response(self) -> Iterator[str | bytes]:
|
|
322
|
+
while True:
|
|
323
|
+
try:
|
|
324
|
+
with self._recv() as res:
|
|
325
|
+
if self._is_meta(res):
|
|
326
|
+
# Raise so we dont consume the message
|
|
327
|
+
raise _MetaMessageFound()
|
|
328
|
+
|
|
329
|
+
yield res
|
|
330
|
+
except _MetaMessageFound:
|
|
331
|
+
break
|
|
332
|
+
|
|
333
|
+
def recv(self) -> bytes:
|
|
334
|
+
start = self._recv_meta("start")
|
|
335
|
+
request_id = start["request_id"]
|
|
336
|
+
|
|
337
|
+
response = b""
|
|
338
|
+
for part in self._recv_response():
|
|
339
|
+
if isinstance(part, str):
|
|
340
|
+
response += part.encode()
|
|
341
|
+
else:
|
|
342
|
+
response += part
|
|
343
|
+
|
|
344
|
+
end = self._recv_meta("end")
|
|
345
|
+
if end["request_id"] != request_id:
|
|
346
|
+
raise ValueError("Mismatched request_id in end message")
|
|
347
|
+
|
|
348
|
+
return response
|
|
349
|
+
|
|
350
|
+
def stream(self) -> Iterator[str | bytes]:
|
|
351
|
+
start = self._recv_meta("start")
|
|
352
|
+
request_id = start["request_id"]
|
|
353
|
+
|
|
354
|
+
yield from self._recv_response()
|
|
355
|
+
|
|
356
|
+
# Make sure we consume the end message
|
|
357
|
+
end = self._recv_meta("end")
|
|
358
|
+
if end["request_id"] != request_id:
|
|
359
|
+
raise ValueError("Mismatched request_id in end message")
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
@contextmanager
|
|
363
|
+
def ws(app_id: str, *, path: str = "") -> Iterator[_WSConnection]:
|
|
364
|
+
"""Connect to a HTTP endpoint but with websocket protocol. This is an internal and
|
|
365
|
+
experimental API, use it at your own risk."""
|
|
366
|
+
|
|
367
|
+
from websockets.sync import client
|
|
368
|
+
|
|
369
|
+
app_id = _backwards_compatible_app_id(app_id)
|
|
370
|
+
url = _WS_URL_FORMAT.format(app_id=app_id)
|
|
371
|
+
if path:
|
|
372
|
+
_path = path[len("/") :] if path.startswith("/") else path
|
|
373
|
+
url += "/" + _path
|
|
374
|
+
|
|
375
|
+
creds = get_default_credentials()
|
|
376
|
+
|
|
377
|
+
with client.connect(
|
|
378
|
+
url, additional_headers=creds.to_headers(), open_timeout=90
|
|
379
|
+
) as ws:
|
|
380
|
+
yield _WSConnection(ws)
|
fal/auth/__init__.py
CHANGED
|
@@ -2,22 +2,70 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
|
+
from threading import Lock
|
|
6
|
+
from typing import Optional
|
|
5
7
|
|
|
6
8
|
import click
|
|
7
9
|
|
|
8
10
|
from fal.auth import auth0, local
|
|
11
|
+
from fal.config import Config
|
|
9
12
|
from fal.console import console
|
|
10
13
|
from fal.console.icons import CHECK_ICON
|
|
11
14
|
from fal.exceptions.auth import UnauthenticatedException
|
|
12
15
|
|
|
13
16
|
|
|
17
|
+
class GoogleColabState:
|
|
18
|
+
def __init__(self):
|
|
19
|
+
self.is_checked = False
|
|
20
|
+
self.lock = Lock()
|
|
21
|
+
self.secret: Optional[str] = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
_colab_state = GoogleColabState()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def is_google_colab() -> bool:
|
|
28
|
+
try:
|
|
29
|
+
from IPython import get_ipython
|
|
30
|
+
|
|
31
|
+
return "google.colab" in str(get_ipython())
|
|
32
|
+
except ModuleNotFoundError:
|
|
33
|
+
return False
|
|
34
|
+
except NameError:
|
|
35
|
+
return False
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_colab_token() -> Optional[str]:
|
|
39
|
+
if not is_google_colab():
|
|
40
|
+
return None
|
|
41
|
+
with _colab_state.lock:
|
|
42
|
+
if _colab_state.is_checked: # request access only once
|
|
43
|
+
return _colab_state.secret
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
from google.colab import userdata # noqa: I001
|
|
47
|
+
except ImportError:
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
token = userdata.get("FAL_KEY")
|
|
52
|
+
_colab_state.secret = token.strip()
|
|
53
|
+
except Exception:
|
|
54
|
+
_colab_state.secret = None
|
|
55
|
+
|
|
56
|
+
_colab_state.is_checked = True
|
|
57
|
+
return _colab_state.secret
|
|
58
|
+
|
|
59
|
+
|
|
14
60
|
def key_credentials() -> tuple[str, str] | None:
|
|
15
61
|
# Ignore key credentials when the user forces auth by user.
|
|
16
62
|
if os.environ.get("FAL_FORCE_AUTH_BY_USER") == "1":
|
|
17
63
|
return None
|
|
18
64
|
|
|
19
|
-
|
|
20
|
-
|
|
65
|
+
config = Config()
|
|
66
|
+
|
|
67
|
+
key = os.environ.get("FAL_KEY") or config.get("key") or get_colab_token()
|
|
68
|
+
if key:
|
|
21
69
|
key_id, key_secret = key.split(":", 1)
|
|
22
70
|
return (key_id, key_secret)
|
|
23
71
|
elif "FAL_KEY_ID" in os.environ and "FAL_KEY_SECRET" in os.environ:
|
fal/cli/_utils.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from fal.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def is_app_name(app_ref: tuple[str, str | None]) -> bool:
|
|
7
|
+
is_single_file = app_ref[1] is None
|
|
8
|
+
is_python_file = app_ref[0].endswith(".py")
|
|
9
|
+
|
|
10
|
+
return is_single_file and not is_python_file
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_app_data_from_toml(app_name):
|
|
14
|
+
toml_path = find_pyproject_toml()
|
|
15
|
+
|
|
16
|
+
if toml_path is None:
|
|
17
|
+
raise ValueError("No pyproject.toml file found.")
|
|
18
|
+
|
|
19
|
+
fal_data = parse_pyproject_toml(toml_path)
|
|
20
|
+
apps = fal_data.get("apps", {})
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
app_data = apps[app_name]
|
|
24
|
+
except KeyError:
|
|
25
|
+
raise ValueError(f"App {app_name} not found in pyproject.toml")
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
app_ref = app_data["ref"]
|
|
29
|
+
except KeyError:
|
|
30
|
+
raise ValueError(f"App {app_name} does not have a ref key in pyproject.toml")
|
|
31
|
+
|
|
32
|
+
# Convert the app_ref to a path relative to the project root
|
|
33
|
+
project_root, _ = find_project_root(None)
|
|
34
|
+
app_ref = str(project_root / app_ref)
|
|
35
|
+
|
|
36
|
+
app_auth = app_data.get("auth", "private")
|
|
37
|
+
app_deployment_strategy = app_data.get("deployment_strategy", "recreate")
|
|
38
|
+
app_no_scale = app_data.get("no_scale", False)
|
|
39
|
+
|
|
40
|
+
return app_ref, app_auth, app_deployment_strategy, app_no_scale
|