fal 0.12.2__py3-none-any.whl → 0.12.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fal might be problematic. Click here for more details.
- fal/__init__.py +11 -2
- fal/api.py +130 -50
- fal/app.py +81 -134
- fal/apps.py +24 -6
- fal/auth/__init__.py +14 -2
- fal/auth/auth0.py +34 -25
- fal/cli.py +9 -4
- fal/env.py +0 -4
- fal/flags.py +1 -0
- fal/logging/__init__.py +0 -2
- fal/logging/trace.py +8 -1
- fal/sdk.py +33 -6
- fal/toolkit/__init__.py +16 -0
- fal/workflows.py +481 -0
- {fal-0.12.2.dist-info → fal-0.12.4.dist-info}/METADATA +4 -7
- fal-0.12.4.dist-info/RECORD +88 -0
- openapi_fal_rest/__init__.py +1 -0
- openapi_fal_rest/api/workflows/__init__.py +0 -0
- openapi_fal_rest/api/workflows/create_or_update_workflow_workflows_post.py +172 -0
- openapi_fal_rest/api/workflows/delete_workflow_workflows_user_id_workflow_name_delete.py +175 -0
- openapi_fal_rest/api/workflows/execute_workflow_workflows_user_id_workflow_name_post.py +268 -0
- openapi_fal_rest/api/workflows/get_workflow_workflows_user_id_workflow_name_get.py +181 -0
- openapi_fal_rest/api/workflows/get_workflows_workflows_get.py +189 -0
- openapi_fal_rest/models/__init__.py +34 -0
- openapi_fal_rest/models/app_metadata_response_app_metadata.py +1 -0
- openapi_fal_rest/models/customer_details.py +15 -14
- openapi_fal_rest/models/execute_workflow_workflows_user_id_workflow_name_post_json_body_type_0.py +44 -0
- openapi_fal_rest/models/execute_workflow_workflows_user_id_workflow_name_post_response_200_type_0.py +44 -0
- openapi_fal_rest/models/page_workflow_item.py +107 -0
- openapi_fal_rest/models/typed_workflow.py +85 -0
- openapi_fal_rest/models/workflow_contents.py +98 -0
- openapi_fal_rest/models/workflow_contents_nodes.py +59 -0
- openapi_fal_rest/models/workflow_contents_output.py +44 -0
- openapi_fal_rest/models/workflow_detail.py +149 -0
- openapi_fal_rest/models/workflow_detail_contents_type_0.py +44 -0
- openapi_fal_rest/models/workflow_item.py +80 -0
- openapi_fal_rest/models/workflow_node.py +74 -0
- openapi_fal_rest/models/workflow_node_type.py +9 -0
- openapi_fal_rest/models/workflow_schema.py +73 -0
- openapi_fal_rest/models/workflow_schema_input.py +44 -0
- openapi_fal_rest/models/workflow_schema_output.py +44 -0
- openapi_fal_rest/types.py +1 -0
- fal/logging/datadog.py +0 -78
- fal-0.12.2.dist-info/RECORD +0 -67
- {fal-0.12.2.dist-info → fal-0.12.4.dist-info}/WHEEL +0 -0
- {fal-0.12.2.dist-info → fal-0.12.4.dist-info}/entry_points.txt +0 -0
fal/__init__.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from fal import apps
|
|
4
|
-
|
|
5
|
-
# TODO: DEPRECATED - use function instead
|
|
6
4
|
from fal.api import FalServerlessHost, LocalHost, cached
|
|
7
5
|
from fal.api import function
|
|
8
6
|
from fal.api import function as isolated
|
|
@@ -16,6 +14,17 @@ serverless = FalServerlessHost()
|
|
|
16
14
|
# DEPRECATED - use serverless instead
|
|
17
15
|
cloud = FalServerlessHost()
|
|
18
16
|
|
|
17
|
+
__all__ = [
|
|
18
|
+
"function",
|
|
19
|
+
"cached",
|
|
20
|
+
"App",
|
|
21
|
+
"endpoint",
|
|
22
|
+
"realtime",
|
|
23
|
+
# "wrap_app",
|
|
24
|
+
"FalServerlessKeyCredentials",
|
|
25
|
+
"sync_dir",
|
|
26
|
+
]
|
|
27
|
+
|
|
19
28
|
|
|
20
29
|
# NOTE: This makes `import fal.dbt` import the `dbt-fal` module and `import fal` import the `fal` module
|
|
21
30
|
# NOTE: taken from dbt-core: https://github.com/dbt-labs/dbt-core/blob/ac539fd5cf325cfb5315339077d03399d575f570/core/dbt/adapters/__init__.py#L1-L7
|
fal/api.py
CHANGED
|
@@ -4,7 +4,7 @@ import inspect
|
|
|
4
4
|
import sys
|
|
5
5
|
from collections import defaultdict
|
|
6
6
|
from concurrent.futures import ThreadPoolExecutor
|
|
7
|
-
from contextlib import suppress
|
|
7
|
+
from contextlib import asynccontextmanager, suppress
|
|
8
8
|
from dataclasses import dataclass, field, replace
|
|
9
9
|
from functools import partial, wraps
|
|
10
10
|
from os import PathLike
|
|
@@ -16,6 +16,7 @@ from typing import (
|
|
|
16
16
|
Generic,
|
|
17
17
|
Iterator,
|
|
18
18
|
Literal,
|
|
19
|
+
NamedTuple,
|
|
19
20
|
TypeVar,
|
|
20
21
|
cast,
|
|
21
22
|
overload,
|
|
@@ -26,6 +27,7 @@ import dill.detect
|
|
|
26
27
|
import grpc
|
|
27
28
|
import isolate
|
|
28
29
|
import yaml
|
|
30
|
+
from fastapi import FastAPI
|
|
29
31
|
from isolate.backends.common import active_python
|
|
30
32
|
from isolate.backends.settings import DEFAULT_SETTINGS
|
|
31
33
|
from isolate.connections import PythonIPC
|
|
@@ -56,6 +58,8 @@ ReturnT = TypeVar("ReturnT", covariant=True)
|
|
|
56
58
|
BasicConfig = Dict[str, Any]
|
|
57
59
|
_UNSET = object()
|
|
58
60
|
|
|
61
|
+
SERVE_REQUIREMENTS = ["fastapi==0.99.1", "uvicorn"]
|
|
62
|
+
|
|
59
63
|
|
|
60
64
|
@dataclass
|
|
61
65
|
class FalServerlessError(Exception):
|
|
@@ -110,7 +114,7 @@ class Host(Generic[ArgsT, ReturnT]):
|
|
|
110
114
|
options.environment[key] = value
|
|
111
115
|
|
|
112
116
|
if options.gateway.get("serve"):
|
|
113
|
-
options.add_requirements(
|
|
117
|
+
options.add_requirements(SERVE_REQUIREMENTS)
|
|
114
118
|
|
|
115
119
|
return options
|
|
116
120
|
|
|
@@ -730,53 +734,17 @@ def function( # type: ignore
|
|
|
730
734
|
|
|
731
735
|
|
|
732
736
|
@mainify
|
|
733
|
-
class
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
self._func = func
|
|
738
|
-
|
|
739
|
-
def build_app(self):
|
|
740
|
-
from fastapi import FastAPI
|
|
741
|
-
from fastapi.middleware.cors import CORSMiddleware
|
|
742
|
-
|
|
743
|
-
_app = FastAPI()
|
|
744
|
-
|
|
745
|
-
_app.add_middleware(
|
|
746
|
-
CORSMiddleware,
|
|
747
|
-
allow_credentials=True,
|
|
748
|
-
allow_headers=("*"),
|
|
749
|
-
allow_methods=("*"),
|
|
750
|
-
allow_origins=("*"),
|
|
751
|
-
)
|
|
752
|
-
|
|
753
|
-
_app.add_api_route(
|
|
754
|
-
"/",
|
|
755
|
-
self._func, # type: ignore
|
|
756
|
-
name=self._func.__name__,
|
|
757
|
-
methods=["POST"],
|
|
758
|
-
)
|
|
759
|
-
|
|
760
|
-
return _app
|
|
761
|
-
|
|
762
|
-
def __call__(self, *args, **kwargs) -> None:
|
|
763
|
-
if len(args) != 0 or len(kwargs) != 0:
|
|
764
|
-
print(
|
|
765
|
-
f"[warning] {self._func.__name__} function is served with no arguments."
|
|
766
|
-
)
|
|
767
|
-
|
|
768
|
-
from uvicorn import run
|
|
769
|
-
|
|
770
|
-
app = self.build_app()
|
|
771
|
-
run(app, host="0.0.0.0", port=8080)
|
|
737
|
+
class FalFastAPI(FastAPI):
|
|
738
|
+
"""
|
|
739
|
+
A subclass of FastAPI that adds some fal-specific functionality.
|
|
740
|
+
"""
|
|
772
741
|
|
|
773
742
|
def openapi(self) -> dict[str, Any]:
|
|
774
743
|
"""
|
|
775
744
|
Build the OpenAPI specification for the served function.
|
|
776
745
|
Attach needed metadata for a better integration to fal.
|
|
777
746
|
"""
|
|
778
|
-
|
|
779
|
-
spec = app.openapi()
|
|
747
|
+
spec = super().openapi()
|
|
780
748
|
self._mark_order_openapi(spec)
|
|
781
749
|
return spec
|
|
782
750
|
|
|
@@ -788,7 +756,8 @@ class ServeWrapper:
|
|
|
788
756
|
"""
|
|
789
757
|
|
|
790
758
|
def mark_order(obj: dict[str, Any], key: str):
|
|
791
|
-
|
|
759
|
+
if key in obj:
|
|
760
|
+
obj[f"x-fal-order-{key}"] = list(obj[key].keys())
|
|
792
761
|
|
|
793
762
|
mark_order(spec, "paths")
|
|
794
763
|
|
|
@@ -797,18 +766,129 @@ class ServeWrapper:
|
|
|
797
766
|
Mark the order of properties in the schema object.
|
|
798
767
|
They can have 'allOf', 'properties' or '$ref' key.
|
|
799
768
|
"""
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
mark_order(schema, "properties")
|
|
769
|
+
for sub_schema in schema.get("allOf", []):
|
|
770
|
+
order_schema_object(sub_schema)
|
|
771
|
+
|
|
772
|
+
mark_order(schema, "properties")
|
|
805
773
|
|
|
806
|
-
for key in spec.get("components", {}).get("schemas"
|
|
774
|
+
for key in spec.get("components", {}).get("schemas", {}):
|
|
807
775
|
order_schema_object(spec["components"]["schemas"][key])
|
|
808
776
|
|
|
809
777
|
return spec
|
|
810
778
|
|
|
811
779
|
|
|
780
|
+
@mainify
|
|
781
|
+
class RouteSignature(NamedTuple):
|
|
782
|
+
path: str
|
|
783
|
+
is_websocket: bool = False
|
|
784
|
+
input_modal: type | None = None
|
|
785
|
+
buffering: int | None = None
|
|
786
|
+
session_timeout: float | None = None
|
|
787
|
+
max_batch_size: int = 1
|
|
788
|
+
emit_timings: bool = False
|
|
789
|
+
|
|
790
|
+
|
|
791
|
+
@mainify
|
|
792
|
+
class BaseServable:
|
|
793
|
+
def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
|
|
794
|
+
raise NotImplementedError
|
|
795
|
+
|
|
796
|
+
def _add_extra_middlewares(self, app: FastAPI):
|
|
797
|
+
"""
|
|
798
|
+
For subclasses to add extra middlewares to the app.
|
|
799
|
+
"""
|
|
800
|
+
pass
|
|
801
|
+
|
|
802
|
+
@asynccontextmanager
|
|
803
|
+
async def lifespan(self, app: FastAPI):
|
|
804
|
+
yield
|
|
805
|
+
|
|
806
|
+
def _build_app(self) -> FastAPI:
|
|
807
|
+
from fastapi import HTTPException, Request
|
|
808
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
809
|
+
from fastapi.responses import JSONResponse
|
|
810
|
+
|
|
811
|
+
_app = FalFastAPI(lifespan=self.lifespan)
|
|
812
|
+
|
|
813
|
+
_app.add_middleware(
|
|
814
|
+
CORSMiddleware,
|
|
815
|
+
allow_credentials=True,
|
|
816
|
+
allow_headers=("*"),
|
|
817
|
+
allow_methods=("*"),
|
|
818
|
+
allow_origins=("*"),
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
self._add_extra_middlewares(_app)
|
|
822
|
+
|
|
823
|
+
@_app.exception_handler(404)
|
|
824
|
+
async def not_found_exception_handler(request: Request, exc: HTTPException):
|
|
825
|
+
# Rewrite the message to include the path that was not found.
|
|
826
|
+
# This is supposed to make it easier to understand to the user
|
|
827
|
+
# that the error comes from the app and not our platform.
|
|
828
|
+
if exc.detail == "Not Found":
|
|
829
|
+
return JSONResponse(
|
|
830
|
+
{"detail": f"Path {request.url.path} not found"}, 404
|
|
831
|
+
)
|
|
832
|
+
else:
|
|
833
|
+
# If it's not a generic 404, just return the original message.
|
|
834
|
+
return JSONResponse({"detail": exc.detail}, 404)
|
|
835
|
+
|
|
836
|
+
routes = self.collect_routes()
|
|
837
|
+
if not routes:
|
|
838
|
+
raise ValueError("An application must have at least one route!")
|
|
839
|
+
|
|
840
|
+
for signature, endpoint in routes.items():
|
|
841
|
+
if signature.is_websocket:
|
|
842
|
+
_app.add_api_websocket_route(
|
|
843
|
+
signature.path,
|
|
844
|
+
endpoint,
|
|
845
|
+
name=endpoint.__name__,
|
|
846
|
+
)
|
|
847
|
+
else:
|
|
848
|
+
_app.add_api_route(
|
|
849
|
+
signature.path,
|
|
850
|
+
endpoint,
|
|
851
|
+
name=endpoint.__name__,
|
|
852
|
+
methods=["POST"],
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
return _app
|
|
856
|
+
|
|
857
|
+
def openapi(self) -> dict[str, Any]:
|
|
858
|
+
"""
|
|
859
|
+
Build the OpenAPI specification for the served function.
|
|
860
|
+
Attach needed metadata for a better integration to fal.
|
|
861
|
+
"""
|
|
862
|
+
return self._build_app().openapi()
|
|
863
|
+
|
|
864
|
+
def serve(self) -> None:
|
|
865
|
+
import uvicorn
|
|
866
|
+
|
|
867
|
+
app = self._build_app()
|
|
868
|
+
uvicorn.run(app, host="0.0.0.0", port=8080)
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
@mainify
|
|
872
|
+
class ServeWrapper(BaseServable):
|
|
873
|
+
_func: Callable
|
|
874
|
+
|
|
875
|
+
def __init__(self, func: Callable):
|
|
876
|
+
self._func = func
|
|
877
|
+
|
|
878
|
+
def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
|
|
879
|
+
return {
|
|
880
|
+
RouteSignature("/"): self._func,
|
|
881
|
+
}
|
|
882
|
+
|
|
883
|
+
def __call__(self, *args, **kwargs) -> None:
|
|
884
|
+
if len(args) != 0 or len(kwargs) != 0:
|
|
885
|
+
print(
|
|
886
|
+
f"[warning] {self._func.__name__} function is served with no arguments."
|
|
887
|
+
)
|
|
888
|
+
|
|
889
|
+
self.serve()
|
|
890
|
+
|
|
891
|
+
|
|
812
892
|
@dataclass
|
|
813
893
|
class IsolatedFunction(Generic[ArgsT, ReturnT]):
|
|
814
894
|
host: Host[ArgsT, ReturnT]
|
fal/app.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
+
import json
|
|
4
5
|
import os
|
|
5
6
|
import typing
|
|
6
7
|
from contextlib import asynccontextmanager
|
|
7
|
-
from typing import Any, Callable, ClassVar,
|
|
8
|
+
from typing import Any, Callable, ClassVar, TypeVar
|
|
8
9
|
|
|
9
10
|
from fastapi import FastAPI
|
|
10
11
|
|
|
11
12
|
import fal.api
|
|
12
13
|
from fal._serialization import add_serialization_listeners_for
|
|
14
|
+
from fal.api import RouteSignature
|
|
13
15
|
from fal.logging import get_logger
|
|
14
16
|
from fal.toolkit import mainify
|
|
15
17
|
|
|
@@ -19,6 +21,13 @@ EndpointT = TypeVar("EndpointT", bound=Callable[..., Any])
|
|
|
19
21
|
logger = get_logger(__name__)
|
|
20
22
|
|
|
21
23
|
|
|
24
|
+
async def _call_any_fn(fn, *args, **kwargs):
|
|
25
|
+
if inspect.iscoroutinefunction(fn):
|
|
26
|
+
return await fn(*args, **kwargs)
|
|
27
|
+
else:
|
|
28
|
+
return fn(*args, **kwargs)
|
|
29
|
+
|
|
30
|
+
|
|
22
31
|
def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
|
|
23
32
|
add_serialization_listeners_for(cls)
|
|
24
33
|
|
|
@@ -44,25 +53,19 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
|
|
|
44
53
|
**cls.host_kwargs,
|
|
45
54
|
**kwargs,
|
|
46
55
|
metadata=metadata,
|
|
47
|
-
|
|
56
|
+
exposed_port=8080,
|
|
57
|
+
serve=False,
|
|
48
58
|
)
|
|
49
59
|
fn = wrapper(initialize_and_serve)
|
|
60
|
+
fn.options.add_requirements(fal.api.SERVE_REQUIREMENTS)
|
|
50
61
|
if realtime_app:
|
|
51
62
|
fn.options.add_requirements(REALTIME_APP_REQUIREMENTS)
|
|
52
|
-
return fn.on(
|
|
53
|
-
serve=False,
|
|
54
|
-
exposed_port=8080,
|
|
55
|
-
)
|
|
56
63
|
|
|
57
|
-
|
|
58
|
-
@mainify
|
|
59
|
-
class RouteSignature(NamedTuple):
|
|
60
|
-
path: str
|
|
61
|
-
is_websocket: bool = False
|
|
64
|
+
return fn
|
|
62
65
|
|
|
63
66
|
|
|
64
67
|
@mainify
|
|
65
|
-
class App:
|
|
68
|
+
class App(fal.api.BaseServable):
|
|
66
69
|
requirements: ClassVar[list[str]] = []
|
|
67
70
|
machine_type: ClassVar[str] = "S"
|
|
68
71
|
host_kwargs: ClassVar[dict[str, Any]] = {}
|
|
@@ -83,19 +86,6 @@ class App:
|
|
|
83
86
|
"Running apps through SDK is not implemented yet."
|
|
84
87
|
)
|
|
85
88
|
|
|
86
|
-
def setup(self):
|
|
87
|
-
"""Setup the application before serving."""
|
|
88
|
-
|
|
89
|
-
def provide_hints(self) -> list[str]:
|
|
90
|
-
"""Provide hints for routing the application."""
|
|
91
|
-
raise NotImplementedError
|
|
92
|
-
|
|
93
|
-
def serve(self) -> None:
|
|
94
|
-
import uvicorn
|
|
95
|
-
|
|
96
|
-
app = self._build_app()
|
|
97
|
-
uvicorn.run(app, host="0.0.0.0", port=8080)
|
|
98
|
-
|
|
99
89
|
def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
|
|
100
90
|
return {
|
|
101
91
|
signature: endpoint
|
|
@@ -103,22 +93,23 @@ class App:
|
|
|
103
93
|
if (signature := getattr(endpoint, "route_signature", None))
|
|
104
94
|
}
|
|
105
95
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
96
|
+
@asynccontextmanager
|
|
97
|
+
async def lifespan(self, app: FastAPI):
|
|
98
|
+
await _call_any_fn(self.setup)
|
|
99
|
+
try:
|
|
100
|
+
yield
|
|
101
|
+
finally:
|
|
102
|
+
await _call_any_fn(self.teardown)
|
|
109
103
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
self.setup()
|
|
113
|
-
try:
|
|
114
|
-
yield
|
|
115
|
-
finally:
|
|
116
|
-
self.teardown()
|
|
104
|
+
def setup(self):
|
|
105
|
+
"""Setup the application before serving."""
|
|
117
106
|
|
|
118
|
-
|
|
107
|
+
def teardown(self):
|
|
108
|
+
"""Teardown the application after serving."""
|
|
119
109
|
|
|
120
|
-
|
|
121
|
-
|
|
110
|
+
def _add_extra_middlewares(self, app: FastAPI):
|
|
111
|
+
@app.middleware("http")
|
|
112
|
+
async def provide_hints_headers(request, call_next):
|
|
122
113
|
response = await call_next(request)
|
|
123
114
|
try:
|
|
124
115
|
response.headers["X-Fal-Runner-Hints"] = ",".join(self.provide_hints())
|
|
@@ -126,57 +117,18 @@ class App:
|
|
|
126
117
|
# This lets us differentiate between apps that don't provide hints
|
|
127
118
|
# and apps that provide empty hints.
|
|
128
119
|
pass
|
|
129
|
-
except Exception
|
|
120
|
+
except Exception:
|
|
130
121
|
from fastapi.logger import logger
|
|
131
122
|
|
|
132
123
|
logger.exception(
|
|
133
124
|
"Failed to provide hints for %s",
|
|
134
125
|
self.__class__.__name__,
|
|
135
|
-
exc_info=exc,
|
|
136
126
|
)
|
|
137
127
|
return response
|
|
138
128
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
allow_headers=("*"),
|
|
143
|
-
allow_methods=("*"),
|
|
144
|
-
allow_origins=("*"),
|
|
145
|
-
)
|
|
146
|
-
|
|
147
|
-
routes = self.collect_routes()
|
|
148
|
-
if not routes:
|
|
149
|
-
raise ValueError("An application must have at least one route!")
|
|
150
|
-
|
|
151
|
-
for signature, endpoint in routes.items():
|
|
152
|
-
if signature.is_websocket:
|
|
153
|
-
_app.add_api_websocket_route(
|
|
154
|
-
signature.path,
|
|
155
|
-
endpoint,
|
|
156
|
-
name=endpoint.__name__,
|
|
157
|
-
)
|
|
158
|
-
else:
|
|
159
|
-
_app.add_api_route(
|
|
160
|
-
signature.path,
|
|
161
|
-
endpoint,
|
|
162
|
-
name=endpoint.__name__,
|
|
163
|
-
methods=["POST"],
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
return _app
|
|
167
|
-
|
|
168
|
-
def openapi(self) -> dict[str, Any]:
|
|
169
|
-
"""
|
|
170
|
-
Build the OpenAPI specification for the served function.
|
|
171
|
-
Attach needed metadata for a better integration to fal.
|
|
172
|
-
"""
|
|
173
|
-
app = self._build_app()
|
|
174
|
-
spec = app.openapi()
|
|
175
|
-
_mark_order_openapi(spec)
|
|
176
|
-
return spec
|
|
177
|
-
|
|
178
|
-
def teardown(self):
|
|
179
|
-
"""Teardown the application after serving."""
|
|
129
|
+
def provide_hints(self) -> list[str]:
|
|
130
|
+
"""Provide hints for routing the application."""
|
|
131
|
+
raise NotImplementedError
|
|
180
132
|
|
|
181
133
|
|
|
182
134
|
@mainify
|
|
@@ -199,10 +151,7 @@ def endpoint(
|
|
|
199
151
|
|
|
200
152
|
def _fal_websocket_template(
|
|
201
153
|
func: EndpointT,
|
|
202
|
-
|
|
203
|
-
session_timeout: float | None = None,
|
|
204
|
-
input_modal: Any | None = None,
|
|
205
|
-
max_batch_size: int = 1,
|
|
154
|
+
route_signature: RouteSignature,
|
|
206
155
|
) -> EndpointT:
|
|
207
156
|
# A template for fal's realtime websocket endpoints to basically
|
|
208
157
|
# be a boilerplate for the user to fill in their inference function
|
|
@@ -220,14 +169,14 @@ def _fal_websocket_template(
|
|
|
220
169
|
try:
|
|
221
170
|
raw_input = await asyncio.wait_for(
|
|
222
171
|
websocket.receive_bytes(),
|
|
223
|
-
timeout=session_timeout,
|
|
172
|
+
timeout=route_signature.session_timeout,
|
|
224
173
|
)
|
|
225
174
|
except asyncio.TimeoutError:
|
|
226
175
|
return
|
|
227
176
|
|
|
228
177
|
input = msgpack.unpackb(raw_input, raw=False)
|
|
229
|
-
if input_modal:
|
|
230
|
-
input = input_modal(**input)
|
|
178
|
+
if route_signature.input_modal:
|
|
179
|
+
input = route_signature.input_modal(**input)
|
|
231
180
|
|
|
232
181
|
queue.append(input)
|
|
233
182
|
|
|
@@ -237,10 +186,18 @@ def _fal_websocket_template(
|
|
|
237
186
|
websocket: WebSocket,
|
|
238
187
|
) -> None:
|
|
239
188
|
loop = asyncio.get_event_loop()
|
|
240
|
-
|
|
189
|
+
max_allowed_buffering = route_signature.buffering or 1
|
|
190
|
+
outgoing_messages: asyncio.Queue[bytes] = asyncio.Queue(
|
|
191
|
+
maxsize=max_allowed_buffering * 2 # x2 for outgoing timings
|
|
192
|
+
)
|
|
241
193
|
|
|
242
194
|
async def emit(message):
|
|
243
|
-
|
|
195
|
+
if isinstance(message, bytes):
|
|
196
|
+
await websocket.send_bytes(message)
|
|
197
|
+
elif isinstance(message, str):
|
|
198
|
+
await websocket.send_text(message)
|
|
199
|
+
else:
|
|
200
|
+
raise TypeError(f"Can't send message of type {type(message)}")
|
|
244
201
|
|
|
245
202
|
async def background_emitter():
|
|
246
203
|
while True:
|
|
@@ -266,7 +223,7 @@ def _fal_websocket_template(
|
|
|
266
223
|
return None # End of input
|
|
267
224
|
|
|
268
225
|
batch = [input]
|
|
269
|
-
while queue and len(batch) < max_batch_size:
|
|
226
|
+
while queue and len(batch) < route_signature.max_batch_size:
|
|
270
227
|
next_input = queue.popleft()
|
|
271
228
|
if hasattr(input, "can_batch") and not input.can_batch(
|
|
272
229
|
next_input, len(batch)
|
|
@@ -275,7 +232,9 @@ def _fal_websocket_template(
|
|
|
275
232
|
break
|
|
276
233
|
batch.append(next_input)
|
|
277
234
|
|
|
235
|
+
t0 = loop.time()
|
|
278
236
|
output = await loop.run_in_executor(None, func, self, *batch) # type: ignore
|
|
237
|
+
total_time = loop.time() - t0
|
|
279
238
|
if not isinstance(output, dict):
|
|
280
239
|
# Handle pydantic output modal
|
|
281
240
|
if hasattr(output, "dict"):
|
|
@@ -285,18 +244,30 @@ def _fal_websocket_template(
|
|
|
285
244
|
f"Expected a dict or pydantic model as output, got {type(output)}"
|
|
286
245
|
)
|
|
287
246
|
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
247
|
+
messages = [
|
|
248
|
+
msgpack.packb(output, use_bin_type=True),
|
|
249
|
+
]
|
|
250
|
+
if route_signature.emit_timings:
|
|
251
|
+
# We emit x-fal messages in JSON, no matter what the input/output format is.
|
|
252
|
+
timings = {
|
|
253
|
+
"type": "x-fal-message",
|
|
254
|
+
"action": "timings",
|
|
255
|
+
"timing": total_time,
|
|
256
|
+
}
|
|
257
|
+
messages.append(json.dumps(timings, separators=(",", ":")))
|
|
258
|
+
|
|
259
|
+
for message in messages:
|
|
260
|
+
try:
|
|
261
|
+
outgoing_messages.put_nowait(message)
|
|
262
|
+
except asyncio.QueueFull:
|
|
263
|
+
await emit(message)
|
|
293
264
|
|
|
294
265
|
async def websocket_template(self, websocket: WebSocket) -> None:
|
|
295
266
|
import asyncio
|
|
296
267
|
|
|
297
268
|
await websocket.accept()
|
|
298
269
|
|
|
299
|
-
queue: deque[Any] = deque(maxlen=buffering)
|
|
270
|
+
queue: deque[Any] = deque(maxlen=route_signature.buffering)
|
|
300
271
|
input_task = asyncio.create_task(mirror_input(queue, websocket))
|
|
301
272
|
input_task.add_done_callback(lambda _: queue.append(None))
|
|
302
273
|
output_task = asyncio.create_task(mirror_output(self, queue, websocket))
|
|
@@ -314,7 +285,9 @@ def _fal_websocket_template(
|
|
|
314
285
|
# so we can just close the connection after the
|
|
315
286
|
# processing of the last input is done.
|
|
316
287
|
input_task.result()
|
|
317
|
-
await asyncio.wait_for(
|
|
288
|
+
await asyncio.wait_for(
|
|
289
|
+
output_task, timeout=route_signature.session_timeout
|
|
290
|
+
)
|
|
318
291
|
else:
|
|
319
292
|
assert output_task.done()
|
|
320
293
|
|
|
@@ -362,7 +335,8 @@ def _fal_websocket_template(
|
|
|
362
335
|
"websocket": WebSocket,
|
|
363
336
|
"return": None,
|
|
364
337
|
}
|
|
365
|
-
|
|
338
|
+
websocket_template.route_signature = route_signature # type: ignore
|
|
339
|
+
websocket_template.original_func = func # type: ignore
|
|
366
340
|
return typing.cast(EndpointT, websocket_template)
|
|
367
341
|
|
|
368
342
|
|
|
@@ -395,44 +369,17 @@ def realtime(
|
|
|
395
369
|
else:
|
|
396
370
|
input_modal = None
|
|
397
371
|
|
|
398
|
-
|
|
399
|
-
|
|
372
|
+
route_signature = RouteSignature(
|
|
373
|
+
path=path,
|
|
374
|
+
is_websocket=True,
|
|
375
|
+
input_modal=input_modal,
|
|
400
376
|
buffering=buffering,
|
|
401
377
|
session_timeout=session_timeout,
|
|
402
|
-
input_modal=input_modal,
|
|
403
378
|
max_batch_size=max_batch_size,
|
|
404
379
|
)
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
380
|
+
return _fal_websocket_template(
|
|
381
|
+
original_func,
|
|
382
|
+
route_signature,
|
|
383
|
+
)
|
|
408
384
|
|
|
409
385
|
return marker_fn
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
def _mark_order_openapi(spec: dict[str, Any]):
|
|
413
|
-
"""
|
|
414
|
-
Add x-fal-order-* keys to the OpenAPI specification to help the rendering of UI.
|
|
415
|
-
|
|
416
|
-
NOTE: We rely on the fact that fastapi and Python dicts keep the order of properties.
|
|
417
|
-
"""
|
|
418
|
-
|
|
419
|
-
def mark_order(obj: dict[str, Any], key: str):
|
|
420
|
-
obj[f"x-fal-order-{key}"] = list(obj[key].keys())
|
|
421
|
-
|
|
422
|
-
mark_order(spec, "paths")
|
|
423
|
-
|
|
424
|
-
def order_schema_object(schema: dict[str, Any]):
|
|
425
|
-
"""
|
|
426
|
-
Mark the order of properties in the schema object.
|
|
427
|
-
They can have 'allOf', 'properties' or '$ref' key.
|
|
428
|
-
"""
|
|
429
|
-
if "allOf" in schema:
|
|
430
|
-
for sub_schema in schema["allOf"]:
|
|
431
|
-
order_schema_object(sub_schema)
|
|
432
|
-
if "properties" in schema:
|
|
433
|
-
mark_order(schema, "properties")
|
|
434
|
-
|
|
435
|
-
for key in spec["components"].get("schemas") or {}:
|
|
436
|
-
order_schema_object(spec["components"]["schemas"][key])
|
|
437
|
-
|
|
438
|
-
return spec
|
fal/apps.py
CHANGED
|
@@ -63,7 +63,10 @@ class RequestHandle:
|
|
|
63
63
|
_creds: Credentials = field(default_factory=get_default_credentials, repr=False)
|
|
64
64
|
|
|
65
65
|
def __post_init__(self):
|
|
66
|
-
|
|
66
|
+
app_id = _backwards_compatible_app_id(self.app_id)
|
|
67
|
+
# drop any extra path components
|
|
68
|
+
user_id, app_name = app_id.split("/")[:2]
|
|
69
|
+
self.app_id = f"{user_id}/{app_name}"
|
|
67
70
|
|
|
68
71
|
def status(self, *, logs: bool = False) -> _Status:
|
|
69
72
|
"""Check the status of an async inference request."""
|
|
@@ -116,7 +119,16 @@ class RequestHandle:
|
|
|
116
119
|
+ f"/requests/{self.request_id}/"
|
|
117
120
|
)
|
|
118
121
|
response = _HTTP_CLIENT.get(url, headers=self._creds.to_headers())
|
|
119
|
-
|
|
122
|
+
try:
|
|
123
|
+
response.raise_for_status()
|
|
124
|
+
except httpx.HTTPStatusError as e:
|
|
125
|
+
if response.headers["Content-Type"] != "application/json":
|
|
126
|
+
raise
|
|
127
|
+
raise httpx.HTTPStatusError(
|
|
128
|
+
f"{response.status_code}: {response.text}",
|
|
129
|
+
request=e.request,
|
|
130
|
+
response=e.response,
|
|
131
|
+
) from e
|
|
120
132
|
|
|
121
133
|
data = response.json()
|
|
122
134
|
return data
|
|
@@ -134,20 +146,23 @@ class RequestHandle:
|
|
|
134
146
|
_HTTP_CLIENT = httpx.Client(headers={"User-Agent": "Fal/Python"})
|
|
135
147
|
|
|
136
148
|
|
|
137
|
-
def run(app_id: str, arguments: dict[str, Any], *, path: str = "
|
|
149
|
+
def run(app_id: str, arguments: dict[str, Any], *, path: str = "") -> dict[str, Any]:
|
|
138
150
|
"""Run an inference task on a Fal app and return the result."""
|
|
139
151
|
|
|
140
152
|
handle = submit(app_id, arguments, path=path)
|
|
141
153
|
return handle.get()
|
|
142
154
|
|
|
143
155
|
|
|
144
|
-
def submit(app_id: str, arguments: dict[str, Any], *, path: str = "
|
|
156
|
+
def submit(app_id: str, arguments: dict[str, Any], *, path: str = "") -> RequestHandle:
|
|
145
157
|
"""Submit an async inference task to the app. Returns a request handle
|
|
146
158
|
which can be used to check the status of the request and retrieve the
|
|
147
159
|
result."""
|
|
148
160
|
|
|
149
161
|
app_id = _backwards_compatible_app_id(app_id)
|
|
150
|
-
url = _QUEUE_URL_FORMAT.format(app_id=app_id)
|
|
162
|
+
url = _QUEUE_URL_FORMAT.format(app_id=app_id)
|
|
163
|
+
if path:
|
|
164
|
+
url += "/" + path.removeprefix("/")
|
|
165
|
+
|
|
151
166
|
creds = get_default_credentials()
|
|
152
167
|
|
|
153
168
|
response = _HTTP_CLIENT.post(
|
|
@@ -206,7 +221,10 @@ def _connect(app_id: str, *, path: str = "/realtime") -> Iterator[_RealtimeConne
|
|
|
206
221
|
from websockets.sync import client
|
|
207
222
|
|
|
208
223
|
app_id = _backwards_compatible_app_id(app_id)
|
|
209
|
-
url = _REALTIME_URL_FORMAT.format(app_id=app_id)
|
|
224
|
+
url = _REALTIME_URL_FORMAT.format(app_id=app_id)
|
|
225
|
+
if path:
|
|
226
|
+
url += "/" + path.removeprefix("/")
|
|
227
|
+
|
|
210
228
|
creds = get_default_credentials()
|
|
211
229
|
|
|
212
230
|
with client.connect(
|