fal 0.14.0__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fal might be problematic. Click here for more details.
- fal/_serialization.py +149 -125
- fal/api.py +88 -52
- fal/app.py +2 -7
- fal/auth/__init__.py +0 -2
- fal/cli.py +2 -2
- fal/exceptions/__init__.py +1 -2
- fal/exceptions/_base.py +1 -12
- fal/exceptions/auth.py +2 -4
- fal/exceptions/handlers.py +8 -19
- fal/sdk.py +2 -3
- fal/toolkit/__init__.py +0 -2
- fal/toolkit/exceptions.py +0 -5
- fal/toolkit/file/file.py +57 -54
- fal/toolkit/file/providers/fal.py +0 -4
- fal/toolkit/file/providers/gcp.py +0 -2
- fal/toolkit/file/providers/r2.py +0 -2
- fal/toolkit/file/types.py +0 -4
- fal/toolkit/image/image.py +10 -14
- fal/toolkit/optimize.py +0 -2
- fal/toolkit/utils/download_utils.py +1 -14
- fal/workflows.py +2 -1
- {fal-0.14.0.dist-info → fal-0.15.0.dist-info}/METADATA +40 -38
- {fal-0.14.0.dist-info → fal-0.15.0.dist-info}/RECORD +50 -51
- {fal-0.14.0.dist-info → fal-0.15.0.dist-info}/WHEEL +2 -1
- fal-0.15.0.dist-info/entry_points.txt +2 -0
- fal-0.15.0.dist-info/top_level.txt +2 -0
- fal/env.py +0 -3
- fal/toolkit/mainify.py +0 -13
- fal-0.14.0.dist-info/entry_points.txt +0 -4
fal/_serialization.py
CHANGED
|
@@ -1,102 +1,71 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
from pathlib import Path
|
|
3
|
+
from typing import Any, Callable
|
|
5
4
|
|
|
6
|
-
import
|
|
7
|
-
|
|
5
|
+
import pickle
|
|
6
|
+
import cloudpickle
|
|
8
7
|
|
|
9
|
-
from fal.toolkit import mainify
|
|
10
8
|
|
|
11
|
-
|
|
12
|
-
#
|
|
13
|
-
|
|
14
|
-
|
|
9
|
+
def _register_pickle_by_value(name) -> None:
|
|
10
|
+
# cloudpickle.register_pickle_by_value wants an imported module object,
|
|
11
|
+
# but there is really no reason to go through that complication, as
|
|
12
|
+
# it might be prone to errors.
|
|
13
|
+
cloudpickle.cloudpickle._PICKLE_BY_VALUE_MODULES.add(name)
|
|
15
14
|
|
|
16
15
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
from pydantic.fields import ModelField
|
|
16
|
+
def include_package_from_path(raw_path: str) -> None:
|
|
17
|
+
from pathlib import Path
|
|
20
18
|
|
|
21
|
-
return ModelField(**kwargs)
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
@mainify
|
|
25
|
-
def _pydantic_make_private_field(kwargs):
|
|
26
|
-
from pydantic.fields import ModelPrivateAttr
|
|
27
|
-
|
|
28
|
-
return ModelPrivateAttr(**kwargs)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
# this allows us to record all the "isolated" function and then mainify everything in
|
|
32
|
-
# module they exist
|
|
33
|
-
@wraps(_dill._locate_function)
|
|
34
|
-
def by_value_locator(obj, pickler=None, og_locator=_dill._locate_function):
|
|
35
|
-
module_name = getattr(obj, "__module__", None)
|
|
36
|
-
if module_name is not None:
|
|
37
|
-
# If it is coming from the same module, directly allow
|
|
38
|
-
# it to be pickled.
|
|
39
|
-
if module_name in _MODULES:
|
|
40
|
-
return False
|
|
41
|
-
|
|
42
|
-
package_name, *_ = module_name.partition(".")
|
|
43
|
-
# If it is coming from the same package, then do the same.
|
|
44
|
-
if package_name in _PACKAGES:
|
|
45
|
-
return False
|
|
46
|
-
|
|
47
|
-
og_result = og_locator(obj, pickler)
|
|
48
|
-
return og_result
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
_dill._locate_function = by_value_locator
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def include_packages_from_path(raw_path: str):
|
|
55
19
|
path = Path(raw_path).resolve()
|
|
56
20
|
parent = path
|
|
57
21
|
while (parent.parent / "__init__.py").exists():
|
|
58
22
|
parent = parent.parent
|
|
59
23
|
|
|
60
24
|
if parent != path:
|
|
61
|
-
|
|
25
|
+
_register_pickle_by_value(parent.name)
|
|
62
26
|
|
|
63
27
|
|
|
64
|
-
def
|
|
28
|
+
def include_modules_from(obj: Any) -> None:
|
|
65
29
|
module_name = getattr(obj, "__module__", None)
|
|
66
30
|
if not module_name:
|
|
67
|
-
return
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
if "." in module_name:
|
|
34
|
+
# Just include the whole package
|
|
35
|
+
package_name, *_ = module_name.partition(".")
|
|
36
|
+
_register_pickle_by_value(package_name)
|
|
37
|
+
return
|
|
68
38
|
|
|
69
|
-
_MODULES.add(module_name)
|
|
70
39
|
if module_name == "__main__":
|
|
71
40
|
# When the module is __main__, we need to recursively go up the
|
|
72
41
|
# tree to locate the actual package name.
|
|
73
42
|
import __main__
|
|
74
43
|
|
|
75
|
-
|
|
44
|
+
include_package_from_path(__main__.__file__)
|
|
45
|
+
return
|
|
76
46
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
47
|
+
_register_pickle_by_value(module_name)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _register(cls: Any, func: Callable) -> None:
|
|
51
|
+
cloudpickle.Pickler.dispatch[cls] = func
|
|
80
52
|
|
|
81
53
|
|
|
82
|
-
|
|
83
|
-
def patch_pydantic_field_serialization():
|
|
54
|
+
def _patch_pydantic_field_serialization() -> None:
|
|
84
55
|
# Cythonized pydantic fields can't be serialized automatically, so we are
|
|
85
56
|
# have a special case handling for them that unpacks it to a dictionary
|
|
86
57
|
# and then reloads it on the other side.
|
|
87
|
-
|
|
88
|
-
|
|
58
|
+
# https://github.com/ray-project/ray/blob/842bbcf4236e41f58d25058b0482cd05bfe9e4da/python/ray/_private/pydantic_compat.py#L80
|
|
89
59
|
try:
|
|
90
|
-
|
|
60
|
+
from pydantic.fields import ModelField, ModelPrivateAttr
|
|
91
61
|
except ImportError:
|
|
92
62
|
return
|
|
93
63
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
args = {
|
|
64
|
+
def create_model_field(kwargs: dict) -> ModelField:
|
|
65
|
+
return ModelField(**kwargs)
|
|
66
|
+
|
|
67
|
+
def pickle_model_field(field: ModelField) -> tuple[Callable, tuple]:
|
|
68
|
+
kwargs = {
|
|
100
69
|
"name": field.name,
|
|
101
70
|
# outer_type_ is the original type for ModelFields,
|
|
102
71
|
# while type_ can be updated later with the nested type
|
|
@@ -110,92 +79,147 @@ def patch_pydantic_field_serialization():
|
|
|
110
79
|
"alias": field.alias,
|
|
111
80
|
"field_info": field.field_info,
|
|
112
81
|
}
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
def
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
) ->
|
|
120
|
-
|
|
82
|
+
|
|
83
|
+
return create_model_field, (kwargs,)
|
|
84
|
+
|
|
85
|
+
def create_private_attr(kwargs: dict) -> ModelPrivateAttr:
|
|
86
|
+
return ModelPrivateAttr(**kwargs)
|
|
87
|
+
|
|
88
|
+
def pickle_private_attr(field: ModelPrivateAttr) -> tuple[Callable, tuple]:
|
|
89
|
+
kwargs = {
|
|
121
90
|
"default": field.default,
|
|
122
91
|
"default_factory": field.default_factory,
|
|
123
92
|
}
|
|
124
|
-
|
|
93
|
+
|
|
94
|
+
return create_private_attr, (kwargs,)
|
|
95
|
+
|
|
96
|
+
_register(ModelField, pickle_model_field)
|
|
97
|
+
_register(ModelPrivateAttr, pickle_private_attr)
|
|
125
98
|
|
|
126
99
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
#
|
|
130
|
-
#
|
|
131
|
-
#
|
|
100
|
+
def _patch_pydantic_model_serialization() -> None:
|
|
101
|
+
# If user has created new pydantic models in his namespace, we will try to pickle those
|
|
102
|
+
# by value, which means recreating class skeleton, which will stumble upon
|
|
103
|
+
# __pydantic_parent_namespace__ in its __dict__ and it may contain modules that happened
|
|
104
|
+
# to be imported in the namespace but are not actually used, resulting in pickling errors.
|
|
105
|
+
# Unfortunately this also means that `model_rebuid()` might not work.
|
|
132
106
|
try:
|
|
133
|
-
import pydantic
|
|
107
|
+
import pydantic
|
|
134
108
|
except ImportError:
|
|
135
109
|
return
|
|
136
110
|
|
|
137
|
-
pydantic
|
|
111
|
+
# https://github.com/pydantic/pydantic/pull/2573
|
|
112
|
+
if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."):
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
backup = "_original_extract_class_dict"
|
|
116
|
+
if getattr(cloudpickle.cloudpickle, backup, None):
|
|
117
|
+
return
|
|
118
|
+
|
|
119
|
+
original = cloudpickle.cloudpickle._extract_class_dict
|
|
120
|
+
|
|
121
|
+
def patched(cls):
|
|
122
|
+
attr_name = "__pydantic_parent_namespace__"
|
|
123
|
+
if issubclass(cls, pydantic.BaseModel) and getattr(cls, attr_name, None):
|
|
124
|
+
setattr(cls, attr_name, None)
|
|
125
|
+
|
|
126
|
+
return original(cls)
|
|
127
|
+
|
|
128
|
+
cloudpickle.cloudpickle._extract_class_dict = patched
|
|
129
|
+
setattr(cloudpickle.cloudpickle, backup, original)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _patch_lru_cache() -> None:
|
|
133
|
+
# https://github.com/cloudpipe/cloudpickle/issues/178
|
|
134
|
+
# https://github.com/uqfoundation/dill/blob/70f569b0dd268d2b1e85c0f300951b11f53c5d53/dill/_dill.py#L1429
|
|
138
135
|
|
|
136
|
+
from functools import lru_cache, _lru_cache_wrapper as LRUCacheType
|
|
139
137
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
# Adapting tblib.pickling_support.install for dill.
|
|
143
|
-
from types import TracebackType
|
|
138
|
+
def create_lru_cache(func: Callable, kwargs: dict) -> LRUCacheType:
|
|
139
|
+
return lru_cache(**kwargs)(func)
|
|
144
140
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
141
|
+
def pickle_lru_cache(obj: LRUCacheType) -> tuple[Callable, tuple]:
|
|
142
|
+
if hasattr(obj, "cache_parameters"):
|
|
143
|
+
params = obj.cache_parameters()
|
|
144
|
+
kwargs = {
|
|
145
|
+
"maxsize": params["maxsize"],
|
|
146
|
+
"typed": params["typed"],
|
|
147
|
+
}
|
|
148
|
+
else:
|
|
149
|
+
kwargs = {"maxsize": obj.cache_info().maxsize}
|
|
151
150
|
|
|
152
|
-
|
|
153
|
-
def save_traceback(pickler, obj):
|
|
154
|
-
unpickle, args = pickle_traceback(obj)
|
|
155
|
-
pickler.save_reduce(unpickle, args, obj=obj)
|
|
151
|
+
return create_lru_cache, (obj.__wrapped__, kwargs)
|
|
156
152
|
|
|
157
|
-
|
|
158
|
-
def save_exception(pickler, obj):
|
|
159
|
-
unpickle, args = pickle_exception(obj)
|
|
160
|
-
pickler.save_reduce(unpickle, args, obj=obj)
|
|
153
|
+
_register(LRUCacheType, pickle_lru_cache)
|
|
161
154
|
|
|
162
|
-
|
|
163
|
-
|
|
155
|
+
|
|
156
|
+
def _patch_lock() -> None:
|
|
157
|
+
# https://github.com/uqfoundation/dill/blob/70f569b0dd268d2b1e85c0f300951b11f53c5d53/dill/_dill.py#L1310
|
|
158
|
+
from threading import Lock
|
|
159
|
+
from _thread import LockType
|
|
160
|
+
|
|
161
|
+
def create_lock(locked: bool) -> Lock:
|
|
162
|
+
lock = Lock()
|
|
163
|
+
if locked and not lock.acquire(False):
|
|
164
|
+
raise pickle.UnpicklingError("Cannot acquire lock")
|
|
165
|
+
return lock
|
|
166
|
+
|
|
167
|
+
def pickle_lock(obj: LockType) -> tuple[Callable, tuple]:
|
|
168
|
+
return create_lock, (obj.locked(),)
|
|
169
|
+
|
|
170
|
+
_register(LockType, pickle_lock)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _patch_rlock() -> None:
|
|
174
|
+
# https://github.com/uqfoundation/dill/blob/70f569b0dd268d2b1e85c0f300951b11f53c5d53/dill/_dill.py#L1317
|
|
175
|
+
from _thread import RLock as RLockType # type: ignore[attr-defined]
|
|
176
|
+
|
|
177
|
+
def create_rlock(count: int, owner: int) -> RLockType:
|
|
178
|
+
lock = RLockType()
|
|
179
|
+
if owner is not None:
|
|
180
|
+
lock._acquire_restore((count, owner)) # type: ignore[attr-defined]
|
|
181
|
+
if owner and not lock._is_owned(): # type: ignore[attr-defined]
|
|
182
|
+
raise pickle.UnpicklingError("Cannot acquire lock")
|
|
183
|
+
return lock
|
|
184
|
+
|
|
185
|
+
def pickle_rlock(obj: RLockType) -> tuple[Callable, tuple]:
|
|
186
|
+
r = obj.__repr__()
|
|
187
|
+
count = int(r.split('count=')[1].split()[0].rstrip('>'))
|
|
188
|
+
owner = int(r.split('owner=')[1].split()[0])
|
|
189
|
+
|
|
190
|
+
return create_rlock, (count, owner)
|
|
191
|
+
|
|
192
|
+
_register(RLockType, pickle_rlock)
|
|
164
193
|
|
|
165
194
|
|
|
166
|
-
@mainify
|
|
167
195
|
def _patch_console_thread_locals() -> None:
|
|
168
|
-
# NOTE: we __sometimes__ might have to serialize these
|
|
169
196
|
from rich.console import ConsoleThreadLocals
|
|
170
197
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
args = {
|
|
174
|
-
"theme_stack": obj.theme_stack,
|
|
175
|
-
"buffer": obj.buffer,
|
|
176
|
-
"buffer_index": obj.buffer_index,
|
|
177
|
-
}
|
|
198
|
+
def create_locals(kwargs: dict) -> ConsoleThreadLocals:
|
|
199
|
+
return ConsoleThreadLocals(**kwargs)
|
|
178
200
|
|
|
179
|
-
|
|
180
|
-
|
|
201
|
+
def pickle_locals(obj: ConsoleThreadLocals) -> tuple[Callable, tuple]:
|
|
202
|
+
kwargs = {"theme_stack": obj.theme_stack, "buffer": obj.buffer, "buffer_index": obj.buffer_index}
|
|
203
|
+
return create_locals, (kwargs, )
|
|
181
204
|
|
|
182
|
-
|
|
205
|
+
_register(ConsoleThreadLocals, pickle_locals)
|
|
183
206
|
|
|
184
207
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
import
|
|
208
|
+
def _patch_exceptions() -> None:
|
|
209
|
+
# Support chained exceptions
|
|
210
|
+
from tblib.pickling_support import install
|
|
188
211
|
|
|
189
|
-
|
|
212
|
+
install()
|
|
190
213
|
|
|
191
|
-
patch_exceptions()
|
|
192
|
-
patch_pydantic_class_attributes()
|
|
193
|
-
patch_pydantic_field_serialization()
|
|
194
|
-
_patch_console_thread_locals()
|
|
195
214
|
|
|
215
|
+
def patch_pickle() -> None:
|
|
216
|
+
_patch_pydantic_field_serialization()
|
|
217
|
+
_patch_pydantic_model_serialization()
|
|
218
|
+
_patch_lru_cache()
|
|
219
|
+
_patch_lock()
|
|
220
|
+
_patch_rlock()
|
|
221
|
+
_patch_console_thread_locals()
|
|
222
|
+
_patch_exceptions()
|
|
196
223
|
|
|
197
|
-
|
|
198
|
-
def patch_pickle():
|
|
199
|
-
from tblib import pickling_support
|
|
224
|
+
_register_pickle_by_value("fal")
|
|
200
225
|
|
|
201
|
-
pickling_support.install()
|
fal/api.py
CHANGED
|
@@ -23,22 +23,25 @@ from typing import (
|
|
|
23
23
|
overload,
|
|
24
24
|
)
|
|
25
25
|
|
|
26
|
-
import
|
|
27
|
-
import dill.detect
|
|
26
|
+
import cloudpickle
|
|
28
27
|
import grpc
|
|
29
28
|
import isolate
|
|
30
29
|
import tblib
|
|
30
|
+
import uvicorn
|
|
31
31
|
import yaml
|
|
32
32
|
from fastapi import FastAPI
|
|
33
|
+
from fastapi import __version__ as fastapi_version
|
|
33
34
|
from isolate.backends.common import active_python
|
|
34
35
|
from isolate.backends.settings import DEFAULT_SETTINGS
|
|
35
36
|
from isolate.connections import PythonIPC
|
|
36
37
|
from packaging.requirements import Requirement
|
|
37
38
|
from packaging.utils import canonicalize_name
|
|
39
|
+
from pydantic import __version__ as pydantic_version
|
|
38
40
|
from typing_extensions import Concatenate, ParamSpec
|
|
39
41
|
|
|
40
42
|
import fal.flags as flags
|
|
41
|
-
from fal.
|
|
43
|
+
from fal.exceptions import FalServerlessException
|
|
44
|
+
from fal._serialization import include_modules_from, patch_pickle
|
|
42
45
|
from fal.logging.isolate import IsolateLogPrinter
|
|
43
46
|
from fal.sdk import (
|
|
44
47
|
FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
|
|
@@ -52,7 +55,6 @@ from fal.sdk import (
|
|
|
52
55
|
get_agent_credentials,
|
|
53
56
|
get_default_credentials,
|
|
54
57
|
)
|
|
55
|
-
from fal.toolkit import mainify
|
|
56
58
|
|
|
57
59
|
ArgsT = ParamSpec("ArgsT")
|
|
58
60
|
ReturnT = TypeVar("ReturnT", covariant=True)
|
|
@@ -60,16 +62,21 @@ ReturnT = TypeVar("ReturnT", covariant=True)
|
|
|
60
62
|
BasicConfig = Dict[str, Any]
|
|
61
63
|
_UNSET = object()
|
|
62
64
|
|
|
63
|
-
SERVE_REQUIREMENTS = [
|
|
65
|
+
SERVE_REQUIREMENTS = [
|
|
66
|
+
f"fastapi=={fastapi_version}",
|
|
67
|
+
f"pydantic=={pydantic_version}",
|
|
68
|
+
"uvicorn",
|
|
69
|
+
"starlette_exporter",
|
|
70
|
+
]
|
|
64
71
|
|
|
65
72
|
|
|
66
73
|
@dataclass
|
|
67
|
-
class FalServerlessError(
|
|
74
|
+
class FalServerlessError(FalServerlessException):
|
|
68
75
|
message: str
|
|
69
76
|
|
|
70
77
|
|
|
71
78
|
@dataclass
|
|
72
|
-
class InternalFalServerlessError(
|
|
79
|
+
class InternalFalServerlessError(FalServerlessException):
|
|
73
80
|
message: str
|
|
74
81
|
|
|
75
82
|
|
|
@@ -177,24 +184,12 @@ def cached(func: Callable[ArgsT, ReturnT]) -> Callable[ArgsT, ReturnT]:
|
|
|
177
184
|
return wrapper
|
|
178
185
|
|
|
179
186
|
|
|
180
|
-
|
|
181
|
-
class UserFunctionException(Exception):
|
|
187
|
+
class UserFunctionException(FalServerlessException):
|
|
182
188
|
pass
|
|
183
189
|
|
|
184
190
|
|
|
185
|
-
def match_class(obj, cls):
|
|
186
|
-
# NOTE: Can't use isinstance because we are not using dill's byref setting when
|
|
187
|
-
# loading/dumping objects in RPC, which means that our exceptions from remote
|
|
188
|
-
# server are created by value and are actually a separate class that only looks
|
|
189
|
-
# like original one.
|
|
190
|
-
#
|
|
191
|
-
# See https://github.com/fal-ai/fal/issues/142
|
|
192
|
-
return type(obj).__name__ == cls.__name__
|
|
193
|
-
|
|
194
|
-
|
|
195
191
|
def _prepare_partial_func(
|
|
196
192
|
func: Callable[ArgsT, ReturnT],
|
|
197
|
-
patch_func: Callable[[], None],
|
|
198
193
|
*args: ArgsT.args,
|
|
199
194
|
**kwargs: ArgsT.kwargs,
|
|
200
195
|
) -> Callable[ArgsT, ReturnT]:
|
|
@@ -204,6 +199,8 @@ def _prepare_partial_func(
|
|
|
204
199
|
def wrapper(*remote_args: ArgsT.args, **remote_kwargs: ArgsT.kwargs) -> ReturnT:
|
|
205
200
|
try:
|
|
206
201
|
result = func(*remote_args, *args, **remote_kwargs, **kwargs)
|
|
202
|
+
except FalServerlessException:
|
|
203
|
+
raise
|
|
207
204
|
except Exception as exc:
|
|
208
205
|
tb = exc.__traceback__
|
|
209
206
|
if tb is not None and tb.tb_next is not None:
|
|
@@ -214,37 +211,22 @@ def _prepare_partial_func(
|
|
|
214
211
|
) from exc.with_traceback(tb)
|
|
215
212
|
finally:
|
|
216
213
|
with suppress(Exception):
|
|
217
|
-
|
|
214
|
+
patch_pickle()
|
|
218
215
|
return result
|
|
219
216
|
|
|
220
217
|
return wrapper
|
|
221
218
|
|
|
222
219
|
|
|
223
|
-
def _prepare_local_partial_func(
|
|
224
|
-
func: Callable[ArgsT, ReturnT],
|
|
225
|
-
*args: ArgsT.args,
|
|
226
|
-
**kwargs: ArgsT.kwargs,
|
|
227
|
-
) -> Callable[ArgsT, ReturnT]:
|
|
228
|
-
|
|
229
|
-
return _prepare_partial_func(func, patch_pickle, *args, **kwargs)
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
def _prepare_remote_partial_func(
|
|
233
|
-
func: Callable[ArgsT, ReturnT],
|
|
234
|
-
*args: ArgsT.args,
|
|
235
|
-
**kwargs: ArgsT.kwargs,
|
|
236
|
-
) -> Callable[ArgsT, ReturnT]:
|
|
237
|
-
|
|
238
|
-
return _prepare_partial_func(func, patch_dill, *args, **kwargs)
|
|
239
|
-
|
|
240
|
-
|
|
241
220
|
@dataclass
|
|
242
221
|
class LocalHost(Host):
|
|
243
222
|
# The environment which provides the default set of
|
|
244
223
|
# packages for isolate agent to run.
|
|
245
224
|
_AGENT_ENVIRONMENT = isolate.prepare_environment(
|
|
246
225
|
"virtualenv",
|
|
247
|
-
requirements=[
|
|
226
|
+
requirements=[
|
|
227
|
+
f"cloudpickle=={cloudpickle.__version__}",
|
|
228
|
+
f"tblib=={tblib.__version__}",
|
|
229
|
+
],
|
|
248
230
|
)
|
|
249
231
|
_log_printer = IsolateLogPrinter(debug=flags.DEBUG)
|
|
250
232
|
|
|
@@ -255,7 +237,11 @@ class LocalHost(Host):
|
|
|
255
237
|
args: tuple[Any, ...],
|
|
256
238
|
kwargs: dict[str, Any],
|
|
257
239
|
) -> ReturnT:
|
|
258
|
-
settings = replace(
|
|
240
|
+
settings = replace(
|
|
241
|
+
DEFAULT_SETTINGS,
|
|
242
|
+
serialization_method="cloudpickle",
|
|
243
|
+
log_hook=self._log_printer.print,
|
|
244
|
+
)
|
|
259
245
|
environment = isolate.prepare_environment(
|
|
260
246
|
**options.environment,
|
|
261
247
|
context=settings,
|
|
@@ -265,7 +251,7 @@ class LocalHost(Host):
|
|
|
265
251
|
environment.create(),
|
|
266
252
|
extra_inheritance_paths=[self._AGENT_ENVIRONMENT.create()],
|
|
267
253
|
) as connection:
|
|
268
|
-
executable =
|
|
254
|
+
executable = _prepare_partial_func(func, *args, **kwargs)
|
|
269
255
|
return connection.run(executable)
|
|
270
256
|
|
|
271
257
|
|
|
@@ -311,6 +297,8 @@ def _handle_grpc_error():
|
|
|
311
297
|
def find_missing_dependencies(
|
|
312
298
|
func: Callable, env: dict
|
|
313
299
|
) -> Iterator[tuple[str, list[str]]]:
|
|
300
|
+
import dill
|
|
301
|
+
|
|
314
302
|
if env["kind"] != "virtualenv":
|
|
315
303
|
return
|
|
316
304
|
|
|
@@ -422,7 +410,7 @@ class FalServerlessHost(Host):
|
|
|
422
410
|
min_concurrency=min_concurrency,
|
|
423
411
|
)
|
|
424
412
|
|
|
425
|
-
partial_func =
|
|
413
|
+
partial_func = _prepare_partial_func(func)
|
|
426
414
|
|
|
427
415
|
if metadata is None:
|
|
428
416
|
metadata = {}
|
|
@@ -452,6 +440,8 @@ class FalServerlessHost(Host):
|
|
|
452
440
|
if partial_result.result:
|
|
453
441
|
return partial_result.result.application_id
|
|
454
442
|
|
|
443
|
+
return None
|
|
444
|
+
|
|
455
445
|
@_handle_grpc_error()
|
|
456
446
|
def run(
|
|
457
447
|
self,
|
|
@@ -492,7 +482,7 @@ class FalServerlessHost(Host):
|
|
|
492
482
|
return_value = _UNSET
|
|
493
483
|
# Allow isolate provided arguments (such as setup function) to take
|
|
494
484
|
# precedence over the ones provided by the user.
|
|
495
|
-
partial_func =
|
|
485
|
+
partial_func = _prepare_partial_func(func, *args, **kwargs)
|
|
496
486
|
for partial_result in self._connection.run(
|
|
497
487
|
partial_func,
|
|
498
488
|
environments,
|
|
@@ -766,7 +756,7 @@ def function( # type: ignore
|
|
|
766
756
|
options = host.parse_options(kind=kind, **config)
|
|
767
757
|
|
|
768
758
|
def wrapper(func: Callable[ArgsT, ReturnT]):
|
|
769
|
-
|
|
759
|
+
include_modules_from(func)
|
|
770
760
|
proxy = IsolatedFunction(
|
|
771
761
|
host=host,
|
|
772
762
|
raw_func=func, # type: ignore
|
|
@@ -777,7 +767,6 @@ def function( # type: ignore
|
|
|
777
767
|
return wrapper
|
|
778
768
|
|
|
779
769
|
|
|
780
|
-
@mainify
|
|
781
770
|
class FalFastAPI(FastAPI):
|
|
782
771
|
"""
|
|
783
772
|
A subclass of FastAPI that adds some fal-specific functionality.
|
|
@@ -821,7 +810,6 @@ class FalFastAPI(FastAPI):
|
|
|
821
810
|
return spec
|
|
822
811
|
|
|
823
812
|
|
|
824
|
-
@mainify
|
|
825
813
|
class RouteSignature(NamedTuple):
|
|
826
814
|
path: str
|
|
827
815
|
is_websocket: bool = False
|
|
@@ -832,7 +820,6 @@ class RouteSignature(NamedTuple):
|
|
|
832
820
|
emit_timings: bool = False
|
|
833
821
|
|
|
834
822
|
|
|
835
|
-
@mainify
|
|
836
823
|
class BaseServable:
|
|
837
824
|
def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
|
|
838
825
|
raise NotImplementedError
|
|
@@ -851,6 +838,7 @@ class BaseServable:
|
|
|
851
838
|
from fastapi import HTTPException, Request
|
|
852
839
|
from fastapi.middleware.cors import CORSMiddleware
|
|
853
840
|
from fastapi.responses import JSONResponse
|
|
841
|
+
from starlette_exporter import PrometheusMiddleware
|
|
854
842
|
|
|
855
843
|
_app = FalFastAPI(lifespan=self.lifespan)
|
|
856
844
|
|
|
@@ -861,6 +849,13 @@ class BaseServable:
|
|
|
861
849
|
allow_methods=("*"),
|
|
862
850
|
allow_origins=("*"),
|
|
863
851
|
)
|
|
852
|
+
_app.add_middleware(
|
|
853
|
+
PrometheusMiddleware,
|
|
854
|
+
prefix="http",
|
|
855
|
+
group_paths=True,
|
|
856
|
+
filter_unhandled_paths=True,
|
|
857
|
+
app_name="fal",
|
|
858
|
+
)
|
|
864
859
|
|
|
865
860
|
self._add_extra_middlewares(_app)
|
|
866
861
|
|
|
@@ -906,13 +901,43 @@ class BaseServable:
|
|
|
906
901
|
return self._build_app().openapi()
|
|
907
902
|
|
|
908
903
|
def serve(self) -> None:
|
|
909
|
-
import
|
|
904
|
+
import asyncio
|
|
905
|
+
|
|
906
|
+
from starlette_exporter import handle_metrics
|
|
907
|
+
from uvicorn import Config
|
|
910
908
|
|
|
911
909
|
app = self._build_app()
|
|
912
|
-
|
|
910
|
+
server = Server(config=Config(app, host="0.0.0.0", port=8080))
|
|
911
|
+
metrics_app = FastAPI()
|
|
912
|
+
metrics_app.add_route("/metrics", handle_metrics)
|
|
913
|
+
metrics_server = Server(config=Config(metrics_app, host="0.0.0.0", port=9090))
|
|
914
|
+
|
|
915
|
+
async def _serve() -> None:
|
|
916
|
+
tasks = {
|
|
917
|
+
asyncio.create_task(server.serve()): server,
|
|
918
|
+
asyncio.create_task(metrics_server.serve()): metrics_server,
|
|
919
|
+
}
|
|
920
|
+
|
|
921
|
+
_, pending = await asyncio.wait(tasks.keys(), return_when=asyncio.FIRST_COMPLETED)
|
|
922
|
+
if not pending:
|
|
923
|
+
return
|
|
924
|
+
|
|
925
|
+
# try graceful shutdown
|
|
926
|
+
for task in pending:
|
|
927
|
+
tasks[task].should_exit = True
|
|
928
|
+
_, pending = await asyncio.wait(pending, timeout=2)
|
|
929
|
+
if not pending:
|
|
930
|
+
return
|
|
931
|
+
|
|
932
|
+
for task in pending:
|
|
933
|
+
task.cancel()
|
|
934
|
+
await asyncio.wait(pending)
|
|
935
|
+
|
|
936
|
+
with suppress(asyncio.CancelledError):
|
|
937
|
+
asyncio.set_event_loop(asyncio.new_event_loop())
|
|
938
|
+
asyncio.run(_serve())
|
|
913
939
|
|
|
914
940
|
|
|
915
|
-
@mainify
|
|
916
941
|
class ServeWrapper(BaseServable):
|
|
917
942
|
_func: Callable
|
|
918
943
|
|
|
@@ -993,7 +1018,7 @@ class IsolatedFunction(Generic[ArgsT, ReturnT]):
|
|
|
993
1018
|
) from None
|
|
994
1019
|
except Exception as exc:
|
|
995
1020
|
cause = exc.__cause__
|
|
996
|
-
if self.reraise and
|
|
1021
|
+
if self.reraise and isinstance(exc, UserFunctionException) and cause:
|
|
997
1022
|
# re-raise original exception without our wrappers
|
|
998
1023
|
raise cause
|
|
999
1024
|
raise
|
|
@@ -1069,3 +1094,14 @@ class ServedIsolatedFunction(
|
|
|
1069
1094
|
self, host: Host | None = None, *, serve: Literal[False], **config: Any
|
|
1070
1095
|
) -> IsolatedFunction[ArgsT, ReturnT]:
|
|
1071
1096
|
...
|
|
1097
|
+
|
|
1098
|
+
|
|
1099
|
+
class Server(uvicorn.Server):
|
|
1100
|
+
"""Server is a uvicorn.Server that actually plays nicely with signals.
|
|
1101
|
+
By default, uvicorn's Server class overwrites the signal handler for SIGINT, swallowing the signal and preventing other tasks from cancelling.
|
|
1102
|
+
This class allows the task to be gracefully cancelled using asyncio's built-in task cancellation or with an event, like aiohttp.
|
|
1103
|
+
"""
|
|
1104
|
+
|
|
1105
|
+
def install_signal_handlers(self) -> None:
|
|
1106
|
+
pass
|
|
1107
|
+
|