fal 0.14.0__py3-none-any.whl → 0.15.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fal might be problematic. Click here for more details.
- fal/__init__.py +1 -13
- fal/_serialization.py +151 -121
- fal/api.py +106 -61
- fal/app.py +25 -9
- fal/auth/__init__.py +2 -3
- fal/auth/auth0.py +4 -2
- fal/auth/local.py +2 -1
- fal/cli.py +10 -7
- fal/exceptions/__init__.py +3 -3
- fal/exceptions/_base.py +1 -12
- fal/exceptions/auth.py +2 -4
- fal/exceptions/handlers.py +8 -19
- fal/flags.py +0 -2
- fal/logging/isolate.py +4 -4
- fal/sdk.py +40 -5
- fal/sync.py +7 -3
- fal/toolkit/__init__.py +0 -2
- fal/toolkit/exceptions.py +0 -5
- fal/toolkit/file/file.py +61 -50
- fal/toolkit/file/providers/fal.py +20 -7
- fal/toolkit/file/providers/gcp.py +0 -2
- fal/toolkit/file/providers/r2.py +0 -2
- fal/toolkit/file/types.py +0 -4
- fal/toolkit/image/image.py +11 -15
- fal/toolkit/optimize.py +0 -3
- fal/toolkit/utils/download_utils.py +7 -17
- fal/workflows.py +9 -3
- fal-0.15.2.dist-info/METADATA +119 -0
- {fal-0.14.0.dist-info → fal-0.15.2.dist-info}/RECORD +50 -51
- {fal-0.14.0.dist-info → fal-0.15.2.dist-info}/WHEEL +2 -1
- fal-0.15.2.dist-info/entry_points.txt +2 -0
- fal-0.15.2.dist-info/top_level.txt +2 -0
- fal/env.py +0 -3
- fal/toolkit/mainify.py +0 -13
- fal-0.14.0.dist-info/METADATA +0 -89
- fal-0.14.0.dist-info/entry_points.txt +0 -4
fal/flags.py
CHANGED
|
@@ -18,8 +18,6 @@ GRPC_HOST = os.getenv("FAL_HOST", "api.alpha.fal.ai")
|
|
|
18
18
|
if not TEST_MODE:
|
|
19
19
|
assert GRPC_HOST.startswith("api"), "FAL_HOST must start with 'api'"
|
|
20
20
|
|
|
21
|
-
GATEWAY_HOST = GRPC_HOST.replace("api", "gateway", 1)
|
|
22
|
-
|
|
23
21
|
REST_HOST = GRPC_HOST.replace("api", "rest", 1)
|
|
24
22
|
REST_SCHEME = "http" if TEST_MODE or AUTH_DISABLED else "https"
|
|
25
23
|
REST_URL = f"{REST_SCHEME}://{REST_HOST}"
|
fal/logging/isolate.py
CHANGED
|
@@ -34,10 +34,10 @@ class IsolateLogPrinter:
|
|
|
34
34
|
timestamp = log.timestamp
|
|
35
35
|
else:
|
|
36
36
|
# Default value for timestamp if user has old `isolate` version.
|
|
37
|
-
# Even if the controller version is controller by us, which means that
|
|
38
|
-
# is being sent in the gRPC message.
|
|
39
|
-
# The `isolate` version users interpret that message with is out of our
|
|
40
|
-
# So we need to handle this case.
|
|
37
|
+
# Even if the controller version is controller by us, which means that
|
|
38
|
+
# the timestamp is being sent in the gRPC message.
|
|
39
|
+
# The `isolate` version users interpret that message with is out of our
|
|
40
|
+
# control. So we need to handle this case.
|
|
41
41
|
timestamp = datetime.now(timezone.utc)
|
|
42
42
|
|
|
43
43
|
event: EventDict = {
|
fal/sdk.py
CHANGED
|
@@ -8,30 +8,29 @@ from enum import Enum
|
|
|
8
8
|
from typing import Any, Callable, Generic, Iterator, Literal, TypeVar
|
|
9
9
|
|
|
10
10
|
import grpc
|
|
11
|
+
import isolate_proto
|
|
11
12
|
from isolate.connections.common import is_agent
|
|
12
13
|
from isolate.logs import Log
|
|
13
14
|
from isolate.server.interface import from_grpc, to_serialized_object, to_struct
|
|
15
|
+
from isolate_proto.configuration import GRPC_OPTIONS
|
|
14
16
|
|
|
15
|
-
import isolate_proto
|
|
16
17
|
from fal import flags
|
|
17
|
-
from fal._serialization import
|
|
18
|
+
from fal._serialization import patch_pickle
|
|
18
19
|
from fal.auth import USER, key_credentials
|
|
19
20
|
from fal.logging import get_logger
|
|
20
21
|
from fal.logging.trace import TraceContextInterceptor
|
|
21
|
-
from isolate_proto.configuration import GRPC_OPTIONS
|
|
22
22
|
|
|
23
23
|
ResultT = TypeVar("ResultT")
|
|
24
24
|
InputT = TypeVar("InputT")
|
|
25
25
|
UNSET = object()
|
|
26
26
|
|
|
27
|
-
_DEFAULT_SERIALIZATION_METHOD = "
|
|
27
|
+
_DEFAULT_SERIALIZATION_METHOD = "cloudpickle"
|
|
28
28
|
FAL_SERVERLESS_DEFAULT_KEEP_ALIVE = 10
|
|
29
29
|
FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING = 1
|
|
30
30
|
FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY = 0
|
|
31
31
|
|
|
32
32
|
logger = get_logger(__name__)
|
|
33
33
|
|
|
34
|
-
patch_dill()
|
|
35
34
|
patch_pickle()
|
|
36
35
|
|
|
37
36
|
|
|
@@ -188,6 +187,16 @@ class HostedRunStatus:
|
|
|
188
187
|
state: HostedRunState
|
|
189
188
|
|
|
190
189
|
|
|
190
|
+
@dataclass
|
|
191
|
+
class ApplicationInfo:
|
|
192
|
+
application_id: str
|
|
193
|
+
keep_alive: int
|
|
194
|
+
max_concurrency: int
|
|
195
|
+
max_multiplexing: int
|
|
196
|
+
active_runners: int
|
|
197
|
+
min_concurrency: int
|
|
198
|
+
|
|
199
|
+
|
|
191
200
|
@dataclass
|
|
192
201
|
class AliasInfo:
|
|
193
202
|
alias: str
|
|
@@ -264,6 +273,20 @@ class KeyScope(enum.Enum):
|
|
|
264
273
|
raise ValueError(f"Unknown KeyScope: {proto}")
|
|
265
274
|
|
|
266
275
|
|
|
276
|
+
@from_grpc.register(isolate_proto.ApplicationInfo)
|
|
277
|
+
def _from_grpc_application_info(
|
|
278
|
+
message: isolate_proto.ApplicationInfo
|
|
279
|
+
) -> ApplicationInfo:
|
|
280
|
+
return ApplicationInfo(
|
|
281
|
+
application_id=message.application_id,
|
|
282
|
+
keep_alive=message.keep_alive,
|
|
283
|
+
max_concurrency=message.max_concurrency,
|
|
284
|
+
max_multiplexing=message.max_multiplexing,
|
|
285
|
+
active_runners=message.active_runners,
|
|
286
|
+
min_concurrency=message.min_concurrency,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
|
|
267
290
|
@from_grpc.register(isolate_proto.AliasInfo)
|
|
268
291
|
def _from_grpc_alias_info(message: isolate_proto.AliasInfo) -> AliasInfo:
|
|
269
292
|
if message.auth_mode is isolate_proto.ApplicationAuthMode.PUBLIC:
|
|
@@ -497,6 +520,18 @@ class FalServerlessConnection:
|
|
|
497
520
|
)
|
|
498
521
|
return from_grpc(res.alias_info)
|
|
499
522
|
|
|
523
|
+
def list_applications(self) -> list[ApplicationInfo]:
|
|
524
|
+
request = isolate_proto.ListApplicationsRequest()
|
|
525
|
+
res: isolate_proto.ListApplicationsResult = self.stub.ListApplications(request)
|
|
526
|
+
return [from_grpc(app) for app in res.applications]
|
|
527
|
+
|
|
528
|
+
def delete_application(
|
|
529
|
+
self,
|
|
530
|
+
application_id: str,
|
|
531
|
+
) -> None:
|
|
532
|
+
request = isolate_proto.DeleteApplicationRequest(application_id=application_id)
|
|
533
|
+
self.stub.DeleteApplication(request)
|
|
534
|
+
|
|
500
535
|
def run(
|
|
501
536
|
self,
|
|
502
537
|
function: Callable[..., ResultT],
|
fal/sync.py
CHANGED
|
@@ -31,7 +31,8 @@ def _upload_file(source_path: str, target_path: str, unzip: bool = False):
|
|
|
31
31
|
body = upload_file_model.BodyUploadLocalFile(
|
|
32
32
|
rest_types.File(
|
|
33
33
|
payload=file_to_upload,
|
|
34
|
-
# We need to set a file_name, otherwise the server errors
|
|
34
|
+
# We need to set a file_name, otherwise the server errors
|
|
35
|
+
# processing the file
|
|
35
36
|
file_name=os.path.basename(source_path),
|
|
36
37
|
)
|
|
37
38
|
)
|
|
@@ -45,7 +46,9 @@ def _upload_file(source_path: str, target_path: str, unzip: bool = False):
|
|
|
45
46
|
|
|
46
47
|
if response.status_code != 200:
|
|
47
48
|
raise Exception(
|
|
48
|
-
|
|
49
|
+
"Failed to upload file. "
|
|
50
|
+
"Server returned status code "
|
|
51
|
+
f"{response.status_code} and message {response.parsed}"
|
|
49
52
|
)
|
|
50
53
|
|
|
51
54
|
|
|
@@ -94,7 +97,8 @@ def sync_dir(local_dir: str | Path, remote_dir: str, force_upload=False) -> str:
|
|
|
94
97
|
local_dir_abs = os.path.expanduser(local_dir)
|
|
95
98
|
if not os.path.isabs(remote_dir) or not remote_dir.startswith("/data"):
|
|
96
99
|
raise ValueError(
|
|
97
|
-
"'remote_dir' must be an absolute path starting with `/data`,
|
|
100
|
+
"'remote_dir' must be an absolute path starting with `/data`, "
|
|
101
|
+
"e.g. '/data/sync/my_dir'"
|
|
98
102
|
)
|
|
99
103
|
|
|
100
104
|
remote_dir = remote_dir.replace("/data/", "", 1)
|
fal/toolkit/__init__.py
CHANGED
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from fal.toolkit.file import CompressedFile, File
|
|
4
4
|
from fal.toolkit.image.image import Image, ImageSizeInput, get_image_size
|
|
5
|
-
from fal.toolkit.mainify import mainify
|
|
6
5
|
from fal.toolkit.optimize import optimize
|
|
7
6
|
from fal.toolkit.utils import (
|
|
8
7
|
FAL_MODEL_WEIGHTS_DIR,
|
|
@@ -19,7 +18,6 @@ __all__ = [
|
|
|
19
18
|
"Image",
|
|
20
19
|
"ImageSizeInput",
|
|
21
20
|
"get_image_size",
|
|
22
|
-
"mainify",
|
|
23
21
|
"optimize",
|
|
24
22
|
"FAL_MODEL_WEIGHTS_DIR",
|
|
25
23
|
"FAL_PERSISTENT_DIR",
|
fal/toolkit/exceptions.py
CHANGED
|
@@ -1,14 +1,9 @@
|
|
|
1
|
-
from fal.toolkit.mainify import mainify
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
@mainify
|
|
5
1
|
class FalTookitException(Exception):
|
|
6
2
|
"""Base exception for all toolkit exceptions"""
|
|
7
3
|
|
|
8
4
|
pass
|
|
9
5
|
|
|
10
6
|
|
|
11
|
-
@mainify
|
|
12
7
|
class FileUploadException(FalTookitException):
|
|
13
8
|
"""Raised when file upload fails"""
|
|
14
9
|
|
fal/toolkit/file/file.py
CHANGED
|
@@ -1,13 +1,23 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import shutil
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from tempfile import NamedTemporaryFile,
|
|
5
|
-
from typing import Any, Callable
|
|
5
|
+
from tempfile import NamedTemporaryFile, mkdtemp
|
|
6
|
+
from typing import Any, Callable, Optional
|
|
6
7
|
from urllib.parse import urlparse
|
|
7
8
|
from zipfile import ZipFile
|
|
8
9
|
|
|
9
|
-
|
|
10
|
-
|
|
10
|
+
import pydantic
|
|
11
|
+
|
|
12
|
+
# https://github.com/pydantic/pydantic/pull/2573
|
|
13
|
+
if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."):
|
|
14
|
+
IS_PYDANTIC_V2 = False
|
|
15
|
+
else:
|
|
16
|
+
from pydantic import GetCoreSchemaHandler
|
|
17
|
+
from pydantic_core import CoreSchema, core_schema
|
|
18
|
+
IS_PYDANTIC_V2 = True
|
|
19
|
+
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
11
21
|
|
|
12
22
|
from fal.toolkit.file.providers.fal import (
|
|
13
23
|
FalCDNFileRepository,
|
|
@@ -17,7 +27,6 @@ from fal.toolkit.file.providers.fal import (
|
|
|
17
27
|
from fal.toolkit.file.providers.gcp import GoogleStorageRepository
|
|
18
28
|
from fal.toolkit.file.providers.r2 import R2Repository
|
|
19
29
|
from fal.toolkit.file.types import FileData, FileRepository, RepositoryId
|
|
20
|
-
from fal.toolkit.mainify import mainify
|
|
21
30
|
from fal.toolkit.utils.download_utils import download_file
|
|
22
31
|
|
|
23
32
|
FileRepositoryFactory = Callable[[], FileRepository]
|
|
@@ -42,52 +51,48 @@ get_builtin_repository.__module__ = "__main__"
|
|
|
42
51
|
DEFAULT_REPOSITORY: FileRepository | RepositoryId = "fal"
|
|
43
52
|
|
|
44
53
|
|
|
45
|
-
@mainify
|
|
46
54
|
class File(BaseModel):
|
|
47
55
|
# public properties
|
|
48
|
-
_file_data: FileData = PrivateAttr()
|
|
49
56
|
url: str = Field(
|
|
50
57
|
description="The URL where the file can be downloaded from.",
|
|
51
58
|
)
|
|
52
59
|
content_type: Optional[str] = Field(
|
|
60
|
+
None,
|
|
53
61
|
description="The mime type of the file.",
|
|
54
62
|
examples=["image/png"],
|
|
55
63
|
)
|
|
56
64
|
file_name: Optional[str] = Field(
|
|
65
|
+
None,
|
|
57
66
|
description="The name of the file. It will be auto-generated if not provided.",
|
|
58
67
|
examples=["z9RV14K95DvU.png"],
|
|
59
68
|
)
|
|
60
69
|
file_size: Optional[int] = Field(
|
|
61
|
-
description="The size of the file in bytes.", examples=[4404019]
|
|
70
|
+
None, description="The size of the file in bytes.", examples=[4404019]
|
|
71
|
+
)
|
|
72
|
+
file_data: Optional[bytes] = Field(
|
|
73
|
+
None,
|
|
74
|
+
description="File data",
|
|
75
|
+
exclude=True,
|
|
76
|
+
repr=False,
|
|
62
77
|
)
|
|
63
78
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
self._file_data = data
|
|
75
|
-
|
|
76
|
-
kwargs.update(
|
|
77
|
-
{
|
|
78
|
-
"url": repo.save(data),
|
|
79
|
-
"content_type": data.content_type,
|
|
80
|
-
"file_name": data.file_name,
|
|
81
|
-
"file_size": len(data.data),
|
|
82
|
-
}
|
|
79
|
+
# Pydantic custom validator for input type conversion
|
|
80
|
+
if IS_PYDANTIC_V2:
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def __get_pydantic_core_schema__(
|
|
84
|
+
cls, source_type: Any, handler: GetCoreSchemaHandler
|
|
85
|
+
) -> CoreSchema:
|
|
86
|
+
return core_schema.no_info_before_validator_function(
|
|
87
|
+
cls.__convert_from_str,
|
|
88
|
+
handler(source_type),
|
|
83
89
|
)
|
|
84
90
|
|
|
85
|
-
|
|
91
|
+
else:
|
|
86
92
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
yield cls.__convert_from_str
|
|
93
|
+
@classmethod
|
|
94
|
+
def __get_validators__(cls):
|
|
95
|
+
yield cls.__convert_from_str
|
|
91
96
|
|
|
92
97
|
@classmethod
|
|
93
98
|
def __convert_from_str(cls, value: Any):
|
|
@@ -119,9 +124,20 @@ class File(BaseModel):
|
|
|
119
124
|
file_name: Optional[str] = None,
|
|
120
125
|
repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
|
|
121
126
|
) -> File:
|
|
127
|
+
repo = (
|
|
128
|
+
repository
|
|
129
|
+
if isinstance(repository, FileRepository)
|
|
130
|
+
else get_builtin_repository(repository)
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
fdata = FileData(data, content_type, file_name)
|
|
134
|
+
|
|
122
135
|
return cls(
|
|
123
|
-
|
|
124
|
-
|
|
136
|
+
url=repo.save(fdata),
|
|
137
|
+
content_type=fdata.content_type,
|
|
138
|
+
file_name=fdata.file_name,
|
|
139
|
+
file_size=len(data),
|
|
140
|
+
file_data=data,
|
|
125
141
|
)
|
|
126
142
|
|
|
127
143
|
@classmethod
|
|
@@ -141,10 +157,10 @@ class File(BaseModel):
|
|
|
141
157
|
)
|
|
142
158
|
|
|
143
159
|
def as_bytes(self) -> bytes:
|
|
144
|
-
if
|
|
160
|
+
if self.file_data is None:
|
|
145
161
|
raise ValueError("File has not been downloaded")
|
|
146
162
|
|
|
147
|
-
return self.
|
|
163
|
+
return self.file_data
|
|
148
164
|
|
|
149
165
|
def save(self, path: str | Path, overwrite: bool = False) -> Path:
|
|
150
166
|
file_path = Path(path).resolve()
|
|
@@ -158,37 +174,32 @@ class File(BaseModel):
|
|
|
158
174
|
return file_path
|
|
159
175
|
|
|
160
176
|
|
|
161
|
-
@mainify
|
|
162
177
|
class CompressedFile(File):
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
def __init__(self, **kwargs):
|
|
166
|
-
super().__init__(**kwargs)
|
|
167
|
-
self._extract_dir = None
|
|
178
|
+
extract_dir: Optional[str] = Field(default=None, exclude=True, repr=False)
|
|
168
179
|
|
|
169
180
|
def __iter__(self):
|
|
170
|
-
if not self.
|
|
181
|
+
if not self.extract_dir:
|
|
171
182
|
self._extract_files()
|
|
172
183
|
|
|
173
|
-
files = Path(self.
|
|
184
|
+
files = Path(self.extract_dir).iterdir() # type: ignore
|
|
174
185
|
return iter(files)
|
|
175
186
|
|
|
176
187
|
def _extract_files(self):
|
|
177
|
-
self.
|
|
188
|
+
self.extract_dir = mkdtemp()
|
|
178
189
|
|
|
179
190
|
with NamedTemporaryFile() as temp_file:
|
|
180
191
|
file_path = temp_file.name
|
|
181
192
|
self.save(file_path, overwrite=True)
|
|
182
193
|
|
|
183
194
|
with ZipFile(file_path) as zip_file:
|
|
184
|
-
zip_file.extractall(self.
|
|
195
|
+
zip_file.extractall(self.extract_dir)
|
|
185
196
|
|
|
186
197
|
def glob(self, pattern: str):
|
|
187
|
-
if not self.
|
|
198
|
+
if not self.extract_dir:
|
|
188
199
|
self._extract_files()
|
|
189
200
|
|
|
190
|
-
return Path(self.
|
|
201
|
+
return Path(self.extract_dir).glob(pattern) # type: ignore
|
|
191
202
|
|
|
192
203
|
def __del__(self):
|
|
193
|
-
if self.
|
|
194
|
-
self.
|
|
204
|
+
if self.extract_dir:
|
|
205
|
+
shutil.rmtree(self.extract_dir)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import dataclasses
|
|
3
4
|
import json
|
|
4
5
|
import os
|
|
5
6
|
from base64 import b64encode
|
|
@@ -10,12 +11,18 @@ from urllib.request import Request, urlopen
|
|
|
10
11
|
from fal.auth import key_credentials
|
|
11
12
|
from fal.toolkit.exceptions import FileUploadException
|
|
12
13
|
from fal.toolkit.file.types import FileData, FileRepository
|
|
13
|
-
from fal.toolkit.mainify import mainify
|
|
14
14
|
|
|
15
15
|
_FAL_CDN = "https://fal.media"
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
@
|
|
18
|
+
@dataclass
|
|
19
|
+
class ObjectLifecyclePreference:
|
|
20
|
+
expriation_duration_seconds: int
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
GLOBAL_LIFECYCLE_PREFERENCE = ObjectLifecyclePreference(expriation_duration_seconds=2)
|
|
24
|
+
|
|
25
|
+
|
|
19
26
|
@dataclass
|
|
20
27
|
class FalFileRepository(FileRepository):
|
|
21
28
|
def save(self, file: FileData) -> str:
|
|
@@ -70,23 +77,29 @@ class FalFileRepository(FileRepository):
|
|
|
70
77
|
return
|
|
71
78
|
|
|
72
79
|
|
|
73
|
-
@mainify
|
|
74
80
|
@dataclass
|
|
75
81
|
class InMemoryRepository(FileRepository):
|
|
76
|
-
def save(
|
|
82
|
+
def save(
|
|
83
|
+
self,
|
|
84
|
+
file: FileData,
|
|
85
|
+
) -> str:
|
|
77
86
|
return f'data:{file.content_type};base64,{b64encode(file.data).decode("utf-8")}'
|
|
78
87
|
|
|
79
88
|
|
|
80
|
-
@mainify
|
|
81
89
|
@dataclass
|
|
82
90
|
class FalCDNFileRepository(FileRepository):
|
|
83
|
-
def save(
|
|
91
|
+
def save(
|
|
92
|
+
self,
|
|
93
|
+
file: FileData,
|
|
94
|
+
) -> str:
|
|
84
95
|
headers = {
|
|
85
96
|
**self.auth_headers,
|
|
86
97
|
"Accept": "application/json",
|
|
87
98
|
"Content-Type": file.content_type,
|
|
99
|
+
"X-Fal-Object-Lifecycle-Preference": json.dumps(
|
|
100
|
+
dataclasses.asdict(GLOBAL_LIFECYCLE_PREFERENCE)
|
|
101
|
+
),
|
|
88
102
|
}
|
|
89
|
-
|
|
90
103
|
url = os.getenv("FAL_CDN_HOST", _FAL_CDN) + "/files/upload"
|
|
91
104
|
request = Request(url, headers=headers, method="POST", data=file.data)
|
|
92
105
|
try:
|
|
@@ -6,12 +6,10 @@ import os
|
|
|
6
6
|
from dataclasses import dataclass
|
|
7
7
|
|
|
8
8
|
from fal.toolkit.file.types import FileData, FileRepository
|
|
9
|
-
from fal.toolkit.mainify import mainify
|
|
10
9
|
|
|
11
10
|
DEFAULT_URL_TIMEOUT = 60 * 15 # 15 minutes
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
@mainify
|
|
15
13
|
@dataclass
|
|
16
14
|
class GoogleStorageRepository(FileRepository):
|
|
17
15
|
bucket_name: str = "fal_file_storage"
|
fal/toolkit/file/providers/r2.py
CHANGED
|
@@ -6,12 +6,10 @@ from dataclasses import dataclass
|
|
|
6
6
|
from io import BytesIO
|
|
7
7
|
|
|
8
8
|
from fal.toolkit.file.types import FileData, FileRepository
|
|
9
|
-
from fal.toolkit.mainify import mainify
|
|
10
9
|
|
|
11
10
|
DEFAULT_URL_TIMEOUT = 60 * 15 # 15 minutes
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
@mainify
|
|
15
13
|
@dataclass
|
|
16
14
|
class R2Repository(FileRepository):
|
|
17
15
|
bucket_name: str = "fal_file_storage"
|
fal/toolkit/file/types.py
CHANGED
|
@@ -5,10 +5,7 @@ from mimetypes import guess_extension, guess_type
|
|
|
5
5
|
from typing import Literal
|
|
6
6
|
from uuid import uuid4
|
|
7
7
|
|
|
8
|
-
from fal.toolkit.mainify import mainify
|
|
9
8
|
|
|
10
|
-
|
|
11
|
-
@mainify
|
|
12
9
|
class FileData:
|
|
13
10
|
data: bytes
|
|
14
11
|
content_type: str
|
|
@@ -34,7 +31,6 @@ class FileData:
|
|
|
34
31
|
RepositoryId = Literal["fal", "in_memory", "gcp_storage", "r2", "cdn"]
|
|
35
32
|
|
|
36
33
|
|
|
37
|
-
@mainify
|
|
38
34
|
@dataclass
|
|
39
35
|
class FileRepository:
|
|
40
36
|
def save(self, data: FileData) -> str:
|
fal/toolkit/image/image.py
CHANGED
|
@@ -2,13 +2,12 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import io
|
|
4
4
|
from tempfile import NamedTemporaryFile
|
|
5
|
-
from typing import TYPE_CHECKING, Literal,
|
|
5
|
+
from typing import TYPE_CHECKING, Literal, Optional, Union
|
|
6
6
|
|
|
7
7
|
from pydantic import BaseModel, Field
|
|
8
8
|
|
|
9
9
|
from fal.toolkit.file.file import DEFAULT_REPOSITORY, File
|
|
10
|
-
from fal.toolkit.file.types import
|
|
11
|
-
from fal.toolkit.mainify import mainify
|
|
10
|
+
from fal.toolkit.file.types import FileRepository, RepositoryId
|
|
12
11
|
from fal.toolkit.utils.download_utils import _download_file_python
|
|
13
12
|
|
|
14
13
|
if TYPE_CHECKING:
|
|
@@ -25,7 +24,6 @@ ImageSizePreset = Literal[
|
|
|
25
24
|
]
|
|
26
25
|
|
|
27
26
|
|
|
28
|
-
@mainify
|
|
29
27
|
class ImageSize(BaseModel):
|
|
30
28
|
width: int = Field(
|
|
31
29
|
default=512, description="The width of the generated image.", gt=0, le=14142
|
|
@@ -46,7 +44,6 @@ IMAGE_SIZE_PRESETS: dict[ImageSizePreset, ImageSize] = {
|
|
|
46
44
|
|
|
47
45
|
ImageSizeInput = Union[ImageSize, ImageSizePreset]
|
|
48
46
|
|
|
49
|
-
@mainify
|
|
50
47
|
def get_image_size(source: ImageSizeInput) -> ImageSize:
|
|
51
48
|
if isinstance(source, ImageSize):
|
|
52
49
|
return source
|
|
@@ -59,18 +56,17 @@ def get_image_size(source: ImageSizeInput) -> ImageSize:
|
|
|
59
56
|
ImageFormat = Literal["png", "jpeg", "jpg", "webp", "gif"]
|
|
60
57
|
|
|
61
58
|
|
|
62
|
-
@mainify
|
|
63
59
|
class Image(File):
|
|
64
60
|
"""
|
|
65
61
|
Represents an image file.
|
|
66
62
|
"""
|
|
67
63
|
|
|
68
64
|
width: Optional[int] = Field(
|
|
69
|
-
description="The width of the image in pixels.",
|
|
65
|
+
None, description="The width of the image in pixels.",
|
|
70
66
|
examples=[1024],
|
|
71
67
|
)
|
|
72
68
|
height: Optional[int] = Field(
|
|
73
|
-
description="The height of the image in pixels.", examples=[1024]
|
|
69
|
+
None, description="The height of the image in pixels.", examples=[1024]
|
|
74
70
|
)
|
|
75
71
|
|
|
76
72
|
@classmethod
|
|
@@ -82,15 +78,15 @@ class Image(File):
|
|
|
82
78
|
file_name: str | None = None,
|
|
83
79
|
repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
|
|
84
80
|
) -> Image:
|
|
85
|
-
|
|
86
|
-
data
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
file_data=file_data,
|
|
81
|
+
obj = super().from_bytes(
|
|
82
|
+
data,
|
|
83
|
+
content_type=f"image/{format}",
|
|
84
|
+
file_name=file_name,
|
|
90
85
|
repository=repository,
|
|
91
|
-
width=size.width if size else None,
|
|
92
|
-
height=size.height if size else None,
|
|
93
86
|
)
|
|
87
|
+
obj.width=size.width if size else None
|
|
88
|
+
obj.height=size.height if size else None
|
|
89
|
+
return obj
|
|
94
90
|
|
|
95
91
|
@classmethod
|
|
96
92
|
def from_pil(
|
fal/toolkit/optimize.py
CHANGED
|
@@ -4,13 +4,10 @@ import os
|
|
|
4
4
|
import traceback
|
|
5
5
|
from typing import TYPE_CHECKING, Any
|
|
6
6
|
|
|
7
|
-
from fal.toolkit.mainify import mainify
|
|
8
|
-
|
|
9
7
|
if TYPE_CHECKING:
|
|
10
8
|
import torch
|
|
11
9
|
|
|
12
10
|
|
|
13
|
-
@mainify
|
|
14
11
|
def optimize(
|
|
15
12
|
module: torch.nn.Module, *, optimization_config: dict[str, Any] | None = None
|
|
16
13
|
) -> torch.nn.Module:
|