fal 0.12.1__py3-none-any.whl → 0.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

Files changed (116) hide show
  1. fal/__init__.py +12 -3
  2. fal/_serialization.py +18 -0
  3. fal/api.py +140 -59
  4. fal/app.py +309 -86
  5. fal/apps.py +92 -8
  6. fal/auth/__init__.py +20 -1
  7. fal/auth/auth0.py +32 -22
  8. fal/cli.py +34 -52
  9. fal/env.py +0 -4
  10. fal/exceptions/handlers.py +3 -2
  11. fal/flags.py +5 -0
  12. fal/logging/__init__.py +0 -2
  13. fal/logging/trace.py +8 -1
  14. fal/logging/user.py +2 -1
  15. fal/rest_client.py +2 -2
  16. fal/sdk.py +46 -31
  17. fal/sync.py +3 -3
  18. fal/toolkit/__init__.py +18 -1
  19. fal/toolkit/file/file.py +98 -11
  20. fal/toolkit/file/providers/fal.py +43 -2
  21. fal/toolkit/file/types.py +1 -1
  22. fal/toolkit/image/image.py +26 -4
  23. fal/toolkit/optimize.py +50 -0
  24. fal/toolkit/utils/download_utils.py +59 -13
  25. {fal-0.12.1.dist-info → fal-0.12.3.dist-info}/METADATA +7 -7
  26. fal-0.12.3.dist-info/RECORD +66 -0
  27. openapi_fal_rest/models/__init__.py +2 -70
  28. openapi_fal_rest/models/customer_details.py +26 -0
  29. openapi_fal_rest/models/lock_reason.py +16 -0
  30. fal/logging/datadog.py +0 -77
  31. fal-0.12.1.dist-info/RECORD +0 -147
  32. openapi_fal_rest/api/admin/get_invoice_users.py +0 -142
  33. openapi_fal_rest/api/admin/get_usage_per_user.py +0 -199
  34. openapi_fal_rest/api/admin/handle_user_lock.py +0 -191
  35. openapi_fal_rest/api/admin/set_billing_type.py +0 -186
  36. openapi_fal_rest/api/applications/get_status_applications_app_user_id_app_alias_or_id_status_get.py +0 -179
  37. openapi_fal_rest/api/billing/delete_payment_method.py +0 -162
  38. openapi_fal_rest/api/billing/get_checkout_page.py +0 -198
  39. openapi_fal_rest/api/billing/get_setup_intent_key.py +0 -141
  40. openapi_fal_rest/api/billing/get_user_invoices.py +0 -152
  41. openapi_fal_rest/api/billing/get_user_payment_methods.py +0 -152
  42. openapi_fal_rest/api/billing/get_user_price.py +0 -186
  43. openapi_fal_rest/api/billing/get_user_spending.py +0 -192
  44. openapi_fal_rest/api/billing/handle_stripe_webhook.py +0 -173
  45. openapi_fal_rest/api/billing/upcoming_invoice.py +0 -143
  46. openapi_fal_rest/api/billing/update_customer_budget.py +0 -183
  47. openapi_fal_rest/api/files/delete.py +0 -162
  48. openapi_fal_rest/api/files/download.py +0 -162
  49. openapi_fal_rest/api/files/file_exists.py +0 -183
  50. openapi_fal_rest/api/files/list_directory.py +0 -173
  51. openapi_fal_rest/api/files/list_root.py +0 -152
  52. openapi_fal_rest/api/files/upload_from_url.py +0 -179
  53. openapi_fal_rest/api/health/__init__.py +0 -0
  54. openapi_fal_rest/api/health/check.py +0 -136
  55. openapi_fal_rest/api/keys/__init__.py +0 -0
  56. openapi_fal_rest/api/keys/create_key.py +0 -188
  57. openapi_fal_rest/api/keys/delete_key.py +0 -162
  58. openapi_fal_rest/api/keys/list_keys.py +0 -152
  59. openapi_fal_rest/api/logs/__init__.py +0 -0
  60. openapi_fal_rest/api/logs/list_since.py +0 -224
  61. openapi_fal_rest/api/requests/__init__.py +0 -0
  62. openapi_fal_rest/api/requests/requests.py +0 -247
  63. openapi_fal_rest/api/storage/__init__.py +0 -0
  64. openapi_fal_rest/api/storage/get_file_link.py +0 -200
  65. openapi_fal_rest/api/storage/initiate_upload.py +0 -172
  66. openapi_fal_rest/api/storage/upload_file.py +0 -172
  67. openapi_fal_rest/api/tokens/__init__.py +0 -0
  68. openapi_fal_rest/api/tokens/create_token.py +0 -166
  69. openapi_fal_rest/api/usage/__init__.py +0 -0
  70. openapi_fal_rest/api/usage/get_custom_usage_per_machine.py +0 -203
  71. openapi_fal_rest/api/usage/get_gateway_request_stats.py +0 -247
  72. openapi_fal_rest/api/usage/get_gateway_request_stats_by_time.py +0 -236
  73. openapi_fal_rest/api/usage/get_gateway_stats_for_yesterday.py +0 -152
  74. openapi_fal_rest/api/usage/get_shared_usage_per_app.py +0 -203
  75. openapi_fal_rest/api/usage/get_usage_records.py +0 -253
  76. openapi_fal_rest/api/usage/per_machine_usage.py +0 -218
  77. openapi_fal_rest/api/usage/per_machine_usage_details.py +0 -173
  78. openapi_fal_rest/api/users/__init__.py +0 -0
  79. openapi_fal_rest/api/users/handle_user_registration.py +0 -228
  80. openapi_fal_rest/models/billing_type.py +0 -9
  81. openapi_fal_rest/models/body_create_token.py +0 -68
  82. openapi_fal_rest/models/body_upload_file.py +0 -75
  83. openapi_fal_rest/models/file_spec.py +0 -110
  84. openapi_fal_rest/models/gateway_stats_by_time.py +0 -115
  85. openapi_fal_rest/models/gateway_usage_stats.py +0 -147
  86. openapi_fal_rest/models/get_gateway_request_stats_by_time_response_get_gateway_request_stats_by_time.py +0 -70
  87. openapi_fal_rest/models/grouped_usage_detail.py +0 -85
  88. openapi_fal_rest/models/handle_stripe_webhook_response_handle_stripe_webhook.py +0 -43
  89. openapi_fal_rest/models/initiate_upload_info.py +0 -64
  90. openapi_fal_rest/models/invoice.py +0 -129
  91. openapi_fal_rest/models/invoice_item.py +0 -85
  92. openapi_fal_rest/models/key_scope.py +0 -9
  93. openapi_fal_rest/models/log_entry.py +0 -104
  94. openapi_fal_rest/models/log_entry_labels.py +0 -43
  95. openapi_fal_rest/models/new_user_key.py +0 -64
  96. openapi_fal_rest/models/payment_method.py +0 -96
  97. openapi_fal_rest/models/per_app_usage_detail.py +0 -88
  98. openapi_fal_rest/models/persisted_usage_record.py +0 -118
  99. openapi_fal_rest/models/persisted_usage_record_meta.py +0 -43
  100. openapi_fal_rest/models/presigned_upload_url.py +0 -64
  101. openapi_fal_rest/models/request_io.py +0 -112
  102. openapi_fal_rest/models/request_io_json_input.py +0 -43
  103. openapi_fal_rest/models/request_io_json_output.py +0 -43
  104. openapi_fal_rest/models/run_type.py +0 -9
  105. openapi_fal_rest/models/stats_timeframe.py +0 -12
  106. openapi_fal_rest/models/status.py +0 -82
  107. openapi_fal_rest/models/status_health.py +0 -10
  108. openapi_fal_rest/models/uploaded_file_result.py +0 -64
  109. openapi_fal_rest/models/url_file_upload.py +0 -57
  110. openapi_fal_rest/models/usage_per_machine_type.py +0 -115
  111. openapi_fal_rest/models/usage_per_user.py +0 -71
  112. openapi_fal_rest/models/usage_run_detail.py +0 -73
  113. openapi_fal_rest/models/user_key_info.py +0 -84
  114. /openapi_fal_rest/api/admin/__init__.py → /fal/py.typed +0 -0
  115. {fal-0.12.1.dist-info → fal-0.12.3.dist-info}/WHEEL +0 -0
  116. {fal-0.12.1.dist-info → fal-0.12.3.dist-info}/entry_points.txt +0 -0
