fal 0.12.2__py3-none-any.whl → 0.12.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

Files changed (46) hide show
  1. fal/__init__.py +11 -2
  2. fal/api.py +130 -50
  3. fal/app.py +81 -134
  4. fal/apps.py +24 -6
  5. fal/auth/__init__.py +14 -2
  6. fal/auth/auth0.py +34 -25
  7. fal/cli.py +9 -4
  8. fal/env.py +0 -4
  9. fal/flags.py +1 -0
  10. fal/logging/__init__.py +0 -2
  11. fal/logging/trace.py +8 -1
  12. fal/sdk.py +33 -6
  13. fal/toolkit/__init__.py +16 -0
  14. fal/workflows.py +481 -0
  15. {fal-0.12.2.dist-info → fal-0.12.4.dist-info}/METADATA +4 -7
  16. fal-0.12.4.dist-info/RECORD +88 -0
  17. openapi_fal_rest/__init__.py +1 -0
  18. openapi_fal_rest/api/workflows/__init__.py +0 -0
  19. openapi_fal_rest/api/workflows/create_or_update_workflow_workflows_post.py +172 -0
  20. openapi_fal_rest/api/workflows/delete_workflow_workflows_user_id_workflow_name_delete.py +175 -0
  21. openapi_fal_rest/api/workflows/execute_workflow_workflows_user_id_workflow_name_post.py +268 -0
  22. openapi_fal_rest/api/workflows/get_workflow_workflows_user_id_workflow_name_get.py +181 -0
  23. openapi_fal_rest/api/workflows/get_workflows_workflows_get.py +189 -0
  24. openapi_fal_rest/models/__init__.py +34 -0
  25. openapi_fal_rest/models/app_metadata_response_app_metadata.py +1 -0
  26. openapi_fal_rest/models/customer_details.py +15 -14
  27. openapi_fal_rest/models/execute_workflow_workflows_user_id_workflow_name_post_json_body_type_0.py +44 -0
  28. openapi_fal_rest/models/execute_workflow_workflows_user_id_workflow_name_post_response_200_type_0.py +44 -0
  29. openapi_fal_rest/models/page_workflow_item.py +107 -0
  30. openapi_fal_rest/models/typed_workflow.py +85 -0
  31. openapi_fal_rest/models/workflow_contents.py +98 -0
  32. openapi_fal_rest/models/workflow_contents_nodes.py +59 -0
  33. openapi_fal_rest/models/workflow_contents_output.py +44 -0
  34. openapi_fal_rest/models/workflow_detail.py +149 -0
  35. openapi_fal_rest/models/workflow_detail_contents_type_0.py +44 -0
  36. openapi_fal_rest/models/workflow_item.py +80 -0
  37. openapi_fal_rest/models/workflow_node.py +74 -0
  38. openapi_fal_rest/models/workflow_node_type.py +9 -0
  39. openapi_fal_rest/models/workflow_schema.py +73 -0
  40. openapi_fal_rest/models/workflow_schema_input.py +44 -0
  41. openapi_fal_rest/models/workflow_schema_output.py +44 -0
  42. openapi_fal_rest/types.py +1 -0
  43. fal/logging/datadog.py +0 -78
  44. fal-0.12.2.dist-info/RECORD +0 -67
  45. {fal-0.12.2.dist-info → fal-0.12.4.dist-info}/WHEEL +0 -0
  46. {fal-0.12.2.dist-info → fal-0.12.4.dist-info}/entry_points.txt +0 -0
fal/__init__.py CHANGED
@@ -1,8 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from fal import apps
4
-
5
- # TODO: DEPRECATED - use function instead
6
4
  from fal.api import FalServerlessHost, LocalHost, cached
7
5
  from fal.api import function
8
6
  from fal.api import function as isolated
@@ -16,6 +14,17 @@ serverless = FalServerlessHost()
16
14
  # DEPRECATED - use serverless instead
17
15
  cloud = FalServerlessHost()
18
16
 
17
+ __all__ = [
18
+ "function",
19
+ "cached",
20
+ "App",
21
+ "endpoint",
22
+ "realtime",
23
+ # "wrap_app",
24
+ "FalServerlessKeyCredentials",
25
+ "sync_dir",
26
+ ]
27
+
19
28
 
