fal 1.2.1__py3-none-any.whl → 1.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

Files changed (45) hide show
  1. fal/__main__.py +3 -1
  2. fal/_fal_version.py +2 -2
  3. fal/api.py +88 -20
  4. fal/app.py +221 -27
  5. fal/apps.py +147 -3
  6. fal/auth/__init__.py +50 -2
  7. fal/cli/_utils.py +40 -0
  8. fal/cli/apps.py +5 -3
  9. fal/cli/create.py +26 -0
  10. fal/cli/deploy.py +97 -16
  11. fal/cli/main.py +2 -2
  12. fal/cli/parser.py +11 -7
  13. fal/cli/run.py +12 -1
  14. fal/cli/runners.py +44 -0
  15. fal/config.py +23 -0
  16. fal/container.py +1 -1
  17. fal/exceptions/__init__.py +7 -1
  18. fal/exceptions/_base.py +51 -0
  19. fal/exceptions/_cuda.py +44 -0
  20. fal/files.py +81 -0
  21. fal/sdk.py +67 -6
  22. fal/toolkit/file/file.py +103 -13
  23. fal/toolkit/file/providers/fal.py +572 -24
  24. fal/toolkit/file/providers/gcp.py +8 -1
  25. fal/toolkit/file/providers/r2.py +8 -1
  26. fal/toolkit/file/providers/s3.py +80 -0
  27. fal/toolkit/file/types.py +28 -3
  28. fal/toolkit/image/__init__.py +71 -0
  29. fal/toolkit/image/image.py +25 -2
  30. fal/toolkit/image/nsfw_filter/__init__.py +11 -0
  31. fal/toolkit/image/nsfw_filter/env.py +9 -0
  32. fal/toolkit/image/nsfw_filter/inference.py +77 -0
  33. fal/toolkit/image/nsfw_filter/model.py +18 -0
  34. fal/toolkit/image/nsfw_filter/requirements.txt +4 -0
  35. fal/toolkit/image/safety_checker.py +107 -0
  36. fal/toolkit/types.py +140 -0
  37. fal/toolkit/utils/download_utils.py +4 -0
  38. fal/toolkit/utils/retry.py +45 -0
  39. fal/utils.py +20 -4
  40. fal/workflows.py +10 -4
  41. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/METADATA +47 -40
  42. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/RECORD +45 -30
  43. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/WHEEL +1 -1
  44. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/entry_points.txt +0 -0
  45. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/top_level.txt +0 -0
fal/app.py CHANGED
@@ -1,24 +1,33 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import inspect
4
5
  import json
5
6
  import os
7
+ import queue
6
8
  import re
9
+ import threading
7
10
  import time
8
11
  import typing
9
- from contextlib import asynccontextmanager, contextmanager
10
- from typing import Any, Callable, ClassVar, TypeVar
12
+ from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
13
+ from dataclasses import dataclass
14
+ from typing import Any, Callable, ClassVar, Literal, TypeVar
11
15
 
16
+ import fastapi
17
+ import grpc.aio as async_grpc
12
18
  import httpx
13
- from fastapi import FastAPI
19
+ from isolate.server import definitions
14
20
 
15
21
  import fal.api
16
22
  from fal._serialization import include_modules_from
17
23
  from fal.api import RouteSignature
24
+ from fal.exceptions import FalServerlessException, RequestCancelledException
18
25
  from fal.logging import get_logger
19
- from fal.toolkit.file.providers import fal as fal_provider_module
26
+ from fal.toolkit.file import request_lifecycle_preference
27
+ from fal.toolkit.file.providers.fal import LIFECYCLE_PREFERENCE
20
28
 
21
29
  REALTIME_APP_REQUIREMENTS = ["websockets", "msgpack"]
30
+ REQUEST_ID_KEY = "x-fal-request-id"
22
31
 
23
32
  EndpointT = TypeVar("EndpointT", bound=Callable[..., Any])
24
33
  logger = get_logger(__name__)
@@ -31,6 +40,56 @@ async def _call_any_fn(fn, *args, **kwargs):
31
40
  return fn(*args, **kwargs)
32
41
 
33
42
 