fal/app.py CHANGED
@@ -1,28 +1,50 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
+ import json
4
5
  import os
5
- import fal.api
6
- from fal.toolkit import mainify
6
+ import typing
7
+ from contextlib import asynccontextmanager
8
+ from typing import Any, Callable, ClassVar, TypeVar
9
+
7
10
  from fastapi import FastAPI
8
- from typing import Any, NamedTuple, Callable, TypeVar, ClassVar
11
+
12
+ import fal.api
13
+ from fal._serialization import add_serialization_listeners_for
14
+ from fal.api import RouteSignature
9
15
  from fal.logging import get_logger
16
+ from fal.toolkit import mainify
17
+
18
+ REALTIME_APP_REQUIREMENTS = ["websockets", "msgpack"]
10
19
 
11
20
  EndpointT = TypeVar("EndpointT", bound=Callable[..., Any])
12
21
  logger = get_logger(__name__)
13
22
 
14
23
 
24
+ async def _call_any_fn(fn, *args, **kwargs):
25
+ if inspect.iscoroutinefunction(fn):
26
+ return await fn(*args, **kwargs)
27
+ else:
28
+ return fn(*args, **kwargs)
29
+
30
+
15
31
  def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
32
+ add_serialization_listeners_for(cls)
33
+
16
34
  def initialize_and_serve():