20
29
  # NOTE: This makes `import fal.dbt` import the `dbt-fal` module and `import fal` import the `fal` module
21
30
  # NOTE: taken from dbt-core: https://github.com/dbt-labs/dbt-core/blob/ac539fd5cf325cfb5315339077d03399d575f570/core/dbt/adapters/__init__.py#L1-L7
fal/api.py CHANGED
@@ -4,7 +4,7 @@ import inspect
4
4
  import sys
5
5
  from collections import defaultdict
6
6
  from concurrent.futures import ThreadPoolExecutor
7
- from contextlib import suppress
7
+ from contextlib import asynccontextmanager, suppress
8
8
  from dataclasses import dataclass, field, replace
9
9
  from functools import partial, wraps
10
10
  from os import PathLike
@@ -16,6 +16,7 @@ from typing import (
16
16
  Generic,
17
17
  Iterator,
18
18
  Literal,
19
+ NamedTuple,
19
20
  TypeVar,
20
21
  cast,
21
22
  overload,
@@ -26,6 +27,7 @@ import dill.detect
26
27
  import grpc
27
28
  import isolate
28
29
  import yaml
30
+ from fastapi import FastAPI
29
31
  from isolate.backends.common import active_python
30
32
  from isolate.backends.settings import DEFAULT_SETTINGS
31
33
  from isolate.connections import PythonIPC
@@ -56,6 +58,8 @@ ReturnT = TypeVar("ReturnT", covariant=True)
56
58
  BasicConfig = Dict[str, Any]
57
59
  _UNSET = object()
58
60
 
61
+ SERVE_REQUIREMENTS = ["fastapi==0.99.1", "uvicorn"]
62
+
59
63
 
60
64
  @dataclass
61
65
  class FalServerlessError(Exception):
@@ -110,7 +114,7 @@ class Host(Generic[ArgsT, ReturnT]):
110
114
  options.environment[key] = value
111
115
 
112
116
  if options.gateway.get("serve"):
113
- options.add_requirements(["fastapi==0.99.1", "uvicorn"])
117
+ options.add_requirements(SERVE_REQUIREMENTS)
114
118
 
115
119
  return options
116
120
 
@@ -730,53 +734,17 @@ def function( # type: ignore
730
734
 
731
735
 
732
736
  @mainify
733
- class ServeWrapper:
734
- _func: Callable
735
-
736
- def __init__(self, func: Callable):
737
- self._func = func
738
-
739
- def build_app(self):
740
- from fastapi import FastAPI
741
- from fastapi.middleware.cors import CORSMiddleware
742
-
743
- _app = FastAPI()
744
-
745
- _app.add_middleware(
746
- CORSMiddleware,
747
- allow_credentials=True,
748
- allow_headers=("*"),
749
- allow_methods=("*"),
750
- allow_origins=("*"),
751
- )
752
-
753
- _app.add_api_route(
754
- "/",
755
- self._func, # type: ignore
756
- name=self._func.__name__,
757
- methods=["POST"],
758
- )
759
-
760
- return _app
761
-
762
- def __call__(self, *args, **kwargs) -> None:
763
- if len(args) != 0 or len(kwargs) != 0:
764
- print(
765
- f"[warning] {self._func.__name__} function is served with no arguments."
766
- )
767
-
768
- from uvicorn import run
769
-
770
- app = self.build_app()
771
- run(app, host="0.0.0.0", port=8080)
737
+ class FalFastAPI(FastAPI):
738
+ """
739
+ A subclass of FastAPI that adds some fal-specific functionality.
740
+ """
772
741
 
773
742
  def openapi(self) -> dict[str, Any]:
774
743
  """
775
744
  Build the OpenAPI specification for the served function.
776
745
  Attach needed metadata for a better integration to fal.
777
746
  """
778
- app = self.build_app()
779
- spec = app.openapi()
747
+ spec = super().openapi()
780
748
  self._mark_order_openapi(spec)
781
749
  return spec
782
750
 
@@ -788,7 +756,8 @@ class ServeWrapper:
788
756
  """
789
757
 
790
758
  def mark_order(obj: dict[str, Any], key: str):
791
- obj[f"x-fal-order-{key}"] = list(obj[key].keys())
759
+ if key in obj:
760
+ obj[f"x-fal-order-{key}"] = list(obj[key].keys())
792
761
 
793
762
  mark_order(spec, "paths")
794
763
 
@@ -797,18 +766,129 @@ class ServeWrapper:
797
766
  Mark the order of properties in the schema object.
798
767
  They can have 'allOf', 'properties' or '$ref' key.
799
768
  """
800
- if "allOf" in schema:
801
- for sub_schema in schema["allOf"]:
802
- order_schema_object(sub_schema)
803
- if "properties" in schema:
804
- mark_order(schema, "properties")
769
+ for sub_schema in schema.get("allOf", []):
770
+ order_schema_object(sub_schema)
771
+
772
+ mark_order(schema, "properties")
805
773
 
806
- for key in spec.get("components", {}).get("schemas") or {}:
774
+ for key in spec.get("components", {}).get("schemas", {}):
807
775
  order_schema_object(spec["components"]["schemas"][key])
808
776
 
809
777
  return spec
810
778
 
811
779
 
780
+ @mainify
781
+ class RouteSignature(NamedTuple):
782
+ path: str
783
+ is_websocket: bool = False
784
+ input_modal: type | None = None
785
+ buffering: int | None = None
786
+ session_timeout: float | None = None
787
+ max_batch_size: int = 1
788
+ emit_timings: bool = False
789
+
790
+
791
+ @mainify
792
+ class BaseServable:
793
+ def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
794
+ raise NotImplementedError
795
+
796
+ def _add_extra_middlewares(self, app: FastAPI):
797
+ """
798
+ For subclasses to add extra middlewares to the app.
799
+ """
800
+ pass
801
+
802
+ @asynccontextmanager
803
+ async def lifespan(self, app: FastAPI):
804
+ yield
805
+
806
+ def _build_app(self) -> FastAPI:
807
+ from fastapi import HTTPException, Request
808
+ from fastapi.middleware.cors import CORSMiddleware
809
+ from fastapi.responses import JSONResponse
810
+
811
+ _app = FalFastAPI(lifespan=self.lifespan)
812
+
813
+ _app.add_middleware(
814
+ CORSMiddleware,
815
+ allow_credentials=True,
816
+ allow_headers=("*"),
817
+ allow_methods=("*"),
818
+ allow_origins=("*"),
819
+ )
820
+
821
+ self._add_extra_middlewares(_app)
822
+
823
+ @_app.exception_handler(404)
824
+ async def not_found_exception_handler(request: Request, exc: HTTPException):
825
+ # Rewrite the message to include the path that was not found.
826
+ # This is supposed to make it easier to understand to the user
827
+ # that the error comes from the app and not our platform.
828
+ if exc.detail == "Not Found":
829
+ return JSONResponse(
830
+ {"detail": f"Path {request.url.path} not found"}, 404
831
+ )
832
+ else:
833
+ # If it's not a generic 404, just return the original message.
834
+ return JSONResponse({"detail": exc.detail}, 404)
835
+
836
+ routes = self.collect_routes()
837
+ if not routes:
838
+ raise ValueError("An application must have at least one route!")
839
+
840
+ for signature, endpoint in routes.items():
841
+ if signature.is_websocket:
842
+ _app.add_api_websocket_route(
843
+ signature.path,
844
+ endpoint,
845
+ name=endpoint.__name__,
846
+ )
847
+ else:
848
+ _app.add_api_route(
849
+ signature.path,
850
+ endpoint,
851
+ name=endpoint.__name__,
852
+ methods=["POST"],
853
+ )
854
+
855
+ return _app
856
+
857
+ def openapi(self) -> dict[str, Any]:
858
+ """
859
+ Build the OpenAPI specification for the served function.
860
+ Attach needed metadata for a better integration to fal.
861
+ """
862
+ return self._build_app().openapi()
863
+
864
+ def serve(self) -> None:
865
+ import uvicorn
866
+
867
+ app = self._build_app()
868
+ uvicorn.run(app, host="0.0.0.0", port=8080)
869
+
870
+
871
+ @mainify
872
+ class ServeWrapper(BaseServable):
873
+ _func: Callable
874
+
875
+ def __init__(self, func: Callable):
876
+ self._func = func
877
+
878
+ def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
879
+ return {
880
+ RouteSignature("/"): self._func,
881
+ }
882
+
883
+ def __call__(self, *args, **kwargs) -> None:
884
+ if len(args) != 0 or len(kwargs) != 0:
885
+ print(
886
+ f"[warning] {self._func.__name__} function is served with no arguments."
887
+ )
888
+
889
+ self.serve()
890
+
891
+
812
892
  @dataclass
813
893
  class IsolatedFunction(Generic[ArgsT, ReturnT]):
814
894
  host: Host[ArgsT, ReturnT]
fal/app.py CHANGED
@@ -1,15 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
+ import json
4
5
  import os
5
6
  import typing
6
7
  from contextlib import asynccontextmanager
7
- from typing import Any, Callable, ClassVar, NamedTuple, TypeVar
8
+ from typing import Any, Callable, ClassVar, TypeVar
8
9
 
9
10
  from fastapi import FastAPI
10
11
 
11
12
  import fal.api
12
13
  from fal._serialization import add_serialization_listeners_for
14
+ from fal.api import RouteSignature
13
15
  from fal.logging import get_logger
14
16
  from fal.toolkit import mainify
15
17
 
@@ -19,6 +21,13 @@ EndpointT = TypeVar("EndpointT", bound=Callable[..., Any])
19
21
  logger = get_logger(__name__)
20
22
 
21
23
 
24
+ async def _call_any_fn(fn, *args, **kwargs):
25
+ if inspect.iscoroutinefunction(fn):
26
+ return await fn(*args, **kwargs)
27
+ else:
28
+ return fn(*args, **kwargs)
29
+
30
+
22
31
  def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
23
32
  add_serialization_listeners_for(cls)
24
33
 
@@ -44,25 +53,19 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
44
53
  **cls.host_kwargs,
45
54
  **kwargs,
46
55
  metadata=metadata,
47
- serve=True,
56
+ exposed_port=8080,
57
+ serve=False,
48
58
  )
49
59
  fn = wrapper(initialize_and_serve)
60
+ fn.options.add_requirements(fal.api.SERVE_REQUIREMENTS)
50
61
  if realtime_app:
51
62
  fn.options.add_requirements(REALTIME_APP_REQUIREMENTS)
52
- return fn.on(
53
- serve=False,
54
- exposed_port=8080,
55
- )
56
63
 
57
-
58
- @mainify
59
- class RouteSignature(NamedTuple):
60
- path: str
61
- is_websocket: bool = False
64
+ return fn
62
65
 
63
66
 
64
67
  @mainify
65
- class App:
68
+ class App(fal.api.BaseServable):
66
69
  requirements: ClassVar[list[str]] = []
67
70
  machine_type: ClassVar[str] = "S"
68
71
  host_kwargs: ClassVar[dict[str, Any]] = {}
@@ -83,19 +86,6 @@ class App:
83
86
  "Running apps through SDK is not implemented yet."
84
87
  )
