fal 0.12.1__py3-none-any.whl → 0.12.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fal might be problematic. Click here for more details.
- fal/__init__.py +12 -3
- fal/_serialization.py +18 -0
- fal/api.py +140 -59
- fal/app.py +309 -86
- fal/apps.py +92 -8
- fal/auth/__init__.py +20 -1
- fal/auth/auth0.py +32 -22
- fal/cli.py +34 -52
- fal/env.py +0 -4
- fal/exceptions/handlers.py +3 -2
- fal/flags.py +5 -0
- fal/logging/__init__.py +0 -2
- fal/logging/trace.py +8 -1
- fal/logging/user.py +2 -1
- fal/rest_client.py +2 -2
- fal/sdk.py +46 -31
- fal/sync.py +3 -3
- fal/toolkit/__init__.py +18 -1
- fal/toolkit/file/file.py +98 -11
- fal/toolkit/file/providers/fal.py +43 -2
- fal/toolkit/file/types.py +1 -1
- fal/toolkit/image/image.py +26 -4
- fal/toolkit/optimize.py +50 -0
- fal/toolkit/utils/download_utils.py +59 -13
- {fal-0.12.1.dist-info → fal-0.12.3.dist-info}/METADATA +7 -7
- fal-0.12.3.dist-info/RECORD +66 -0
- openapi_fal_rest/models/__init__.py +2 -70
- openapi_fal_rest/models/customer_details.py +26 -0
- openapi_fal_rest/models/lock_reason.py +16 -0
- fal/logging/datadog.py +0 -77
- fal-0.12.1.dist-info/RECORD +0 -147
- openapi_fal_rest/api/admin/get_invoice_users.py +0 -142
- openapi_fal_rest/api/admin/get_usage_per_user.py +0 -199
- openapi_fal_rest/api/admin/handle_user_lock.py +0 -191
- openapi_fal_rest/api/admin/set_billing_type.py +0 -186
- openapi_fal_rest/api/applications/get_status_applications_app_user_id_app_alias_or_id_status_get.py +0 -179
- openapi_fal_rest/api/billing/delete_payment_method.py +0 -162
- openapi_fal_rest/api/billing/get_checkout_page.py +0 -198
- openapi_fal_rest/api/billing/get_setup_intent_key.py +0 -141
- openapi_fal_rest/api/billing/get_user_invoices.py +0 -152
- openapi_fal_rest/api/billing/get_user_payment_methods.py +0 -152
- openapi_fal_rest/api/billing/get_user_price.py +0 -186
- openapi_fal_rest/api/billing/get_user_spending.py +0 -192
- openapi_fal_rest/api/billing/handle_stripe_webhook.py +0 -173
- openapi_fal_rest/api/billing/upcoming_invoice.py +0 -143
- openapi_fal_rest/api/billing/update_customer_budget.py +0 -183
- openapi_fal_rest/api/files/delete.py +0 -162
- openapi_fal_rest/api/files/download.py +0 -162
- openapi_fal_rest/api/files/file_exists.py +0 -183
- openapi_fal_rest/api/files/list_directory.py +0 -173
- openapi_fal_rest/api/files/list_root.py +0 -152
- openapi_fal_rest/api/files/upload_from_url.py +0 -179
- openapi_fal_rest/api/health/__init__.py +0 -0
- openapi_fal_rest/api/health/check.py +0 -136
- openapi_fal_rest/api/keys/__init__.py +0 -0
- openapi_fal_rest/api/keys/create_key.py +0 -188
- openapi_fal_rest/api/keys/delete_key.py +0 -162
- openapi_fal_rest/api/keys/list_keys.py +0 -152
- openapi_fal_rest/api/logs/__init__.py +0 -0
- openapi_fal_rest/api/logs/list_since.py +0 -224
- openapi_fal_rest/api/requests/__init__.py +0 -0
- openapi_fal_rest/api/requests/requests.py +0 -247
- openapi_fal_rest/api/storage/__init__.py +0 -0
- openapi_fal_rest/api/storage/get_file_link.py +0 -200
- openapi_fal_rest/api/storage/initiate_upload.py +0 -172
- openapi_fal_rest/api/storage/upload_file.py +0 -172
- openapi_fal_rest/api/tokens/__init__.py +0 -0
- openapi_fal_rest/api/tokens/create_token.py +0 -166
- openapi_fal_rest/api/usage/__init__.py +0 -0
- openapi_fal_rest/api/usage/get_custom_usage_per_machine.py +0 -203
- openapi_fal_rest/api/usage/get_gateway_request_stats.py +0 -247
- openapi_fal_rest/api/usage/get_gateway_request_stats_by_time.py +0 -236
- openapi_fal_rest/api/usage/get_gateway_stats_for_yesterday.py +0 -152
- openapi_fal_rest/api/usage/get_shared_usage_per_app.py +0 -203
- openapi_fal_rest/api/usage/get_usage_records.py +0 -253
- openapi_fal_rest/api/usage/per_machine_usage.py +0 -218
- openapi_fal_rest/api/usage/per_machine_usage_details.py +0 -173
- openapi_fal_rest/api/users/__init__.py +0 -0
- openapi_fal_rest/api/users/handle_user_registration.py +0 -228
- openapi_fal_rest/models/billing_type.py +0 -9
- openapi_fal_rest/models/body_create_token.py +0 -68
- openapi_fal_rest/models/body_upload_file.py +0 -75
- openapi_fal_rest/models/file_spec.py +0 -110
- openapi_fal_rest/models/gateway_stats_by_time.py +0 -115
- openapi_fal_rest/models/gateway_usage_stats.py +0 -147
- openapi_fal_rest/models/get_gateway_request_stats_by_time_response_get_gateway_request_stats_by_time.py +0 -70
- openapi_fal_rest/models/grouped_usage_detail.py +0 -85
- openapi_fal_rest/models/handle_stripe_webhook_response_handle_stripe_webhook.py +0 -43
- openapi_fal_rest/models/initiate_upload_info.py +0 -64
- openapi_fal_rest/models/invoice.py +0 -129
- openapi_fal_rest/models/invoice_item.py +0 -85
- openapi_fal_rest/models/key_scope.py +0 -9
- openapi_fal_rest/models/log_entry.py +0 -104
- openapi_fal_rest/models/log_entry_labels.py +0 -43
- openapi_fal_rest/models/new_user_key.py +0 -64
- openapi_fal_rest/models/payment_method.py +0 -96
- openapi_fal_rest/models/per_app_usage_detail.py +0 -88
- openapi_fal_rest/models/persisted_usage_record.py +0 -118
- openapi_fal_rest/models/persisted_usage_record_meta.py +0 -43
- openapi_fal_rest/models/presigned_upload_url.py +0 -64
- openapi_fal_rest/models/request_io.py +0 -112
- openapi_fal_rest/models/request_io_json_input.py +0 -43
- openapi_fal_rest/models/request_io_json_output.py +0 -43
- openapi_fal_rest/models/run_type.py +0 -9
- openapi_fal_rest/models/stats_timeframe.py +0 -12
- openapi_fal_rest/models/status.py +0 -82
- openapi_fal_rest/models/status_health.py +0 -10
- openapi_fal_rest/models/uploaded_file_result.py +0 -64
- openapi_fal_rest/models/url_file_upload.py +0 -57
- openapi_fal_rest/models/usage_per_machine_type.py +0 -115
- openapi_fal_rest/models/usage_per_user.py +0 -71
- openapi_fal_rest/models/usage_run_detail.py +0 -73
- openapi_fal_rest/models/user_key_info.py +0 -84
- /openapi_fal_rest/api/admin/__init__.py → /fal/py.typed +0 -0
- {fal-0.12.1.dist-info → fal-0.12.3.dist-info}/WHEEL +0 -0
- {fal-0.12.1.dist-info → fal-0.12.3.dist-info}/entry_points.txt +0 -0
fal/app.py
CHANGED
|
@@ -1,28 +1,50 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
+
import json
|
|
4
5
|
import os
|
|
5
|
-
import
|
|
6
|
-
from
|
|
6
|
+
import typing
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
8
|
+
from typing import Any, Callable, ClassVar, TypeVar
|
|
9
|
+
|
|
7
10
|
from fastapi import FastAPI
|
|
8
|
-
|
|
11
|
+
|
|
12
|
+
import fal.api
|
|
13
|
+
from fal._serialization import add_serialization_listeners_for
|
|
14
|
+
from fal.api import RouteSignature
|
|
9
15
|
from fal.logging import get_logger
|
|
16
|
+
from fal.toolkit import mainify
|
|
17
|
+
|
|
18
|
+
REALTIME_APP_REQUIREMENTS = ["websockets", "msgpack"]
|
|
10
19
|
|
|
11
20
|
EndpointT = TypeVar("EndpointT", bound=Callable[..., Any])
|
|
12
21
|
logger = get_logger(__name__)
|
|
13
22
|
|
|
14
23
|
|
|
24
|
+
async def _call_any_fn(fn, *args, **kwargs):
|
|
25
|
+
if inspect.iscoroutinefunction(fn):
|
|
26
|
+
return await fn(*args, **kwargs)
|
|
27
|
+
else:
|
|
28
|
+
return fn(*args, **kwargs)
|
|
29
|
+
|
|
30
|
+
|
|
15
31
|
def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
|
|
32
|
+
add_serialization_listeners_for(cls)
|
|
33
|
+
|
|
16
34
|
def initialize_and_serve():
|
|
17
35
|
app = cls()
|
|
18
36
|
app.serve()
|
|
19
37
|
|
|
38
|
+
metadata = {}
|
|
20
39
|
try:
|
|
21
40
|
app = cls(_allow_init=True)
|
|
22
|
-
metadata = app.openapi()
|
|
41
|
+
metadata["openapi"] = app.openapi()
|
|
23
42
|
except Exception as exc:
|
|
24
43
|
logger.warning("Failed to build OpenAPI specification for %s", cls.__name__)
|
|
25
|
-
|
|
44
|
+
realtime_app = False
|
|
45
|
+
else:
|
|
46
|
+
routes = app.collect_routes()
|
|
47
|
+
realtime_app = any(route.is_websocket for route in routes)
|
|
26
48
|
|
|
27
49
|
wrapper = fal.api.function(
|
|
28
50
|
"virtualenv",
|
|
@@ -31,27 +53,26 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
|
|
|
31
53
|
**cls.host_kwargs,
|
|
32
54
|
**kwargs,
|
|
33
55
|
metadata=metadata,
|
|
34
|
-
serve=True,
|
|
35
|
-
)
|
|
36
|
-
return wrapper(initialize_and_serve).on(
|
|
37
|
-
serve=False,
|
|
38
56
|
exposed_port=8080,
|
|
57
|
+
serve=False,
|
|
39
58
|
)
|
|
59
|
+
fn = wrapper(initialize_and_serve)
|
|
60
|
+
fn.options.add_requirements(fal.api.SERVE_REQUIREMENTS)
|
|
61
|
+
if realtime_app:
|
|
62
|
+
fn.options.add_requirements(REALTIME_APP_REQUIREMENTS)
|
|
40
63
|
|
|
41
|
-
|
|
42
|
-
@mainify
|
|
43
|
-
class RouteSignature(NamedTuple):
|
|
44
|
-
path: str
|
|
64
|
+
return fn
|
|
45
65
|
|
|
46
66
|
|
|
47
67
|
@mainify
|
|
48
|
-
class App:
|
|
68
|
+
class App(fal.api.BaseServable):
|
|
49
69
|
requirements: ClassVar[list[str]] = []
|
|
50
70
|
machine_type: ClassVar[str] = "S"
|
|
51
71
|
host_kwargs: ClassVar[dict[str, Any]] = {}
|
|
52
72
|
|
|
53
73
|
def __init_subclass__(cls, **kwargs):
|
|
54
|
-
cls
|
|
74
|
+
parent_settings = getattr(cls, "host_kwargs", {})
|
|
75
|
+
cls.host_kwargs = {**parent_settings, **kwargs}
|
|
55
76
|
|
|
56
77
|
if cls.__init__ is not App.__init__:
|
|
57
78
|
raise ValueError(
|
|
@@ -65,98 +86,300 @@ class App:
|
|
|
65
86
|
"Running apps through SDK is not implemented yet."
|
|
66
87
|
)
|
|
67
88
|
|
|
68
|
-
def
|
|
69
|
-
|
|
89
|
+
def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
|
|
90
|
+
return {
|
|
91
|
+
signature: endpoint
|
|
92
|
+
for _, endpoint in inspect.getmembers(self, inspect.ismethod)
|
|
93
|
+
if (signature := getattr(endpoint, "route_signature", None))
|
|
94
|
+
}
|
|
70
95
|
|
|
71
|
-
|
|
72
|
-
|
|
96
|
+
@asynccontextmanager
|
|
97
|
+
async def lifespan(self, app: FastAPI):
|
|
98
|
+
await _call_any_fn(self.setup)
|
|
99
|
+
try:
|
|
100
|
+
yield
|
|
101
|
+
finally:
|
|
102
|
+
await _call_any_fn(self.teardown)
|
|
73
103
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
uvicorn.run(app, host="0.0.0.0", port=8080)
|
|
104
|
+
def setup(self):
|
|
105
|
+
"""Setup the application before serving."""
|
|
77
106
|
|
|
78
|
-
def
|
|
79
|
-
|
|
80
|
-
|
|
107
|
+
def teardown(self):
|
|
108
|
+
"""Teardown the application after serving."""
|
|
109
|
+
|
|
110
|
+
def _add_extra_middlewares(self, app: FastAPI):
|
|
111
|
+
@app.middleware("http")
|
|
112
|
+
async def provide_hints_headers(request, call_next):
|
|
113
|
+
response = await call_next(request)
|
|
114
|
+
try:
|
|
115
|
+
response.headers["X-Fal-Runner-Hints"] = ",".join(self.provide_hints())
|
|
116
|
+
except NotImplementedError:
|
|
117
|
+
# This lets us differentiate between apps that don't provide hints
|
|
118
|
+
# and apps that provide empty hints.
|
|
119
|
+
pass
|
|
120
|
+
except Exception:
|
|
121
|
+
from fastapi.logger import logger
|
|
122
|
+
|
|
123
|
+
logger.exception(
|
|
124
|
+
"Failed to provide hints for %s",
|
|
125
|
+
self.__class__.__name__,
|
|
126
|
+
)
|
|
127
|
+
return response
|
|
128
|
+
|
|
129
|
+
def provide_hints(self) -> list[str]:
|
|
130
|
+
"""Provide hints for routing the application."""
|
|
131
|
+
raise NotImplementedError
|
|
81
132
|
|
|
82
|
-
_app = FastAPI()
|
|
83
133
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
allow_origins=("*"),
|
|
90
|
-
)
|
|
134
|
+
@mainify
|
|
135
|
+
def endpoint(
|
|
136
|
+
path: str, *, is_websocket: bool = False
|
|
137
|
+
) -> Callable[[EndpointT], EndpointT]:
|
|
138
|
+
"""Designate the decorated function as an application endpoint."""
|
|
91
139
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
}
|
|
97
|
-
if not routes:
|
|
98
|
-
raise ValueError("An application must have at least one route!")
|
|
99
|
-
|
|
100
|
-
for signature, endpoint in routes.items():
|
|
101
|
-
_app.add_api_route(
|
|
102
|
-
signature.path,
|
|
103
|
-
endpoint,
|
|
104
|
-
name=endpoint.__name__,
|
|
105
|
-
methods=["POST"],
|
|
140
|
+
def marker_fn(callable: EndpointT) -> EndpointT:
|
|
141
|
+
if hasattr(callable, "route_signature"):
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"Can't set multiple routes for the same function: {callable.__name__}"
|
|
106
144
|
)
|
|
107
145
|
|
|
108
|
-
|
|
146
|
+
callable.route_signature = RouteSignature(path=path, is_websocket=is_websocket) # type: ignore
|
|
147
|
+
return callable
|
|
109
148
|
|
|
110
|
-
|
|
111
|
-
"""
|
|
112
|
-
Build the OpenAPI specification for the served function.
|
|
113
|
-
Attach needed metadata for a better integration to fal.
|
|
114
|
-
"""
|
|
115
|
-
app = self._build_app()
|
|
116
|
-
spec = app.openapi()
|
|
117
|
-
self._mark_order_openapi(spec)
|
|
118
|
-
return spec
|
|
149
|
+
return marker_fn
|
|
119
150
|
|
|
120
|
-
def _mark_order_openapi(self, spec: dict[str, Any]):
|
|
121
|
-
"""
|
|
122
|
-
Add x-fal-order-* keys to the OpenAPI specification to help the rendering of UI.
|
|
123
151
|
|
|
124
|
-
|
|
125
|
-
|
|
152
|
+
def _fal_websocket_template(
|
|
153
|
+
func: EndpointT,
|
|
154
|
+
route_signature: RouteSignature,
|
|
155
|
+
) -> EndpointT:
|
|
156
|
+
# A template for fal's realtime websocket endpoints to basically
|
|
157
|
+
# be a boilerplate for the user to fill in their inference function
|
|
158
|
+
# and start using it.
|
|
159
|
+
|
|
160
|
+
import asyncio
|
|
161
|
+
from collections import deque
|
|
162
|
+
from contextlib import suppress
|
|
163
|
+
|
|
164
|
+
import msgpack
|
|
165
|
+
from fastapi import WebSocket, WebSocketDisconnect
|
|
166
|
+
|
|
167
|
+
async def mirror_input(queue: deque[Any], websocket: WebSocket) -> None:
|
|
168
|
+
while True:
|
|
169
|
+
try:
|
|
170
|
+
raw_input = await asyncio.wait_for(
|
|
171
|
+
websocket.receive_bytes(),
|
|
172
|
+
timeout=route_signature.session_timeout,
|
|
173
|
+
)
|
|
174
|
+
except asyncio.TimeoutError:
|
|
175
|
+
return
|
|
176
|
+
|
|
177
|
+
input = msgpack.unpackb(raw_input, raw=False)
|
|
178
|
+
if route_signature.input_modal:
|
|
179
|
+
input = route_signature.input_modal(**input)
|
|
180
|
+
|
|
181
|
+
queue.append(input)
|
|
182
|
+
|
|
183
|
+
async def mirror_output(
|
|
184
|
+
self,
|
|
185
|
+
queue: deque[Any],
|
|
186
|
+
websocket: WebSocket,
|
|
187
|
+
) -> None:
|
|
188
|
+
loop = asyncio.get_event_loop()
|
|
189
|
+
max_allowed_buffering = route_signature.buffering or 1
|
|
190
|
+
outgoing_messages: asyncio.Queue[bytes] = asyncio.Queue(
|
|
191
|
+
maxsize=max_allowed_buffering * 2 # x2 for outgoing timings
|
|
192
|
+
)
|
|
126
193
|
|
|
127
|
-
def
|
|
128
|
-
|
|
194
|
+
async def emit(message):
|
|
195
|
+
if isinstance(message, bytes):
|
|
196
|
+
await websocket.send_bytes(message)
|
|
197
|
+
elif isinstance(message, str):
|
|
198
|
+
await websocket.send_text(message)
|
|
199
|
+
else:
|
|
200
|
+
raise TypeError(f"Can't send message of type {type(message)}")
|
|
201
|
+
|
|
202
|
+
async def background_emitter():
|
|
203
|
+
while True:
|
|
204
|
+
output = await outgoing_messages.get()
|
|
205
|
+
await emit(output)
|
|
206
|
+
outgoing_messages.task_done()
|
|
207
|
+
|
|
208
|
+
emitter = asyncio.create_task(background_emitter())
|
|
209
|
+
|
|
210
|
+
while True:
|
|
211
|
+
if not queue:
|
|
212
|
+
await asyncio.sleep(0.05)
|
|
213
|
+
continue
|
|
214
|
+
|
|
215
|
+
input = queue.popleft()
|
|
216
|
+
if input is None or emitter.done():
|
|
217
|
+
if not emitter.done():
|
|
218
|
+
await outgoing_messages.join()
|
|
219
|
+
emitter.cancel()
|
|
220
|
+
|
|
221
|
+
with suppress(asyncio.CancelledError):
|
|
222
|
+
await emitter
|
|
223
|
+
return None # End of input
|
|
224
|
+
|
|
225
|
+
batch = [input]
|
|
226
|
+
while queue and len(batch) < route_signature.max_batch_size:
|
|
227
|
+
next_input = queue.popleft()
|
|
228
|
+
if hasattr(input, "can_batch") and not input.can_batch(
|
|
229
|
+
next_input, len(batch)
|
|
230
|
+
):
|
|
231
|
+
queue.appendleft(next_input)
|
|
232
|
+
break
|
|
233
|
+
batch.append(next_input)
|
|
234
|
+
|
|
235
|
+
t0 = loop.time()
|
|
236
|
+
output = await loop.run_in_executor(None, func, self, *batch) # type: ignore
|
|
237
|
+
total_time = loop.time() - t0
|
|
238
|
+
if not isinstance(output, dict):
|
|
239
|
+
# Handle pydantic output modal
|
|
240
|
+
if hasattr(output, "dict"):
|
|
241
|
+
output = output.dict()
|
|
242
|
+
else:
|
|
243
|
+
raise TypeError(
|
|
244
|
+
f"Expected a dict or pydantic model as output, got {type(output)}"
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
messages = [
|
|
248
|
+
msgpack.packb(output, use_bin_type=True),
|
|
249
|
+
]
|
|
250
|
+
if route_signature.emit_timings:
|
|
251
|
+
# We emit x-fal messages in JSON, no matter what the input/output format is.
|
|
252
|
+
timings = {
|
|
253
|
+
"type": "x-fal-message",
|
|
254
|
+
"action": "timings",
|
|
255
|
+
"timing": total_time,
|
|
256
|
+
}
|
|
257
|
+
messages.append(json.dumps(timings, separators=(",", ":")))
|
|
258
|
+
|
|
259
|
+
for message in messages:
|
|
260
|
+
try:
|
|
261
|
+
outgoing_messages.put_nowait(message)
|
|
262
|
+
except asyncio.QueueFull:
|
|
263
|
+
await emit(message)
|
|
264
|
+
|
|
265
|
+
async def websocket_template(self, websocket: WebSocket) -> None:
|
|
266
|
+
import asyncio
|
|
267
|
+
|
|
268
|
+
await websocket.accept()
|
|
269
|
+
|
|
270
|
+
queue: deque[Any] = deque(maxlen=route_signature.buffering)
|
|
271
|
+
input_task = asyncio.create_task(mirror_input(queue, websocket))
|
|
272
|
+
input_task.add_done_callback(lambda _: queue.append(None))
|
|
273
|
+
output_task = asyncio.create_task(mirror_output(self, queue, websocket))
|
|
274
|
+
|
|
275
|
+
try:
|
|
276
|
+
await asyncio.wait(
|
|
277
|
+
{
|
|
278
|
+
input_task,
|
|
279
|
+
output_task,
|
|
280
|
+
},
|
|
281
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
282
|
+
)
|
|
283
|
+
if input_task.done():
|
|
284
|
+
# User didn't send any input within the timeout
|
|
285
|
+
# so we can just close the connection after the
|
|
286
|
+
# processing of the last input is done.
|
|
287
|
+
input_task.result()
|
|
288
|
+
await asyncio.wait_for(
|
|
289
|
+
output_task, timeout=route_signature.session_timeout
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
assert output_task.done()
|
|
293
|
+
|
|
294
|
+
# The execution of the inference function failed or exitted,
|
|
295
|
+
# so just propagate the result.
|
|
296
|
+
input_task.cancel()
|
|
297
|
+
with suppress(asyncio.CancelledError):
|
|
298
|
+
await input_task
|
|
299
|
+
|
|
300
|
+
output_task.result()
|
|
301
|
+
except WebSocketDisconnect:
|
|
302
|
+
input_task.cancel()
|
|
303
|
+
output_task.cancel()
|
|
304
|
+
with suppress(asyncio.CancelledError):
|
|
305
|
+
await input_task
|
|
306
|
+
|
|
307
|
+
with suppress(asyncio.CancelledError):
|
|
308
|
+
await output_task
|
|
309
|
+
except Exception as exc:
|
|
310
|
+
import traceback
|
|
311
|
+
|
|
312
|
+
traceback.print_exc()
|
|
313
|
+
|
|
314
|
+
await websocket.send_json(
|
|
315
|
+
{
|
|
316
|
+
"type": "x-fal-error",
|
|
317
|
+
"error": "INTERNAL_ERROR",
|
|
318
|
+
"reason": str(exc),
|
|
319
|
+
}
|
|
320
|
+
)
|
|
321
|
+
else:
|
|
322
|
+
await websocket.send_json(
|
|
323
|
+
{
|
|
324
|
+
"type": "x-fal-error",
|
|
325
|
+
"error": "TIMEOUT",
|
|
326
|
+
"reason": "no inputs, reconnect when needed!",
|
|
327
|
+
}
|
|
328
|
+
)
|
|
129
329
|
|
|
130
|
-
|
|
330
|
+
await websocket.close()
|
|
131
331
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
mark_order(schema, "properties")
|
|
332
|
+
# Seems like templating + stringified annotations don't play well,
|
|
333
|
+
# so we have to set them manually.
|
|
334
|
+
websocket_template.__annotations__ = {
|
|
335
|
+
"websocket": WebSocket,
|
|
336
|
+
"return": None,
|
|
337
|
+
}
|
|
338
|
+
websocket_template.route_signature = route_signature # type: ignore
|
|
339
|
+
websocket_template.original_func = func # type: ignore
|
|
340
|
+
return typing.cast(EndpointT, websocket_template)
|
|
142
341
|
|
|
143
|
-
for key in spec["components"].get("schemas") or {}:
|
|
144
|
-
order_schema_object(spec["components"]["schemas"][key])
|
|
145
342
|
|
|
146
|
-
|
|
343
|
+
_SENTINEL = object()
|
|
147
344
|
|
|
148
345
|
|
|
149
346
|
@mainify
|
|
150
|
-
def
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
347
|
+
def realtime(
|
|
348
|
+
path: str,
|
|
349
|
+
*,
|
|
350
|
+
buffering: int | None = None,
|
|
351
|
+
session_timeout: float | None = None,
|
|
352
|
+
input_modal: Any | None = _SENTINEL,
|
|
353
|
+
max_batch_size: int = 1,
|
|
354
|
+
) -> Callable[[EndpointT], EndpointT]:
|
|
355
|
+
"""Designate the decorated function as a realtime application endpoint."""
|
|
356
|
+
|
|
357
|
+
def marker_fn(original_func: EndpointT) -> EndpointT:
|
|
358
|
+
nonlocal input_modal
|
|
359
|
+
|
|
360
|
+
if hasattr(original_func, "route_signature"):
|
|
155
361
|
raise ValueError(
|
|
156
|
-
f"Can't set multiple routes for the same function: {
|
|
362
|
+
f"Can't set multiple routes for the same function: {original_func.__name__}"
|
|
157
363
|
)
|
|
158
364
|
|
|
159
|
-
|
|
160
|
-
|
|
365
|
+
if input_modal is _SENTINEL:
|
|
366
|
+
type_hints = typing.get_type_hints(original_func)
|
|
367
|
+
if len(type_hints) >= 1:
|
|
368
|
+
input_modal = type_hints[list(type_hints.keys())[0]]
|
|
369
|
+
else:
|
|
370
|
+
input_modal = None
|
|
371
|
+
|
|
372
|
+
route_signature = RouteSignature(
|
|
373
|
+
path=path,
|
|
374
|
+
is_websocket=True,
|
|
375
|
+
input_modal=input_modal,
|
|
376
|
+
buffering=buffering,
|
|
377
|
+
session_timeout=session_timeout,
|
|
378
|
+
max_batch_size=max_batch_size,
|
|
379
|
+
)
|
|
380
|
+
return _fal_websocket_template(
|
|
381
|
+
original_func,
|
|
382
|
+
route_signature,
|
|
383
|
+
)
|
|
161
384
|
|
|
162
385
|
return marker_fn
|
fal/apps.py
CHANGED
|
@@ -1,14 +1,26 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import json
|
|
3
4
|
import time
|
|
5
|
+
from contextlib import contextmanager
|
|
4
6
|
from dataclasses import dataclass, field
|
|
5
7
|
from typing import Any, Iterator
|
|
6
8
|
|
|
7
9
|
import httpx
|
|
10
|
+
|
|
8
11
|
from fal import flags
|
|
9
12
|
from fal.sdk import Credentials, get_default_credentials
|
|
10
13
|
|
|
11
|
-
|
|
14
|
+
_QUEUE_URL_FORMAT = f"https://queue.{flags.FAL_RUN_HOST}/{{app_id}}"
|
|
15
|
+
_REALTIME_URL_FORMAT = f"wss://{flags.FAL_RUN_HOST}/{{app_id}}"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _backwards_compatible_app_id(app_id: str) -> str:
|
|
19
|
+
if "/" not in app_id:
|
|
20
|
+
# Convert the app_id to the format used in the URL.
|
|
21
|
+
return app_id.replace("-", "/", 1)
|
|
22
|
+
|
|
23
|
+
return app_id
|
|
12
24
|
|
|
13
25
|
|
|
14
26
|
@dataclass
|
|
@@ -50,11 +62,17 @@ class RequestHandle:
|
|
|
50
62
|
# Use the credentials that were used to submit the request by default.
|
|
51
63
|
_creds: Credentials = field(default_factory=get_default_credentials, repr=False)
|
|
52
64
|
|
|
65
|
+
def __post_init__(self):
|
|
66
|
+
app_id = _backwards_compatible_app_id(self.app_id)
|
|
67
|
+
# drop any extra path components
|
|
68
|
+
user_id, app_name = app_id.split("/")[:2]
|
|
69
|
+
self.app_id = f"{user_id}/{app_name}"
|
|
70
|
+
|
|
53
71
|
def status(self, *, logs: bool = False) -> _Status:
|
|
54
72
|
"""Check the status of an async inference request."""
|
|
55
73
|
|
|
56
74
|
url = (
|
|
57
|
-
|
|
75
|
+
_QUEUE_URL_FORMAT.format(app_id=self.app_id)
|
|
58
76
|
+ f"/requests/{self.request_id}/status/"
|
|
59
77
|
)
|
|
60
78
|
response = _HTTP_CLIENT.get(
|
|
@@ -97,11 +115,20 @@ class RequestHandle:
|
|
|
97
115
|
"""Retrieve the result of an async inference request, raises an exception
|
|
98
116
|
if the request is not completed yet."""
|
|
99
117
|
url = (
|
|
100
|
-
|
|
101
|
-
+ f"/requests/{self.request_id}/
|
|
118
|
+
_QUEUE_URL_FORMAT.format(app_id=self.app_id)
|
|
119
|
+
+ f"/requests/{self.request_id}/"
|
|
102
120
|
)
|
|
103
121
|
response = _HTTP_CLIENT.get(url, headers=self._creds.to_headers())
|
|
104
|
-
|
|
122
|
+
try:
|
|
123
|
+
response.raise_for_status()
|
|
124
|
+
except httpx.HTTPStatusError as e:
|
|
125
|
+
if response.headers["Content-Type"] != "application/json":
|
|
126
|
+
raise
|
|
127
|
+
raise httpx.HTTPStatusError(
|
|
128
|
+
f"{response.status_code}: {response.text}",
|
|
129
|
+
request=e.request,
|
|
130
|
+
response=e.response,
|
|
131
|
+
) from e
|
|
105
132
|
|
|
106
133
|
data = response.json()
|
|
107
134
|
return data
|
|
@@ -119,19 +146,23 @@ class RequestHandle:
|
|
|
119
146
|
_HTTP_CLIENT = httpx.Client(headers={"User-Agent": "Fal/Python"})
|
|
120
147
|
|
|
121
148
|
|
|
122
|
-
def run(app_id: str, arguments: dict[str, Any], *, path: str = "
|
|
149
|
+
def run(app_id: str, arguments: dict[str, Any], *, path: str = "") -> dict[str, Any]:
|
|
123
150
|
"""Run an inference task on a Fal app and return the result."""
|
|
124
151
|
|
|
125
152
|
handle = submit(app_id, arguments, path=path)
|
|
126
153
|
return handle.get()
|
|
127
154
|
|
|
128
155
|
|
|
129
|
-
def submit(app_id: str, arguments: dict[str, Any], *, path: str = "
|
|
156
|
+
def submit(app_id: str, arguments: dict[str, Any], *, path: str = "") -> RequestHandle:
|
|
130
157
|
"""Submit an async inference task to the app. Returns a request handle
|
|
131
158
|
which can be used to check the status of the request and retrieve the
|
|
132
159
|
result."""
|
|
133
160
|
|
|
134
|
-
|
|
161
|
+
app_id = _backwards_compatible_app_id(app_id)
|
|
162
|
+
url = _QUEUE_URL_FORMAT.format(app_id=app_id)
|
|
163
|
+
if path:
|
|
164
|
+
url += "/" + path.removeprefix("/")
|
|
165
|
+
|
|
135
166
|
creds = get_default_credentials()
|
|
136
167
|
|
|
137
168
|
response = _HTTP_CLIENT.post(
|
|
@@ -147,3 +178,56 @@ def submit(app_id: str, arguments: dict[str, Any], *, path: str = "/") -> Reques
|
|
|
147
178
|
request_id=data["request_id"],
|
|
148
179
|
_creds=creds,
|
|
149
180
|
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@dataclass
|
|
184
|
+
class _RealtimeConnection:
|
|
185
|
+
"""A realtime connection to a Fal app."""
|
|
186
|
+
|
|
187
|
+
_ws: Any
|
|
188
|
+
|
|
189
|
+
def run(self, arguments: dict[str, Any]) -> dict[str, Any]:
|
|
190
|
+
"""Run an inference task on the app and return the result."""
|
|
191
|
+
self.send(arguments)
|
|
192
|
+
return self.recv()
|
|
193
|
+
|
|
194
|
+
def send(self, arguments: dict[str, Any]) -> None:
|
|
195
|
+
import msgpack
|
|
196
|
+
|
|
197
|
+
"""Send an inference task to the app."""
|
|
198
|
+
payload = msgpack.packb(arguments)
|
|
199
|
+
self._ws.send(payload)
|
|
200
|
+
|
|
201
|
+
def recv(self) -> dict[str, Any]:
|
|
202
|
+
import msgpack
|
|
203
|
+
|
|
204
|
+
"""Receive the result of an inference task."""
|
|
205
|
+
while True:
|
|
206
|
+
response = self._ws.recv()
|
|
207
|
+
if isinstance(response, str):
|
|
208
|
+
print(response)
|
|
209
|
+
json_payload = json.loads(response)
|
|
210
|
+
if json_payload.get("type") == "x-fal-error":
|
|
211
|
+
raise ValueError(json_payload["reason"])
|
|
212
|
+
continue
|
|
213
|
+
return msgpack.unpackb(response)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@contextmanager
|
|
217
|
+
def _connect(app_id: str, *, path: str = "/realtime") -> Iterator[_RealtimeConnection]:
|
|
218
|
+
"""Connect to a realtime endpoint. This is an internal and experimental API, use it
|
|
219
|
+
at your own risk."""
|
|
220
|
+
|
|
221
|
+
from websockets.sync import client
|
|
222
|
+
|
|
223
|
+
app_id = _backwards_compatible_app_id(app_id)
|
|
224
|
+
url = _REALTIME_URL_FORMAT.format(app_id=app_id)
|
|
225
|
+
if path:
|
|
226
|
+
url += "/" + path.removeprefix("/")
|
|
227
|
+
|
|
228
|
+
creds = get_default_credentials()
|
|
229
|
+
|
|
230
|
+
with client.connect(
|
|
231
|
+
url, additional_headers=creds.to_headers(), open_timeout=90
|
|
232
|
+
) as ws:
|
|
233
|
+
yield _RealtimeConnection(ws)
|
fal/auth/__init__.py
CHANGED
|
@@ -1,12 +1,31 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
5
|
|
|
5
6
|
import click
|
|
7
|
+
|
|
6
8
|
from fal.auth import auth0, local
|
|
7
9
|
from fal.console import console
|
|
8
10
|
from fal.console.icons import CHECK_ICON
|
|
9
11
|
from fal.exceptions.auth import UnauthenticatedException
|
|
12
|
+
from fal.toolkit.mainify import mainify
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@mainify
|
|
16
|
+
def key_credentials() -> tuple[str, str] | None:
|
|
17
|
+
# Ignore key credentials when the user forces auth by user.
|
|
18
|
+
if os.environ.get("FAL_FORCE_AUTH_BY_USER") == "1":
|
|
19
|
+
return None
|
|
20
|
+
|
|
21
|
+
if "FAL_KEY" in os.environ:
|
|
22
|
+
key = os.environ["FAL_KEY"]
|
|
23
|
+
key_id, key_secret = key.split(":", 1)
|
|
24
|
+
return (key_id, key_secret)
|
|
25
|
+
elif "FAL_KEY_ID" in os.environ and "FAL_KEY_SECRET" in os.environ:
|
|
26
|
+
return (os.environ["FAL_KEY_ID"], os.environ["FAL_KEY_SECRET"])
|
|
27
|
+
else:
|
|
28
|
+
return None
|
|
10
29
|
|
|
11
30
|
|
|
12
31
|
def login():
|
|
@@ -43,7 +62,7 @@ def _fetch_access_token() -> str:
|
|
|
43
62
|
|
|
44
63
|
if access_token is not None:
|
|
45
64
|
try:
|
|
46
|
-
auth0.
|
|
65
|
+
auth0.verify_access_token_expiration(access_token)
|
|
47
66
|
return access_token
|
|
48
67
|
except:
|
|
49
68
|
# access_token expired, will refresh
|