fal 0.13.0__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

fal/__init__.py CHANGED
@@ -1,10 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- from fal import apps
3
+ from fal import apps # noqa: F401
4
4
  from fal.api import FalServerlessHost, LocalHost, cached
5
5
  from fal.api import function
6
- from fal.api import function as isolated
7
- from fal.app import App, endpoint, realtime, wrap_app
6
+ from fal.api import function as isolated # noqa: F401
7
+ from fal.app import App, endpoint, realtime, wrap_app # noqa: F401
8
8
  from fal.sdk import FalServerlessKeyCredentials
9
9
  from fal.sync import sync_dir
10
10
 
@@ -32,6 +32,6 @@ __all__ = [
32
32
  # This will add to the package’s __path__ all subdirectories of directories on sys.path named after the package which effectively combines both modules into a single namespace (dbt.adapters)
33
33
  # The matching statement is in plugins/postgres/dbt/adapters/__init__.py
34
34
 
35
- from pkgutil import extend_path
35
+ from pkgutil import extend_path # noqa: E402
36
36
 
37
37
  __path__ = extend_path(__path__, __name__)
fal/_serialization.py CHANGED
@@ -1,102 +1,71 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import wraps
4
- from pathlib import Path
3
+ from typing import Any, Callable
5
4
 
6
- import dill
7
- from dill import _dill
5
+ import pickle
6
+ import cloudpickle
8
7
 
9
- from fal.toolkit import mainify
10
8
 
11
- # each @fal.function gets added to this set so that we can
12
- # mainify the module this function is in
13
- _MODULES: set[str] = set()
14
- _PACKAGES: set[str] = set()
9
+ def _register_pickle_by_value(name) -> None:
10
+ # cloudpickle.register_pickle_by_value wants an imported module object,
11
+ # but there is really no reason to go through that complication, as
12
+ # it might be prone to errors.
13
+ cloudpickle.cloudpickle._PICKLE_BY_VALUE_MODULES.add(name)
15
14
 
16
15
 
17
- @mainify
18
- def _pydantic_make_field(kwargs):
19
- from pydantic.fields import ModelField
16
+ def include_package_from_path(raw_path: str) -> None:
17
+ from pathlib import Path
20
18
 
21
- return ModelField(**kwargs)
22
-
23
-
24
- @mainify
25
- def _pydantic_make_private_field(kwargs):
26
- from pydantic.fields import ModelPrivateAttr
27
-
28
- return ModelPrivateAttr(**kwargs)
29
-
30
-
31
- # this allows us to record all the "isolated" function and then mainify everything in
32
- # module they exist
33
- @wraps(_dill._locate_function)
34
- def by_value_locator(obj, pickler=None, og_locator=_dill._locate_function):
35
- module_name = getattr(obj, "__module__", None)
36
- if module_name is not None:
37
- # If it is coming from the same module, directly allow
38
- # it to be pickled.
39
- if module_name in _MODULES:
40
- return False
41
-
42
- package_name, *_ = module_name.partition(".")
43
- # If it is coming from the same package, then do the same.
44
- if package_name in _PACKAGES:
45
- return False
46
-
47
- og_result = og_locator(obj, pickler)
48
- return og_result
49
-
50
-
51
- _dill._locate_function = by_value_locator
52
-
53
-
54
- def include_packages_from_path(raw_path: str):
55
19
  path = Path(raw_path).resolve()
56
20
  parent = path
57
21
  while (parent.parent / "__init__.py").exists():
58
22
  parent = parent.parent
59
23
 
60
24
  if parent != path:
61
- _PACKAGES.add(parent.name)
25
+ _register_pickle_by_value(parent.name)
62
26
 
63
27
 
64
- def add_serialization_listeners_for(obj):
28
+ def include_modules_from(obj: Any) -> None:
65
29
  module_name = getattr(obj, "__module__", None)
66
30
  if not module_name:
67
- return None
31
+ return
32
+
33
+ if "." in module_name:
34
+ # Just include the whole package
35
+ package_name, *_ = module_name.partition(".")
36
+ _register_pickle_by_value(package_name)
37
+ return
68
38
 
69
- _MODULES.add(module_name)
70
39
  if module_name == "__main__":
71
40
  # When the module is __main__, we need to recursively go up the
72
41
  # tree to locate the actual package name.
73
42
  import __main__
74
43
 
75
- include_packages_from_path(__main__.__file__)
44
+ include_package_from_path(__main__.__file__)
45
+ return
46
+
47
+ _register_pickle_by_value(module_name)
76
48
 
77
- if "." in module_name:
78
- package_name, *_ = module_name.partition(".")
79
- _PACKAGES.add(package_name)
80
49
 
50
+ def _register(cls: Any, func: Callable) -> None:
51
+ cloudpickle.Pickler.dispatch[cls] = func
81
52
 
82
- @mainify
83
- def patch_pydantic_field_serialization():
53
+
54
+ def _patch_pydantic_field_serialization() -> None:
84
55
  # Cythonized pydantic fields can't be serialized automatically, so we are
85
56
  # have a special case handling for them that unpacks it to a dictionary
86
57
  # and then reloads it on the other side.
87
- import dill
88
-
58
+ # https://github.com/ray-project/ray/blob/842bbcf4236e41f58d25058b0482cd05bfe9e4da/python/ray/_private/pydantic_compat.py#L80
89
59
  try:
90
- import pydantic.fields
60
+ from pydantic.fields import ModelField, ModelPrivateAttr
91
61
  except ImportError:
92
62
  return
93
63
 
94
- @dill.register(pydantic.fields.ModelField)
95
- def _pickle_model_field(
96
- pickler: dill.Pickler,
97
- field: pydantic.fields.ModelField,
98
- ) -> None:
99
- args = {
64
+ def create_model_field(kwargs: dict) -> ModelField:
65
+ return ModelField(**kwargs)
66
+
67
+ def pickle_model_field(field: ModelField) -> tuple[Callable, tuple]:
68
+ kwargs = {
100
69
  "name": field.name,
101
70
  # outer_type_ is the original type for ModelFields,
102
71
  # while type_ can be updated later with the nested type
@@ -110,65 +79,147 @@ def patch_pydantic_field_serialization():
110
79
  "alias": field.alias,
111
80
  "field_info": field.field_info,
112
81
  }
113
- pickler.save_reduce(_pydantic_make_field, (args,), obj=field)
114
-
115
- @dill.register(pydantic.fields.ModelPrivateAttr)
116
- def _pickle_model_private_attr(
117
- pickler: dill.Pickler,
118
- field: pydantic.fields.ModelPrivateAttr,
119
- ) -> None:
120
- args = {
82
+
83
+ return create_model_field, (kwargs,)
84
+
85
+ def create_private_attr(kwargs: dict) -> ModelPrivateAttr:
86
+ return ModelPrivateAttr(**kwargs)
87
+
88
+ def pickle_private_attr(field: ModelPrivateAttr) -> tuple[Callable, tuple]:
89
+ kwargs = {
121
90
  "default": field.default,
122
91
  "default_factory": field.default_factory,
123
92
  }
124
- pickler.save_reduce(_pydantic_make_private_field, (args,), obj=field)
125
93
 
94
+ return create_private_attr, (kwargs,)
95
+
96
+ _register(ModelField, pickle_model_field)
97
+ _register(ModelPrivateAttr, pickle_private_attr)
126
98
 
127
- @mainify
128
- def patch_pydantic_class_attributes():
129
- # Dill attempts to modify the __class__ of deserialized pydantic objects
130
- # on this side but it meets with a rejection from pydantic's semantics since
131
- # __class__ is not recognized as a proper dunder attribute.
99
+
100
+ def _patch_pydantic_model_serialization() -> None:
101
+ # If user has created new pydantic models in his namespace, we will try to pickle those
102
+ # by value, which means recreating class skeleton, which will stumble upon
103
+ # __pydantic_parent_namespace__ in its __dict__ and it may contain modules that happened
104
+ # to be imported in the namespace but are not actually used, resulting in pickling errors.
105
+ # Unfortunately this also means that `model_rebuid()` might not work.
132
106
  try:
133
- import pydantic.utils
107
+ import pydantic
134
108
  except ImportError:
135
109
  return
136
110
 
137
- pydantic.utils.DUNDER_ATTRIBUTES.add("__class__")
111
+ # https://github.com/pydantic/pydantic/pull/2573
112
+ if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."):
113
+ return
114
+
115
+ backup = "_original_extract_class_dict"
116
+ if getattr(cloudpickle.cloudpickle, backup, None):
117
+ return
118
+
119
+ original = cloudpickle.cloudpickle._extract_class_dict
120
+
121
+ def patched(cls):
122
+ attr_name = "__pydantic_parent_namespace__"
123
+ if issubclass(cls, pydantic.BaseModel) and getattr(cls, attr_name, None):
124
+ setattr(cls, attr_name, None)
125
+
126
+ return original(cls)
127
+
128
+ cloudpickle.cloudpickle._extract_class_dict = patched
129
+ setattr(cloudpickle.cloudpickle, backup, original)
130
+
131
+
132
+ def _patch_lru_cache() -> None:
133
+ # https://github.com/cloudpipe/cloudpickle/issues/178
134
+ # https://github.com/uqfoundation/dill/blob/70f569b0dd268d2b1e85c0f300951b11f53c5d53/dill/_dill.py#L1429
135
+
136
+ from functools import lru_cache, _lru_cache_wrapper as LRUCacheType
137
+
138
+ def create_lru_cache(func: Callable, kwargs: dict) -> LRUCacheType:
139
+ return lru_cache(**kwargs)(func)
140
+
141
+ def pickle_lru_cache(obj: LRUCacheType) -> tuple[Callable, tuple]:
142
+ if hasattr(obj, "cache_parameters"):
143
+ params = obj.cache_parameters()
144
+ kwargs = {
145
+ "maxsize": params["maxsize"],
146
+ "typed": params["typed"],
147
+ }
148
+ else:
149
+ kwargs = {"maxsize": obj.cache_info().maxsize}
150
+
151
+ return create_lru_cache, (obj.__wrapped__, kwargs)
152
+
153
+ _register(LRUCacheType, pickle_lru_cache)
154
+
155
+
156
+ def _patch_lock() -> None:
157
+ # https://github.com/uqfoundation/dill/blob/70f569b0dd268d2b1e85c0f300951b11f53c5d53/dill/_dill.py#L1310
158
+ from threading import Lock
159
+ from _thread import LockType
160
+
161
+ def create_lock(locked: bool) -> Lock:
162
+ lock = Lock()
163
+ if locked and not lock.acquire(False):
164
+ raise pickle.UnpicklingError("Cannot acquire lock")
165
+ return lock
166
+
167
+ def pickle_lock(obj: LockType) -> tuple[Callable, tuple]:
168
+ return create_lock, (obj.locked(),)
169
+
170
+ _register(LockType, pickle_lock)
171
+
172
+
173
+ def _patch_rlock() -> None:
174
+ # https://github.com/uqfoundation/dill/blob/70f569b0dd268d2b1e85c0f300951b11f53c5d53/dill/_dill.py#L1317
175
+ from _thread import RLock as RLockType # type: ignore[attr-defined]
176
+
177
+ def create_rlock(count: int, owner: int) -> RLockType:
178
+ lock = RLockType()
179
+ if owner is not None:
180
+ lock._acquire_restore((count, owner)) # type: ignore[attr-defined]
181
+ if owner and not lock._is_owned(): # type: ignore[attr-defined]
182
+ raise pickle.UnpicklingError("Cannot acquire lock")
183
+ return lock
184
+
185
+ def pickle_rlock(obj: RLockType) -> tuple[Callable, tuple]:
186
+ r = obj.__repr__()
187
+ count = int(r.split('count=')[1].split()[0].rstrip('>'))
188
+ owner = int(r.split('owner=')[1].split()[0])
189
+
190
+ return create_rlock, (count, owner)
191
+
192
+ _register(RLockType, pickle_rlock)
193
+
194
+
195
+ def _patch_console_thread_locals() -> None:
196
+ from rich.console import ConsoleThreadLocals
138
197
 
198
+ def create_locals(kwargs: dict) -> ConsoleThreadLocals:
199
+ return ConsoleThreadLocals(**kwargs)
139
200
 
140
- @mainify
141
- def patch_exceptions():
142
- # Adapting tblib.pickling_support.install for dill.
143
- from types import TracebackType
201
+ def pickle_locals(obj: ConsoleThreadLocals) -> tuple[Callable, tuple]:
202
+ kwargs = {"theme_stack": obj.theme_stack, "buffer": obj.buffer, "buffer_index": obj.buffer_index}
203
+ return create_locals, (kwargs, )
144
204
 
145
- import dill
146
- from tblib.pickling_support import (
147
- _get_subclasses,
148
- pickle_exception,
149
- pickle_traceback,
150
- )
205
+ _register(ConsoleThreadLocals, pickle_locals)
151
206
 
152
- @dill.register(TracebackType)
153
- def save_traceback(pickler, obj):
154
- unpickle, args = pickle_traceback(obj)
155
- pickler.save_reduce(unpickle, args, obj=obj)
156
207
 
157
- @dill.register(BaseException)
158
- def save_exception(pickler, obj):
159
- unpickle, args = pickle_exception(obj)
160
- pickler.save_reduce(unpickle, args, obj=obj)
208
+ def _patch_exceptions() -> None:
209
+ # Support chained exceptions
210
+ from tblib.pickling_support import install
161
211
 
162
- for exception_cls in _get_subclasses(BaseException):
163
- dill.pickle(exception_cls, save_exception)
212
+ install()
164
213
 
165
214
 
166
- @mainify
167
- def patch_dill():
168
- import dill
215
+ def patch_pickle() -> None:
216
+ _patch_pydantic_field_serialization()
217
+ _patch_pydantic_model_serialization()
218
+ _patch_lru_cache()
219
+ _patch_lock()
220
+ _patch_rlock()
221
+ _patch_console_thread_locals()
222
+ _patch_exceptions()
169
223
 
170
- dill.settings["recurse"] = True
224
+ _register_pickle_by_value("fal")
171
225
 
172
- patch_exceptions()
173
- patch_pydantic_class_attributes()
174
- patch_pydantic_field_serialization()
fal/api.py CHANGED
@@ -2,11 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  import inspect
4
4
  import sys
5
+ import threading
5
6
  from collections import defaultdict
6
7
  from concurrent.futures import ThreadPoolExecutor
7
8
  from contextlib import asynccontextmanager, suppress
8
9
  from dataclasses import dataclass, field, replace
9
- from functools import partial, wraps
10
+ from functools import wraps
10
11
  from os import PathLike
11
12
  from typing import (
12
13
  Any,
@@ -22,21 +23,25 @@ from typing import (
22
23
  overload,
23
24
  )
24
25
 
25
- import dill
26
- import dill.detect
26
+ import cloudpickle
27
27
  import grpc
28
28
  import isolate
29
+ import tblib
30
+ import uvicorn
29
31
  import yaml
30
32
  from fastapi import FastAPI
33
+ from fastapi import __version__ as fastapi_version
31
34
  from isolate.backends.common import active_python
32
35
  from isolate.backends.settings import DEFAULT_SETTINGS
33
36
  from isolate.connections import PythonIPC
34
37
  from packaging.requirements import Requirement
35
38
  from packaging.utils import canonicalize_name
39
+ from pydantic import __version__ as pydantic_version
36
40
  from typing_extensions import Concatenate, ParamSpec
37
41
 
38
42
  import fal.flags as flags
39
- from fal._serialization import add_serialization_listeners_for, patch_dill
43
+ from fal.exceptions import FalServerlessException
44
+ from fal._serialization import include_modules_from, patch_pickle
40
45
  from fal.logging.isolate import IsolateLogPrinter
41
46
  from fal.sdk import (
42
47
  FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
@@ -50,7 +55,6 @@ from fal.sdk import (
50
55
  get_agent_credentials,
51
56
  get_default_credentials,
52
57
  )
53
- from fal.toolkit import mainify
54
58
 
55
59
  ArgsT = ParamSpec("ArgsT")
56
60
  ReturnT = TypeVar("ReturnT", covariant=True)
@@ -58,16 +62,21 @@ ReturnT = TypeVar("ReturnT", covariant=True)
58
62
  BasicConfig = Dict[str, Any]
59
63
  _UNSET = object()
60
64
 
61
- SERVE_REQUIREMENTS = ["fastapi==0.99.1", "uvicorn"]
65
+ SERVE_REQUIREMENTS = [
66
+ f"fastapi=={fastapi_version}",
67
+ f"pydantic=={pydantic_version}",
68
+ "uvicorn",
69
+ "starlette_exporter",
70
+ ]
62
71
 
63
72
 
64
73
  @dataclass
65
- class FalServerlessError(Exception):
74
+ class FalServerlessError(FalServerlessException):
66
75
  message: str
67
76
 
68
77
 
69
78
  @dataclass
70
- class InternalFalServerlessError(Exception):
79
+ class InternalFalServerlessError(FalServerlessException):
71
80
  message: str
72
81
 
73
82
 
@@ -175,32 +184,23 @@ def cached(func: Callable[ArgsT, ReturnT]) -> Callable[ArgsT, ReturnT]:
175
184
  return wrapper
176
185
 
177
186
 
178
- @mainify
179
- class UserFunctionException(Exception):
187
+ class UserFunctionException(FalServerlessException):
180
188
  pass
181
189
 
182
190
 
183
- def match_class(obj, cls):
184
- # NOTE: Can't use isinstance because we are not using dill's byref setting when
185
- # loading/dumping objects in RPC, which means that our exceptions from remote
186
- # server are created by value and are actually a separate class that only looks
187
- # like original one.
188
- #
189
- # See https://github.com/fal-ai/fal/issues/142
190
- return type(obj).__name__ == cls.__name__
191
-
192
-
193
191
  def _prepare_partial_func(
194
192
  func: Callable[ArgsT, ReturnT],
195
193
  *args: ArgsT.args,
196
194
  **kwargs: ArgsT.kwargs,
197
195
  ) -> Callable[ArgsT, ReturnT]:
198
- """Prepare the given function for execution on the remote isolate workers."""
196
+ """Prepare the given function for execution on isolate workers."""
199
197
 
200
198
  @wraps(func)
201
199
  def wrapper(*remote_args: ArgsT.args, **remote_kwargs: ArgsT.kwargs) -> ReturnT:
202
200
  try:
203
201
  result = func(*remote_args, *args, **remote_kwargs, **kwargs)
202
+ except FalServerlessException:
203
+ raise
204
204
  except Exception as exc:
205
205
  tb = exc.__traceback__
206
206
  if tb is not None and tb.tb_next is not None:
@@ -211,7 +211,7 @@ def _prepare_partial_func(
211
211
  ) from exc.with_traceback(tb)
212
212
  finally:
213
213
  with suppress(Exception):
214
- patch_dill()
214
+ patch_pickle()
215
215
  return result
216
216
 
217
217
  return wrapper
@@ -223,8 +223,12 @@ class LocalHost(Host):
223
223
  # packages for isolate agent to run.
224
224
  _AGENT_ENVIRONMENT = isolate.prepare_environment(
225
225
  "virtualenv",
226
- requirements=[f"dill=={dill.__version__}"],
226
+ requirements=[
227
+ f"cloudpickle=={cloudpickle.__version__}",
228
+ f"tblib=={tblib.__version__}",
229
+ ],
227
230
  )
231
+ _log_printer = IsolateLogPrinter(debug=flags.DEBUG)
228
232
 
229
233
  def run(
230
234
  self,
@@ -233,7 +237,11 @@ class LocalHost(Host):
233
237
  args: tuple[Any, ...],
234
238
  kwargs: dict[str, Any],
235
239
  ) -> ReturnT:
236
- settings = replace(DEFAULT_SETTINGS, serialization_method="dill")
240
+ settings = replace(
241
+ DEFAULT_SETTINGS,
242
+ serialization_method="cloudpickle",
243
+ log_hook=self._log_printer.print,
244
+ )
237
245
  environment = isolate.prepare_environment(
238
246
  **options.environment,
239
247
  context=settings,
@@ -243,7 +251,7 @@ class LocalHost(Host):
243
251
  environment.create(),
244
252
  extra_inheritance_paths=[self._AGENT_ENVIRONMENT.create()],
245
253
  ) as connection:
246
- executable = partial(func, *args, **kwargs)
254
+ executable = _prepare_partial_func(func, *args, **kwargs)
247
255
  return connection.run(executable)
248
256
 
249
257
 
@@ -251,9 +259,6 @@ FAL_SERVERLESS_DEFAULT_URL = flags.GRPC_HOST
251
259
  FAL_SERVERLESS_DEFAULT_MACHINE_TYPE = "XS"
252
260
 
253
261
 
254
- import threading
255
-
256
-
257
262
  def _handle_grpc_error():
258
263
  def decorator(fn):
259
264
  @wraps(fn)
@@ -292,6 +297,8 @@ def _handle_grpc_error():
292
297
  def find_missing_dependencies(
293
298
  func: Callable, env: dict
294
299
  ) -> Iterator[tuple[str, list[str]]]:
300
+ import dill
301
+
295
302
  if env["kind"] != "virtualenv":
296
303
  return
297
304
 
@@ -433,6 +440,8 @@ class FalServerlessHost(Host):
433
440
  if partial_result.result:
434
441
  return partial_result.result.application_id
435
442
 
443
+ return None
444
+
436
445
  @_handle_grpc_error()
437
446
  def run(
438
447
  self,
@@ -747,7 +756,7 @@ def function( # type: ignore
747
756
  options = host.parse_options(kind=kind, **config)
748
757
 
749
758
  def wrapper(func: Callable[ArgsT, ReturnT]):
750
- add_serialization_listeners_for(func)
759
+ include_modules_from(func)
751
760
  proxy = IsolatedFunction(
752
761
  host=host,
753
762
  raw_func=func, # type: ignore
@@ -758,7 +767,6 @@ def function( # type: ignore
758
767
  return wrapper
759
768
 
760
769
 
761
- @mainify
762
770
  class FalFastAPI(FastAPI):
763
771
  """
764
772
  A subclass of FastAPI that adds some fal-specific functionality.
@@ -802,7 +810,6 @@ class FalFastAPI(FastAPI):
802
810
  return spec
803
811
 
804
812
 
805
- @mainify
806
813
  class RouteSignature(NamedTuple):
807
814
  path: str
808
815
  is_websocket: bool = False
@@ -813,7 +820,6 @@ class RouteSignature(NamedTuple):
813
820
  emit_timings: bool = False
814
821
 
815
822
 
816
- @mainify
817
823
  class BaseServable:
818
824
  def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
819
825
  raise NotImplementedError
@@ -832,6 +838,7 @@ class BaseServable:
832
838
  from fastapi import HTTPException, Request
833
839
  from fastapi.middleware.cors import CORSMiddleware
834
840
  from fastapi.responses import JSONResponse
841
+ from starlette_exporter import PrometheusMiddleware
835
842
 
836
843
  _app = FalFastAPI(lifespan=self.lifespan)
837
844
 
@@ -842,6 +849,13 @@ class BaseServable:
842
849
  allow_methods=("*"),
843
850
  allow_origins=("*"),
844
851
  )
852
+ _app.add_middleware(
853
+ PrometheusMiddleware,
854
+ prefix="http",
855
+ group_paths=True,
856
+ filter_unhandled_paths=True,
857
+ app_name="fal",
858
+ )
845
859
 
846
860
  self._add_extra_middlewares(_app)
847
861
 
@@ -887,13 +901,43 @@ class BaseServable:
887
901
  return self._build_app().openapi()
888
902
 
889
903
  def serve(self) -> None:
890
- import uvicorn
904
+ import asyncio
905
+
906
+ from starlette_exporter import handle_metrics
907
+ from uvicorn import Config
891
908
 
892
909
  app = self._build_app()
893
- uvicorn.run(app, host="0.0.0.0", port=8080)
910
+ server = Server(config=Config(app, host="0.0.0.0", port=8080))
911
+ metrics_app = FastAPI()
912
+ metrics_app.add_route("/metrics", handle_metrics)
913
+ metrics_server = Server(config=Config(metrics_app, host="0.0.0.0", port=9090))
914
+
915
+ async def _serve() -> None:
916
+ tasks = {
917
+ asyncio.create_task(server.serve()): server,
918
+ asyncio.create_task(metrics_server.serve()): metrics_server,
919
+ }
920
+
921
+ _, pending = await asyncio.wait(tasks.keys(), return_when=asyncio.FIRST_COMPLETED)
922
+ if not pending:
923
+ return
924
+
925
+ # try graceful shutdown
926
+ for task in pending:
927
+ tasks[task].should_exit = True
928
+ _, pending = await asyncio.wait(pending, timeout=2)
929
+ if not pending:
930
+ return
931
+
932
+ for task in pending:
933
+ task.cancel()
934
+ await asyncio.wait(pending)
935
+
936
+ with suppress(asyncio.CancelledError):
937
+ asyncio.set_event_loop(asyncio.new_event_loop())
938
+ asyncio.run(_serve())
894
939
 
895
940
 
896
- @mainify
897
941
  class ServeWrapper(BaseServable):
898
942
  _func: Callable
899
943
 
@@ -974,7 +1018,7 @@ class IsolatedFunction(Generic[ArgsT, ReturnT]):
974
1018
  ) from None
975
1019
  except Exception as exc:
976
1020
  cause = exc.__cause__
977
- if self.reraise and match_class(exc, UserFunctionException) and cause:
1021
+ if self.reraise and isinstance(exc, UserFunctionException) and cause:
978
1022
  # re-raise original exception without our wrappers
979
1023
  raise cause
980
1024
  raise
@@ -1050,3 +1094,14 @@ class ServedIsolatedFunction(
1050
1094
  self, host: Host | None = None, *, serve: Literal[False], **config: Any
1051
1095
  ) -> IsolatedFunction[ArgsT, ReturnT]:
1052
1096
  ...
1097
+
1098
+
1099
+ class Server(uvicorn.Server):
1100
+ """Server is a uvicorn.Server that actually plays nicely with signals.
1101
+ By default, uvicorn's Server class overwrites the signal handler for SIGINT, swallowing the signal and preventing other tasks from cancelling.
1102
+ This class allows the task to be gracefully cancelled using asyncio's built-in task cancellation or with an event, like aiohttp.
1103
+ """
1104
+
1105
+ def install_signal_handlers(self) -> None:
1106
+ pass
1107
+