43
+ async def open_isolate_channel(address: str) -> async_grpc.Channel:
44
+ _stack = AsyncExitStack()
45
+ channel = await _stack.enter_async_context(
46
+ async_grpc.insecure_channel(
47
+ address,
48
+ options=[
49
+ ("grpc.max_send_message_length", -1),
50
+ ("grpc.max_receive_message_length", -1),
51
+ ("grpc.min_reconnect_backoff_ms", 0),
52
+ ("grpc.max_reconnect_backoff_ms", 100),
53
+ ("grpc.dns_min_time_between_resolutions_ms", 100),
54
+ ],
55
+ )
56
+ )
57
+
58
+ channel_status = channel.channel_ready()
59
+ try:
60
+ await asyncio.wait_for(channel_status, timeout=1)
61
+ except asyncio.TimeoutError:
62
+ await _stack.aclose()
63
+ raise Exception("Timed out trying to connect to local isolate")
64
+
65
+ return channel
66
+
67
+
68
+ async def _set_logger_labels(
69
+ logger_labels: dict[str, str], channel: async_grpc.Channel
70
+ ):
71
+ try:
72
+ import sys
73
+
74
+ # Flush any prints that were buffered before setting the logger labels
75
+ sys.stderr.flush()
76
+ sys.stdout.flush()
77
+
78
+ isolate = definitions.IsolateStub(channel)
79
+ isolate_request = definitions.SetMetadataRequest(
80
+ # TODO: when submit is shipped, get task_id from an env var
81
+ task_id="RUN",
82
+ metadata=definitions.TaskMetadata(logger_labels=logger_labels),
83
+ )
84
+ res = isolate.SetMetadata(isolate_request)
85
+ code = await res.code()
86
+ assert str(code) == "StatusCode.OK", str(code)
87
+ except BaseException:
88
+ # NOTE hiding this for now to not print on every request
89
+ # logger.debug("Failed to set logger labels", exc_info=True)
90
+ pass
91
+
92
+
34
93
  def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
35
94
  include_modules_from(cls)
36
95
 
@@ -57,6 +116,7 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
57
116
  kind,
58
117
  requirements=cls.requirements,
59
118
  machine_type=cls.machine_type,
119
+ num_gpus=cls.num_gpus,
60
120
  **cls.host_kwargs,
61
121
  **kwargs,
62
122
  metadata=metadata,
@@ -71,19 +131,37 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
71
131
  return fn
72
132
 
73
133
 
134
+ @dataclass
135
+ class AppClientError(FalServerlessException):
136
+ message: str
137
+ status_code: int
138
+
139
+
74
140
  class EndpointClient:
75
- def __init__(self, url, endpoint, signature):
141
+ def __init__(self, url, endpoint, signature, timeout: int | None = None):
76
142
  self.url = url
77
143
  self.endpoint = endpoint
78
144
  self.signature = signature
145
+ self.timeout = timeout
79
146
 
80
147
  annotations = endpoint.__annotations__ or {}
81
148
  self.return_type = annotations.get("return") or None
82
149
 
83
150
  def __call__(self, data):
84
151
  with httpx.Client() as client:
85
- resp = client.post(self.url + self.signature.path, json=dict(data))
86
- resp.raise_for_status()
152
+ url = self.url + self.signature.path
153
+ resp = client.post(
154
+ self.url + self.signature.path,
155
+ json=data.dict() if hasattr(data, "dict") else dict(data),
156
+ timeout=self.timeout,
157
+ )
158
+ if not resp.is_success:
159
+ # allow logs to be printed before raising the exception
160
+ time.sleep(1)
161
+ raise AppClientError(
162
+ f"Failed to POST {url}: {resp.status_code} {resp.text}",
163
+ status_code=resp.status_code,
164
+ )
87
165
  resp_dict = resp.json()
88
166
 
89
167
  if not self.return_type:
@@ -93,7 +171,12 @@ class EndpointClient:
93
171
 
94
172
 
95
173
  class AppClient:
96
- def __init__(self, cls, url):
174
+ def __init__(
175
+ self,
176
+ cls,
177
+ url,
178
+ timeout: int | None = None,
179
+ ):
97
180
  self.url = url
98
181
  self.cls = cls
99
182
 
@@ -101,29 +184,54 @@ class AppClient:
101
184
  signature = getattr(endpoint, "route_signature", None)