85
88
 
86
- def setup(self):
87
- """Setup the application before serving."""
88
-
89
- def provide_hints(self) -> list[str]:
90
- """Provide hints for routing the application."""
91
- raise NotImplementedError
92
-
93
- def serve(self) -> None:
94
- import uvicorn
95
-
96
- app = self._build_app()
97
- uvicorn.run(app, host="0.0.0.0", port=8080)
98
-
99
89
  def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
100
90
  return {
101
91
  signature: endpoint
@@ -103,22 +93,23 @@ class App:
103
93
  if (signature := getattr(endpoint, "route_signature", None))
104
94
  }
105
95
 
106
- def _build_app(self) -> FastAPI:
107
- from fastapi import FastAPI
108
- from fastapi.middleware.cors import CORSMiddleware
96
+ @asynccontextmanager
97
+ async def lifespan(self, app: FastAPI):
98
+ await _call_any_fn(self.setup)
99
+ try:
100
+ yield
101
+ finally:
102
+ await _call_any_fn(self.teardown)
109
103
 
110
- @asynccontextmanager
111
- async def lifespan(app: FastAPI):
112
- self.setup()
113
- try:
114
- yield
115
- finally:
116
- self.teardown()
104
+ def setup(self):
105
+ """Setup the application before serving."""
117
106
 
