fal 0.14.0__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fal might be problematic. Click here for more details.
- fal/_serialization.py +149 -125
- fal/api.py +88 -52
- fal/app.py +2 -7
- fal/auth/__init__.py +0 -2
- fal/cli.py +2 -2
- fal/exceptions/__init__.py +1 -2
- fal/exceptions/_base.py +1 -12
- fal/exceptions/auth.py +2 -4
- fal/exceptions/handlers.py +8 -19
- fal/sdk.py +2 -3
- fal/toolkit/__init__.py +0 -2
- fal/toolkit/exceptions.py +0 -5
- fal/toolkit/file/file.py +57 -54
- fal/toolkit/file/providers/fal.py +0 -4
- fal/toolkit/file/providers/gcp.py +0 -2
- fal/toolkit/file/providers/r2.py +0 -2
- fal/toolkit/file/types.py +0 -4
- fal/toolkit/image/image.py +10 -14
- fal/toolkit/optimize.py +0 -2
- fal/toolkit/utils/download_utils.py +1 -14
- fal/workflows.py +2 -1
- {fal-0.14.0.dist-info → fal-0.15.0.dist-info}/METADATA +40 -38
- {fal-0.14.0.dist-info → fal-0.15.0.dist-info}/RECORD +50 -51
- {fal-0.14.0.dist-info → fal-0.15.0.dist-info}/WHEEL +2 -1
- fal-0.15.0.dist-info/entry_points.txt +2 -0
- fal-0.15.0.dist-info/top_level.txt +2 -0
- fal/env.py +0 -3
- fal/toolkit/mainify.py +0 -13
- fal-0.14.0.dist-info/entry_points.txt +0 -4
fal/app.py
CHANGED
|
@@ -10,11 +10,9 @@ from typing import Any, Callable, ClassVar, TypeVar
|
|
|
10
10
|
from fastapi import FastAPI
|
|
11
11
|
|
|
12
12
|
import fal.api
|
|
13
|
-
from fal._serialization import add_serialization_listeners_for
|
|
14
13
|
from fal.api import RouteSignature
|
|
15
14
|
from fal.logging import get_logger
|
|
16
|
-
from fal.
|
|
17
|
-
|
|
15
|
+
from fal._serialization import include_modules_from
|
|
18
16
|
REALTIME_APP_REQUIREMENTS = ["websockets", "msgpack"]
|
|
19
17
|
|
|
20
18
|
EndpointT = TypeVar("EndpointT", bound=Callable[..., Any])
|
|
@@ -29,7 +27,7 @@ async def _call_any_fn(fn, *args, **kwargs):
|
|
|
29
27
|
|
|
30
28
|
|
|
31
29
|
def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
|
|
32
|
-
|
|
30
|
+
include_modules_from(cls)
|
|
33
31
|
|
|
34
32
|
def initialize_and_serve():
|
|
35
33
|
app = cls()
|
|
@@ -64,7 +62,6 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
|
|
|
64
62
|
return fn
|
|
65
63
|
|
|
66
64
|
|
|
67
|
-
@mainify
|
|
68
65
|
class App(fal.api.BaseServable):
|
|
69
66
|
requirements: ClassVar[list[str]] = []
|
|
70
67
|
machine_type: ClassVar[str] = "S"
|
|
@@ -131,7 +128,6 @@ class App(fal.api.BaseServable):
|
|
|
131
128
|
raise NotImplementedError
|
|
132
129
|
|
|
133
130
|
|
|
134
|
-
@mainify
|
|
135
131
|
def endpoint(
|
|
136
132
|
path: str, *, is_websocket: bool = False
|
|
137
133
|
) -> Callable[[EndpointT], EndpointT]:
|
|
@@ -343,7 +339,6 @@ def _fal_websocket_template(
|
|
|
343
339
|
_SENTINEL = object()
|
|
344
340
|
|
|
345
341
|
|
|
346
|
-
@mainify
|
|
347
342
|
def realtime(
|
|
348
343
|
path: str,
|
|
349
344
|
*,
|
fal/auth/__init__.py
CHANGED
|
@@ -9,10 +9,8 @@ from fal.auth import auth0, local
|
|
|
9
9
|
from fal.console import console
|
|
10
10
|
from fal.console.icons import CHECK_ICON
|
|
11
11
|
from fal.exceptions.auth import UnauthenticatedException
|
|
12
|
-
from fal.toolkit.mainify import mainify
|
|
13
12
|
|
|
14
13
|
|
|
15
|
-
@mainify
|
|
16
14
|
def key_credentials() -> tuple[str, str] | None:
|
|
17
15
|
# Ignore key credentials when the user forces auth by user.
|
|
18
16
|
if os.environ.get("FAL_FORCE_AUTH_BY_USER") == "1":
|
fal/cli.py
CHANGED
|
@@ -267,8 +267,8 @@ def load_function_from(
|
|
|
267
267
|
raise api.FalServerlessError(f"Function '{function_name}' not found in module")
|
|
268
268
|
|
|
269
269
|
# The module for the function is set to <run_path> when runpy is used, in which
|
|
270
|
-
# case we want to manually include the
|
|
271
|
-
_serialization.
|
|
270
|
+
# case we want to manually include the package it is defined in.
|
|
271
|
+
_serialization.include_package_from_path(file_path)
|
|
272
272
|
|
|
273
273
|
target = module[function_name]
|
|
274
274
|
if isinstance(target, type) and issubclass(target, fal.App):
|
fal/exceptions/__init__.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from ._base import FalServerlessException # noqa: F401
|
|
3
4
|
from .handlers import (
|
|
4
5
|
BaseExceptionHandler,
|
|
5
|
-
FalServerlessExceptionHandler,
|
|
6
6
|
GrpcExceptionHandler,
|
|
7
7
|
UserFunctionExceptionHandler,
|
|
8
8
|
)
|
|
@@ -20,7 +20,6 @@ class ApplicationExceptionHandler:
|
|
|
20
20
|
|
|
21
21
|
_handlers: list[BaseExceptionHandler] = [
|
|
22
22
|
GrpcExceptionHandler(),
|
|
23
|
-
FalServerlessExceptionHandler(),
|
|
24
23
|
UserFunctionExceptionHandler(),
|
|
25
24
|
]
|
|
26
25
|
|
fal/exceptions/_base.py
CHANGED
|
@@ -3,15 +3,4 @@ from __future__ import annotations
|
|
|
3
3
|
|
|
4
4
|
class FalServerlessException(Exception):
|
|
5
5
|
"""Base exception type for fal Serverless related flows and APIs."""
|
|
6
|
-
|
|
7
|
-
message: str
|
|
8
|
-
|
|
9
|
-
hint: str | None
|
|
10
|
-
|
|
11
|
-
def __init__(self, message: str, hint: str | None = None) -> None:
|
|
12
|
-
self.message = message
|
|
13
|
-
self.hint = hint
|
|
14
|
-
super().__init__(message)
|
|
15
|
-
|
|
16
|
-
def __str__(self) -> str:
|
|
17
|
-
return self.message + (f"\nHint: {self.hint}" if self.hint else "")
|
|
6
|
+
pass
|
fal/exceptions/auth.py
CHANGED
|
@@ -4,10 +4,8 @@ from ._base import FalServerlessException
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class UnauthenticatedException(FalServerlessException):
|
|
7
|
-
"""Exception that indicates that"""
|
|
8
|
-
|
|
9
7
|
def __init__(self) -> None:
|
|
10
8
|
super().__init__(
|
|
11
|
-
|
|
12
|
-
|
|
9
|
+
"You must be authenticated. "
|
|
10
|
+
"Login via `fal auth login` or make sure to setup fal keys correctly."
|
|
13
11
|
)
|
fal/exceptions/handlers.py
CHANGED
|
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
|
3
3
|
from typing import TYPE_CHECKING, Generic, TypeVar
|
|
4
4
|
|
|
5
5
|
from grpc import Call as RpcCall
|
|
6
|
-
from rich.markdown import Markdown
|
|
7
6
|
|
|
8
7
|
from fal.console import console
|
|
9
8
|
from fal.console.icons import CROSS_ICON
|
|
@@ -11,9 +10,8 @@ from fal.console.icons import CROSS_ICON
|
|
|
11
10
|
if TYPE_CHECKING:
|
|
12
11
|
from fal.api import UserFunctionException
|
|
13
12
|
|
|
14
|
-
from ._base import FalServerlessException
|
|
15
13
|
|
|
16
|
-
ExceptionType = TypeVar("ExceptionType")
|
|
14
|
+
ExceptionType = TypeVar("ExceptionType", bound=BaseException)
|
|
17
15
|
|
|
18
16
|
|
|
19
17
|
class BaseExceptionHandler(Generic[ExceptionType]):
|
|
@@ -23,20 +21,11 @@ class BaseExceptionHandler(Generic[ExceptionType]):
|
|
|
23
21
|
return True
|
|
24
22
|
|
|
25
23
|
def handle(self, exception: ExceptionType):
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def should_handle(self, exception: Exception) -> bool:
|
|
33
|
-
return isinstance(exception, FalServerlessException)
|
|
34
|
-
|
|
35
|
-
def handle(self, exception: FalServerlessException):
|
|
36
|
-
console.print(f"{CROSS_ICON} {exception.message}")
|
|
37
|
-
if exception.hint is not None:
|
|
38
|
-
console.print(Markdown(f"**Hint:** {exception.hint}"))
|
|
39
|
-
console.print()
|
|
24
|
+
msg = f"{CROSS_ICON} {str(exception)}"
|
|
25
|
+
cause = exception.__cause__
|
|
26
|
+
if cause is not None:
|
|
27
|
+
msg += f": {str(cause)}"
|
|
28
|
+
console.print(msg)
|
|
40
29
|
|
|
41
30
|
|
|
42
31
|
class GrpcExceptionHandler(BaseExceptionHandler[RpcCall]):
|
|
@@ -51,9 +40,9 @@ class GrpcExceptionHandler(BaseExceptionHandler[RpcCall]):
|
|
|
51
40
|
|
|
52
41
|
class UserFunctionExceptionHandler(BaseExceptionHandler["UserFunctionException"]):
|
|
53
42
|
def should_handle(self, exception: Exception) -> bool:
|
|
54
|
-
from fal.api import UserFunctionException
|
|
43
|
+
from fal.api import UserFunctionException
|
|
55
44
|
|
|
56
|
-
return
|
|
45
|
+
return isinstance(exception, UserFunctionException)
|
|
57
46
|
|
|
58
47
|
def handle(self, exception: UserFunctionException):
|
|
59
48
|
import rich
|
fal/sdk.py
CHANGED
|
@@ -14,7 +14,7 @@ from isolate.server.interface import from_grpc, to_serialized_object, to_struct
|
|
|
14
14
|
|
|
15
15
|
import isolate_proto
|
|
16
16
|
from fal import flags
|
|
17
|
-
from fal._serialization import
|
|
17
|
+
from fal._serialization import patch_pickle
|
|
18
18
|
from fal.auth import USER, key_credentials
|
|
19
19
|
from fal.logging import get_logger
|
|
20
20
|
from fal.logging.trace import TraceContextInterceptor
|
|
@@ -24,14 +24,13 @@ ResultT = TypeVar("ResultT")
|
|
|
24
24
|
InputT = TypeVar("InputT")
|
|
25
25
|
UNSET = object()
|
|
26
26
|
|
|
27
|
-
_DEFAULT_SERIALIZATION_METHOD = "
|
|
27
|
+
_DEFAULT_SERIALIZATION_METHOD = "cloudpickle"
|
|
28
28
|
FAL_SERVERLESS_DEFAULT_KEEP_ALIVE = 10
|
|
29
29
|
FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING = 1
|
|
30
30
|
FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY = 0
|
|
31
31
|
|
|
32
32
|
logger = get_logger(__name__)
|
|
33
33
|
|
|
34
|
-
patch_dill()
|
|
35
34
|
patch_pickle()
|
|
36
35
|
|
|
37
36
|
|
fal/toolkit/__init__.py
CHANGED
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from fal.toolkit.file import CompressedFile, File
|
|
4
4
|
from fal.toolkit.image.image import Image, ImageSizeInput, get_image_size
|
|
5
|
-
from fal.toolkit.mainify import mainify
|
|
6
5
|
from fal.toolkit.optimize import optimize
|
|
7
6
|
from fal.toolkit.utils import (
|
|
8
7
|
FAL_MODEL_WEIGHTS_DIR,
|
|
@@ -19,7 +18,6 @@ __all__ = [
|
|
|
19
18
|
"Image",
|
|
20
19
|
"ImageSizeInput",
|
|
21
20
|
"get_image_size",
|
|
22
|
-
"mainify",
|
|
23
21
|
"optimize",
|
|
24
22
|
"FAL_MODEL_WEIGHTS_DIR",
|
|
25
23
|
"FAL_PERSISTENT_DIR",
|
fal/toolkit/exceptions.py
CHANGED
|
@@ -1,14 +1,9 @@
|
|
|
1
|
-
from fal.toolkit.mainify import mainify
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
@mainify
|
|
5
1
|
class FalTookitException(Exception):
|
|
6
2
|
"""Base exception for all toolkit exceptions"""
|
|
7
3
|
|
|
8
4
|
pass
|
|
9
5
|
|
|
10
6
|
|
|
11
|
-
@mainify
|
|
12
7
|
class FileUploadException(FalTookitException):
|
|
13
8
|
"""Raised when file upload fails"""
|
|
14
9
|
|
fal/toolkit/file/file.py
CHANGED
|
@@ -1,13 +1,23 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import shutil
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from tempfile import NamedTemporaryFile,
|
|
5
|
-
from typing import
|
|
5
|
+
from tempfile import NamedTemporaryFile, mkdtemp
|
|
6
|
+
from typing import Callable, Optional, Any
|
|
6
7
|
from urllib.parse import urlparse
|
|
7
8
|
from zipfile import ZipFile
|
|
8
9
|
|
|
9
|
-
|
|
10
|
-
|
|
10
|
+
import pydantic
|
|
11
|
+
|
|
12
|
+
# https://github.com/pydantic/pydantic/pull/2573
|
|
13
|
+
if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."):
|
|
14
|
+
IS_PYDANTIC_V2 = False
|
|
15
|
+
else:
|
|
16
|
+
from pydantic_core import core_schema, CoreSchema
|
|
17
|
+
from pydantic import GetCoreSchemaHandler
|
|
18
|
+
IS_PYDANTIC_V2 = True
|
|
19
|
+
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
11
21
|
|
|
12
22
|
from fal.toolkit.file.providers.fal import (
|
|
13
23
|
FalCDNFileRepository,
|
|
@@ -17,7 +27,6 @@ from fal.toolkit.file.providers.fal import (
|
|
|
17
27
|
from fal.toolkit.file.providers.gcp import GoogleStorageRepository
|
|
18
28
|
from fal.toolkit.file.providers.r2 import R2Repository
|
|
19
29
|
from fal.toolkit.file.types import FileData, FileRepository, RepositoryId
|
|
20
|
-
from fal.toolkit.mainify import mainify
|
|
21
30
|
from fal.toolkit.utils.download_utils import download_file
|
|
22
31
|
|
|
23
32
|
FileRepositoryFactory = Callable[[], FileRepository]
|
|
@@ -42,52 +51,40 @@ get_builtin_repository.__module__ = "__main__"
|
|
|
42
51
|
DEFAULT_REPOSITORY: FileRepository | RepositoryId = "fal"
|
|
43
52
|
|
|
44
53
|
|
|
45
|
-
@mainify
|
|
46
54
|
class File(BaseModel):
|
|
47
55
|
# public properties
|
|
48
|
-
_file_data: FileData = PrivateAttr()
|
|
49
56
|
url: str = Field(
|
|
50
57
|
description="The URL where the file can be downloaded from.",
|
|
51
58
|
)
|
|
52
59
|
content_type: Optional[str] = Field(
|
|
53
|
-
description="The mime type of the file.",
|
|
60
|
+
None, description="The mime type of the file.",
|
|
54
61
|
examples=["image/png"],
|
|
55
62
|
)
|
|
56
63
|
file_name: Optional[str] = Field(
|
|
57
|
-
description="The name of the file. It will be auto-generated if not provided.",
|
|
64
|
+
None, description="The name of the file. It will be auto-generated if not provided.",
|
|
58
65
|
examples=["z9RV14K95DvU.png"],
|
|
59
66
|
)
|
|
60
67
|
file_size: Optional[int] = Field(
|
|
61
|
-
description="The size of the file in bytes.", examples=[4404019]
|
|
68
|
+
None, description="The size of the file in bytes.", examples=[4404019]
|
|
69
|
+
)
|
|
70
|
+
file_data: Optional[bytes] = Field(
|
|
71
|
+
None, description="File data", exclude=True, repr=False,
|
|
62
72
|
)
|
|
63
|
-
|
|
64
|
-
def __init__(self, **kwargs):
|
|
65
|
-
if "file_data" in kwargs:
|
|
66
|
-
data: FileData = kwargs.pop("file_data")
|
|
67
|
-
repository = kwargs.pop("repository", None)
|
|
68
|
-
|
|
69
|
-
repo = (
|
|
70
|
-
repository
|
|
71
|
-
if isinstance(repository, FileRepository)
|
|
72
|
-
else get_builtin_repository(repository)
|
|
73
|
-
)
|
|
74
|
-
self._file_data = data
|
|
75
|
-
|
|
76
|
-
kwargs.update(
|
|
77
|
-
{
|
|
78
|
-
"url": repo.save(data),
|
|
79
|
-
"content_type": data.content_type,
|
|
80
|
-
"file_name": data.file_name,
|
|
81
|
-
"file_size": len(data.data),
|
|
82
|
-
}
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
super().__init__(**kwargs)
|
|
86
73
|
|
|
87
74
|
# Pydantic custom validator for input type conversion
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
75
|
+
if IS_PYDANTIC_V2:
|
|
76
|
+
@classmethod
|
|
77
|
+
def __get_pydantic_core_schema__(
|
|
78
|
+
cls, source_type: Any, handler: GetCoreSchemaHandler
|
|
79
|
+
) -> CoreSchema:
|
|
80
|
+
return core_schema.no_info_before_validator_function(
|
|
81
|
+
cls.__convert_from_str,
|
|
82
|
+
handler(source_type),
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
@classmethod
|
|
86
|
+
def __get_validators__(cls):
|
|
87
|
+
yield cls.__convert_from_str
|
|
91
88
|
|
|
92
89
|
@classmethod
|
|
93
90
|
def __convert_from_str(cls, value: Any):
|
|
@@ -119,9 +116,20 @@ class File(BaseModel):
|
|
|
119
116
|
file_name: Optional[str] = None,
|
|
120
117
|
repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
|
|
121
118
|
) -> File:
|
|
119
|
+
repo = (
|
|
120
|
+
repository
|
|
121
|
+
if isinstance(repository, FileRepository)
|
|
122
|
+
else get_builtin_repository(repository)
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
fdata = FileData(data, content_type, file_name)
|
|
126
|
+
|
|
122
127
|
return cls(
|
|
123
|
-
|
|
124
|
-
|
|
128
|
+
url=repo.save(fdata),
|
|
129
|
+
content_type=fdata.content_type,
|
|
130
|
+
file_name=fdata.file_name,
|
|
131
|
+
file_size=len(data),
|
|
132
|
+
file_data=data,
|
|
125
133
|
)
|
|
126
134
|
|
|
127
135
|
@classmethod
|
|
@@ -141,10 +149,10 @@ class File(BaseModel):
|
|
|
141
149
|
)
|
|
142
150
|
|
|
143
151
|
def as_bytes(self) -> bytes:
|
|
144
|
-
if
|
|
152
|
+
if self.file_data is None:
|
|
145
153
|
raise ValueError("File has not been downloaded")
|
|
146
154
|
|
|
147
|
-
return self.
|
|
155
|
+
return self.file_data
|
|
148
156
|
|
|
149
157
|
def save(self, path: str | Path, overwrite: bool = False) -> Path:
|
|
150
158
|
file_path = Path(path).resolve()
|
|
@@ -158,37 +166,32 @@ class File(BaseModel):
|
|
|
158
166
|
return file_path
|
|
159
167
|
|
|
160
168
|
|
|
161
|
-
@mainify
|
|
162
169
|
class CompressedFile(File):
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
def __init__(self, **kwargs):
|
|
166
|
-
super().__init__(**kwargs)
|
|
167
|
-
self._extract_dir = None
|
|
170
|
+
extract_dir: Optional[str] = Field(default=None, exclude=True, repr=False)
|
|
168
171
|
|
|
169
172
|
def __iter__(self):
|
|
170
|
-
if not self.
|
|
173
|
+
if not self.extract_dir:
|
|
171
174
|
self._extract_files()
|
|
172
175
|
|
|
173
|
-
files = Path(self.
|
|
176
|
+
files = Path(self.extract_dir).iterdir() # type: ignore
|
|
174
177
|
return iter(files)
|
|
175
178
|
|
|
176
179
|
def _extract_files(self):
|
|
177
|
-
self.
|
|
180
|
+
self.extract_dir = mkdtemp()
|
|
178
181
|
|
|
179
182
|
with NamedTemporaryFile() as temp_file:
|
|
180
183
|
file_path = temp_file.name
|
|
181
184
|
self.save(file_path, overwrite=True)
|
|
182
185
|
|
|
183
186
|
with ZipFile(file_path) as zip_file:
|
|
184
|
-
zip_file.extractall(self.
|
|
187
|
+
zip_file.extractall(self.extract_dir)
|
|
185
188
|
|
|
186
189
|
def glob(self, pattern: str):
|
|
187
|
-
if not self.
|
|
190
|
+
if not self.extract_dir:
|
|
188
191
|
self._extract_files()
|
|
189
192
|
|
|
190
|
-
return Path(self.
|
|
193
|
+
return Path(self.extract_dir).glob(pattern) # type: ignore
|
|
191
194
|
|
|
192
195
|
def __del__(self):
|
|
193
|
-
if self.
|
|
194
|
-
self.
|
|
196
|
+
if self.extract_dir:
|
|
197
|
+
shutil.rmtree(self.extract_dir)
|
|
@@ -10,12 +10,10 @@ from urllib.request import Request, urlopen
|
|
|
10
10
|
from fal.auth import key_credentials
|
|
11
11
|
from fal.toolkit.exceptions import FileUploadException
|
|
12
12
|
from fal.toolkit.file.types import FileData, FileRepository
|
|
13
|
-
from fal.toolkit.mainify import mainify
|
|
14
13
|
|
|
15
14
|
_FAL_CDN = "https://fal.media"
|
|
16
15
|
|
|
17
16
|
|
|
18
|
-
@mainify
|
|
19
17
|
@dataclass
|
|
20
18
|
class FalFileRepository(FileRepository):
|
|
21
19
|
def save(self, file: FileData) -> str:
|
|
@@ -70,14 +68,12 @@ class FalFileRepository(FileRepository):
|
|
|
70
68
|
return
|
|
71
69
|
|
|
72
70
|
|
|
73
|
-
@mainify
|
|
74
71
|
@dataclass
|
|
75
72
|
class InMemoryRepository(FileRepository):
|
|
76
73
|
def save(self, file: FileData) -> str:
|
|
77
74
|
return f'data:{file.content_type};base64,{b64encode(file.data).decode("utf-8")}'
|
|
78
75
|
|
|
79
76
|
|
|
80
|
-
@mainify
|
|
81
77
|
@dataclass
|
|
82
78
|
class FalCDNFileRepository(FileRepository):
|
|
83
79
|
def save(self, file: FileData) -> str:
|
|
@@ -6,12 +6,10 @@ import os
|
|
|
6
6
|
from dataclasses import dataclass
|
|
7
7
|
|
|
8
8
|
from fal.toolkit.file.types import FileData, FileRepository
|
|
9
|
-
from fal.toolkit.mainify import mainify
|
|
10
9
|
|
|
11
10
|
DEFAULT_URL_TIMEOUT = 60 * 15 # 15 minutes
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
@mainify
|
|
15
13
|
@dataclass
|
|
16
14
|
class GoogleStorageRepository(FileRepository):
|
|
17
15
|
bucket_name: str = "fal_file_storage"
|
fal/toolkit/file/providers/r2.py
CHANGED
|
@@ -6,12 +6,10 @@ from dataclasses import dataclass
|
|
|
6
6
|
from io import BytesIO
|
|
7
7
|
|
|
8
8
|
from fal.toolkit.file.types import FileData, FileRepository
|
|
9
|
-
from fal.toolkit.mainify import mainify
|
|
10
9
|
|
|
11
10
|
DEFAULT_URL_TIMEOUT = 60 * 15 # 15 minutes
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
@mainify
|
|
15
13
|
@dataclass
|
|
16
14
|
class R2Repository(FileRepository):
|
|
17
15
|
bucket_name: str = "fal_file_storage"
|
fal/toolkit/file/types.py
CHANGED
|
@@ -5,10 +5,7 @@ from mimetypes import guess_extension, guess_type
|
|
|
5
5
|
from typing import Literal
|
|
6
6
|
from uuid import uuid4
|
|
7
7
|
|
|
8
|
-
from fal.toolkit.mainify import mainify
|
|
9
8
|
|
|
10
|
-
|
|
11
|
-
@mainify
|
|
12
9
|
class FileData:
|
|
13
10
|
data: bytes
|
|
14
11
|
content_type: str
|
|
@@ -34,7 +31,6 @@ class FileData:
|
|
|
34
31
|
RepositoryId = Literal["fal", "in_memory", "gcp_storage", "r2", "cdn"]
|
|
35
32
|
|
|
36
33
|
|
|
37
|
-
@mainify
|
|
38
34
|
@dataclass
|
|
39
35
|
class FileRepository:
|
|
40
36
|
def save(self, data: FileData) -> str:
|
fal/toolkit/image/image.py
CHANGED
|
@@ -7,8 +7,7 @@ from typing import TYPE_CHECKING, Literal, Union, Optional
|
|
|
7
7
|
from pydantic import BaseModel, Field
|
|
8
8
|
|
|
9
9
|
from fal.toolkit.file.file import DEFAULT_REPOSITORY, File
|
|
10
|
-
from fal.toolkit.file.types import
|
|
11
|
-
from fal.toolkit.mainify import mainify
|
|
10
|
+
from fal.toolkit.file.types import FileRepository, RepositoryId
|
|
12
11
|
from fal.toolkit.utils.download_utils import _download_file_python
|
|
13
12
|
|
|
14
13
|
if TYPE_CHECKING:
|
|
@@ -25,7 +24,6 @@ ImageSizePreset = Literal[
|
|
|
25
24
|
]
|
|
26
25
|
|
|
27
26
|
|
|
28
|
-
@mainify
|
|
29
27
|
class ImageSize(BaseModel):
|
|
30
28
|
width: int = Field(
|
|
31
29
|
default=512, description="The width of the generated image.", gt=0, le=14142
|
|
@@ -46,7 +44,6 @@ IMAGE_SIZE_PRESETS: dict[ImageSizePreset, ImageSize] = {
|
|
|
46
44
|
|
|
47
45
|
ImageSizeInput = Union[ImageSize, ImageSizePreset]
|
|
48
46
|
|
|
49
|
-
@mainify
|
|
50
47
|
def get_image_size(source: ImageSizeInput) -> ImageSize:
|
|
51
48
|
if isinstance(source, ImageSize):
|
|
52
49
|
return source
|
|
@@ -59,18 +56,17 @@ def get_image_size(source: ImageSizeInput) -> ImageSize:
|
|
|
59
56
|
ImageFormat = Literal["png", "jpeg", "jpg", "webp", "gif"]
|
|
60
57
|
|
|
61
58
|
|
|
62
|
-
@mainify
|
|
63
59
|
class Image(File):
|
|
64
60
|
"""
|
|
65
61
|
Represents an image file.
|
|
66
62
|
"""
|
|
67
63
|
|
|
68
64
|
width: Optional[int] = Field(
|
|
69
|
-
description="The width of the image in pixels.",
|
|
65
|
+
None, description="The width of the image in pixels.",
|
|
70
66
|
examples=[1024],
|
|
71
67
|
)
|
|
72
68
|
height: Optional[int] = Field(
|
|
73
|
-
description="The height of the image in pixels.", examples=[1024]
|
|
69
|
+
None, description="The height of the image in pixels.", examples=[1024]
|
|
74
70
|
)
|
|
75
71
|
|
|
76
72
|
@classmethod
|
|
@@ -82,15 +78,15 @@ class Image(File):
|
|
|
82
78
|
file_name: str | None = None,
|
|
83
79
|
repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
|
|
84
80
|
) -> Image:
|
|
85
|
-
|
|
86
|
-
data
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
file_data=file_data,
|
|
81
|
+
obj = super().from_bytes(
|
|
82
|
+
data,
|
|
83
|
+
content_type=f"image/{format}",
|
|
84
|
+
file_name=file_name,
|
|
90
85
|
repository=repository,
|
|
91
|
-
width=size.width if size else None,
|
|
92
|
-
height=size.height if size else None,
|
|
93
86
|
)
|
|
87
|
+
obj.width=size.width if size else None
|
|
88
|
+
obj.height=size.height if size else None
|
|
89
|
+
return obj
|
|
94
90
|
|
|
95
91
|
@classmethod
|
|
96
92
|
def from_pil(
|
fal/toolkit/optimize.py
CHANGED
|
@@ -4,13 +4,11 @@ import os
|
|
|
4
4
|
import traceback
|
|
5
5
|
from typing import TYPE_CHECKING, Any
|
|
6
6
|
|
|
7
|
-
from fal.toolkit.mainify import mainify
|
|
8
7
|
|
|
9
8
|
if TYPE_CHECKING:
|
|
10
9
|
import torch
|
|
11
10
|
|
|
12
11
|
|
|
13
|
-
@mainify
|
|
14
12
|
def optimize(
|
|
15
13
|
module: torch.nn.Module, *, optimization_config: dict[str, Any] | None = None
|
|
16
14
|
) -> torch.nn.Module:
|
|
@@ -10,8 +10,6 @@ from tempfile import TemporaryDirectory
|
|
|
10
10
|
from urllib.parse import urlparse
|
|
11
11
|
from urllib.request import Request, urlopen
|
|
12
12
|
|
|
13
|
-
from fal.toolkit.mainify import mainify
|
|
14
|
-
|
|
15
13
|
FAL_PERSISTENT_DIR = PurePath("/data")
|
|
16
14
|
FAL_REPOSITORY_DIR = FAL_PERSISTENT_DIR / ".fal" / "repos"
|
|
17
15
|
FAL_MODEL_WEIGHTS_DIR = FAL_PERSISTENT_DIR / ".fal" / "model_weights"
|
|
@@ -23,12 +21,10 @@ TEMP_HEADERS = {
|
|
|
23
21
|
}
|
|
24
22
|
|
|
25
23
|
|
|
26
|
-
@mainify
|
|
27
24
|
class DownloadError(Exception):
|
|
28
25
|
pass
|
|
29
26
|
|
|
30
27
|
|
|
31
|
-
@mainify
|
|
32
28
|
def _hash_url(url: str) -> str:
|
|
33
29
|
"""Hashes a URL using SHA-256.
|
|
34
30
|
|
|
@@ -41,7 +37,6 @@ def _hash_url(url: str) -> str:
|
|
|
41
37
|
return hashlib.sha256(url.encode("utf-8")).hexdigest()
|
|
42
38
|
|
|
43
39
|
|
|
44
|
-
@mainify
|
|
45
40
|
@lru_cache
|
|
46
41
|
def _get_remote_file_properties(url: str) -> tuple[str, int]:
|
|
47
42
|
"""Retrieves the file name and content length of a remote file.
|
|
@@ -83,7 +78,6 @@ def _get_remote_file_properties(url: str) -> tuple[str, int]:
|
|
|
83
78
|
return file_name, content_length
|
|
84
79
|
|
|
85
80
|
|
|
86
|
-
@mainify
|
|
87
81
|
def _file_content_length_matches(url: str, file_path: Path) -> bool:
|
|
88
82
|
"""Check if the local file's content length matches the expected remote
|
|
89
83
|
file's content length.
|
|
@@ -109,7 +103,6 @@ def _file_content_length_matches(url: str, file_path: Path) -> bool:
|
|
|
109
103
|
return local_file_content_length == remote_file_content_length
|
|
110
104
|
|
|
111
105
|
|
|
112
|
-
@mainify
|
|
113
106
|
def download_file(
|
|
114
107
|
url: str,
|
|
115
108
|
target_dir: str | Path,
|
|
@@ -154,7 +147,7 @@ def download_file(
|
|
|
154
147
|
|
|
155
148
|
# If target_dir is not an absolute path, use "/data" as the relative directory
|
|
156
149
|
if not target_dir_path.is_absolute():
|
|
157
|
-
target_dir_path = FAL_PERSISTENT_DIR / target_dir_path # type: ignore[assignment]
|
|
150
|
+
target_dir_path = Path(FAL_PERSISTENT_DIR / target_dir_path) # type: ignore[assignment]
|
|
158
151
|
|
|
159
152
|
target_path = target_dir_path.resolve() / file_name
|
|
160
153
|
|
|
@@ -185,7 +178,6 @@ def download_file(
|
|
|
185
178
|
return target_path
|
|
186
179
|
|
|
187
180
|
|
|
188
|
-
@mainify
|
|
189
181
|
def _download_file_python(url: str, target_path: Path | str) -> Path:
|
|
190
182
|
"""Download a file from a given URL and save it to a specified path using a
|
|
191
183
|
Python interface.
|
|
@@ -224,7 +216,6 @@ def _download_file_python(url: str, target_path: Path | str) -> Path:
|
|
|
224
216
|
return Path(target_path)
|
|
225
217
|
|
|
226
218
|
|
|
227
|
-
@mainify
|
|
228
219
|
def _stream_url_data_to_file(url: str, file_path: str, chunk_size_in_mb: int = 64):
|
|
229
220
|
"""Download data from a URL and stream it to a file.
|
|
230
221
|
|
|
@@ -273,7 +264,6 @@ def _stream_url_data_to_file(url: str, file_path: str, chunk_size_in_mb: int = 6
|
|
|
273
264
|
raise DownloadError("Received less data than expected from the server.")
|
|
274
265
|
|
|
275
266
|
|
|
276
|
-
@mainify
|
|
277
267
|
def download_model_weights(url: str, force: bool = False):
|
|
278
268
|
"""Downloads model weights from the specified URL and saves them to a
|
|
279
269
|
predefined directory.
|
|
@@ -313,7 +303,6 @@ def download_model_weights(url: str, force: bool = False):
|
|
|
313
303
|
)
|
|
314
304
|
|
|
315
305
|
|
|
316
|
-
@mainify
|
|
317
306
|
def clone_repository(
|
|
318
307
|
https_url: str,
|
|
319
308
|
*,
|
|
@@ -408,7 +397,6 @@ def clone_repository(
|
|
|
408
397
|
return local_repo_path
|
|
409
398
|
|
|
410
399
|
|
|
411
|
-
@mainify
|
|
412
400
|
def __add_local_path_to_sys_path(local_path: Path | str):
|
|
413
401
|
local_path_str = str(local_path)
|
|
414
402
|
|
|
@@ -416,7 +404,6 @@ def __add_local_path_to_sys_path(local_path: Path | str):
|
|
|
416
404
|
sys.path.insert(0, local_path_str)
|
|
417
405
|
|
|
418
406
|
|
|
419
|
-
@mainify
|
|
420
407
|
def _get_git_revision_hash(repo_path: Path) -> str:
|
|
421
408
|
import subprocess
|
|
422
409
|
|