102
185
  if signature is None:
103
186
  continue
104
-
105
- setattr(self, name, EndpointClient(self.url, endpoint, signature))
187
+ endpoint_client = EndpointClient(
188
+ self.url,
189
+ endpoint,
190
+ signature,
191
+ timeout=timeout,
192
+ )
193
+ setattr(self, name, endpoint_client)
106
194
 
107
195
  @classmethod
108
196
  @contextmanager
109
197
  def connect(cls, app_cls):
110
198
  app = wrap_app(app_cls)
111
199
  info = app.spawn()
200
+ _shutdown_event = threading.Event()
201
+
202
+ def _print_logs():
203
+ while not _shutdown_event.is_set():
204
+ try:
205
+ log = info.logs.get(timeout=0.1)
206
+ except queue.Empty:
207
+ continue
208
+ print(log)
209
+
210
+ _log_printer = threading.Thread(target=_print_logs, daemon=True)
211
+ _log_printer.start()
212
+
112
213
  try:
113
214
  with httpx.Client() as client:
114
215
  retries = 100
115
- while retries:
116
- resp = client.get(info.url + "/health")
216
+ for _ in range(retries):
217
+ url = info.url + "/health"
218
+ resp = client.get(url, timeout=60)
219
+
117
220
  if resp.is_success:
118
221
  break
119
- elif resp.status_code != 500:
120
- resp.raise_for_status()
222
+ elif resp.status_code not in (500, 404):
223
+ raise AppClientError(
224
+ f"Failed to GET {url}: {resp.status_code} {resp.text}",
225
+ status_code=resp.status_code,
226
+ )
121
227
  time.sleep(0.1)
122
- retries -= 1
123
228
 
124
- yield cls(app_cls, info.url)
229
+ client = cls(app_cls, info.url)
230
+ yield client
125
231
  finally:
126
232
  info.stream.cancel()
233
+ _shutdown_event.set()
234
+ _log_printer.join()
127
235
 
128
236
  def health(self):
129
237
  with httpx.Client() as client:
@@ -140,9 +248,18 @@ def _to_fal_app_name(name: str) -> str:
140
248
  return "-".join(part.lower() for part in PART_FINDER_RE.findall(name))
141
249
 
142
250
 
251
+ def _print_python_packages() -> None:
252
+ from importlib.metadata import distributions
253
+
254
+ packages = [f"{dist.metadata['Name']}=={dist.version}" for dist in distributions()]
255
+
256
+ print("[debug] Python packages installed:", ", ".join(packages))
257
+
258
+
143
259
  class App(fal.api.BaseServable):
144
260
  requirements: ClassVar[list[str]] = []
145
261
  machine_type: ClassVar[str] = "S"