118
- _app = FastAPI(lifespan=lifespan)
107
+ def teardown(self):
108
+ """Teardown the application after serving."""
119
109
 
120
- @_app.middleware("http")
121
- async def provide_hints(request, call_next):
110
+ def _add_extra_middlewares(self, app: FastAPI):
111
+ @app.middleware("http")
112
+ async def provide_hints_headers(request, call_next):
122
113
  response = await call_next(request)
123
114
  try:
124
115
  response.headers["X-Fal-Runner-Hints"] = ",".join(self.provide_hints())
@@ -126,57 +117,18 @@ class App:
126
117
  # This lets us differentiate between apps that don't provide hints
127
118
  # and apps that provide empty hints.
128
119
  pass
129
- except Exception as exc:
120
+ except Exception:
130
121
  from fastapi.logger import logger
131
122
 
132
123
  logger.exception(
133
124
  "Failed to provide hints for %s",
134
125
  self.__class__.__name__,
135
- exc_info=exc,
136
126
  )
137
127
  return response
138
128
 
139
- _app.add_middleware(
140
- CORSMiddleware,
141
- allow_credentials=True,
142
- allow_headers=("*"),
143
- allow_methods=("*"),
144
- allow_origins=("*"),
145
- )
146
-
147
- routes = self.collect_routes()
148
- if not routes:
149
- raise ValueError("An application must have at least one route!")
150
-
151
- for signature, endpoint in routes.items():
152
- if signature.is_websocket:
153
- _app.add_api_websocket_route(
154
- signature.path,
155
- endpoint,
156
- name=endpoint.__name__,
157
- )
158
- else:
159
- _app.add_api_route(
160
- signature.path,
161
- endpoint,
162
- name=endpoint.__name__,
163
- methods=["POST"],
164
- )
165
-
166
- return _app
167
-
168
- def openapi(self) -> dict[str, Any]:
169
- """
170
- Build the OpenAPI specification for the served function.
171
- Attach needed metadata for a better integration to fal.
172
- """
173
- app = self._build_app()
174
- spec = app.openapi()
175
- _mark_order_openapi(spec)
176
- return spec
177
-
178
- def teardown(self):
179
- """Teardown the application after serving."""
129
+ def provide_hints(self) -> list[str]:
130
+ """Provide hints for routing the application."""
131
+ raise NotImplementedError
180
132
 
