fal 0.14.0__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

fal/_serialization.py CHANGED
@@ -1,102 +1,71 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import wraps
4
- from pathlib import Path
3
+ from typing import Any, Callable
5
4
 
6
- import dill
7
- from dill import _dill
5
+ import pickle
6
+ import cloudpickle
8
7
 
9
- from fal.toolkit import mainify
10
8
 
11
- # each @fal.function gets added to this set so that we can
12
- # mainify the module this function is in
13
- _MODULES: set[str] = set()
14
- _PACKAGES: set[str] = set()
9
+ def _register_pickle_by_value(name) -> None:
10
+ # cloudpickle.register_pickle_by_value wants an imported module object,
11
+ # but there is really no reason to go through that complication, as
12
+ # it might be prone to errors.
13
+ cloudpickle.cloudpickle._PICKLE_BY_VALUE_MODULES.add(name)
15
14
 
16
15
 
17
- @mainify
18
- def _pydantic_make_field(kwargs):
19
- from pydantic.fields import ModelField
16
+ def include_package_from_path(raw_path: str) -> None:
17
+ from pathlib import Path
20
18
 
21
- return ModelField(**kwargs)
22
-
23
-
24
- @mainify
25
- def _pydantic_make_private_field(kwargs):
26
- from pydantic.fields import ModelPrivateAttr
27
-
28
- return ModelPrivateAttr(**kwargs)
29
-
30
-
31
- # this allows us to record all the "isolated" function and then mainify everything in
32
- # module they exist
33
- @wraps(_dill._locate_function)
34
- def by_value_locator(obj, pickler=None, og_locator=_dill._locate_function):
35
- module_name = getattr(obj, "__module__", None)
36
- if module_name is not None:
37
- # If it is coming from the same module, directly allow
38
- # it to be pickled.
39
- if module_name in _MODULES:
40
- return False
41
-
42
- package_name, *_ = module_name.partition(".")
43
- # If it is coming from the same package, then do the same.
44
- if package_name in _PACKAGES:
45
- return False
46
-
47
- og_result = og_locator(obj, pickler)
48
- return og_result
49
-
50
-
51
- _dill._locate_function = by_value_locator
52
-
53
-
54
- def include_packages_from_path(raw_path: str):
55
19
  path = Path(raw_path).resolve()
56
20
  parent = path
57
21
  while (parent.parent / "__init__.py").exists():
58
22
  parent = parent.parent
59
23
 
60
24
  if parent != path:
61
- _PACKAGES.add(parent.name)
25
+ _register_pickle_by_value(parent.name)
62
26
 
63
27
 
64
- def add_serialization_listeners_for(obj):
28
+ def include_modules_from(obj: Any) -> None:
65
29
  module_name = getattr(obj, "__module__", None)
66
30
  if not module_name:
67
- return None
31
+ return
32
+
33
+ if "." in module_name:
34
+ # Just include the whole package
35
+ package_name, *_ = module_name.partition(".")
36
+ _register_pickle_by_value(package_name)
37
+ return
68
38
 
69
- _MODULES.add(module_name)
70
39
  if module_name == "__main__":
71
40
  # When the module is __main__, we need to recursively go up the
72
41
  # tree to locate the actual package name.
73
42
  import __main__
74
43
 
75
- include_packages_from_path(__main__.__file__)
44
+ include_package_from_path(__main__.__file__)
45
+ return
76
46
 
77
- if "." in module_name:
78
- package_name, *_ = module_name.partition(".")
79
- _PACKAGES.add(package_name)
47
+ _register_pickle_by_value(module_name)
48
+
49
+
50
+ def _register(cls: Any, func: Callable) -> None:
51
+ cloudpickle.Pickler.dispatch[cls] = func
80
52
 
81
53
 
82
- @mainify
83
- def patch_pydantic_field_serialization():
54
+ def _patch_pydantic_field_serialization() -> None:
84
55
  # Cythonized pydantic fields can't be serialized automatically, so we are