17
35
  app = cls()
18
36
  app.serve()
19
37
 
38
+ metadata = {}
20
39
  try:
21
40
  app = cls(_allow_init=True)
22
- metadata = app.openapi()
41
+ metadata["openapi"] = app.openapi()
23
42
  except Exception as exc:
24
43
  logger.warning("Failed to build OpenAPI specification for %s", cls.__name__)
25
- metadata = {}
44
+ realtime_app = False
45
+ else:
46
+ routes = app.collect_routes()
47
+ realtime_app = any(route.is_websocket for route in routes)
26
48
 
27
49
  wrapper = fal.api.function(
28
50
  "virtualenv",
@@ -31,27 +53,26 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
31
53
  **cls.host_kwargs,
32
54
  **kwargs,
33
55
  metadata=metadata,
34
- serve=True,
35
- )
36
- return wrapper(initialize_and_serve).on(
37
- serve=False,
38
56
  exposed_port=8080,
57
+ serve=False,
39
58
  )
59
+ fn = wrapper(initialize_and_serve)
60
+ fn.options.add_requirements(fal.api.SERVE_REQUIREMENTS)
61
+ if realtime_app:
62
+ fn.options.add_requirements(REALTIME_APP_REQUIREMENTS)
40
63
 
41
-
42
- @mainify
43
- class RouteSignature(NamedTuple):
44
- path: str
64
+ return fn
45
65
 
46
66
 
47
67
  @mainify
48
- class App:
68
+ class App(fal.api.BaseServable):
49
69
  requirements: ClassVar[list[str]] = []
50
70
  machine_type: ClassVar[str] = "S"
51
71
  host_kwargs: ClassVar[dict[str, Any]] = {}
52
72
 
53
73
  def __init_subclass__(cls, **kwargs):
54
- cls.host_kwargs = kwargs
74
+ parent_settings = getattr(cls, "host_kwargs", {})
75
+ cls.host_kwargs = {**parent_settings, **kwargs}
55
76
 
56
77
  if cls.__init__ is not App.__init__:
57
78
  raise ValueError(
@@ -65,98 +86,300 @@ class App:
65
86
  "Running apps through SDK is not implemented yet."
66
87
  )
67
88
 
68
- def setup(self):
69
- """Setup the application before serving."""
89
+ def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
90
+ return {
91
+ signature: endpoint
92
+ for _, endpoint in inspect.getmembers(self, inspect.ismethod)
93
+ if (signature := getattr(endpoint, "route_signature", None))
94
+ }
70
95
 
71
- def serve(self) -> None:
72
- import uvicorn
96
+ @asynccontextmanager
97
+ async def lifespan(self, app: FastAPI):
98
+ await _call_any_fn(self.setup)
99
+ try:
100
+ yield
101
+ finally:
102
+ await _call_any_fn(self.teardown)
73
103
 