262
+ num_gpus: ClassVar[int | None] = None
146
263
  host_kwargs: ClassVar[dict[str, Any]] = {
147
264
  "_scheduler": "nomad",
148
265
  "_scheduler_options": {
@@ -152,12 +269,20 @@ class App(fal.api.BaseServable):
152
269
  "keep_alive": 60,
153
270
  }
154
271
  app_name: ClassVar[str]
272
+ app_auth: ClassVar[Literal["private", "public", "shared"]] = "private"
273
+ request_timeout: ClassVar[int | None] = None
274
+
275
+ isolate_channel: async_grpc.Channel | None = None
155
276
 
156
277
  def __init_subclass__(cls, **kwargs):
157
278
  app_name = kwargs.pop("name", None) or _to_fal_app_name(cls.__name__)
158
279
  parent_settings = getattr(cls, "host_kwargs", {})
159
280
  cls.host_kwargs = {**parent_settings, **kwargs}
160
- cls.app_name = app_name
281
+
282
+ if cls.request_timeout is not None:
283
+ cls.host_kwargs["request_timeout"] = cls.request_timeout
284
+
285
+ cls.app_name = getattr(cls, "app_name", app_name)
161
286
 
162
287
  if cls.__init__ is not App.__init__:
163
288
  raise ValueError(
@@ -171,6 +296,14 @@ class App(fal.api.BaseServable):
171
296
  "Running apps through SDK is not implemented yet."
172
297
  )
173
298
 
299
+ @classmethod
300
+ def get_endpoints(cls) -> list[str]:
301
+ return [
302
+ signature.path
303
+ for _, endpoint in inspect.getmembers(cls, inspect.isfunction)
304
+ if (signature := getattr(endpoint, "route_signature", None))
305
+ ]
306
+
174
307
  def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
175
308
  return {
176
309
  signature: endpoint
@@ -179,7 +312,8 @@ class App(fal.api.BaseServable):
179
312
  }
180
313
 
181
314
  @asynccontextmanager
182
- async def lifespan(self, app: FastAPI):
315
+ async def lifespan(self, app: fastapi.FastAPI):
316
+ _print_python_packages()
183
317
  await _call_any_fn(self.setup)
184
318
  try:
185
319
  yield
@@ -187,7 +321,7 @@ class App(fal.api.BaseServable):
187
321
  await _call_any_fn(self.teardown)
188
322
 
189
323
  def health(self):
190
- return {}
324
+ return {"version": self.version}
191
325
 
192
326
  def setup(self):
193
327
  """Setup the application before serving."""
@@ -195,7 +329,7 @@ class App(fal.api.BaseServable):
195
329
  def teardown(self):
196
330
  """Teardown the application after serving."""
197
331
 
198
- def _add_extra_middlewares(self, app: FastAPI):
332
+ def _add_extra_middlewares(self, app: fastapi.FastAPI):
199
333
  @app.middleware("http")
200
334
  async def provide_hints_headers(request, call_next):
201
335
  response = await call_next(request)
@@ -216,11 +350,12 @@ class App(fal.api.BaseServable):
216
350
 
217
351
  @app.middleware("http")
218
352
  async def set_global_object_preference(request, call_next):
219
- response = await call_next(request)
220
353
  try:
221
- fal_provider_module.GLOBAL_LIFECYCLE_PREFERENCE = request.headers.get(
222
- "X-Fal-Object-Lifecycle-Preference"
223
- )
354
+ preference_dict = request_lifecycle_preference(request)
355
+ if preference_dict is not None:
356
+ # This will not work properly for apps with multiplexing enabled
357
+ # we may mix up the preferences between requests
358
+ LIFECYCLE_PREFERENCE.set(preference_dict)
224
359
  except Exception:
225
360
  from fastapi.logger import logger
226
361
 
@@ -228,9 +363,65 @@ class App(fal.api.BaseServable):
228
363
  "Failed set a global lifecycle preference %s",
229
364
  self.__class__.__name__,
230
365
  )
231
- return response
232
366
 
233
- def _add_extra_routes(self, app: FastAPI):
367
+ try:
368
+ return await call_next(request)
369
+ finally:
370
+ # We may miss the global preference if there are operations
371
+ # being done in the background that go beyond the request
372
+ LIFECYCLE_PREFERENCE.set(None)
373
+
374
+ @app.middleware("http")
375
+ async def set_request_id(request, call_next):
376
+ # NOTE: Setting request_id is not supported for websocket/realtime endpoints
377
+
378
+ if self.isolate_channel is None:
379
+ grpc_port = os.environ.get("NOMAD_ALLOC_PORT_grpc")
380
+ self.isolate_channel = await open_isolate_channel(
381
+ f"localhost:{grpc_port}"
382
+ )
383
+
384
+ request_id = request.headers.get(REQUEST_ID_KEY)
385
+ if request_id is None:
386
+ # Cut it short
387
+ return await call_next(request)
388
+
389
+ await _set_logger_labels(
390
+ {"fal_request_id": request_id}, channel=self.isolate_channel
391
+ )
392
+
393
+ async def _unset_at_end():
394
+ await _set_logger_labels({}, channel=self.isolate_channel) # type: ignore
395
+
396
+ try:
397
+ response: fastapi.responses.Response = await call_next(request)
398
+ except BaseException:
399
+ await _unset_at_end()
400
+ raise
401
+ else:
402
+ # We need to wait for the entire response to be sent before
403
+ # we can set the logger labels back to the default.
404
+ background_tasks = fastapi.BackgroundTasks()
405
+ background_tasks.add_task(_unset_at_end)
406
+ if response.background:
407
+ # We normally have no background tasks, but we should handle it
408
+ background_tasks.add_task(response.background)
409
+ response.background = background_tasks
410
+
411
+ return response
412
+
413
+ @app.exception_handler(RequestCancelledException)
414
+ async def value_error_exception_handler(
415
+ request, exc: RequestCancelledException
416
+ ):
417
+ from fastapi.responses import JSONResponse
418
+
419
+ # A 499 status code is not an officially recognized HTTP status code,
420
+ # but it is sometimes used by servers to indicate that a client has closed
421
+ # the connection without receiving a response
422
+ return JSONResponse({"detail": str(exc)}, 499)
423
+
424
+ def _add_extra_routes(self, app: fastapi.FastAPI):
234
425
  @app.get("/health")