181
133
 
182
134
  @mainify
@@ -199,10 +151,7 @@ def endpoint(
199
151
 
200
152
  def _fal_websocket_template(
201
153
  func: EndpointT,
202
- buffering: int | None = None,
203
- session_timeout: float | None = None,
204
- input_modal: Any | None = None,
205
- max_batch_size: int = 1,
154
+ route_signature: RouteSignature,
206
155
  ) -> EndpointT:
207
156
  # A template for fal's realtime websocket endpoints to basically
208
157
  # be a boilerplate for the user to fill in their inference function
@@ -220,14 +169,14 @@ def _fal_websocket_template(
220
169
  try:
221
170
  raw_input = await asyncio.wait_for(
222
171
  websocket.receive_bytes(),
223
- timeout=session_timeout,
172
+ timeout=route_signature.session_timeout,
224
173
  )
225
174
  except asyncio.TimeoutError:
226
175
  return
227
176
 
228
177
  input = msgpack.unpackb(raw_input, raw=False)
229
- if input_modal:
230
- input = input_modal(**input)
178
+ if route_signature.input_modal:
179
+ input = route_signature.input_modal(**input)
231
180
 
232
181
  queue.append(input)
233
182
 
@@ -237,10 +186,18 @@ def _fal_websocket_template(
237
186
  websocket: WebSocket,
238
187
  ) -> None:
239
188
  loop = asyncio.get_event_loop()
240
- outgoing_messages: asyncio.Queue[bytes] = asyncio.Queue(maxsize=buffering or 1)
189
+ max_allowed_buffering = route_signature.buffering or 1
190
+ outgoing_messages: asyncio.Queue[bytes] = asyncio.Queue(
191
+ maxsize=max_allowed_buffering * 2 # x2 for outgoing timings
192
+ )
241
193
 
242
194
  async def emit(message):
243
- await websocket.send_bytes(message)
195
+ if isinstance(message, bytes):
196
+ await websocket.send_bytes(message)
197
+ elif isinstance(message, str):
198
+ await websocket.send_text(message)
199
+ else:
200
+ raise TypeError(f"Can't send message of type {type(message)}")
244
201
 
245
202
  async def background_emitter():
246
203
  while True:
@@ -266,7 +223,7 @@ def _fal_websocket_template(
266
223
  return None # End of input
267
224
 
268
225
  batch = [input]
269
- while queue and len(batch) < max_batch_size:
226
+ while queue and len(batch) < route_signature.max_batch_size:
270
227
  next_input = queue.popleft()
271
228
  if hasattr(input, "can_batch") and not input.can_batch(
272
229
  next_input, len(batch)
@@ -275,7 +232,9 @@ def _fal_websocket_template(
275
232
  break
276
233
  batch.append(next_input)
277
234
 
235
+ t0 = loop.time()
278
236
  output = await loop.run_in_executor(None, func, self, *batch) # type: ignore
237
+ total_time = loop.time() - t0
279
238
  if not isinstance(output, dict):
280
239
  # Handle pydantic output modal
281
240
  if hasattr(output, "dict"):
@@ -285,18 +244,30 @@ def _fal_websocket_template(
285
244
  f"Expected a dict or pydantic model as output, got {type(output)}"
286
245
  )
287
246
 
288
- message = msgpack.packb(output, use_bin_type=True)
289
- try:
290
- outgoing_messages.put_nowait(message)
291
- except asyncio.QueueFull:
292
- await emit(message)
247
+ messages = [
248
+ msgpack.packb(output, use_bin_type=True),
249
+ ]
250
+ if route_signature.emit_timings:
251
+ # We emit x-fal messages in JSON, no matter what the input/output format is.
252
+ timings = {
253
+ "type": "x-fal-message",
254
+ "action": "timings",
255
+ "timing": total_time,
256
+ }
257
+ messages.append(json.dumps(timings, separators=(",", ":")))
258
+
259
+ for message in messages:
260
+ try:
261
+ outgoing_messages.put_nowait(message)
262
+ except asyncio.QueueFull:
263
+ await emit(message)
293
264
 
294
265
  async def websocket_template(self, websocket: WebSocket) -> None:
295
266
  import asyncio
296
267
 
297
268
  await websocket.accept()
298
269
 
299
- queue: deque[Any] = deque(maxlen=buffering)
270
+ queue: deque[Any] = deque(maxlen=route_signature.buffering)
300
271
  input_task = asyncio.create_task(mirror_input(queue, websocket))
301
272
  input_task.add_done_callback(lambda _: queue.append(None))
302
273
  output_task = asyncio.create_task(mirror_output(self, queue, websocket))
@@ -314,7 +285,9 @@ def _fal_websocket_template(
314
285
  # so we can just close the connection after the
315
286
  # processing of the last input is done.
316
287
  input_task.result()
317
- await asyncio.wait_for(output_task, timeout=session_timeout)
288
+ await asyncio.wait_for(
289
+ output_task, timeout=route_signature.session_timeout
290
+ )
318
291
  else:
319
292
  assert output_task.done()
320
293
 
@@ -362,7 +335,8 @@ def _fal_websocket_template(
362
335
  "websocket": WebSocket,
363
336
  "return": None,
364
337
  }
365
-
338
+ websocket_template.route_signature = route_signature # type: ignore
339
+ websocket_template.original_func = func # type: ignore
366
340
  return typing.cast(EndpointT, websocket_template)
367
341
 
368
342
 
@@ -395,44 +369,17 @@ def realtime(
395
369
  else:
396
370
  input_modal = None
397
371
 
398
- callable = _fal_websocket_template(
399
- original_func,
372
+ route_signature = RouteSignature(
373
+ path=path,
374
+ is_websocket=True,
375
+ input_modal=input_modal,
400
376
  buffering=buffering,
401
377
  session_timeout=session_timeout,
402
- input_modal=input_modal,
403
378
  max_batch_size=max_batch_size,
404
379
  )
405
- callable.route_signature = RouteSignature(path=path, is_websocket=True) # type: ignore
406
- callable.original_func = original_func # type: ignore
407
- return callable
380
+ return _fal_websocket_template(
381
+ original_func,
382
+ route_signature,
383
+ )
408
384
 
409
385
  return marker_fn
410
-
411
-
412
- def _mark_order_openapi(spec: dict[str, Any]):
413
- """
414
- Add x-fal-order-* keys to the OpenAPI specification to help the rendering of UI.
415
-
416
- NOTE: We rely on the fact that fastapi and Python dicts keep the order of properties.
417
- """
418
-
419
- def mark_order(obj: dict[str, Any], key: str):
420
- obj[f"x-fal-order-{key}"] = list(obj[key].keys())
421
-
422
- mark_order(spec, "paths")
423
-
424
- def order_schema_object(schema: dict[str, Any]):
425
- """
426
- Mark the order of properties in the schema object.
427
- They can have 'allOf', 'properties' or '$ref' key.
428
- """
429
- if "allOf" in schema:
430
- for sub_schema in schema["allOf"]:
431
- order_schema_object(sub_schema)
432
- if "properties" in schema:
433
- mark_order(schema, "properties")
434
-
435
- for key in spec["components"].get("schemas") or {}:
436
- order_schema_object(spec["components"]["schemas"][key])
437
-
438
- return spec
fal/apps.py CHANGED
@@ -63,7 +63,10 @@ class RequestHandle:
63
63
  _creds: Credentials = field(default_factory=get_default_credentials, repr=False)
64
64
 
65
65
  def __post_init__(self):
66
- self.app_id = _backwards_compatible_app_id(self.app_id)
66
+ app_id = _backwards_compatible_app_id(self.app_id)
67
+ # drop any extra path components
68
+ user_id, app_name = app_id.split("/")[:2]
69
+ self.app_id = f"{user_id}/{app_name}"
67
70
 
68
71
  def status(self, *, logs: bool = False) -> _Status:
69
72
  """Check the status of an async inference request."""
@@ -116,7 +119,16 @@ class RequestHandle:
116
119
  + f"/requests/{self.request_id}/"
117
120
  )
118
121
  response = _HTTP_CLIENT.get(url, headers=self._creds.to_headers())
119
- response.raise_for_status()
122
+ try:
123
+ response.raise_for_status()
124
+ except httpx.HTTPStatusError as e:
125
+ if response.headers["Content-Type"] != "application/json":
126
+ raise
127
+ raise httpx.HTTPStatusError(
128
+ f"{response.status_code}: {response.text}",
129
+ request=e.request,
130
+ response=e.response,
131
+ ) from e
120
132
 
121
133
  data = response.json()
122
134
  return data
@@ -134,20 +146,23 @@ class RequestHandle:
134
146
  _HTTP_CLIENT = httpx.Client(headers={"User-Agent": "Fal/Python"})
135
147
 
136
148
 
137
- def run(app_id: str, arguments: dict[str, Any], *, path: str = "/") -> dict[str, Any]:
149
+ def run(app_id: str, arguments: dict[str, Any], *, path: str = "") -> dict[str, Any]:
138
150
  """Run an inference task on a Fal app and return the result."""
139
151
 
140
152
  handle = submit(app_id, arguments, path=path)
141
153
  return handle.get()
142
154
 
143
155
 
144
- def submit(app_id: str, arguments: dict[str, Any], *, path: str = "/") -> RequestHandle:
156
+ def submit(app_id: str, arguments: dict[str, Any], *, path: str = "") -> RequestHandle:
145
157
  """Submit an async inference task to the app. Returns a request handle
146
158
  which can be used to check the status of the request and retrieve the
147
159
  result."""
148
160
 
149
161
  app_id = _backwards_compatible_app_id(app_id)
150
- url = _QUEUE_URL_FORMAT.format(app_id=app_id) + path
162
+ url = _QUEUE_URL_FORMAT.format(app_id=app_id)
163
+ if path:
164
+ url += "/" + path.removeprefix("/")
165
+
151
166
  creds = get_default_credentials()
152
167
 
153
168
  response = _HTTP_CLIENT.post(
@@ -206,7 +221,10 @@ def _connect(app_id: str, *, path: str = "/realtime") -> Iterator[_RealtimeConne
206
221
  from websockets.sync import client
207
222
 
208
223
  app_id = _backwards_compatible_app_id(app_id)
209
- url = _REALTIME_URL_FORMAT.format(app_id=app_id) + path
224
+ url = _REALTIME_URL_FORMAT.format(app_id=app_id)
225
+ if path:
226
+ url += "/" + path.removeprefix("/")
227
+
210
228
  creds = get_default_credentials()
211
229
 
212
230
  with client.connect(