fal 0.13.0__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fal might be problematic. Click here for more details.
- fal/__init__.py +4 -4
- fal/_serialization.py +159 -108
- fal/api.py +91 -36
- fal/app.py +3 -8
- fal/auth/__init__.py +1 -3
- fal/cli.py +3 -4
- fal/exceptions/__init__.py +1 -2
- fal/exceptions/_base.py +1 -12
- fal/exceptions/auth.py +2 -4
- fal/exceptions/handlers.py +8 -19
- fal/logging/isolate.py +8 -19
- fal/logging/user.py +1 -1
- fal/sdk.py +3 -3
- fal/toolkit/__init__.py +0 -2
- fal/toolkit/exceptions.py +0 -5
- fal/toolkit/file/__init__.py +1 -1
- fal/toolkit/file/file.py +58 -55
- fal/toolkit/file/providers/fal.py +2 -6
- fal/toolkit/file/providers/gcp.py +0 -2
- fal/toolkit/file/providers/r2.py +0 -2
- fal/toolkit/file/types.py +0 -4
- fal/toolkit/image/__init__.py +1 -1
- fal/toolkit/image/image.py +11 -15
- fal/toolkit/optimize.py +1 -3
- fal/toolkit/utils/__init__.py +1 -1
- fal/toolkit/utils/download_utils.py +2 -15
- fal/workflows.py +3 -2
- {fal-0.13.0.dist-info → fal-0.15.0.dist-info}/METADATA +40 -38
- {fal-0.13.0.dist-info → fal-0.15.0.dist-info}/RECORD +50 -51
- {fal-0.13.0.dist-info → fal-0.15.0.dist-info}/WHEEL +2 -1
- fal-0.15.0.dist-info/entry_points.txt +2 -0
- fal-0.15.0.dist-info/top_level.txt +2 -0
- fal/env.py +0 -3
- fal/toolkit/mainify.py +0 -13
- fal-0.13.0.dist-info/entry_points.txt +0 -4
fal/__init__.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from fal import apps
|
|
3
|
+
from fal import apps # noqa: F401
|
|
4
4
|
from fal.api import FalServerlessHost, LocalHost, cached
|
|
5
5
|
from fal.api import function
|
|
6
|
-
from fal.api import function as isolated
|
|
7
|
-
from fal.app import App, endpoint, realtime, wrap_app
|
|
6
|
+
from fal.api import function as isolated # noqa: F401
|
|
7
|
+
from fal.app import App, endpoint, realtime, wrap_app # noqa: F401
|
|
8
8
|
from fal.sdk import FalServerlessKeyCredentials
|
|
9
9
|
from fal.sync import sync_dir
|
|
10
10
|
|
|
@@ -32,6 +32,6 @@ __all__ = [
|
|
|
32
32
|
# This will add to the package’s __path__ all subdirectories of directories on sys.path named after the package which effectively combines both modules into a single namespace (dbt.adapters)
|
|
33
33
|
# The matching statement is in plugins/postgres/dbt/adapters/__init__.py
|
|
34
34
|
|
|
35
|
-
from pkgutil import extend_path
|
|
35
|
+
from pkgutil import extend_path # noqa: E402
|
|
36
36
|
|
|
37
37
|
__path__ = extend_path(__path__, __name__)
|
fal/_serialization.py
CHANGED
|
@@ -1,102 +1,71 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
from pathlib import Path
|
|
3
|
+
from typing import Any, Callable
|
|
5
4
|
|
|
6
|
-
import
|
|
7
|
-
|
|
5
|
+
import pickle
|
|
6
|
+
import cloudpickle
|
|
8
7
|
|
|
9
|
-
from fal.toolkit import mainify
|
|
10
8
|
|
|
11
|
-
|
|
12
|
-
#
|
|
13
|
-
|
|
14
|
-
|
|
9
|
+
def _register_pickle_by_value(name) -> None:
|
|
10
|
+
# cloudpickle.register_pickle_by_value wants an imported module object,
|
|
11
|
+
# but there is really no reason to go through that complication, as
|
|
12
|
+
# it might be prone to errors.
|
|
13
|
+
cloudpickle.cloudpickle._PICKLE_BY_VALUE_MODULES.add(name)
|
|
15
14
|
|
|
16
15
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
from pydantic.fields import ModelField
|
|
16
|
+
def include_package_from_path(raw_path: str) -> None:
|
|
17
|
+
from pathlib import Path
|
|
20
18
|
|
|
21
|
-
return ModelField(**kwargs)
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
@mainify
|
|
25
|
-
def _pydantic_make_private_field(kwargs):
|
|
26
|
-
from pydantic.fields import ModelPrivateAttr
|
|
27
|
-
|
|
28
|
-
return ModelPrivateAttr(**kwargs)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
# this allows us to record all the "isolated" function and then mainify everything in
|
|
32
|
-
# module they exist
|
|
33
|
-
@wraps(_dill._locate_function)
|
|
34
|
-
def by_value_locator(obj, pickler=None, og_locator=_dill._locate_function):
|
|
35
|
-
module_name = getattr(obj, "__module__", None)
|
|
36
|
-
if module_name is not None:
|
|
37
|
-
# If it is coming from the same module, directly allow
|
|
38
|
-
# it to be pickled.
|
|
39
|
-
if module_name in _MODULES:
|
|
40
|
-
return False
|
|
41
|
-
|
|
42
|
-
package_name, *_ = module_name.partition(".")
|
|
43
|
-
# If it is coming from the same package, then do the same.
|
|
44
|
-
if package_name in _PACKAGES:
|
|
45
|
-
return False
|
|
46
|
-
|
|
47
|
-
og_result = og_locator(obj, pickler)
|
|
48
|
-
return og_result
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
_dill._locate_function = by_value_locator
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def include_packages_from_path(raw_path: str):
|
|
55
19
|
path = Path(raw_path).resolve()
|
|
56
20
|
parent = path
|
|
57
21
|
while (parent.parent / "__init__.py").exists():
|
|
58
22
|
parent = parent.parent
|
|
59
23
|
|
|
60
24
|
if parent != path:
|
|
61
|
-
|
|
25
|
+
_register_pickle_by_value(parent.name)
|
|
62
26
|
|
|
63
27
|
|
|
64
|
-
def
|
|
28
|
+
def include_modules_from(obj: Any) -> None:
|
|
65
29
|
module_name = getattr(obj, "__module__", None)
|
|
66
30
|
if not module_name:
|
|
67
|
-
return
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
if "." in module_name:
|
|
34
|
+
# Just include the whole package
|
|
35
|
+
package_name, *_ = module_name.partition(".")
|
|
36
|
+
_register_pickle_by_value(package_name)
|
|
37
|
+
return
|
|
68
38
|
|
|
69
|
-
_MODULES.add(module_name)
|
|
70
39
|
if module_name == "__main__":
|
|
71
40
|
# When the module is __main__, we need to recursively go up the
|
|
72
41
|
# tree to locate the actual package name.
|
|
73
42
|
import __main__
|
|
74
43
|
|
|
75
|
-
|
|
44
|
+
include_package_from_path(__main__.__file__)
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
_register_pickle_by_value(module_name)
|
|
76
48
|
|
|
77
|
-
if "." in module_name:
|
|
78
|
-
package_name, *_ = module_name.partition(".")
|
|
79
|
-
_PACKAGES.add(package_name)
|
|
80
49
|
|
|
50
|
+
def _register(cls: Any, func: Callable) -> None:
|
|
51
|
+
cloudpickle.Pickler.dispatch[cls] = func
|
|
81
52
|
|
|
82
|
-
|
|
83
|
-
def
|
|
53
|
+
|
|
54
|
+
def _patch_pydantic_field_serialization() -> None:
|
|
84
55
|
# Cythonized pydantic fields can't be serialized automatically, so we are
|
|
85
56
|
# have a special case handling for them that unpacks it to a dictionary
|
|
86
57
|
# and then reloads it on the other side.
|
|
87
|
-
|
|
88
|
-
|
|
58
|
+
# https://github.com/ray-project/ray/blob/842bbcf4236e41f58d25058b0482cd05bfe9e4da/python/ray/_private/pydantic_compat.py#L80
|
|
89
59
|
try:
|
|
90
|
-
|
|
60
|
+
from pydantic.fields import ModelField, ModelPrivateAttr
|
|
91
61
|
except ImportError:
|
|
92
62
|
return
|
|
93
63
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
args = {
|
|
64
|
+
def create_model_field(kwargs: dict) -> ModelField:
|
|
65
|
+
return ModelField(**kwargs)
|
|
66
|
+
|
|
67
|
+
def pickle_model_field(field: ModelField) -> tuple[Callable, tuple]:
|
|
68
|
+
kwargs = {
|
|
100
69
|
"name": field.name,
|
|
101
70
|
# outer_type_ is the original type for ModelFields,
|
|
102
71
|
# while type_ can be updated later with the nested type
|
|
@@ -110,65 +79,147 @@ def patch_pydantic_field_serialization():
|
|
|
110
79
|
"alias": field.alias,
|
|
111
80
|
"field_info": field.field_info,
|
|
112
81
|
}
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
def
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
) ->
|
|
120
|
-
|
|
82
|
+
|
|
83
|
+
return create_model_field, (kwargs,)
|
|
84
|
+
|
|
85
|
+
def create_private_attr(kwargs: dict) -> ModelPrivateAttr:
|
|
86
|
+
return ModelPrivateAttr(**kwargs)
|
|
87
|
+
|
|
88
|
+
def pickle_private_attr(field: ModelPrivateAttr) -> tuple[Callable, tuple]:
|
|
89
|
+
kwargs = {
|
|
121
90
|
"default": field.default,
|
|
122
91
|
"default_factory": field.default_factory,
|
|
123
92
|
}
|
|
124
|
-
pickler.save_reduce(_pydantic_make_private_field, (args,), obj=field)
|
|
125
93
|
|
|
94
|
+
return create_private_attr, (kwargs,)
|
|
95
|
+
|
|
96
|
+
_register(ModelField, pickle_model_field)
|
|
97
|
+
_register(ModelPrivateAttr, pickle_private_attr)
|
|
126
98
|
|
|
127
|
-
|
|
128
|
-
def
|
|
129
|
-
#
|
|
130
|
-
#
|
|
131
|
-
#
|
|
99
|
+
|
|
100
|
+
def _patch_pydantic_model_serialization() -> None:
|
|
101
|
+
# If user has created new pydantic models in his namespace, we will try to pickle those
|
|
102
|
+
# by value, which means recreating class skeleton, which will stumble upon
|
|
103
|
+
# __pydantic_parent_namespace__ in its __dict__ and it may contain modules that happened
|
|
104
|
+
# to be imported in the namespace but are not actually used, resulting in pickling errors.
|
|
105
|
+
# Unfortunately this also means that `model_rebuid()` might not work.
|
|
132
106
|
try:
|
|
133
|
-
import pydantic
|
|
107
|
+
import pydantic
|
|
134
108
|
except ImportError:
|
|
135
109
|
return
|
|
136
110
|
|
|
137
|
-
pydantic
|
|
111
|
+
# https://github.com/pydantic/pydantic/pull/2573
|
|
112
|
+
if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."):
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
backup = "_original_extract_class_dict"
|
|
116
|
+
if getattr(cloudpickle.cloudpickle, backup, None):
|
|
117
|
+
return
|
|
118
|
+
|
|
119
|
+
original = cloudpickle.cloudpickle._extract_class_dict
|
|
120
|
+
|
|
121
|
+
def patched(cls):
|
|
122
|
+
attr_name = "__pydantic_parent_namespace__"
|
|
123
|
+
if issubclass(cls, pydantic.BaseModel) and getattr(cls, attr_name, None):
|
|
124
|
+
setattr(cls, attr_name, None)
|
|
125
|
+
|
|
126
|
+
return original(cls)
|
|
127
|
+
|
|
128
|
+
cloudpickle.cloudpickle._extract_class_dict = patched
|
|
129
|
+
setattr(cloudpickle.cloudpickle, backup, original)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _patch_lru_cache() -> None:
|
|
133
|
+
# https://github.com/cloudpipe/cloudpickle/issues/178
|
|
134
|
+
# https://github.com/uqfoundation/dill/blob/70f569b0dd268d2b1e85c0f300951b11f53c5d53/dill/_dill.py#L1429
|
|
135
|
+
|
|
136
|
+
from functools import lru_cache, _lru_cache_wrapper as LRUCacheType
|
|
137
|
+
|
|
138
|
+
def create_lru_cache(func: Callable, kwargs: dict) -> LRUCacheType:
|
|
139
|
+
return lru_cache(**kwargs)(func)
|
|
140
|
+
|
|
141
|
+
def pickle_lru_cache(obj: LRUCacheType) -> tuple[Callable, tuple]:
|
|
142
|
+
if hasattr(obj, "cache_parameters"):
|
|
143
|
+
params = obj.cache_parameters()
|
|
144
|
+
kwargs = {
|
|
145
|
+
"maxsize": params["maxsize"],
|
|
146
|
+
"typed": params["typed"],
|
|
147
|
+
}
|
|
148
|
+
else:
|
|
149
|
+
kwargs = {"maxsize": obj.cache_info().maxsize}
|
|
150
|
+
|
|
151
|
+
return create_lru_cache, (obj.__wrapped__, kwargs)
|
|
152
|
+
|
|
153
|
+
_register(LRUCacheType, pickle_lru_cache)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _patch_lock() -> None:
|
|
157
|
+
# https://github.com/uqfoundation/dill/blob/70f569b0dd268d2b1e85c0f300951b11f53c5d53/dill/_dill.py#L1310
|
|
158
|
+
from threading import Lock
|
|
159
|
+
from _thread import LockType
|
|
160
|
+
|
|
161
|
+
def create_lock(locked: bool) -> Lock:
|
|
162
|
+
lock = Lock()
|
|
163
|
+
if locked and not lock.acquire(False):
|
|
164
|
+
raise pickle.UnpicklingError("Cannot acquire lock")
|
|
165
|
+
return lock
|
|
166
|
+
|
|
167
|
+
def pickle_lock(obj: LockType) -> tuple[Callable, tuple]:
|
|
168
|
+
return create_lock, (obj.locked(),)
|
|
169
|
+
|
|
170
|
+
_register(LockType, pickle_lock)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _patch_rlock() -> None:
|
|
174
|
+
# https://github.com/uqfoundation/dill/blob/70f569b0dd268d2b1e85c0f300951b11f53c5d53/dill/_dill.py#L1317
|
|
175
|
+
from _thread import RLock as RLockType # type: ignore[attr-defined]
|
|
176
|
+
|
|
177
|
+
def create_rlock(count: int, owner: int) -> RLockType:
|
|
178
|
+
lock = RLockType()
|
|
179
|
+
if owner is not None:
|
|
180
|
+
lock._acquire_restore((count, owner)) # type: ignore[attr-defined]
|
|
181
|
+
if owner and not lock._is_owned(): # type: ignore[attr-defined]
|
|
182
|
+
raise pickle.UnpicklingError("Cannot acquire lock")
|
|
183
|
+
return lock
|
|
184
|
+
|
|
185
|
+
def pickle_rlock(obj: RLockType) -> tuple[Callable, tuple]:
|
|
186
|
+
r = obj.__repr__()
|
|
187
|
+
count = int(r.split('count=')[1].split()[0].rstrip('>'))
|
|
188
|
+
owner = int(r.split('owner=')[1].split()[0])
|
|
189
|
+
|
|
190
|
+
return create_rlock, (count, owner)
|
|
191
|
+
|
|
192
|
+
_register(RLockType, pickle_rlock)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _patch_console_thread_locals() -> None:
|
|
196
|
+
from rich.console import ConsoleThreadLocals
|
|
138
197
|
|
|
198
|
+
def create_locals(kwargs: dict) -> ConsoleThreadLocals:
|
|
199
|
+
return ConsoleThreadLocals(**kwargs)
|
|
139
200
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
from types import TracebackType
|
|
201
|
+
def pickle_locals(obj: ConsoleThreadLocals) -> tuple[Callable, tuple]:
|
|
202
|
+
kwargs = {"theme_stack": obj.theme_stack, "buffer": obj.buffer, "buffer_index": obj.buffer_index}
|
|
203
|
+
return create_locals, (kwargs, )
|
|
144
204
|
|
|
145
|
-
|
|
146
|
-
from tblib.pickling_support import (
|
|
147
|
-
_get_subclasses,
|
|
148
|
-
pickle_exception,
|
|
149
|
-
pickle_traceback,
|
|
150
|
-
)
|
|
205
|
+
_register(ConsoleThreadLocals, pickle_locals)
|
|
151
206
|
|
|
152
|
-
@dill.register(TracebackType)
|
|
153
|
-
def save_traceback(pickler, obj):
|
|
154
|
-
unpickle, args = pickle_traceback(obj)
|
|
155
|
-
pickler.save_reduce(unpickle, args, obj=obj)
|
|
156
207
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
pickler.save_reduce(unpickle, args, obj=obj)
|
|
208
|
+
def _patch_exceptions() -> None:
|
|
209
|
+
# Support chained exceptions
|
|
210
|
+
from tblib.pickling_support import install
|
|
161
211
|
|
|
162
|
-
|
|
163
|
-
dill.pickle(exception_cls, save_exception)
|
|
212
|
+
install()
|
|
164
213
|
|
|
165
214
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
215
|
+
def patch_pickle() -> None:
|
|
216
|
+
_patch_pydantic_field_serialization()
|
|
217
|
+
_patch_pydantic_model_serialization()
|
|
218
|
+
_patch_lru_cache()
|
|
219
|
+
_patch_lock()
|
|
220
|
+
_patch_rlock()
|
|
221
|
+
_patch_console_thread_locals()
|
|
222
|
+
_patch_exceptions()
|
|
169
223
|
|
|
170
|
-
|
|
224
|
+
_register_pickle_by_value("fal")
|
|
171
225
|
|
|
172
|
-
patch_exceptions()
|
|
173
|
-
patch_pydantic_class_attributes()
|
|
174
|
-
patch_pydantic_field_serialization()
|
fal/api.py
CHANGED
|
@@ -2,11 +2,12 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import sys
|
|
5
|
+
import threading
|
|
5
6
|
from collections import defaultdict
|
|
6
7
|
from concurrent.futures import ThreadPoolExecutor
|
|
7
8
|
from contextlib import asynccontextmanager, suppress
|
|
8
9
|
from dataclasses import dataclass, field, replace
|
|
9
|
-
from functools import
|
|
10
|
+
from functools import wraps
|
|
10
11
|
from os import PathLike
|
|
11
12
|
from typing import (
|
|
12
13
|
Any,
|
|
@@ -22,21 +23,25 @@ from typing import (
|
|
|
22
23
|
overload,
|
|
23
24
|
)
|
|
24
25
|
|
|
25
|
-
import
|
|
26
|
-
import dill.detect
|
|
26
|
+
import cloudpickle
|
|
27
27
|
import grpc
|
|
28
28
|
import isolate
|
|
29
|
+
import tblib
|
|
30
|
+
import uvicorn
|
|
29
31
|
import yaml
|
|
30
32
|
from fastapi import FastAPI
|
|
33
|
+
from fastapi import __version__ as fastapi_version
|
|
31
34
|
from isolate.backends.common import active_python
|
|
32
35
|
from isolate.backends.settings import DEFAULT_SETTINGS
|
|
33
36
|
from isolate.connections import PythonIPC
|
|
34
37
|
from packaging.requirements import Requirement
|
|
35
38
|
from packaging.utils import canonicalize_name
|
|
39
|
+
from pydantic import __version__ as pydantic_version
|
|
36
40
|
from typing_extensions import Concatenate, ParamSpec
|
|
37
41
|
|
|
38
42
|
import fal.flags as flags
|
|
39
|
-
from fal.
|
|
43
|
+
from fal.exceptions import FalServerlessException
|
|
44
|
+
from fal._serialization import include_modules_from, patch_pickle
|
|
40
45
|
from fal.logging.isolate import IsolateLogPrinter
|
|
41
46
|
from fal.sdk import (
|
|
42
47
|
FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
|
|
@@ -50,7 +55,6 @@ from fal.sdk import (
|
|
|
50
55
|
get_agent_credentials,
|
|
51
56
|
get_default_credentials,
|
|
52
57
|
)
|
|
53
|
-
from fal.toolkit import mainify
|
|
54
58
|
|
|
55
59
|
ArgsT = ParamSpec("ArgsT")
|
|
56
60
|
ReturnT = TypeVar("ReturnT", covariant=True)
|
|
@@ -58,16 +62,21 @@ ReturnT = TypeVar("ReturnT", covariant=True)
|
|
|
58
62
|
BasicConfig = Dict[str, Any]
|
|
59
63
|
_UNSET = object()
|
|
60
64
|
|
|
61
|
-
SERVE_REQUIREMENTS = [
|
|
65
|
+
SERVE_REQUIREMENTS = [
|
|
66
|
+
f"fastapi=={fastapi_version}",
|
|
67
|
+
f"pydantic=={pydantic_version}",
|
|
68
|
+
"uvicorn",
|
|
69
|
+
"starlette_exporter",
|
|
70
|
+
]
|
|
62
71
|
|
|
63
72
|
|
|
64
73
|
@dataclass
|
|
65
|
-
class FalServerlessError(
|
|
74
|
+
class FalServerlessError(FalServerlessException):
|
|
66
75
|
message: str
|
|
67
76
|
|
|
68
77
|
|
|
69
78
|
@dataclass
|
|
70
|
-
class InternalFalServerlessError(
|
|
79
|
+
class InternalFalServerlessError(FalServerlessException):
|
|
71
80
|
message: str
|
|
72
81
|
|
|
73
82
|
|
|
@@ -175,32 +184,23 @@ def cached(func: Callable[ArgsT, ReturnT]) -> Callable[ArgsT, ReturnT]:
|
|
|
175
184
|
return wrapper
|
|
176
185
|
|
|
177
186
|
|
|
178
|
-
|
|
179
|
-
class UserFunctionException(Exception):
|
|
187
|
+
class UserFunctionException(FalServerlessException):
|
|
180
188
|
pass
|
|
181
189
|
|
|
182
190
|
|
|
183
|
-
def match_class(obj, cls):
|
|
184
|
-
# NOTE: Can't use isinstance because we are not using dill's byref setting when
|
|
185
|
-
# loading/dumping objects in RPC, which means that our exceptions from remote
|
|
186
|
-
# server are created by value and are actually a separate class that only looks
|
|
187
|
-
# like original one.
|
|
188
|
-
#
|
|
189
|
-
# See https://github.com/fal-ai/fal/issues/142
|
|
190
|
-
return type(obj).__name__ == cls.__name__
|
|
191
|
-
|
|
192
|
-
|
|
193
191
|
def _prepare_partial_func(
|
|
194
192
|
func: Callable[ArgsT, ReturnT],
|
|
195
193
|
*args: ArgsT.args,
|
|
196
194
|
**kwargs: ArgsT.kwargs,
|
|
197
195
|
) -> Callable[ArgsT, ReturnT]:
|
|
198
|
-
"""Prepare the given function for execution on
|
|
196
|
+
"""Prepare the given function for execution on isolate workers."""
|
|
199
197
|
|
|
200
198
|
@wraps(func)
|
|
201
199
|
def wrapper(*remote_args: ArgsT.args, **remote_kwargs: ArgsT.kwargs) -> ReturnT:
|
|
202
200
|
try:
|
|
203
201
|
result = func(*remote_args, *args, **remote_kwargs, **kwargs)
|
|
202
|
+
except FalServerlessException:
|
|
203
|
+
raise
|
|
204
204
|
except Exception as exc:
|
|
205
205
|
tb = exc.__traceback__
|
|
206
206
|
if tb is not None and tb.tb_next is not None:
|
|
@@ -211,7 +211,7 @@ def _prepare_partial_func(
|
|
|
211
211
|
) from exc.with_traceback(tb)
|
|
212
212
|
finally:
|
|
213
213
|
with suppress(Exception):
|
|
214
|
-
|
|
214
|
+
patch_pickle()
|
|
215
215
|
return result
|
|
216
216
|
|
|
217
217
|
return wrapper
|
|
@@ -223,8 +223,12 @@ class LocalHost(Host):
|
|
|
223
223
|
# packages for isolate agent to run.
|
|
224
224
|
_AGENT_ENVIRONMENT = isolate.prepare_environment(
|
|
225
225
|
"virtualenv",
|
|
226
|
-
requirements=[
|
|
226
|
+
requirements=[
|
|
227
|
+
f"cloudpickle=={cloudpickle.__version__}",
|
|
228
|
+
f"tblib=={tblib.__version__}",
|
|
229
|
+
],
|
|
227
230
|
)
|
|
231
|
+
_log_printer = IsolateLogPrinter(debug=flags.DEBUG)
|
|
228
232
|
|
|
229
233
|
def run(
|
|
230
234
|
self,
|
|
@@ -233,7 +237,11 @@ class LocalHost(Host):
|
|
|
233
237
|
args: tuple[Any, ...],
|
|
234
238
|
kwargs: dict[str, Any],
|
|
235
239
|
) -> ReturnT:
|
|
236
|
-
settings = replace(
|
|
240
|
+
settings = replace(
|
|
241
|
+
DEFAULT_SETTINGS,
|
|
242
|
+
serialization_method="cloudpickle",
|
|
243
|
+
log_hook=self._log_printer.print,
|
|
244
|
+
)
|
|
237
245
|
environment = isolate.prepare_environment(
|
|
238
246
|
**options.environment,
|
|
239
247
|
context=settings,
|
|
@@ -243,7 +251,7 @@ class LocalHost(Host):
|
|
|
243
251
|
environment.create(),
|
|
244
252
|
extra_inheritance_paths=[self._AGENT_ENVIRONMENT.create()],
|
|
245
253
|
) as connection:
|
|
246
|
-
executable =
|
|
254
|
+
executable = _prepare_partial_func(func, *args, **kwargs)
|
|
247
255
|
return connection.run(executable)
|
|
248
256
|
|
|
249
257
|
|
|
@@ -251,9 +259,6 @@ FAL_SERVERLESS_DEFAULT_URL = flags.GRPC_HOST
|
|
|
251
259
|
FAL_SERVERLESS_DEFAULT_MACHINE_TYPE = "XS"
|
|
252
260
|
|
|
253
261
|
|
|
254
|
-
import threading
|
|
255
|
-
|
|
256
|
-
|
|
257
262
|
def _handle_grpc_error():
|
|
258
263
|
def decorator(fn):
|
|
259
264
|
@wraps(fn)
|
|
@@ -292,6 +297,8 @@ def _handle_grpc_error():
|
|
|
292
297
|
def find_missing_dependencies(
|
|
293
298
|
func: Callable, env: dict
|
|
294
299
|
) -> Iterator[tuple[str, list[str]]]:
|
|
300
|
+
import dill
|
|
301
|
+
|
|
295
302
|
if env["kind"] != "virtualenv":
|
|
296
303
|
return
|
|
297
304
|
|
|
@@ -433,6 +440,8 @@ class FalServerlessHost(Host):
|
|
|
433
440
|
if partial_result.result:
|
|
434
441
|
return partial_result.result.application_id
|
|
435
442
|
|
|
443
|
+
return None
|
|
444
|
+
|
|
436
445
|
@_handle_grpc_error()
|
|
437
446
|
def run(
|
|
438
447
|
self,
|
|
@@ -747,7 +756,7 @@ def function( # type: ignore
|
|
|
747
756
|
options = host.parse_options(kind=kind, **config)
|
|
748
757
|
|
|
749
758
|
def wrapper(func: Callable[ArgsT, ReturnT]):
|
|
750
|
-
|
|
759
|
+
include_modules_from(func)
|
|
751
760
|
proxy = IsolatedFunction(
|
|
752
761
|
host=host,
|
|
753
762
|
raw_func=func, # type: ignore
|
|
@@ -758,7 +767,6 @@ def function( # type: ignore
|
|
|
758
767
|
return wrapper
|
|
759
768
|
|
|
760
769
|
|
|
761
|
-
@mainify
|
|
762
770
|
class FalFastAPI(FastAPI):
|
|
763
771
|
"""
|
|
764
772
|
A subclass of FastAPI that adds some fal-specific functionality.
|
|
@@ -802,7 +810,6 @@ class FalFastAPI(FastAPI):
|
|
|
802
810
|
return spec
|
|
803
811
|
|
|
804
812
|
|
|
805
|
-
@mainify
|
|
806
813
|
class RouteSignature(NamedTuple):
|
|
807
814
|
path: str
|
|
808
815
|
is_websocket: bool = False
|
|
@@ -813,7 +820,6 @@ class RouteSignature(NamedTuple):
|
|
|
813
820
|
emit_timings: bool = False
|
|
814
821
|
|
|
815
822
|
|
|
816
|
-
@mainify
|
|
817
823
|
class BaseServable:
|
|
818
824
|
def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
|
|
819
825
|
raise NotImplementedError
|
|
@@ -832,6 +838,7 @@ class BaseServable:
|
|
|
832
838
|
from fastapi import HTTPException, Request
|
|
833
839
|
from fastapi.middleware.cors import CORSMiddleware
|
|
834
840
|
from fastapi.responses import JSONResponse
|
|
841
|
+
from starlette_exporter import PrometheusMiddleware
|
|
835
842
|
|
|
836
843
|
_app = FalFastAPI(lifespan=self.lifespan)
|
|
837
844
|
|
|
@@ -842,6 +849,13 @@ class BaseServable:
|
|
|
842
849
|
allow_methods=("*"),
|
|
843
850
|
allow_origins=("*"),
|
|
844
851
|
)
|
|
852
|
+
_app.add_middleware(
|
|
853
|
+
PrometheusMiddleware,
|
|
854
|
+
prefix="http",
|
|
855
|
+
group_paths=True,
|
|
856
|
+
filter_unhandled_paths=True,
|
|
857
|
+
app_name="fal",
|
|
858
|
+
)
|
|
845
859
|
|
|
846
860
|
self._add_extra_middlewares(_app)
|
|
847
861
|
|
|
@@ -887,13 +901,43 @@ class BaseServable:
|
|
|
887
901
|
return self._build_app().openapi()
|
|
888
902
|
|
|
889
903
|
def serve(self) -> None:
|
|
890
|
-
import
|
|
904
|
+
import asyncio
|
|
905
|
+
|
|
906
|
+
from starlette_exporter import handle_metrics
|
|
907
|
+
from uvicorn import Config
|
|
891
908
|
|
|
892
909
|
app = self._build_app()
|
|
893
|
-
|
|
910
|
+
server = Server(config=Config(app, host="0.0.0.0", port=8080))
|
|
911
|
+
metrics_app = FastAPI()
|
|
912
|
+
metrics_app.add_route("/metrics", handle_metrics)
|
|
913
|
+
metrics_server = Server(config=Config(metrics_app, host="0.0.0.0", port=9090))
|
|
914
|
+
|
|
915
|
+
async def _serve() -> None:
|
|
916
|
+
tasks = {
|
|
917
|
+
asyncio.create_task(server.serve()): server,
|
|
918
|
+
asyncio.create_task(metrics_server.serve()): metrics_server,
|
|
919
|
+
}
|
|
920
|
+
|
|
921
|
+
_, pending = await asyncio.wait(tasks.keys(), return_when=asyncio.FIRST_COMPLETED)
|
|
922
|
+
if not pending:
|
|
923
|
+
return
|
|
924
|
+
|
|
925
|
+
# try graceful shutdown
|
|
926
|
+
for task in pending:
|
|
927
|
+
tasks[task].should_exit = True
|
|
928
|
+
_, pending = await asyncio.wait(pending, timeout=2)
|
|
929
|
+
if not pending:
|
|
930
|
+
return
|
|
931
|
+
|
|
932
|
+
for task in pending:
|
|
933
|
+
task.cancel()
|
|
934
|
+
await asyncio.wait(pending)
|
|
935
|
+
|
|
936
|
+
with suppress(asyncio.CancelledError):
|
|
937
|
+
asyncio.set_event_loop(asyncio.new_event_loop())
|
|
938
|
+
asyncio.run(_serve())
|
|
894
939
|
|
|
895
940
|
|
|
896
|
-
@mainify
|
|
897
941
|
class ServeWrapper(BaseServable):
|
|
898
942
|
_func: Callable
|
|
899
943
|
|
|
@@ -974,7 +1018,7 @@ class IsolatedFunction(Generic[ArgsT, ReturnT]):
|
|
|
974
1018
|
) from None
|
|
975
1019
|
except Exception as exc:
|
|
976
1020
|
cause = exc.__cause__
|
|
977
|
-
if self.reraise and
|
|
1021
|
+
if self.reraise and isinstance(exc, UserFunctionException) and cause:
|
|
978
1022
|
# re-raise original exception without our wrappers
|
|
979
1023
|
raise cause
|
|
980
1024
|
raise
|
|
@@ -1050,3 +1094,14 @@ class ServedIsolatedFunction(
|
|
|
1050
1094
|
self, host: Host | None = None, *, serve: Literal[False], **config: Any
|
|
1051
1095
|
) -> IsolatedFunction[ArgsT, ReturnT]:
|
|
1052
1096
|
...
|
|
1097
|
+
|
|
1098
|
+
|
|
1099
|
+
class Server(uvicorn.Server):
|
|
1100
|
+
"""Server is a uvicorn.Server that actually plays nicely with signals.
|
|
1101
|
+
By default, uvicorn's Server class overwrites the signal handler for SIGINT, swallowing the signal and preventing other tasks from cancelling.
|
|
1102
|
+
This class allows the task to be gracefully cancelled using asyncio's built-in task cancellation or with an event, like aiohttp.
|
|
1103
|
+
"""
|
|
1104
|
+
|
|
1105
|
+
def install_signal_handlers(self) -> None:
|
|
1106
|
+
pass
|
|
1107
|
+
|