235
426
  def health():
236
427
  return self.health()
@@ -341,7 +532,10 @@ def _fal_websocket_template(
341
532
  batch.append(next_input)
342
533
 
343
534
  t0 = loop.time()
344
- output = await loop.run_in_executor(None, func, self, *batch) # type: ignore
535
+ if inspect.iscoroutinefunction(func):
536
+ output = await func(self, *batch)
537
+ else:
538
+ output = await loop.run_in_executor(None, func, self, *batch) # type: ignore
345
539
  total_time = loop.time() - t0
346
540
  if not isinstance(output, dict):
347
541
  # Handle pydantic output modal
fal/apps.py CHANGED
@@ -4,15 +4,19 @@ import json
4
4
  import time
5
5
  from contextlib import contextmanager
6
6
  from dataclasses import dataclass, field
7
- from typing import Any, Iterator
7
+ from typing import TYPE_CHECKING, Any, Iterator
8
8
 
9
9
  import httpx
10
10
 
11
11
  from fal import flags
12
12
  from fal.sdk import Credentials, get_default_credentials
13
13
 
14
+ if TYPE_CHECKING:
15
+ from websockets.sync.connection import Connection
16
+
14
17
  _QUEUE_URL_FORMAT = f"https://queue.{flags.FAL_RUN_HOST}/{{app_id}}"
15
18
  _REALTIME_URL_FORMAT = f"wss://{flags.FAL_RUN_HOST}/{{app_id}}"
19
+ _WS_URL_FORMAT = f"wss://ws.{flags.FAL_RUN_HOST}/{{app_id}}"
16
20
 
17
21
 
18
22
  def _backwards_compatible_app_id(app_id: str) -> str:
@@ -97,6 +101,15 @@ class RequestHandle:
97
101
  else:
98
102
  raise ValueError(f"Unknown status: {data['status']}")
99
103
 
104
+ def cancel(self) -> None:
105
+ """Cancel an async inference request."""
106
+ url = (
107
+ _QUEUE_URL_FORMAT.format(app_id=self.app_id)
108
+ + f"/requests/{self.request_id}/cancel"
109
+ )
110
+ response = _HTTP_CLIENT.put(url, headers=self._creds.to_headers())
111
+ response.raise_for_status()
112
+
100
113
  def iter_events(
101
114
  self,
102
115
  *,
@@ -164,7 +177,8 @@ def submit(app_id: str, arguments: dict[str, Any], *, path: str = "") -> Request
164
177
  app_id = _backwards_compatible_app_id(app_id)
165
178
  url = _QUEUE_URL_FORMAT.format(app_id=app_id)
166
179
  if path:
167
- url += "/" + path.removeprefix("/")
180
+ _path = path[len("/") :] if path.startswith("/") else path
181
+ url += "/" + _path
168
182
 
169
183
  creds = get_default_credentials()
170
184
 
@@ -226,7 +240,8 @@ def _connect(app_id: str, *, path: str = "/realtime") -> Iterator[_RealtimeConne
226
240
  app_id = _backwards_compatible_app_id(app_id)
227
241
  url = _REALTIME_URL_FORMAT.format(app_id=app_id)
228
242
  if path:
229
- url += "/" + path.removeprefix("/")
243
+ _path = path[len("/") :] if path.startswith("/") else path
244
+ url += "/" + _path
230
245
 
231
246
  creds = get_default_credentials()
232
247
 
@@ -234,3 +249,132 @@ def _connect(app_id: str, *, path: str = "/realtime") -> Iterator[_RealtimeConne
234
249
  url, additional_headers=creds.to_headers(), open_timeout=90
235
250
  ) as ws:
236
251
  yield _RealtimeConnection(ws)
252
+
253
+
254
+ class _MetaMessageFound(Exception): ...
255
+
256
+
257
+ @dataclass
258
+ class _WSConnection:
259
+ """A WS connection to an HTTP Fal app."""
260
+
261
+ _ws: Connection
262
+ _buffer: str | bytes | None = None
263
+
264
+ def run(self, arguments: dict[str, Any]) -> bytes:
265
+ """Run an inference task on the app and return the result."""
266
+ self.send(arguments)
267
+ return self.recv()
268
+
269
+ def send(self, arguments: dict[str, Any]) -> None:
270
+ import json
271
+
272
+ payload = json.dumps(arguments)
273
+ self._ws.send(payload)
274
+
275
+ def _peek(self) -> bytes | str:
276
+ if self._buffer is None:
277
+ self._buffer = self._ws.recv()
278
+
279
+ return self._buffer
280
+
281
+ def _consume(self) -> None:
282
+ if self._buffer is None:
283
+ raise ValueError("No data to consume")
284
+
285
+ self._buffer = None
286
+
287
+ @contextmanager
288
+ def _recv(self) -> Iterator[str | bytes]:
289
+ res = self._peek()
290
+
291
+ yield res
292
+
293
+ # Only consume if it went through the context manager without raising
294
+ self._consume()
295
+
296
+ def _is_meta(self, res: str | bytes) -> bool:
297
+ if not isinstance(res, str):
298
+ return False
299
+
300
+ try:
301
+ json_payload: Any = json.loads(res)
302
+ except json.JSONDecodeError:
303
+ return False
304
+
305
+ if not isinstance(json_payload, dict):
306
+ return False
307
+
308
+ return "type" in json_payload and "request_id" in json_payload
309
+
310
+ def _recv_meta(self, type: str) -> dict[str, Any]:
311
+ with self._recv() as res:
312
+ if not self._is_meta(res):
313
+ raise ValueError(f"Expected a {type} message")
314
+
315
+ json_payload: dict = json.loads(res)
316
+ if json_payload.get("type") != type:
317
+ raise ValueError(f"Expected a {type} message")
318
+
319
+ return json_payload
320
+
321
+ def _recv_response(self) -> Iterator[str | bytes]:
322
+ while True:
323
+ try:
324
+ with self._recv() as res:
325
+ if self._is_meta(res):
326
+ # Raise so we dont consume the message
327
+ raise _MetaMessageFound()
328
+
329
+ yield res
330
+ except _MetaMessageFound:
331
+ break
332
+
333
+ def recv(self) -> bytes:
334
+ start = self._recv_meta("start")
335
+ request_id = start["request_id"]
336
+
337
+ response = b""
338
+ for part in self._recv_response():
339
+ if isinstance(part, str):
340
+ response += part.encode()
341
+ else:
342
+ response += part
343
+
344
+ end = self._recv_meta("end")
345
+ if end["request_id"] != request_id:
346
+ raise ValueError("Mismatched request_id in end message")
347
+
348
+ return response
349
+
350
+ def stream(self) -> Iterator[str | bytes]:
351
+ start = self._recv_meta("start")
352
+ request_id = start["request_id"]
353
+
354
+ yield from self._recv_response()
355
+
356
+ # Make sure we consume the end message
357
+ end = self._recv_meta("end")
358
+ if end["request_id"] != request_id:
359
+ raise ValueError("Mismatched request_id in end message")
360
+
361
+
362
+ @contextmanager
363
+ def ws(app_id: str, *, path: str = "") -> Iterator[_WSConnection]:
364
+ """Connect to a HTTP endpoint but with websocket protocol. This is an internal and
365
+ experimental API, use it at your own risk."""
366
+
367
+ from websockets.sync import client
368
+
369
+ app_id = _backwards_compatible_app_id(app_id)
370
+ url = _WS_URL_FORMAT.format(app_id=app_id)
371
+ if path:
372
+ _path = path[len("/") :] if path.startswith("/") else path
373
+ url += "/" + _path
374
+
375
+ creds = get_default_credentials()
376
+
377
+ with client.connect(
378
+ url, additional_headers=creds.to_headers(), open_timeout=90
379
+ ) as ws:
380
+ yield _WSConnection(ws)
fal/auth/__init__.py CHANGED
@@ -2,22 +2,70 @@ from __future__ import annotations
2
2
 
3
3
  import os
4
4
  from dataclasses import dataclass, field
5
+ from threading import Lock
6
+ from typing import Optional
5
7
 
6
8
  import click
7
9
 
8
10
  from fal.auth import auth0, local
11
+ from fal.config import Config
9
12
  from fal.console import console
10
13
  from fal.console.icons import CHECK_ICON
11
14
  from fal.exceptions.auth import UnauthenticatedException
12
15
 
13
16
 
17
+ class GoogleColabState:
18
+ def __init__(self):
19
+ self.is_checked = False
20
+ self.lock = Lock()
21
+ self.secret: Optional[str] = None
22
+
23
+
24
+ _colab_state = GoogleColabState()
25
+
26
+
27
+ def is_google_colab() -> bool:
28
+ try:
29
+ from IPython import get_ipython
30
+
31
+ return "google.colab" in str(get_ipython())
32
+ except ModuleNotFoundError:
33
+ return False
34
+ except NameError:
35
+ return False
36
+
37
+
38
+ def get_colab_token() -> Optional[str]:
39
+ if not is_google_colab():
40
+ return None
41
+ with _colab_state.lock:
42
+ if _colab_state.is_checked: # request access only once
43
+ return _colab_state.secret
44
+
45
+ try:
46
+ from google.colab import userdata # noqa: I001
47
+ except ImportError:
48
+ return None
49
+
50
+ try:
51
+ token = userdata.get("FAL_KEY")
52
+ _colab_state.secret = token.strip()
53
+ except Exception:
54
+ _colab_state.secret = None
55
+
56
+ _colab_state.is_checked = True
57
+ return _colab_state.secret
58
+
59
+
14
60
  def key_credentials() -> tuple[str, str] | None:
15
61
  # Ignore key credentials when the user forces auth by user.
16
62
  if os.environ.get("FAL_FORCE_AUTH_BY_USER") == "1":
17
63
  return None
18
64
 
19
- if "FAL_KEY" in os.environ:
20
- key = os.environ["FAL_KEY"]
65
+ config = Config()
66
+
67
+ key = os.environ.get("FAL_KEY") or config.get("key") or get_colab_token()
68
+ if key:
21
69
  key_id, key_secret = key.split(":", 1)
22
70
  return (key_id, key_secret)
23
71
  elif "FAL_KEY_ID" in os.environ and "FAL_KEY_SECRET" in os.environ:
fal/cli/_utils.py ADDED
@@ -0,0 +1,40 @@
1
+ from __future__ import annotations
2
+
3
+ from fal.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
4
+
5
+
6
+ def is_app_name(app_ref: tuple[str, str | None]) -> bool:
7
+ is_single_file = app_ref[1] is None
8
+ is_python_file = app_ref[0].endswith(".py")
9
+
10
+ return is_single_file and not is_python_file
11
+
12
+
13
+ def get_app_data_from_toml(app_name):
14
+ toml_path = find_pyproject_toml()
15
+
16
+ if toml_path is None:
17
+ raise ValueError("No pyproject.toml file found.")
18
+
19
+ fal_data = parse_pyproject_toml(toml_path)
20
+ apps = fal_data.get("apps", {})
21
+
22
+ try:
23
+ app_data = apps[app_name]
24
+ except KeyError:
25
+ raise ValueError(f"App {app_name} not found in pyproject.toml")
26
+
27
+ try:
28
+ app_ref = app_data["ref"]
29
+ except KeyError:
30
+ raise ValueError(f"App {app_name} does not have a ref key in pyproject.toml")
31
+
32
+ # Convert the app_ref to a path relative to the project root
33
+ project_root, _ = find_project_root(None)
34
+ app_ref = str(project_root / app_ref)
35
+
36
+ app_auth = app_data.get("auth", "private")
37
+ app_deployment_strategy = app_data.get("deployment_strategy", "recreate")
38
+ app_no_scale = app_data.get("no_scale", False)
39
+
40
+ return app_ref, app_auth, app_deployment_strategy, app_no_scale