74
- app = self._build_app()
75
- self.setup()
76
- uvicorn.run(app, host="0.0.0.0", port=8080)
104
+ def setup(self):
105
+ """Setup the application before serving."""
77
106
 
78
- def _build_app(self) -> FastAPI:
79
- from fastapi import FastAPI
80
- from fastapi.middleware.cors import CORSMiddleware
107
+ def teardown(self):
108
+ """Teardown the application after serving."""
109
+
110
+ def _add_extra_middlewares(self, app: FastAPI):
111
+ @app.middleware("http")
112
+ async def provide_hints_headers(request, call_next):
113
+ response = await call_next(request)
114
+ try:
115
+ response.headers["X-Fal-Runner-Hints"] = ",".join(self.provide_hints())
116
+ except NotImplementedError:
117
+ # This lets us differentiate between apps that don't provide hints
118
+ # and apps that provide empty hints.
119
+ pass
120
+ except Exception:
121
+ from fastapi.logger import logger
122
+
123
+ logger.exception(
124
+ "Failed to provide hints for %s",
125
+ self.__class__.__name__,
126
+ )
127
+ return response
128
+
129
+ def provide_hints(self) -> list[str]:
130
+ """Provide hints for routing the application."""
131
+ raise NotImplementedError
81
132
 
82
- _app = FastAPI()
83
133
 
84
- _app.add_middleware(
85
- CORSMiddleware,
86
- allow_credentials=True,
87
- allow_headers=("*"),
88
- allow_methods=("*"),
89
- allow_origins=("*"),
90
- )
134
+ @mainify
135
+ def endpoint(
136
+ path: str, *, is_websocket: bool = False
137
+ ) -> Callable[[EndpointT], EndpointT]:
138
+ """Designate the decorated function as an application endpoint."""
91
139
 