85
56
  # have a special case handling for them that unpacks it to a dictionary
86
57
  # and then reloads it on the other side.
87
- import dill
88
-
58
+ # https://github.com/ray-project/ray/blob/842bbcf4236e41f58d25058b0482cd05bfe9e4da/python/ray/_private/pydantic_compat.py#L80
89
59
  try:
90
- import pydantic.fields
60
+ from pydantic.fields import ModelField, ModelPrivateAttr
91
61
  except ImportError:
92
62
  return
93
63
 
94
- @dill.register(pydantic.fields.ModelField)
95
- def _pickle_model_field(
96
- pickler: dill.Pickler,
97
- field: pydantic.fields.ModelField,
98
- ) -> None:
99
- args = {
64
+ def create_model_field(kwargs: dict) -> ModelField:
65
+ return ModelField(**kwargs)
66
+
67
+ def pickle_model_field(field: ModelField) -> tuple[Callable, tuple]:
68
+ kwargs = {
100
69
  "name": field.name,
101
70
  # outer_type_ is the original type for ModelFields,
102
71
  # while type_ can be updated later with the nested type
@@ -110,92 +79,147 @@ def patch_pydantic_field_serialization():
110
79
  "alias": field.alias,
111
80
  "field_info": field.field_info,
112
81
  }
113
- pickler.save_reduce(_pydantic_make_field, (args,), obj=field)
114
-
115
- @dill.register(pydantic.fields.ModelPrivateAttr)
116
- def _pickle_model_private_attr(
117
- pickler: dill.Pickler,
118
- field: pydantic.fields.ModelPrivateAttr,
119
- ) -> None:
120
- args = {
82
+
83
+ return create_model_field, (kwargs,)
84
+
85
+ def create_private_attr(kwargs: dict) -> ModelPrivateAttr:
86
+ return ModelPrivateAttr(**kwargs)
87
+
88
+ def pickle_private_attr(field: ModelPrivateAttr) -> tuple[Callable, tuple]:
89
+ kwargs = {
121
90
  "default": field.default,
122
91
  "default_factory": field.default_factory,
123
92
  }
124
- pickler.save_reduce(_pydantic_make_private_field, (args,), obj=field)
93
+
94
+ return create_private_attr, (kwargs,)
95
+
96
+ _register(ModelField, pickle_model_field)
97
+ _register(ModelPrivateAttr, pickle_private_attr)
125
98
 
126
99
 
127
- @mainify
128
- def patch_pydantic_class_attributes():
129
- # Dill attempts to modify the __class__ of deserialized pydantic objects
130
- # on this side but it meets with a rejection from pydantic's semantics since
131
- # __class__ is not recognized as a proper dunder attribute.
100
+ def _patch_pydantic_model_serialization() -> None:
101
+ # If user has created new pydantic models in his namespace, we will try to pickle those
102
+ # by value, which means recreating class skeleton, which will stumble upon
103
+ # __pydantic_parent_namespace__ in its __dict__ and it may contain modules that happened
104
+ # to be imported in the namespace but are not actually used, resulting in pickling errors.
105
+ # Unfortunately this also means that `model_rebuid()` might not work.
132
106
  try:
133
- import pydantic.utils
107
+ import pydantic
134
108
  except ImportError:
135
109
  return
136
110
 
137
- pydantic.utils.DUNDER_ATTRIBUTES.add("__class__")
111
+ # https://github.com/pydantic/pydantic/pull/2573
112
+ if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."):
113
+ return
114
+
115
+ backup = "_original_extract_class_dict"
116
+ if getattr(cloudpickle.cloudpickle, backup, None):
117
+ return
118
+
119
+ original = cloudpickle.cloudpickle._extract_class_dict
120
+
121
+ def patched(cls):
122
+ attr_name = "__pydantic_parent_namespace__"
123
+ if issubclass(cls, pydantic.BaseModel) and getattr(cls, attr_name, None):
124
+ setattr(cls, attr_name, None)
125
+
126
+ return original(cls)
127
+
128
+ cloudpickle.cloudpickle._extract_class_dict = patched
129
+ setattr(cloudpickle.cloudpickle, backup, original)
130
+
131
+
132
+ def _patch_lru_cache() -> None:
133
+ # https://github.com/cloudpipe/cloudpickle/issues/178
134
+ # https://github.com/uqfoundation/dill/blob/70f569b0dd268d2b1e85c0f300951b11f53c5d53/dill/_dill.py#L1429
138
135
 
136
+ from functools import lru_cache, _lru_cache_wrapper as LRUCacheType
139
137
 
140
- @mainify
141
- def patch_exceptions():
142
- # Adapting tblib.pickling_support.install for dill.
143
- from types import TracebackType
138
+ def create_lru_cache(func: Callable, kwargs: dict) -> LRUCacheType:
139
+ return lru_cache(**kwargs)(func)
144
140
 
145
- import dill
146
- from tblib.pickling_support import (
147
- _get_subclasses,
148
- pickle_exception,
149
- pickle_traceback,
150
- )
141
+ def pickle_lru_cache(obj: LRUCacheType) -> tuple[Callable, tuple]:
142
+ if hasattr(obj, "cache_parameters"):
143
+ params = obj.cache_parameters()
144
+ kwargs = {
145
+ "maxsize": params["maxsize"],
146
+ "typed": params["typed"],
147
+ }
148
+ else:
149
+ kwargs = {"maxsize": obj.cache_info().maxsize}
151
150
 
152
- @dill.register(TracebackType)
153
- def save_traceback(pickler, obj):
154
- unpickle, args = pickle_traceback(obj)
155
- pickler.save_reduce(unpickle, args, obj=obj)
151
+ return create_lru_cache, (obj.__wrapped__, kwargs)
156
152
 
157
- @dill.register(BaseException)
158
- def save_exception(pickler, obj):
159
- unpickle, args = pickle_exception(obj)
160
- pickler.save_reduce(unpickle, args, obj=obj)
153
+ _register(LRUCacheType, pickle_lru_cache)
161
154
 
162
- for exception_cls in _get_subclasses(BaseException):
163
- dill.pickle(exception_cls, save_exception)
155
+
156
+ def _patch_lock() -> None:
157
+ # https://github.com/uqfoundation/dill/blob/70f569b0dd268d2b1e85c0f300951b11f53c5d53/dill/_dill.py#L1310
158
+ from threading import Lock
159
+ from _thread import LockType
160
+
161
+ def create_lock(locked: bool) -> Lock:
162
+ lock = Lock()
163
+ if locked and not lock.acquire(False):
164
+ raise pickle.UnpicklingError("Cannot acquire lock")
165
+ return lock
166
+
167
+ def pickle_lock(obj: LockType) -> tuple[Callable, tuple]:
168
+ return create_lock, (obj.locked(),)
169
+
170
+ _register(LockType, pickle_lock)
171
+
172
+
173
+ def _patch_rlock() -> None:
174
+ # https://github.com/uqfoundation/dill/blob/70f569b0dd268d2b1e85c0f300951b11f53c5d53/dill/_dill.py#L1317
175
+ from _thread import RLock as RLockType # type: ignore[attr-defined]
176
+
177
+ def create_rlock(count: int, owner: int) -> RLockType:
178
+ lock = RLockType()
179
+ if owner is not None:
180
+ lock._acquire_restore((count, owner)) # type: ignore[attr-defined]
181
+ if owner and not lock._is_owned(): # type: ignore[attr-defined]
182
+ raise pickle.UnpicklingError("Cannot acquire lock")
183
+ return lock
184
+
185
+ def pickle_rlock(obj: RLockType) -> tuple[Callable, tuple]:
186
+ r = obj.__repr__()
187
+ count = int(r.split('count=')[1].split()[0].rstrip('>'))
188
+ owner = int(r.split('owner=')[1].split()[0])
189
+
190
+ return create_rlock, (count, owner)
191
+
192
+ _register(RLockType, pickle_rlock)
164
193
 
165
194
 
166
- @mainify
167
195
  def _patch_console_thread_locals() -> None:
168
- # NOTE: we __sometimes__ might have to serialize these
169
196
  from rich.console import ConsoleThreadLocals
170
197
 
171
- @dill.register(ConsoleThreadLocals)
172
- def save_console_thread_locals(pickler, obj):
173
- args = {
174
- "theme_stack": obj.theme_stack,
175
- "buffer": obj.buffer,
176
- "buffer_index": obj.buffer_index,
177
- }
198
+ def create_locals(kwargs: dict) -> ConsoleThreadLocals:
199
+ return ConsoleThreadLocals(**kwargs)
178
200
 
179
- def unpickle(kwargs):
180
- return ConsoleThreadLocals(**kwargs)
201
+ def pickle_locals(obj: ConsoleThreadLocals) -> tuple[Callable, tuple]:
202
+ kwargs = {"theme_stack": obj.theme_stack, "buffer": obj.buffer, "buffer_index": obj.buffer_index}
203
+ return create_locals, (kwargs, )
181
204
 
182
- pickler.save_reduce(unpickle, (args,), obj=obj)
205
+ _register(ConsoleThreadLocals, pickle_locals)
183
206
 
184
207
 
185
- @mainify
186
- def patch_dill():
187
- import dill
208
+ def _patch_exceptions() -> None:
209
+ # Support chained exceptions
210
+ from tblib.pickling_support import install
188
211
 
189
- dill.settings["recurse"] = True
212
+ install()
190
213
 
191
- patch_exceptions()
192
- patch_pydantic_class_attributes()
193
- patch_pydantic_field_serialization()
194
- _patch_console_thread_locals()
195
214
 
215
+ def patch_pickle() -> None:
216
+ _patch_pydantic_field_serialization()
217
+ _patch_pydantic_model_serialization()
218
+ _patch_lru_cache()
219
+ _patch_lock()
220
+ _patch_rlock()
221
+ _patch_console_thread_locals()
222
+ _patch_exceptions()
196
223
 
197
- @mainify
198
- def patch_pickle():
199
- from tblib import pickling_support
224
+ _register_pickle_by_value("fal")
200
225
 
201
- pickling_support.install()
fal/api.py CHANGED
@@ -23,22 +23,25 @@ from typing import (
23
23
  overload,
24
24
  )
25
25
 
26
- import dill
27
- import dill.detect
26
+ import cloudpickle
28
27
  import grpc
29
28
  import isolate
30
29
  import tblib
30
+ import uvicorn
31
31
  import yaml
32
32
  from fastapi import FastAPI
33
+ from fastapi import __version__ as fastapi_version
33
34
  from isolate.backends.common import active_python
34
35
  from isolate.backends.settings import DEFAULT_SETTINGS
35
36
  from isolate.connections import PythonIPC
36
37
  from packaging.requirements import Requirement
37
38
  from packaging.utils import canonicalize_name
39
+ from pydantic import __version__ as pydantic_version
38
40
  from typing_extensions import Concatenate, ParamSpec
39
41
 
40
42
  import fal.flags as flags
41
- from fal._serialization import add_serialization_listeners_for, patch_dill, patch_pickle
43
+ from fal.exceptions import FalServerlessException
44
+ from fal._serialization import include_modules_from, patch_pickle
42
45
  from fal.logging.isolate import IsolateLogPrinter
43
46
  from fal.sdk import (
44
47
  FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
@@ -52,7 +55,6 @@ from fal.sdk import (
52
55
  get_agent_credentials,
53
56
  get_default_credentials,
54
57
  )
55
- from fal.toolkit import mainify
56
58
 
57
59
  ArgsT = ParamSpec("ArgsT")
58
60
  ReturnT = TypeVar("ReturnT", covariant=True)
@@ -60,16 +62,21 @@ ReturnT = TypeVar("ReturnT", covariant=True)
60
62
  BasicConfig = Dict[str, Any]
61
63
  _UNSET = object()
62
64
 
63
- SERVE_REQUIREMENTS = ["fastapi==0.99.1", "uvicorn"]
65
+ SERVE_REQUIREMENTS = [
66
+ f"fastapi=={fastapi_version}",
67
+ f"pydantic=={pydantic_version}",
68
+ "uvicorn",
69
+ "starlette_exporter",
70
+ ]
64
71
 
65
72
 
66
73
  @dataclass
67
- class FalServerlessError(Exception):
74
+ class FalServerlessError(FalServerlessException):
68
75
  message: str
69
76
 
70
77
 
71
78
  @dataclass
72
- class InternalFalServerlessError(Exception):
79
+ class InternalFalServerlessError(FalServerlessException):
73
80
  message: str
74
81
 
75
82
 
@@ -177,24 +184,12 @@ def cached(func: Callable[ArgsT, ReturnT]) -> Callable[ArgsT, ReturnT]:
177
184
  return wrapper
178
185
 
179
186
 
180
- @mainify
181
- class UserFunctionException(Exception):
187
+ class UserFunctionException(FalServerlessException):
182
188
  pass
183
189
 
184
190
 
185
- def match_class(obj, cls):
186
- # NOTE: Can't use isinstance because we are not using dill's byref setting when
187
- # loading/dumping objects in RPC, which means that our exceptions from remote
188
- # server are created by value and are actually a separate class that only looks
189
- # like original one.
190
- #
191
- # See https://github.com/fal-ai/fal/issues/142
192
- return type(obj).__name__ == cls.__name__
193
-
194
-
195
191
  def _prepare_partial_func(
196
192
  func: Callable[ArgsT, ReturnT],
197
- patch_func: Callable[[], None],
198
193
  *args: ArgsT.args,
199
194
  **kwargs: ArgsT.kwargs,
200
195
  ) -> Callable[ArgsT, ReturnT]:
@@ -204,6 +199,8 @@ def _prepare_partial_func(
204
199
  def wrapper(*remote_args: ArgsT.args, **remote_kwargs: ArgsT.kwargs) -> ReturnT:
205
200
  try:
206
201
  result = func(*remote_args, *args, **remote_kwargs, **kwargs)
202
+ except FalServerlessException:
203
+ raise
207
204
  except Exception as exc:
208
205
  tb = exc.__traceback__
209
206
  if tb is not None and tb.tb_next is not None:
@@ -214,37 +211,22 @@ def _prepare_partial_func(
214
211
  ) from exc.with_traceback(tb)
215
212
  finally:
216
213
  with suppress(Exception):
217
- patch_func()
214
+ patch_pickle()
218
215
  return result
219
216
 
220
217
  return wrapper
221
218
 
222
219
 
223
- def _prepare_local_partial_func(
224
- func: Callable[ArgsT, ReturnT],
225
- *args: ArgsT.args,
226
- **kwargs: ArgsT.kwargs,
227
- ) -> Callable[ArgsT, ReturnT]:
228
-
229
- return _prepare_partial_func(func, patch_pickle, *args, **kwargs)
230
-
231
-
232
- def _prepare_remote_partial_func(
233
- func: Callable[ArgsT, ReturnT],
234
- *args: ArgsT.args,
235
- **kwargs: ArgsT.kwargs,
236
- ) -> Callable[ArgsT, ReturnT]:
237
-
238
- return _prepare_partial_func(func, patch_dill, *args, **kwargs)
239
-
240
-
241
220
  @dataclass
242
221
  class LocalHost(Host):
243
222
  # The environment which provides the default set of
244
223
  # packages for isolate agent to run.
245
224
  _AGENT_ENVIRONMENT = isolate.prepare_environment(
246
225
  "virtualenv",
247
- requirements=[f"dill=={dill.__version__}", f"tblib=={tblib.__version__}"],
226
+ requirements=[
227
+ f"cloudpickle=={cloudpickle.__version__}",
228
+ f"tblib=={tblib.__version__}",
229
+ ],
248
230
  )
249
231
  _log_printer = IsolateLogPrinter(debug=flags.DEBUG)
250
232
 
@@ -255,7 +237,11 @@ class LocalHost(Host):
255
237
  args: tuple[Any, ...],
256
238
  kwargs: dict[str, Any],
257
239
  ) -> ReturnT:
258
- settings = replace(DEFAULT_SETTINGS, serialization_method="dill", log_hook=self._log_printer.print)
240
+ settings = replace(
241
+ DEFAULT_SETTINGS,
242
+ serialization_method="cloudpickle",
243
+ log_hook=self._log_printer.print,
244
+ )
259
245
  environment = isolate.prepare_environment(
260
246
  **options.environment,
261
247
  context=settings,
@@ -265,7 +251,7 @@ class LocalHost(Host):
265
251
  environment.create(),
266
252
  extra_inheritance_paths=[self._AGENT_ENVIRONMENT.create()],
267
253
  ) as connection:
268
- executable = _prepare_local_partial_func(func, *args, **kwargs)
254
+ executable = _prepare_partial_func(func, *args, **kwargs)
269
255
  return connection.run(executable)
270
256
 
271
257
 
@@ -311,6 +297,8 @@ def _handle_grpc_error():
311
297
  def find_missing_dependencies(
312
298
  func: Callable, env: dict
313
299
  ) -> Iterator[tuple[str, list[str]]]:
300
+ import dill
301
+
314
302
  if env["kind"] != "virtualenv":
315
303
  return
316
304
 
@@ -422,7 +410,7 @@ class FalServerlessHost(Host):
422
410
  min_concurrency=min_concurrency,
423
411
  )
424
412
 
425
- partial_func = _prepare_remote_partial_func(func)
413
+ partial_func = _prepare_partial_func(func)
426
414
 
427
415
  if metadata is None:
428
416
  metadata = {}
@@ -452,6 +440,8 @@ class FalServerlessHost(Host):
452
440
  if partial_result.result:
453
441
  return partial_result.result.application_id
454
442
 
443
+ return None
444
+
455
445
  @_handle_grpc_error()
456
446
  def run(
457
447
  self,
@@ -492,7 +482,7 @@ class FalServerlessHost(Host):
492
482
  return_value = _UNSET
493
483
  # Allow isolate provided arguments (such as setup function) to take
494
484
  # precedence over the ones provided by the user.
495
- partial_func = _prepare_remote_partial_func(func, *args, **kwargs)
485
+ partial_func = _prepare_partial_func(func, *args, **kwargs)
496
486
  for partial_result in self._connection.run(
497
487
  partial_func,
498
488
  environments,
@@ -766,7 +756,7 @@ def function( # type: ignore
766
756
  options = host.parse_options(kind=kind, **config)
767
757
 
768
758
  def wrapper(func: Callable[ArgsT, ReturnT]):
769
- add_serialization_listeners_for(func)
759
+ include_modules_from(func)
770
760
  proxy = IsolatedFunction(
771
761
  host=host,
772
762
  raw_func=func, # type: ignore
@@ -777,7 +767,6 @@ def function( # type: ignore
777
767
  return wrapper
778
768
 
779
769
 
780
- @mainify
781
770
  class FalFastAPI(FastAPI):
782
771
  """
783
772
  A subclass of FastAPI that adds some fal-specific functionality.
@@ -821,7 +810,6 @@ class FalFastAPI(FastAPI):
821
810
  return spec
822
811
 
823
812
 
824
- @mainify
825
813
  class RouteSignature(NamedTuple):
826
814
  path: str
827
815
  is_websocket: bool = False
@@ -832,7 +820,6 @@ class RouteSignature(NamedTuple):
832
820
  emit_timings: bool = False
833
821
 
834
822
 
835
- @mainify
836
823
  class BaseServable:
837
824
  def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
838
825
  raise NotImplementedError
@@ -851,6 +838,7 @@ class BaseServable:
851
838
  from fastapi import HTTPException, Request
852
839
  from fastapi.middleware.cors import CORSMiddleware
853
840
  from fastapi.responses import JSONResponse
841
+ from starlette_exporter import PrometheusMiddleware
854
842
 
855
843
  _app = FalFastAPI(lifespan=self.lifespan)
856
844
 
@@ -861,6 +849,13 @@ class BaseServable:
861
849
  allow_methods=("*"),
862
850
  allow_origins=("*"),
863
851
  )
852
+ _app.add_middleware(
853
+ PrometheusMiddleware,
854
+ prefix="http",
855
+ group_paths=True,
856
+ filter_unhandled_paths=True,
857
+ app_name="fal",
858
+ )
864
859
 
865
860
  self._add_extra_middlewares(_app)
866
861
 
@@ -906,13 +901,43 @@ class BaseServable:
906
901
  return self._build_app().openapi()
907
902
 
908
903
  def serve(self) -> None:
909
- import uvicorn
904
+ import asyncio
905
+
906
+ from starlette_exporter import handle_metrics
907
+ from uvicorn import Config
910
908
 
911
909
  app = self._build_app()
912
- uvicorn.run(app, host="0.0.0.0", port=8080)
910
+ server = Server(config=Config(app, host="0.0.0.0", port=8080))
911
+ metrics_app = FastAPI()
912
+ metrics_app.add_route("/metrics", handle_metrics)
913
+ metrics_server = Server(config=Config(metrics_app, host="0.0.0.0", port=9090))
914
+
915
+ async def _serve() -> None:
916
+ tasks = {
917
+ asyncio.create_task(server.serve()): server,
918
+ asyncio.create_task(metrics_server.serve()): metrics_server,
919
+ }
920
+
921
+ _, pending = await asyncio.wait(tasks.keys(), return_when=asyncio.FIRST_COMPLETED)
922
+ if not pending:
923
+ return
924
+
925
+ # try graceful shutdown
926
+ for task in pending:
927
+ tasks[task].should_exit = True
928
+ _, pending = await asyncio.wait(pending, timeout=2)
929
+ if not pending:
930
+ return
931
+
932
+ for task in pending:
933
+ task.cancel()
934
+ await asyncio.wait(pending)
935
+
936
+ with suppress(asyncio.CancelledError):
937
+ asyncio.set_event_loop(asyncio.new_event_loop())
938
+ asyncio.run(_serve())
913
939
 
914
940
 
915
- @mainify
916
941
  class ServeWrapper(BaseServable):
917
942
  _func: Callable
918
943
 
@@ -993,7 +1018,7 @@ class IsolatedFunction(Generic[ArgsT, ReturnT]):
993
1018
  ) from None
994
1019
  except Exception as exc:
995
1020
  cause = exc.__cause__
996
- if self.reraise and match_class(exc, UserFunctionException) and cause:
1021
+ if self.reraise and isinstance(exc, UserFunctionException) and cause:
997
1022
  # re-raise original exception without our wrappers
998
1023
  raise cause
999
1024
  raise
@@ -1069,3 +1094,14 @@ class ServedIsolatedFunction(
1069
1094
  self, host: Host | None = None, *, serve: Literal[False], **config: Any
1070
1095
  ) -> IsolatedFunction[ArgsT, ReturnT]:
1071
1096
  ...
1097
+
1098
+
1099
+ class Server(uvicorn.Server):
1100
+ """Server is a uvicorn.Server that actually plays nicely with signals.
1101
+ By default, uvicorn's Server class overwrites the signal handler for SIGINT, swallowing the signal and preventing other tasks from cancelling.
1102
+ This class allows the task to be gracefully cancelled using asyncio's built-in task cancellation or with an event, like aiohttp.
1103
+ """
1104
+
1105
+ def install_signal_handlers(self) -> None:
1106
+ pass
1107
+