fal 1.3.3__py3-none-any.whl → 1.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

fal/_fal_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '1.3.3'
16
- __version_tuple__ = version_tuple = (1, 3, 3)
15
+ __version__ = version = '1.7.2'
16
+ __version_tuple__ = version_tuple = (1, 7, 2)
fal/api.py CHANGED
@@ -76,6 +76,8 @@ SERVE_REQUIREMENTS = [
76
76
  f"pydantic=={pydantic_version}",
77
77
  "uvicorn",
78
78
  "starlette_exporter",
79
+ "structlog",
80
+ "tomli",
79
81
  ]
80
82
 
81
83
 
@@ -170,6 +172,7 @@ class Host(Generic[ArgsT, ReturnT]):
170
172
  application_name: str | None = None,
171
173
  application_auth_mode: Literal["public", "shared", "private"] | None = None,
172
174
  metadata: dict[str, Any] | None = None,
175
+ scale: bool = True,
173
176
  ) -> str | None:
174
177
  """Register the given function on the host for API call execution."""
175
178
  raise NotImplementedError
@@ -389,12 +392,15 @@ class FalServerlessHost(Host):
389
392
  _SUPPORTED_KEYS = frozenset(
390
393
  {
391
394
  "machine_type",
395
+ "machine_types",
396
+ "num_gpus",
392
397
  "keep_alive",
393
398
  "max_concurrency",
394
399
  "min_concurrency",
395
400
  "max_multiplexing",
396
401
  "setup_function",
397
402
  "metadata",
403
+ "request_timeout",
398
404
  "_base_image",
399
405
  "_scheduler",
400
406
  "_scheduler_options",
@@ -426,25 +432,27 @@ class FalServerlessHost(Host):
426
432
  application_auth_mode: Literal["public", "shared", "private"] | None = None,
427
433
  metadata: dict[str, Any] | None = None,
428
434
  deployment_strategy: Literal["recreate", "rolling"] = "recreate",
435
+ scale: bool = True,
429
436
  ) -> str | None:
430
437
  environment_options = options.environment.copy()
431
438
  environment_options.setdefault("python_version", active_python())
432
439
  environments = [self._connection.define_environment(**environment_options)]
433
440
 
434
- machine_type = options.host.get(
441
+ machine_type: list[str] | str = options.host.get(
435
442
  "machine_type", FAL_SERVERLESS_DEFAULT_MACHINE_TYPE
436
443
  )
437
444
  keep_alive = options.host.get("keep_alive", FAL_SERVERLESS_DEFAULT_KEEP_ALIVE)
438
- max_concurrency = options.host.get("max_concurrency")
439
- min_concurrency = options.host.get("min_concurrency")
440
- max_multiplexing = options.host.get("max_multiplexing")
441
445
  base_image = options.host.get("_base_image", None)
442
446
  scheduler = options.host.get("_scheduler", None)
443
447
  scheduler_options = options.host.get("_scheduler_options", None)
448
+ max_concurrency = options.host.get("max_concurrency")
449
+ min_concurrency = options.host.get("min_concurrency")
450
+ max_multiplexing = options.host.get("max_multiplexing")
444
451
  exposed_port = options.get_exposed_port()
445
-
452
+ request_timeout = options.host.get("request_timeout")
446
453
  machine_requirements = MachineRequirements(
447
- machine_type=machine_type,
454
+ machine_types=machine_type, # type: ignore
455
+ num_gpus=options.host.get("num_gpus"),
448
456
  keep_alive=keep_alive,
449
457
  base_image=base_image,
450
458
  exposed_port=exposed_port,
@@ -453,6 +461,7 @@ class FalServerlessHost(Host):
453
461
  max_multiplexing=max_multiplexing,
454
462
  max_concurrency=max_concurrency,
455
463
  min_concurrency=min_concurrency,
464
+ request_timeout=request_timeout,
456
465
  )
457
466
 
458
467
  partial_func = _prepare_partial_func(func)
@@ -479,6 +488,7 @@ class FalServerlessHost(Host):
479
488
  machine_requirements=machine_requirements,
480
489
  metadata=metadata,
481
490
  deployment_strategy=deployment_strategy,
491
+ scale=scale,
482
492
  ):
483
493
  for log in partial_result.logs:
484
494
  self._log_printer.print(log)
@@ -501,7 +511,7 @@ class FalServerlessHost(Host):
501
511
  environment_options.setdefault("python_version", active_python())
502
512
  environments = [self._connection.define_environment(**environment_options)]
503
513
 
504
- machine_type = options.host.get(
514
+ machine_type: list[str] | str = options.host.get(
505
515
  "machine_type", FAL_SERVERLESS_DEFAULT_MACHINE_TYPE
506
516
  )
507
517
  keep_alive = options.host.get("keep_alive", FAL_SERVERLESS_DEFAULT_KEEP_ALIVE)
@@ -513,9 +523,11 @@ class FalServerlessHost(Host):
513
523
  scheduler_options = options.host.get("_scheduler_options", None)
514
524
  exposed_port = options.get_exposed_port()
515
525
  setup_function = options.host.get("setup_function", None)
526
+ request_timeout = options.host.get("request_timeout")
516
527
 
517
528
  machine_requirements = MachineRequirements(
518
- machine_type=machine_type,
529
+ machine_types=machine_type, # type: ignore
530
+ num_gpus=options.host.get("num_gpus"),
519
531
  keep_alive=keep_alive,
520
532
  base_image=base_image,
521
533
  exposed_port=exposed_port,
@@ -524,6 +536,7 @@ class FalServerlessHost(Host):
524
536
  max_multiplexing=max_multiplexing,
525
537
  max_concurrency=max_concurrency,
526
538
  min_concurrency=min_concurrency,
539
+ request_timeout=request_timeout,
527
540
  )
528
541
 
529
542
  return_value = _UNSET
@@ -684,10 +697,12 @@ def function(
684
697
  max_concurrency: int | None = None,
685
698
  # FalServerlessHost options
686
699
  metadata: dict[str, Any] | None = None,
687
- machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
700
+ machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
701
+ num_gpus: int | None = None,
688
702
  keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
689
703
  max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
690
704
  min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
705
+ request_timeout: int | None = None,
691
706
  setup_function: Callable[..., None] | None = None,
692
707
  _base_image: str | None = None,
693
708
  _scheduler: str | None = None,
@@ -709,10 +724,12 @@ def function(
709
724
  max_concurrency: int | None = None,
710
725
  # FalServerlessHost options
711
726
  metadata: dict[str, Any] | None = None,
712
- machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
727
+ machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
728
+ num_gpus: int | None = None,
713
729
  keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
714
730
  max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
715
731
  min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
732
+ request_timeout: int | None = None,
716
733
  setup_function: Callable[..., None] | None = None,
717
734
  _base_image: str | None = None,
718
735
  _scheduler: str | None = None,
@@ -784,10 +801,12 @@ def function(
784
801
  max_concurrency: int | None = None,
785
802
  # FalServerlessHost options
786
803
  metadata: dict[str, Any] | None = None,
787
- machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
804
+ machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
805
+ num_gpus: int | None = None,
788
806
  keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
789
807
  max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
790
808
  min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
809
+ request_timeout: int | None = None,
791
810
  setup_function: Callable[..., None] | None = None,
792
811
  _base_image: str | None = None,
793
812
  _scheduler: str | None = None,
@@ -814,10 +833,12 @@ def function(
814
833
  max_concurrency: int | None = None,
815
834
  # FalServerlessHost options
816
835
  metadata: dict[str, Any] | None = None,
817
- machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
836
+ machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
837
+ num_gpus: int | None = None,
818
838
  keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
819
839
  max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
820
840
  min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
841
+ request_timeout: int | None = None,
821
842
  setup_function: Callable[..., None] | None = None,
822
843
  _base_image: str | None = None,
823
844
  _scheduler: str | None = None,
@@ -838,10 +859,12 @@ def function(
838
859
  max_concurrency: int | None = None,
839
860
  # FalServerlessHost options
840
861
  metadata: dict[str, Any] | None = None,
841
- machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
862
+ machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
863
+ num_gpus: int | None = None,
842
864
  keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
843
865
  max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
844
866
  min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
867
+ request_timeout: int | None = None,
845
868
  setup_function: Callable[..., None] | None = None,
846
869
  _base_image: str | None = None,
847
870
  _scheduler: str | None = None,
@@ -862,10 +885,12 @@ def function(
862
885
  max_concurrency: int | None = None,
863
886
  # FalServerlessHost options
864
887
  metadata: dict[str, Any] | None = None,
865
- machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
888
+ machine_type: str | list[str] = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
889
+ num_gpus: int | None = None,
866
890
  keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
867
891
  max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
868
892
  min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
893
+ request_timeout: int | None = None,
869
894
  setup_function: Callable[..., None] | None = None,
870
895
  _base_image: str | None = None,
871
896
  _scheduler: str | None = None,
@@ -950,6 +975,8 @@ class RouteSignature(NamedTuple):
950
975
 
951
976
 
952
977
  class BaseServable:
978
+ version: ClassVar[str] = "unknown"
979
+
953
980
  def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
954
981
  raise NotImplementedError
955
982
 
@@ -1078,9 +1105,14 @@ class BaseServable:
1078
1105
  def serve(self) -> None:
1079
1106
  import asyncio
1080
1107
 
1108
+ from prometheus_client import Gauge
1081
1109
  from starlette_exporter import handle_metrics
1082
1110
  from uvicorn import Config
1083
1111
 
1112
+ # NOTE: this uses the global prometheus registry
1113
+ app_info = Gauge("fal_app_info", "Fal application information", ["version"])
1114
+ app_info.labels(version=self.version).set(1)
1115
+
1084
1116
  app = self._build_app()
1085
1117
  server = Server(
1086
1118
  config=Config(app, host="0.0.0.0", port=8080, timeout_keep_alive=300)
fal/app.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import inspect
4
5
  import json
5
6
  import os
@@ -8,20 +9,25 @@ import re
8
9
  import threading
9
10
  import time
10
11
  import typing
11
- from contextlib import asynccontextmanager, contextmanager
12
+ from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
13
+ from dataclasses import dataclass
12
14
  from typing import Any, Callable, ClassVar, Literal, TypeVar
13
15
 
16
+ import fastapi
17
+ import grpc.aio as async_grpc
14
18
  import httpx
15
- from fastapi import FastAPI
19
+ from isolate.server import definitions
16
20
 
17
21
  import fal.api
18
22
  from fal._serialization import include_modules_from
19
23
  from fal.api import RouteSignature
20
- from fal.exceptions import RequestCancelledException
24
+ from fal.exceptions import FalServerlessException, RequestCancelledException
21
25
  from fal.logging import get_logger
22
- from fal.toolkit.file.providers import fal as fal_provider_module
26
+ from fal.toolkit.file import request_lifecycle_preference
27
+ from fal.toolkit.file.providers.fal import LIFECYCLE_PREFERENCE
23
28
 
24
29
  REALTIME_APP_REQUIREMENTS = ["websockets", "msgpack"]
30
+ REQUEST_ID_KEY = "x-fal-request-id"
25
31
 
26
32
  EndpointT = TypeVar("EndpointT", bound=Callable[..., Any])
27
33
  logger = get_logger(__name__)
@@ -34,6 +40,56 @@ async def _call_any_fn(fn, *args, **kwargs):
34
40
  return fn(*args, **kwargs)
35
41
 
36
42
 
43
+ async def open_isolate_channel(address: str) -> async_grpc.Channel:
44
+ _stack = AsyncExitStack()
45
+ channel = await _stack.enter_async_context(
46
+ async_grpc.insecure_channel(
47
+ address,
48
+ options=[
49
+ ("grpc.max_send_message_length", -1),
50
+ ("grpc.max_receive_message_length", -1),
51
+ ("grpc.min_reconnect_backoff_ms", 0),
52
+ ("grpc.max_reconnect_backoff_ms", 100),
53
+ ("grpc.dns_min_time_between_resolutions_ms", 100),
54
+ ],
55
+ )
56
+ )
57
+
58
+ channel_status = channel.channel_ready()
59
+ try:
60
+ await asyncio.wait_for(channel_status, timeout=1)
61
+ except asyncio.TimeoutError:
62
+ await _stack.aclose()
63
+ raise Exception("Timed out trying to connect to local isolate")
64
+
65
+ return channel
66
+
67
+
68
+ async def _set_logger_labels(
69
+ logger_labels: dict[str, str], channel: async_grpc.Channel
70
+ ):
71
+ try:
72
+ import sys
73
+
74
+ # Flush any prints that were buffered before setting the logger labels
75
+ sys.stderr.flush()
76
+ sys.stdout.flush()
77
+
78
+ isolate = definitions.IsolateStub(channel)
79
+ isolate_request = definitions.SetMetadataRequest(
80
+ # TODO: when submit is shipped, get task_id from an env var
81
+ task_id="RUN",
82
+ metadata=definitions.TaskMetadata(logger_labels=logger_labels),
83
+ )
84
+ res = isolate.SetMetadata(isolate_request)
85
+ code = await res.code()
86
+ assert str(code) == "StatusCode.OK", str(code)
87
+ except BaseException:
88
+ # NOTE hiding this for now to not print on every request
89
+ # logger.debug("Failed to set logger labels", exc_info=True)
90
+ pass
91
+
92
+
37
93
  def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
38
94
  include_modules_from(cls)
39
95
 
@@ -60,6 +116,7 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
60
116
  kind,
61
117
  requirements=cls.requirements,
62
118
  machine_type=cls.machine_type,
119
+ num_gpus=cls.num_gpus,
63
120
  **cls.host_kwargs,
64
121
  **kwargs,
65
122
  metadata=metadata,
@@ -74,6 +131,12 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
74
131
  return fn
75
132
 
76
133
 
134
+ @dataclass
135
+ class AppClientError(FalServerlessException):
136
+ message: str
137
+ status_code: int
138
+
139
+
77
140
  class EndpointClient:
78
141
  def __init__(self, url, endpoint, signature, timeout: int | None = None):
79
142
  self.url = url
@@ -86,12 +149,19 @@ class EndpointClient:
86
149
 
87
150
  def __call__(self, data):
88
151
  with httpx.Client() as client:
152
+ url = self.url + self.signature.path
89
153
  resp = client.post(
90
154
  self.url + self.signature.path,
91
155
  json=data.dict() if hasattr(data, "dict") else dict(data),
92
156
  timeout=self.timeout,
93
157
  )
94
- resp.raise_for_status()
158
+ if not resp.is_success:
159
+ # allow logs to be printed before raising the exception
160
+ time.sleep(1)
161
+ raise AppClientError(
162
+ f"Failed to POST {url}: {resp.status_code} {resp.text}",
163
+ status_code=resp.status_code,
164
+ )
95
165
  resp_dict = resp.json()
96
166
 
97
167
  if not self.return_type:
@@ -144,12 +214,16 @@ class AppClient:
144
214
  with httpx.Client() as client:
145
215
  retries = 100
146
216
  for _ in range(retries):
147
- resp = client.get(info.url + "/health", timeout=60)
217
+ url = info.url + "/health"
218
+ resp = client.get(url, timeout=60)
148
219
 
149
220
  if resp.is_success:
150
221
  break
151
222
  elif resp.status_code not in (500, 404):
152
- resp.raise_for_status()
223
+ raise AppClientError(
224
+ f"Failed to GET {url}: {resp.status_code} {resp.text}",
225
+ status_code=resp.status_code,
226
+ )
153
227
  time.sleep(0.1)
154
228
 
155
229
  client = cls(app_cls, info.url)
@@ -174,9 +248,18 @@ def _to_fal_app_name(name: str) -> str:
174
248
  return "-".join(part.lower() for part in PART_FINDER_RE.findall(name))
175
249
 
176
250
 
251
+ def _print_python_packages() -> None:
252
+ from importlib.metadata import distributions
253
+
254
+ packages = [f"{dist.metadata['Name']}=={dist.version}" for dist in distributions()]
255
+
256
+ print("[debug] Python packages installed:", ", ".join(packages))
257
+
258
+
177
259
  class App(fal.api.BaseServable):
178
260
  requirements: ClassVar[list[str]] = []
179
261
  machine_type: ClassVar[str] = "S"
262
+ num_gpus: ClassVar[int | None] = None
180
263
  host_kwargs: ClassVar[dict[str, Any]] = {
181
264
  "_scheduler": "nomad",
182
265
  "_scheduler_options": {
@@ -187,11 +270,18 @@ class App(fal.api.BaseServable):
187
270
  }
188
271
  app_name: ClassVar[str]
189
272
  app_auth: ClassVar[Literal["private", "public", "shared"]] = "private"
273
+ request_timeout: ClassVar[int | None] = None
274
+
275
+ isolate_channel: async_grpc.Channel | None = None
190
276
 
191
277
  def __init_subclass__(cls, **kwargs):
192
278
  app_name = kwargs.pop("name", None) or _to_fal_app_name(cls.__name__)
193
279
  parent_settings = getattr(cls, "host_kwargs", {})
194
280
  cls.host_kwargs = {**parent_settings, **kwargs}
281
+
282
+ if cls.request_timeout is not None:
283
+ cls.host_kwargs["request_timeout"] = cls.request_timeout
284
+
195
285
  cls.app_name = getattr(cls, "app_name", app_name)
196
286
 
197
287
  if cls.__init__ is not App.__init__:
@@ -222,7 +312,8 @@ class App(fal.api.BaseServable):
222
312
  }
223
313
 
224
314
  @asynccontextmanager
225
- async def lifespan(self, app: FastAPI):
315
+ async def lifespan(self, app: fastapi.FastAPI):
316
+ _print_python_packages()
226
317
  await _call_any_fn(self.setup)
227
318
  try:
228
319
  yield
@@ -230,7 +321,7 @@ class App(fal.api.BaseServable):
230
321
  await _call_any_fn(self.teardown)
231
322
 
232
323
  def health(self):
233
- return {}
324
+ return {"version": self.version}
234
325
 
235
326
  def setup(self):
236
327
  """Setup the application before serving."""
@@ -238,7 +329,7 @@ class App(fal.api.BaseServable):
238
329
  def teardown(self):
239
330
  """Teardown the application after serving."""
240
331
 
241
- def _add_extra_middlewares(self, app: FastAPI):
332
+ def _add_extra_middlewares(self, app: fastapi.FastAPI):
242
333
  @app.middleware("http")
243
334
  async def provide_hints_headers(request, call_next):
244
335
  response = await call_next(request)
@@ -259,11 +350,12 @@ class App(fal.api.BaseServable):
259
350
 
260
351
  @app.middleware("http")
261
352
  async def set_global_object_preference(request, call_next):
262
- response = await call_next(request)
263
353
  try:
264
- fal_provider_module.GLOBAL_LIFECYCLE_PREFERENCE = request.headers.get(
265
- "X-Fal-Object-Lifecycle-Preference"
266
- )
354
+ preference_dict = request_lifecycle_preference(request)
355
+ if preference_dict is not None:
356
+ # This will not work properly for apps with multiplexing enabled
357
+ # we may mix up the preferences between requests
358
+ LIFECYCLE_PREFERENCE.set(preference_dict)
267
359
  except Exception:
268
360
  from fastapi.logger import logger
269
361
 
@@ -271,7 +363,52 @@ class App(fal.api.BaseServable):
271
363
  "Failed set a global lifecycle preference %s",
272
364
  self.__class__.__name__,
273
365
  )
274
- return response
366
+
367
+ try:
368
+ return await call_next(request)
369
+ finally:
370
+ # We may miss the global preference if there are operations
371
+ # being done in the background that go beyond the request
372
+ LIFECYCLE_PREFERENCE.set(None)
373
+
374
+ @app.middleware("http")
375
+ async def set_request_id(request, call_next):
376
+ # NOTE: Setting request_id is not supported for websocket/realtime endpoints
377
+
378
+ if self.isolate_channel is None:
379
+ grpc_port = os.environ.get("NOMAD_ALLOC_PORT_grpc")
380
+ self.isolate_channel = await open_isolate_channel(
381
+ f"localhost:{grpc_port}"
382
+ )
383
+
384
+ request_id = request.headers.get(REQUEST_ID_KEY)
385
+ if request_id is None:
386
+ # Cut it short
387
+ return await call_next(request)
388
+
389
+ await _set_logger_labels(
390
+ {"fal_request_id": request_id}, channel=self.isolate_channel
391
+ )
392
+
393
+ async def _unset_at_end():
394
+ await _set_logger_labels({}, channel=self.isolate_channel) # type: ignore
395
+
396
+ try:
397
+ response: fastapi.responses.Response = await call_next(request)
398
+ except BaseException:
399
+ await _unset_at_end()
400
+ raise
401
+ else:
402
+ # We need to wait for the entire response to be sent before
403
+ # we can set the logger labels back to the default.
404
+ background_tasks = fastapi.BackgroundTasks()
405
+ background_tasks.add_task(_unset_at_end)
406
+ if response.background:
407
+ # We normally have no background tasks, but we should handle it
408
+ background_tasks.add_task(response.background)
409
+ response.background = background_tasks
410
+
411
+ return response
275
412
 
276
413
  @app.exception_handler(RequestCancelledException)
277
414
  async def value_error_exception_handler(
@@ -284,7 +421,7 @@ class App(fal.api.BaseServable):
284
421
  # the connection without receiving a response
285
422
  return JSONResponse({"detail": str(exc)}, 499)
286
423
 
287
- def _add_extra_routes(self, app: FastAPI):
424
+ def _add_extra_routes(self, app: fastapi.FastAPI):
288
425
  @app.get("/health")
289
426
  def health():
290
427
  return self.health()
@@ -395,7 +532,10 @@ def _fal_websocket_template(
395
532
  batch.append(next_input)
396
533
 
397
534
  t0 = loop.time()
398
- output = await loop.run_in_executor(None, func, self, *batch) # type: ignore
535
+ if inspect.iscoroutinefunction(func):
536
+ output = await func(self, *batch)
537
+ else:
538
+ output = await loop.run_in_executor(None, func, self, *batch) # type: ignore
399
539
  total_time = loop.time() - t0
400
540
  if not isinstance(output, dict):
401
541
  # Handle pydantic output modal