92
- routes: dict[RouteSignature, Callable[..., Any]] = {
93
- signature: endpoint
94
- for _, endpoint in inspect.getmembers(self, inspect.ismethod)
95
- if (signature := getattr(endpoint, "route_signature", None))
96
- }
97
- if not routes:
98
- raise ValueError("An application must have at least one route!")
99
-
100
- for signature, endpoint in routes.items():
101
- _app.add_api_route(
102
- signature.path,
103
- endpoint,
104
- name=endpoint.__name__,
105
- methods=["POST"],
140
+ def marker_fn(callable: EndpointT) -> EndpointT:
141
+ if hasattr(callable, "route_signature"):
142
+ raise ValueError(
143
+ f"Can't set multiple routes for the same function: {callable.__name__}"
106
144
  )
107
145
 
108
- return _app
146
+ callable.route_signature = RouteSignature(path=path, is_websocket=is_websocket) # type: ignore
147
+ return callable
109
148
 
110
- def openapi(self) -> dict[str, Any]:
111
- """
112
- Build the OpenAPI specification for the served function.
113
- Attach needed metadata for a better integration to fal.
114
- """
115
- app = self._build_app()
116
- spec = app.openapi()
117
- self._mark_order_openapi(spec)
118
- return spec
149
+ return marker_fn
119
150
 
120
- def _mark_order_openapi(self, spec: dict[str, Any]):
121
- """
122
- Add x-fal-order-* keys to the OpenAPI specification to help the rendering of UI.
123
151
 
124
- NOTE: We rely on the fact that fastapi and Python dicts keep the order of properties.
125
- """
152
+ def _fal_websocket_template(
153
+ func: EndpointT,
154
+ route_signature: RouteSignature,
155
+ ) -> EndpointT:
156
+ # A template for fal's realtime websocket endpoints to basically
157
+ # be a boilerplate for the user to fill in their inference function
158
+ # and start using it.
159
+
160
+ import asyncio
161
+ from collections import deque
162
+ from contextlib import suppress
163
+
164
+ import msgpack
165
+ from fastapi import WebSocket, WebSocketDisconnect
166
+
167
+ async def mirror_input(queue: deque[Any], websocket: WebSocket) -> None:
168
+ while True:
169
+ try:
170
+ raw_input = await asyncio.wait_for(
171
+ websocket.receive_bytes(),
172
+ timeout=route_signature.session_timeout,
173
+ )
174
+ except asyncio.TimeoutError:
175
+ return
176
+
177
+ input = msgpack.unpackb(raw_input, raw=False)
178
+ if route_signature.input_modal:
179
+ input = route_signature.input_modal(**input)
180
+
181
+ queue.append(input)
182
+
183
+ async def mirror_output(
184
+ self,
185
+ queue: deque[Any],
186
+ websocket: WebSocket,
187
+ ) -> None:
188
+ loop = asyncio.get_event_loop()
189
+ max_allowed_buffering = route_signature.buffering or 1
190
+ outgoing_messages: asyncio.Queue[bytes] = asyncio.Queue(
191
+ maxsize=max_allowed_buffering * 2 # x2 for outgoing timings
192
+ )
126
193
 
127
- def mark_order(obj: dict[str, Any], key: str):
128
- obj[f"x-fal-order-{key}"] = list(obj[key].keys())
194
+ async def emit(message):
195
+ if isinstance(message, bytes):
196
+ await websocket.send_bytes(message)
197
+ elif isinstance(message, str):
198
+ await websocket.send_text(message)
199
+ else:
200
+ raise TypeError(f"Can't send message of type {type(message)}")
201
+
202
+ async def background_emitter():
203
+ while True:
204
+ output = await outgoing_messages.get()
205
+ await emit(output)
206
+ outgoing_messages.task_done()
207
+
208
+ emitter = asyncio.create_task(background_emitter())
209
+
210
+ while True:
211
+ if not queue:
212
+ await asyncio.sleep(0.05)
213
+ continue
214
+
215
+ input = queue.popleft()
216
+ if input is None or emitter.done():
217
+ if not emitter.done():
218
+ await outgoing_messages.join()
219
+ emitter.cancel()
220
+
221
+ with suppress(asyncio.CancelledError):
222
+ await emitter
223
+ return None # End of input
224
+
225
+ batch = [input]
226
+ while queue and len(batch) < route_signature.max_batch_size:
227
+ next_input = queue.popleft()
228
+ if hasattr(input, "can_batch") and not input.can_batch(
229
+ next_input, len(batch)
230
+ ):
231
+ queue.appendleft(next_input)
232
+ break
233
+ batch.append(next_input)
234
+
235
+ t0 = loop.time()
236
+ output = await loop.run_in_executor(None, func, self, *batch) # type: ignore
237
+ total_time = loop.time() - t0
238
+ if not isinstance(output, dict):
239
+ # Handle pydantic output modal
240
+ if hasattr(output, "dict"):
241
+ output = output.dict()
242
+ else:
243
+ raise TypeError(
244
+ f"Expected a dict or pydantic model as output, got {type(output)}"
245
+ )
246
+
247
+ messages = [
248
+ msgpack.packb(output, use_bin_type=True),
249
+ ]
250
+ if route_signature.emit_timings:
251
+ # We emit x-fal messages in JSON, no matter what the input/output format is.
252
+ timings = {
253
+ "type": "x-fal-message",
254
+ "action": "timings",
255
+ "timing": total_time,
256
+ }
257
+ messages.append(json.dumps(timings, separators=(",", ":")))
258
+
259
+ for message in messages:
260
+ try:
261
+ outgoing_messages.put_nowait(message)
262
+ except asyncio.QueueFull:
263
+ await emit(message)
264
+
265
+ async def websocket_template(self, websocket: WebSocket) -> None:
266
+ import asyncio
267
+
268
+ await websocket.accept()
269
+
270
+ queue: deque[Any] = deque(maxlen=route_signature.buffering)
271
+ input_task = asyncio.create_task(mirror_input(queue, websocket))
272
+ input_task.add_done_callback(lambda _: queue.append(None))
273
+ output_task = asyncio.create_task(mirror_output(self, queue, websocket))
274
+
275
+ try:
276
+ await asyncio.wait(
277
+ {
278
+ input_task,
279
+ output_task,
280
+ },
281
+ return_when=asyncio.FIRST_COMPLETED,
282
+ )
283
+ if input_task.done():
284
+ # User didn't send any input within the timeout
285
+ # so we can just close the connection after the
286
+ # processing of the last input is done.
287
+ input_task.result()
288
+ await asyncio.wait_for(
289
+ output_task, timeout=route_signature.session_timeout
290
+ )
291
+ else:
292
+ assert output_task.done()
293
+
294
+ # The execution of the inference function failed or exitted,
295
+ # so just propagate the result.
296
+ input_task.cancel()
297
+ with suppress(asyncio.CancelledError):
298
+ await input_task
299
+
300
+ output_task.result()
301
+ except WebSocketDisconnect:
302
+ input_task.cancel()
303
+ output_task.cancel()
304
+ with suppress(asyncio.CancelledError):
305
+ await input_task
306
+
307
+ with suppress(asyncio.CancelledError):
308
+ await output_task
309
+ except Exception as exc:
310
+ import traceback
311
+
312
+ traceback.print_exc()
313
+
314
+ await websocket.send_json(
315
+ {
316
+ "type": "x-fal-error",
317
+ "error": "INTERNAL_ERROR",
318
+ "reason": str(exc),
319
+ }
320
+ )
321
+ else:
322
+ await websocket.send_json(
323
+ {
324
+ "type": "x-fal-error",
325
+ "error": "TIMEOUT",
326
+ "reason": "no inputs, reconnect when needed!",
327
+ }
328
+ )
129
329
 
130
- mark_order(spec, "paths")
330
+ await websocket.close()
131
331
 
132
- def order_schema_object(schema: dict[str, Any]):
133
- """
134
- Mark the order of properties in the schema object.
135
- They can have 'allOf', 'properties' or '$ref' key.
136
- """
137
- if "allOf" in schema:
138
- for sub_schema in schema["allOf"]:
139
- order_schema_object(sub_schema)
140
- if "properties" in schema:
141
- mark_order(schema, "properties")
332
+ # Seems like templating + stringified annotations don't play well,
333
+ # so we have to set them manually.
334
+ websocket_template.__annotations__ = {
335
+ "websocket": WebSocket,
336
+ "return": None,
337
+ }
338
+ websocket_template.route_signature = route_signature # type: ignore
339
+ websocket_template.original_func = func # type: ignore
340
+ return typing.cast(EndpointT, websocket_template)
142
341
 
143
- for key in spec["components"].get("schemas") or {}:
144
- order_schema_object(spec["components"]["schemas"][key])
145
342
 
146
- return spec
343
+ _SENTINEL = object()
147
344
 
148
345
 
149
346
  @mainify
150
- def endpoint(path: str) -> Callable[[EndpointT], EndpointT]:
151
- """Designate the decorated function as an application endpoint."""
152
-
153
- def marker_fn(callable: EndpointT) -> EndpointT:
154
- if hasattr(callable, "route_signature"):
347
+ def realtime(
348
+ path: str,
349
+ *,
350
+ buffering: int | None = None,
351
+ session_timeout: float | None = None,
352
+ input_modal: Any | None = _SENTINEL,
353
+ max_batch_size: int = 1,
354
+ ) -> Callable[[EndpointT], EndpointT]:
355
+ """Designate the decorated function as a realtime application endpoint."""
356
+
357
+ def marker_fn(original_func: EndpointT) -> EndpointT:
358
+ nonlocal input_modal
359
+
360
+ if hasattr(original_func, "route_signature"):
155
361
  raise ValueError(
156
- f"Can't set multiple routes for the same function: {callable.__name__}"
362
+ f"Can't set multiple routes for the same function: {original_func.__name__}"
157
363
  )
158
364
 
159
- callable.route_signature = RouteSignature(path=path) # type: ignore
160
- return callable
365
+ if input_modal is _SENTINEL:
366
+ type_hints = typing.get_type_hints(original_func)
367
+ if len(type_hints) >= 1:
368
+ input_modal = type_hints[list(type_hints.keys())[0]]
369
+ else:
370
+ input_modal = None
371
+
372
+ route_signature = RouteSignature(
373
+ path=path,
374
+ is_websocket=True,
375
+ input_modal=input_modal,
376
+ buffering=buffering,
377
+ session_timeout=session_timeout,
378
+ max_batch_size=max_batch_size,
379
+ )
380
+ return _fal_websocket_template(
381
+ original_func,
382
+ route_signature,
383
+ )
161
384
 
162
385
  return marker_fn
fal/apps.py CHANGED
@@ -1,14 +1,26 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import json
3
4
  import time
5
+ from contextlib import contextmanager
4
6
  from dataclasses import dataclass, field
5
7
  from typing import Any, Iterator
6
8
 
7
9
  import httpx
10
+
8
11
  from fal import flags
9
12
  from fal.sdk import Credentials, get_default_credentials
10
13
 
11
- _URL_FORMAT = f"https://{{app_id}}.{flags.GATEWAY_HOST}/fal/queue"
14
+ _QUEUE_URL_FORMAT = f"https://queue.{flags.FAL_RUN_HOST}/{{app_id}}"
15
+ _REALTIME_URL_FORMAT = f"wss://{flags.FAL_RUN_HOST}/{{app_id}}"
16
+
17
+
18
+ def _backwards_compatible_app_id(app_id: str) -> str:
19
+ if "/" not in app_id:
20
+ # Convert the app_id to the format used in the URL.
21
+ return app_id.replace("-", "/", 1)
22
+
23
+ return app_id
12
24
 
13
25
 
14
26
  @dataclass
@@ -50,11 +62,17 @@ class RequestHandle:
50
62
  # Use the credentials that were used to submit the request by default.
51
63
  _creds: Credentials = field(default_factory=get_default_credentials, repr=False)
52
64
 
65
+ def __post_init__(self):
66
+ app_id = _backwards_compatible_app_id(self.app_id)
67
+ # drop any extra path components
68
+ user_id, app_name = app_id.split("/")[:2]
69
+ self.app_id = f"{user_id}/{app_name}"
70
+
53
71
  def status(self, *, logs: bool = False) -> _Status:
54
72
  """Check the status of an async inference request."""
55
73
 
56
74
  url = (
57
- _URL_FORMAT.format(app_id=self.app_id)
75
+ _QUEUE_URL_FORMAT.format(app_id=self.app_id)
58
76
  + f"/requests/{self.request_id}/status/"
59
77
  )
60
78
  response = _HTTP_CLIENT.get(
@@ -97,11 +115,20 @@ class RequestHandle:
97
115
  """Retrieve the result of an async inference request, raises an exception
98
116
  if the request is not completed yet."""
99
117
  url = (
100
- _URL_FORMAT.format(app_id=self.app_id)
101
- + f"/requests/{self.request_id}/response/"
118
+ _QUEUE_URL_FORMAT.format(app_id=self.app_id)
119
+ + f"/requests/{self.request_id}/"
102
120
  )
103
121
  response = _HTTP_CLIENT.get(url, headers=self._creds.to_headers())
104
- response.raise_for_status()
122
+ try:
123
+ response.raise_for_status()
124
+ except httpx.HTTPStatusError as e:
125
+ if response.headers["Content-Type"] != "application/json":
126
+ raise
127
+ raise httpx.HTTPStatusError(
128
+ f"{response.status_code}: {response.text}",
129
+ request=e.request,
130
+ response=e.response,
131
+ ) from e
105
132
 
106
133
  data = response.json()
107
134
  return data
@@ -119,19 +146,23 @@ class RequestHandle:
119
146
  _HTTP_CLIENT = httpx.Client(headers={"User-Agent": "Fal/Python"})
120
147
 
121
148
 
122
- def run(app_id: str, arguments: dict[str, Any], *, path: str = "/") -> dict[str, Any]:
149
+ def run(app_id: str, arguments: dict[str, Any], *, path: str = "") -> dict[str, Any]:
123
150
  """Run an inference task on a Fal app and return the result."""
124
151
 
125
152
  handle = submit(app_id, arguments, path=path)
126
153
  return handle.get()
127
154
 
128
155
 
129
- def submit(app_id: str, arguments: dict[str, Any], *, path: str = "/") -> RequestHandle:
156
+ def submit(app_id: str, arguments: dict[str, Any], *, path: str = "") -> RequestHandle:
130
157
  """Submit an async inference task to the app. Returns a request handle
131
158
  which can be used to check the status of the request and retrieve the
132
159
  result."""
133
160
 
134
- url = _URL_FORMAT.format(app_id=app_id) + "/submit" + path
161
+ app_id = _backwards_compatible_app_id(app_id)
162
+ url = _QUEUE_URL_FORMAT.format(app_id=app_id)
163
+ if path:
164
+ url += "/" + path.removeprefix("/")
165
+
135
166
  creds = get_default_credentials()
136
167
 
137
168
  response = _HTTP_CLIENT.post(
@@ -147,3 +178,56 @@ def submit(app_id: str, arguments: dict[str, Any], *, path: str = "/") -> Reques
147
178
  request_id=data["request_id"],
148
179
  _creds=creds,
149
180
  )
181
+
182
+
183
+ @dataclass
184
+ class _RealtimeConnection:
185
+ """A realtime connection to a Fal app."""
186
+
187
+ _ws: Any
188
+
189
+ def run(self, arguments: dict[str, Any]) -> dict[str, Any]:
190
+ """Run an inference task on the app and return the result."""
191
+ self.send(arguments)
192
+ return self.recv()
193
+
194
+ def send(self, arguments: dict[str, Any]) -> None:
195
+ import msgpack
196
+
197
+ """Send an inference task to the app."""
198
+ payload = msgpack.packb(arguments)
199
+ self._ws.send(payload)
200
+
201
+ def recv(self) -> dict[str, Any]:
202
+ import msgpack
203
+
204
+ """Receive the result of an inference task."""
205
+ while True:
206
+ response = self._ws.recv()
207
+ if isinstance(response, str):
208
+ print(response)
209
+ json_payload = json.loads(response)
210
+ if json_payload.get("type") == "x-fal-error":
211
+ raise ValueError(json_payload["reason"])
212
+ continue
213
+ return msgpack.unpackb(response)
214
+
215
+
216
+ @contextmanager
217
+ def _connect(app_id: str, *, path: str = "/realtime") -> Iterator[_RealtimeConnection]:
218
+ """Connect to a realtime endpoint. This is an internal and experimental API, use it
219
+ at your own risk."""
220
+
221
+ from websockets.sync import client
222
+
223
+ app_id = _backwards_compatible_app_id(app_id)
224
+ url = _REALTIME_URL_FORMAT.format(app_id=app_id)
225
+ if path:
226
+ url += "/" + path.removeprefix("/")
227
+
228
+ creds = get_default_credentials()
229
+
230
+ with client.connect(
231
+ url, additional_headers=creds.to_headers(), open_timeout=90
232
+ ) as ws:
233
+ yield _RealtimeConnection(ws)
fal/auth/__init__.py CHANGED
@@ -1,12 +1,31 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import os
3
4
  from dataclasses import dataclass, field
4
5
 
5
6
  import click
7
+
6
8
  from fal.auth import auth0, local
7
9
  from fal.console import console
8
10
  from fal.console.icons import CHECK_ICON
9
11
  from fal.exceptions.auth import UnauthenticatedException
12
+ from fal.toolkit.mainify import mainify
13
+
14
+
15
+ @mainify
16
+ def key_credentials() -> tuple[str, str] | None:
17
+ # Ignore key credentials when the user forces auth by user.
18
+ if os.environ.get("FAL_FORCE_AUTH_BY_USER") == "1":
19
+ return None
20
+
21
+ if "FAL_KEY" in os.environ:
22
+ key = os.environ["FAL_KEY"]
23
+ key_id, key_secret = key.split(":", 1)
24
+ return (key_id, key_secret)
25
+ elif "FAL_KEY_ID" in os.environ and "FAL_KEY_SECRET" in os.environ:
26
+ return (os.environ["FAL_KEY_ID"], os.environ["FAL_KEY_SECRET"])
27
+ else:
28
+ return None
10
29
 
11
30
 
12
31
  def login():
@@ -43,7 +62,7 @@ def _fetch_access_token() -> str:
43
62
 
44
63
  if access_token is not None:
45
64
  try:
46
- auth0.validate_access_token(access_token)
65
+ auth0.verify_access_token_expiration(access_token)
47
66
  return access_token
48
67
  except:
49
68
  # access_token expired, will refresh