fal 0.13.0__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fal might be problematic. Click here for more details.
- fal/__init__.py +4 -4
- fal/_serialization.py +159 -108
- fal/api.py +91 -36
- fal/app.py +3 -8
- fal/auth/__init__.py +1 -3
- fal/cli.py +3 -4
- fal/exceptions/__init__.py +1 -2
- fal/exceptions/_base.py +1 -12
- fal/exceptions/auth.py +2 -4
- fal/exceptions/handlers.py +8 -19
- fal/logging/isolate.py +8 -19
- fal/logging/user.py +1 -1
- fal/sdk.py +3 -3
- fal/toolkit/__init__.py +0 -2
- fal/toolkit/exceptions.py +0 -5
- fal/toolkit/file/__init__.py +1 -1
- fal/toolkit/file/file.py +58 -55
- fal/toolkit/file/providers/fal.py +2 -6
- fal/toolkit/file/providers/gcp.py +0 -2
- fal/toolkit/file/providers/r2.py +0 -2
- fal/toolkit/file/types.py +0 -4
- fal/toolkit/image/__init__.py +1 -1
- fal/toolkit/image/image.py +11 -15
- fal/toolkit/optimize.py +1 -3
- fal/toolkit/utils/__init__.py +1 -1
- fal/toolkit/utils/download_utils.py +2 -15
- fal/workflows.py +3 -2
- {fal-0.13.0.dist-info → fal-0.15.0.dist-info}/METADATA +40 -38
- {fal-0.13.0.dist-info → fal-0.15.0.dist-info}/RECORD +50 -51
- {fal-0.13.0.dist-info → fal-0.15.0.dist-info}/WHEEL +2 -1
- fal-0.15.0.dist-info/entry_points.txt +2 -0
- fal-0.15.0.dist-info/top_level.txt +2 -0
- fal/env.py +0 -3
- fal/toolkit/mainify.py +0 -13
- fal-0.13.0.dist-info/entry_points.txt +0 -4
fal/app.py
CHANGED
|
@@ -10,11 +10,9 @@ from typing import Any, Callable, ClassVar, TypeVar
|
|
|
10
10
|
from fastapi import FastAPI
|
|
11
11
|
|
|
12
12
|
import fal.api
|
|
13
|
-
from fal._serialization import add_serialization_listeners_for
|
|
14
13
|
from fal.api import RouteSignature
|
|
15
14
|
from fal.logging import get_logger
|
|
16
|
-
from fal.
|
|
17
|
-
|
|
15
|
+
from fal._serialization import include_modules_from
|
|
18
16
|
REALTIME_APP_REQUIREMENTS = ["websockets", "msgpack"]
|
|
19
17
|
|
|
20
18
|
EndpointT = TypeVar("EndpointT", bound=Callable[..., Any])
|
|
@@ -29,7 +27,7 @@ async def _call_any_fn(fn, *args, **kwargs):
|
|
|
29
27
|
|
|
30
28
|
|
|
31
29
|
def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
|
|
32
|
-
|
|
30
|
+
include_modules_from(cls)
|
|
33
31
|
|
|
34
32
|
def initialize_and_serve():
|
|
35
33
|
app = cls()
|
|
@@ -39,7 +37,7 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
|
|
|
39
37
|
try:
|
|
40
38
|
app = cls(_allow_init=True)
|
|
41
39
|
metadata["openapi"] = app.openapi()
|
|
42
|
-
except Exception
|
|
40
|
+
except Exception:
|
|
43
41
|
logger.warning("Failed to build OpenAPI specification for %s", cls.__name__)
|
|
44
42
|
realtime_app = False
|
|
45
43
|
else:
|
|
@@ -64,7 +62,6 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
|
|
|
64
62
|
return fn
|
|
65
63
|
|
|
66
64
|
|
|
67
|
-
@mainify
|
|
68
65
|
class App(fal.api.BaseServable):
|
|
69
66
|
requirements: ClassVar[list[str]] = []
|
|
70
67
|
machine_type: ClassVar[str] = "S"
|
|
@@ -131,7 +128,6 @@ class App(fal.api.BaseServable):
|
|
|
131
128
|
raise NotImplementedError
|
|
132
129
|
|
|
133
130
|
|
|
134
|
-
@mainify
|
|
135
131
|
def endpoint(
|
|
136
132
|
path: str, *, is_websocket: bool = False
|
|
137
133
|
) -> Callable[[EndpointT], EndpointT]:
|
|
@@ -343,7 +339,6 @@ def _fal_websocket_template(
|
|
|
343
339
|
_SENTINEL = object()
|
|
344
340
|
|
|
345
341
|
|
|
346
|
-
@mainify
|
|
347
342
|
def realtime(
|
|
348
343
|
path: str,
|
|
349
344
|
*,
|
fal/auth/__init__.py
CHANGED
|
@@ -9,10 +9,8 @@ from fal.auth import auth0, local
|
|
|
9
9
|
from fal.console import console
|
|
10
10
|
from fal.console.icons import CHECK_ICON
|
|
11
11
|
from fal.exceptions.auth import UnauthenticatedException
|
|
12
|
-
from fal.toolkit.mainify import mainify
|
|
13
12
|
|
|
14
13
|
|
|
15
|
-
@mainify
|
|
16
14
|
def key_credentials() -> tuple[str, str] | None:
|
|
17
15
|
# Ignore key credentials when the user forces auth by user.
|
|
18
16
|
if os.environ.get("FAL_FORCE_AUTH_BY_USER") == "1":
|
|
@@ -64,7 +62,7 @@ def _fetch_access_token() -> str:
|
|
|
64
62
|
try:
|
|
65
63
|
auth0.verify_access_token_expiration(access_token)
|
|
66
64
|
return access_token
|
|
67
|
-
except:
|
|
65
|
+
except Exception:
|
|
68
66
|
# access_token expired, will refresh
|
|
69
67
|
pass
|
|
70
68
|
|
fal/cli.py
CHANGED
|
@@ -113,7 +113,6 @@ class MainGroup(RichGroup):
|
|
|
113
113
|
|
|
114
114
|
if aliases:
|
|
115
115
|
# Add aliases to the help text
|
|
116
|
-
aliases_str = "Alias: " + ", ".join([name, *aliases])
|
|
117
116
|
cmd.help = (cmd.help or "") + "\n\nAlias: " + ", ".join([name, *aliases])
|
|
118
117
|
cmd.short_help = (
|
|
119
118
|
(cmd.short_help or "") + "(Alias: " + ", ".join(aliases) + ")"
|
|
@@ -268,8 +267,8 @@ def load_function_from(
|
|
|
268
267
|
raise api.FalServerlessError(f"Function '{function_name}' not found in module")
|
|
269
268
|
|
|
270
269
|
# The module for the function is set to <run_path> when runpy is used, in which
|
|
271
|
-
# case we want to manually include the
|
|
272
|
-
_serialization.
|
|
270
|
+
# case we want to manually include the package it is defined in.
|
|
271
|
+
_serialization.include_package_from_path(file_path)
|
|
273
272
|
|
|
274
273
|
target = module[function_name]
|
|
275
274
|
if isinstance(target, type) and issubclass(target, fal.App):
|
|
@@ -603,7 +602,7 @@ def _get_user_id() -> str:
|
|
|
603
602
|
if user_details_response.status_code != HTTPStatus.OK:
|
|
604
603
|
try:
|
|
605
604
|
content = json.loads(user_details_response.content.decode("utf8"))
|
|
606
|
-
except:
|
|
605
|
+
except Exception:
|
|
607
606
|
raise api.FalServerlessError(
|
|
608
607
|
f"Error fetching user details: {user_details_response}"
|
|
609
608
|
)
|
fal/exceptions/__init__.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from ._base import FalServerlessException # noqa: F401
|
|
3
4
|
from .handlers import (
|
|
4
5
|
BaseExceptionHandler,
|
|
5
|
-
FalServerlessExceptionHandler,
|
|
6
6
|
GrpcExceptionHandler,
|
|
7
7
|
UserFunctionExceptionHandler,
|
|
8
8
|
)
|
|
@@ -20,7 +20,6 @@ class ApplicationExceptionHandler:
|
|
|
20
20
|
|
|
21
21
|
_handlers: list[BaseExceptionHandler] = [
|
|
22
22
|
GrpcExceptionHandler(),
|
|
23
|
-
FalServerlessExceptionHandler(),
|
|
24
23
|
UserFunctionExceptionHandler(),
|
|
25
24
|
]
|
|
26
25
|
|
fal/exceptions/_base.py
CHANGED
|
@@ -3,15 +3,4 @@ from __future__ import annotations
|
|
|
3
3
|
|
|
4
4
|
class FalServerlessException(Exception):
|
|
5
5
|
"""Base exception type for fal Serverless related flows and APIs."""
|
|
6
|
-
|
|
7
|
-
message: str
|
|
8
|
-
|
|
9
|
-
hint: str | None
|
|
10
|
-
|
|
11
|
-
def __init__(self, message: str, hint: str | None = None) -> None:
|
|
12
|
-
self.message = message
|
|
13
|
-
self.hint = hint
|
|
14
|
-
super().__init__(message)
|
|
15
|
-
|
|
16
|
-
def __str__(self) -> str:
|
|
17
|
-
return self.message + (f"\nHint: {self.hint}" if self.hint else "")
|
|
6
|
+
pass
|
fal/exceptions/auth.py
CHANGED
|
@@ -4,10 +4,8 @@ from ._base import FalServerlessException
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class UnauthenticatedException(FalServerlessException):
|
|
7
|
-
"""Exception that indicates that"""
|
|
8
|
-
|
|
9
7
|
def __init__(self) -> None:
|
|
10
8
|
super().__init__(
|
|
11
|
-
|
|
12
|
-
|
|
9
|
+
"You must be authenticated. "
|
|
10
|
+
"Login via `fal auth login` or make sure to setup fal keys correctly."
|
|
13
11
|
)
|
fal/exceptions/handlers.py
CHANGED
|
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
|
3
3
|
from typing import TYPE_CHECKING, Generic, TypeVar
|
|
4
4
|
|
|
5
5
|
from grpc import Call as RpcCall
|
|
6
|
-
from rich.markdown import Markdown
|
|
7
6
|
|
|
8
7
|
from fal.console import console
|
|
9
8
|
from fal.console.icons import CROSS_ICON
|
|
@@ -11,9 +10,8 @@ from fal.console.icons import CROSS_ICON
|
|
|
11
10
|
if TYPE_CHECKING:
|
|
12
11
|
from fal.api import UserFunctionException
|
|
13
12
|
|
|
14
|
-
from ._base import FalServerlessException
|
|
15
13
|
|
|
16
|
-
ExceptionType = TypeVar("ExceptionType")
|
|
14
|
+
ExceptionType = TypeVar("ExceptionType", bound=BaseException)
|
|
17
15
|
|
|
18
16
|
|
|
19
17
|
class BaseExceptionHandler(Generic[ExceptionType]):
|
|
@@ -23,20 +21,11 @@ class BaseExceptionHandler(Generic[ExceptionType]):
|
|
|
23
21
|
return True
|
|
24
22
|
|
|
25
23
|
def handle(self, exception: ExceptionType):
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def should_handle(self, exception: Exception) -> bool:
|
|
33
|
-
return isinstance(exception, FalServerlessException)
|
|
34
|
-
|
|
35
|
-
def handle(self, exception: FalServerlessException):
|
|
36
|
-
console.print(f"{CROSS_ICON} {exception.message}")
|
|
37
|
-
if exception.hint is not None:
|
|
38
|
-
console.print(Markdown(f"**Hint:** {exception.hint}"))
|
|
39
|
-
console.print()
|
|
24
|
+
msg = f"{CROSS_ICON} {str(exception)}"
|
|
25
|
+
cause = exception.__cause__
|
|
26
|
+
if cause is not None:
|
|
27
|
+
msg += f": {str(cause)}"
|
|
28
|
+
console.print(msg)
|
|
40
29
|
|
|
41
30
|
|
|
42
31
|
class GrpcExceptionHandler(BaseExceptionHandler[RpcCall]):
|
|
@@ -51,9 +40,9 @@ class GrpcExceptionHandler(BaseExceptionHandler[RpcCall]):
|
|
|
51
40
|
|
|
52
41
|
class UserFunctionExceptionHandler(BaseExceptionHandler["UserFunctionException"]):
|
|
53
42
|
def should_handle(self, exception: Exception) -> bool:
|
|
54
|
-
from fal.api import UserFunctionException
|
|
43
|
+
from fal.api import UserFunctionException
|
|
55
44
|
|
|
56
|
-
return
|
|
45
|
+
return isinstance(exception, UserFunctionException)
|
|
57
46
|
|
|
58
47
|
def handle(self, exception: UserFunctionException):
|
|
59
48
|
import rich
|
fal/logging/isolate.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import sys
|
|
3
4
|
from datetime import datetime, timezone
|
|
4
5
|
|
|
5
|
-
from isolate.logs import Log, LogLevel
|
|
6
|
+
from isolate.logs import Log, LogLevel, LogSource
|
|
6
7
|
from structlog.dev import ConsoleRenderer
|
|
7
8
|
from structlog.typing import EventDict
|
|
8
9
|
|
|
@@ -21,6 +22,12 @@ class IsolateLogPrinter:
|
|
|
21
22
|
def print(self, log: Log):
|
|
22
23
|
if log.level < LogLevel.INFO and not self.debug:
|
|
23
24
|
return
|
|
25
|
+
|
|
26
|
+
if log.source == LogSource.USER:
|
|
27
|
+
stream = sys.stderr if log.level == LogLevel.STDERR else sys.stdout
|
|
28
|
+
print(log.message, file=stream)
|
|
29
|
+
return
|
|
30
|
+
|
|
24
31
|
level = str(log.level)
|
|
25
32
|
|
|
26
33
|
if hasattr(log, "timestamp"):
|
|
@@ -44,21 +51,3 @@ class IsolateLogPrinter:
|
|
|
44
51
|
# Use structlog processors to get consistent output with local logs
|
|
45
52
|
message = _renderer.__call__(logger={}, name=level, event_dict=event)
|
|
46
53
|
print(message)
|
|
47
|
-
|
|
48
|
-
def print_dict(self, log: dict):
|
|
49
|
-
level = LogLevel[log["level"]]
|
|
50
|
-
if level < LogLevel.INFO and not self.debug:
|
|
51
|
-
return
|
|
52
|
-
if "timestamp" in log.keys():
|
|
53
|
-
timestamp = log["timestamp"]
|
|
54
|
-
else:
|
|
55
|
-
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
|
56
|
-
|
|
57
|
-
event: EventDict = {
|
|
58
|
-
"event": log["message"],
|
|
59
|
-
"level": log["level"],
|
|
60
|
-
"timestamp": timestamp[:-3],
|
|
61
|
-
}
|
|
62
|
-
|
|
63
|
-
message = _renderer.__call__(logger={}, name=log["level"], event_dict=event)
|
|
64
|
-
print(message)
|
fal/logging/user.py
CHANGED
|
@@ -12,7 +12,7 @@ def add_user_id(
|
|
|
12
12
|
user_id: str | None = None
|
|
13
13
|
try:
|
|
14
14
|
user_id = USER.info.get("sub")
|
|
15
|
-
except:
|
|
15
|
+
except Exception:
|
|
16
16
|
# logs are fail-safe, so any exception is safe to ignore
|
|
17
17
|
# this is expected to happen only when user is logged out
|
|
18
18
|
# or there's no internet connection
|
fal/sdk.py
CHANGED
|
@@ -14,7 +14,7 @@ from isolate.server.interface import from_grpc, to_serialized_object, to_struct
|
|
|
14
14
|
|
|
15
15
|
import isolate_proto
|
|
16
16
|
from fal import flags
|
|
17
|
-
from fal._serialization import
|
|
17
|
+
from fal._serialization import patch_pickle
|
|
18
18
|
from fal.auth import USER, key_credentials
|
|
19
19
|
from fal.logging import get_logger
|
|
20
20
|
from fal.logging.trace import TraceContextInterceptor
|
|
@@ -24,14 +24,14 @@ ResultT = TypeVar("ResultT")
|
|
|
24
24
|
InputT = TypeVar("InputT")
|
|
25
25
|
UNSET = object()
|
|
26
26
|
|
|
27
|
-
_DEFAULT_SERIALIZATION_METHOD = "
|
|
27
|
+
_DEFAULT_SERIALIZATION_METHOD = "cloudpickle"
|
|
28
28
|
FAL_SERVERLESS_DEFAULT_KEEP_ALIVE = 10
|
|
29
29
|
FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING = 1
|
|
30
30
|
FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY = 0
|
|
31
31
|
|
|
32
32
|
logger = get_logger(__name__)
|
|
33
33
|
|
|
34
|
-
|
|
34
|
+
patch_pickle()
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
class ServerCredentials:
|
fal/toolkit/__init__.py
CHANGED
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from fal.toolkit.file import CompressedFile, File
|
|
4
4
|
from fal.toolkit.image.image import Image, ImageSizeInput, get_image_size
|
|
5
|
-
from fal.toolkit.mainify import mainify
|
|
6
5
|
from fal.toolkit.optimize import optimize
|
|
7
6
|
from fal.toolkit.utils import (
|
|
8
7
|
FAL_MODEL_WEIGHTS_DIR,
|
|
@@ -19,7 +18,6 @@ __all__ = [
|
|
|
19
18
|
"Image",
|
|
20
19
|
"ImageSizeInput",
|
|
21
20
|
"get_image_size",
|
|
22
|
-
"mainify",
|
|
23
21
|
"optimize",
|
|
24
22
|
"FAL_MODEL_WEIGHTS_DIR",
|
|
25
23
|
"FAL_PERSISTENT_DIR",
|
fal/toolkit/exceptions.py
CHANGED
|
@@ -1,14 +1,9 @@
|
|
|
1
|
-
from fal.toolkit.mainify import mainify
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
@mainify
|
|
5
1
|
class FalTookitException(Exception):
|
|
6
2
|
"""Base exception for all toolkit exceptions"""
|
|
7
3
|
|
|
8
4
|
pass
|
|
9
5
|
|
|
10
6
|
|
|
11
|
-
@mainify
|
|
12
7
|
class FileUploadException(FalTookitException):
|
|
13
8
|
"""Raised when file upload fails"""
|
|
14
9
|
|
fal/toolkit/file/__init__.py
CHANGED
fal/toolkit/file/file.py
CHANGED
|
@@ -1,13 +1,23 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import shutil
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from tempfile import NamedTemporaryFile,
|
|
5
|
-
from typing import
|
|
5
|
+
from tempfile import NamedTemporaryFile, mkdtemp
|
|
6
|
+
from typing import Callable, Optional, Any
|
|
6
7
|
from urllib.parse import urlparse
|
|
7
8
|
from zipfile import ZipFile
|
|
8
9
|
|
|
9
|
-
|
|
10
|
-
|
|
10
|
+
import pydantic
|
|
11
|
+
|
|
12
|
+
# https://github.com/pydantic/pydantic/pull/2573
|
|
13
|
+
if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."):
|
|
14
|
+
IS_PYDANTIC_V2 = False
|
|
15
|
+
else:
|
|
16
|
+
from pydantic_core import core_schema, CoreSchema
|
|
17
|
+
from pydantic import GetCoreSchemaHandler
|
|
18
|
+
IS_PYDANTIC_V2 = True
|
|
19
|
+
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
11
21
|
|
|
12
22
|
from fal.toolkit.file.providers.fal import (
|
|
13
23
|
FalCDNFileRepository,
|
|
@@ -17,7 +27,6 @@ from fal.toolkit.file.providers.fal import (
|
|
|
17
27
|
from fal.toolkit.file.providers.gcp import GoogleStorageRepository
|
|
18
28
|
from fal.toolkit.file.providers.r2 import R2Repository
|
|
19
29
|
from fal.toolkit.file.types import FileData, FileRepository, RepositoryId
|
|
20
|
-
from fal.toolkit.mainify import mainify
|
|
21
30
|
from fal.toolkit.utils.download_utils import download_file
|
|
22
31
|
|
|
23
32
|
FileRepositoryFactory = Callable[[], FileRepository]
|
|
@@ -42,59 +51,47 @@ get_builtin_repository.__module__ = "__main__"
|
|
|
42
51
|
DEFAULT_REPOSITORY: FileRepository | RepositoryId = "fal"
|
|
43
52
|
|
|
44
53
|
|
|
45
|
-
@mainify
|
|
46
54
|
class File(BaseModel):
|
|
47
55
|
# public properties
|
|
48
|
-
_file_data: FileData = PrivateAttr()
|
|
49
56
|
url: str = Field(
|
|
50
57
|
description="The URL where the file can be downloaded from.",
|
|
51
58
|
)
|
|
52
59
|
content_type: Optional[str] = Field(
|
|
53
|
-
description="The mime type of the file.",
|
|
60
|
+
None, description="The mime type of the file.",
|
|
54
61
|
examples=["image/png"],
|
|
55
62
|
)
|
|
56
63
|
file_name: Optional[str] = Field(
|
|
57
|
-
description="The name of the file. It will be auto-generated if not provided.",
|
|
64
|
+
None, description="The name of the file. It will be auto-generated if not provided.",
|
|
58
65
|
examples=["z9RV14K95DvU.png"],
|
|
59
66
|
)
|
|
60
67
|
file_size: Optional[int] = Field(
|
|
61
|
-
description="The size of the file in bytes.", examples=[4404019]
|
|
68
|
+
None, description="The size of the file in bytes.", examples=[4404019]
|
|
69
|
+
)
|
|
70
|
+
file_data: Optional[bytes] = Field(
|
|
71
|
+
None, description="File data", exclude=True, repr=False,
|
|
62
72
|
)
|
|
63
|
-
|
|
64
|
-
def __init__(self, **kwargs):
|
|
65
|
-
if "file_data" in kwargs:
|
|
66
|
-
data: FileData = kwargs.pop("file_data")
|
|
67
|
-
repository = kwargs.pop("repository", None)
|
|
68
|
-
|
|
69
|
-
repo = (
|
|
70
|
-
repository
|
|
71
|
-
if isinstance(repository, FileRepository)
|
|
72
|
-
else get_builtin_repository(repository)
|
|
73
|
-
)
|
|
74
|
-
self._file_data = data
|
|
75
|
-
|
|
76
|
-
kwargs.update(
|
|
77
|
-
{
|
|
78
|
-
"url": repo.save(data),
|
|
79
|
-
"content_type": data.content_type,
|
|
80
|
-
"file_name": data.file_name,
|
|
81
|
-
"file_size": len(data.data),
|
|
82
|
-
}
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
super().__init__(**kwargs)
|
|
86
73
|
|
|
87
74
|
# Pydantic custom validator for input type conversion
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
75
|
+
if IS_PYDANTIC_V2:
|
|
76
|
+
@classmethod
|
|
77
|
+
def __get_pydantic_core_schema__(
|
|
78
|
+
cls, source_type: Any, handler: GetCoreSchemaHandler
|
|
79
|
+
) -> CoreSchema:
|
|
80
|
+
return core_schema.no_info_before_validator_function(
|
|
81
|
+
cls.__convert_from_str,
|
|
82
|
+
handler(source_type),
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
@classmethod
|
|
86
|
+
def __get_validators__(cls):
|
|
87
|
+
yield cls.__convert_from_str
|
|
91
88
|
|
|
92
89
|
@classmethod
|
|
93
90
|
def __convert_from_str(cls, value: Any):
|
|
94
91
|
if isinstance(value, str):
|
|
95
92
|
parsed_url = urlparse(value)
|
|
96
93
|
if parsed_url.scheme not in ["http", "https", "data"]:
|
|
97
|
-
raise ValueError(
|
|
94
|
+
raise ValueError("value must be a valid URL")
|
|
98
95
|
return cls._from_url(parsed_url.geturl())
|
|
99
96
|
|
|
100
97
|
return value
|
|
@@ -119,9 +116,20 @@ class File(BaseModel):
|
|
|
119
116
|
file_name: Optional[str] = None,
|
|
120
117
|
repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
|
|
121
118
|
) -> File:
|
|
119
|
+
repo = (
|
|
120
|
+
repository
|
|
121
|
+
if isinstance(repository, FileRepository)
|
|
122
|
+
else get_builtin_repository(repository)
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
fdata = FileData(data, content_type, file_name)
|
|
126
|
+
|
|
122
127
|
return cls(
|
|
123
|
-
|
|
124
|
-
|
|
128
|
+
url=repo.save(fdata),
|
|
129
|
+
content_type=fdata.content_type,
|
|
130
|
+
file_name=fdata.file_name,
|
|
131
|
+
file_size=len(data),
|
|
132
|
+
file_data=data,
|
|
125
133
|
)
|
|
126
134
|
|
|
127
135
|
@classmethod
|
|
@@ -141,10 +149,10 @@ class File(BaseModel):
|
|
|
141
149
|
)
|
|
142
150
|
|
|
143
151
|
def as_bytes(self) -> bytes:
|
|
144
|
-
if
|
|
152
|
+
if self.file_data is None:
|
|
145
153
|
raise ValueError("File has not been downloaded")
|
|
146
154
|
|
|
147
|
-
return self.
|
|
155
|
+
return self.file_data
|
|
148
156
|
|
|
149
157
|
def save(self, path: str | Path, overwrite: bool = False) -> Path:
|
|
150
158
|
file_path = Path(path).resolve()
|
|
@@ -158,37 +166,32 @@ class File(BaseModel):
|
|
|
158
166
|
return file_path
|
|
159
167
|
|
|
160
168
|
|
|
161
|
-
@mainify
|
|
162
169
|
class CompressedFile(File):
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
def __init__(self, **kwargs):
|
|
166
|
-
super().__init__(**kwargs)
|
|
167
|
-
self._extract_dir = None
|
|
170
|
+
extract_dir: Optional[str] = Field(default=None, exclude=True, repr=False)
|
|
168
171
|
|
|
169
172
|
def __iter__(self):
|
|
170
|
-
if not self.
|
|
173
|
+
if not self.extract_dir:
|
|
171
174
|
self._extract_files()
|
|
172
175
|
|
|
173
|
-
files = Path(self.
|
|
176
|
+
files = Path(self.extract_dir).iterdir() # type: ignore
|
|
174
177
|
return iter(files)
|
|
175
178
|
|
|
176
179
|
def _extract_files(self):
|
|
177
|
-
self.
|
|
180
|
+
self.extract_dir = mkdtemp()
|
|
178
181
|
|
|
179
182
|
with NamedTemporaryFile() as temp_file:
|
|
180
183
|
file_path = temp_file.name
|
|
181
184
|
self.save(file_path, overwrite=True)
|
|
182
185
|
|
|
183
186
|
with ZipFile(file_path) as zip_file:
|
|
184
|
-
zip_file.extractall(self.
|
|
187
|
+
zip_file.extractall(self.extract_dir)
|
|
185
188
|
|
|
186
189
|
def glob(self, pattern: str):
|
|
187
|
-
if not self.
|
|
190
|
+
if not self.extract_dir:
|
|
188
191
|
self._extract_files()
|
|
189
192
|
|
|
190
|
-
return Path(self.
|
|
193
|
+
return Path(self.extract_dir).glob(pattern) # type: ignore
|
|
191
194
|
|
|
192
195
|
def __del__(self):
|
|
193
|
-
if self.
|
|
194
|
-
self.
|
|
196
|
+
if self.extract_dir:
|
|
197
|
+
shutil.rmtree(self.extract_dir)
|
|
@@ -10,12 +10,10 @@ from urllib.request import Request, urlopen
|
|
|
10
10
|
from fal.auth import key_credentials
|
|
11
11
|
from fal.toolkit.exceptions import FileUploadException
|
|
12
12
|
from fal.toolkit.file.types import FileData, FileRepository
|
|
13
|
-
from fal.toolkit.mainify import mainify
|
|
14
13
|
|
|
15
|
-
_FAL_CDN = "https://fal
|
|
14
|
+
_FAL_CDN = "https://fal.media"
|
|
16
15
|
|
|
17
16
|
|
|
18
|
-
@mainify
|
|
19
17
|
@dataclass
|
|
20
18
|
class FalFileRepository(FileRepository):
|
|
21
19
|
def save(self, file: FileData) -> str:
|
|
@@ -27,7 +25,7 @@ class FalFileRepository(FileRepository):
|
|
|
27
25
|
headers = {
|
|
28
26
|
"Authorization": f"Key {key_id}:{key_secret}",
|
|
29
27
|
"Accept": "application/json",
|
|
30
|
-
"Content-Type":
|
|
28
|
+
"Content-Type": "application/json",
|
|
31
29
|
}
|
|
32
30
|
|
|
33
31
|
grpc_host = os.environ.get("FAL_HOST", "api.alpha.fal.ai")
|
|
@@ -70,14 +68,12 @@ class FalFileRepository(FileRepository):
|
|
|
70
68
|
return
|
|
71
69
|
|
|
72
70
|
|
|
73
|
-
@mainify
|
|
74
71
|
@dataclass
|
|
75
72
|
class InMemoryRepository(FileRepository):
|
|
76
73
|
def save(self, file: FileData) -> str:
|
|
77
74
|
return f'data:{file.content_type};base64,{b64encode(file.data).decode("utf-8")}'
|
|
78
75
|
|
|
79
76
|
|
|
80
|
-
@mainify
|
|
81
77
|
@dataclass
|
|
82
78
|
class FalCDNFileRepository(FileRepository):
|
|
83
79
|
def save(self, file: FileData) -> str:
|
|
@@ -6,12 +6,10 @@ import os
|
|
|
6
6
|
from dataclasses import dataclass
|
|
7
7
|
|
|
8
8
|
from fal.toolkit.file.types import FileData, FileRepository
|
|
9
|
-
from fal.toolkit.mainify import mainify
|
|
10
9
|
|
|
11
10
|
DEFAULT_URL_TIMEOUT = 60 * 15 # 15 minutes
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
@mainify
|
|
15
13
|
@dataclass
|
|
16
14
|
class GoogleStorageRepository(FileRepository):
|
|
17
15
|
bucket_name: str = "fal_file_storage"
|
fal/toolkit/file/providers/r2.py
CHANGED
|
@@ -6,12 +6,10 @@ from dataclasses import dataclass
|
|
|
6
6
|
from io import BytesIO
|
|
7
7
|
|
|
8
8
|
from fal.toolkit.file.types import FileData, FileRepository
|
|
9
|
-
from fal.toolkit.mainify import mainify
|
|
10
9
|
|
|
11
10
|
DEFAULT_URL_TIMEOUT = 60 * 15 # 15 minutes
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
@mainify
|
|
15
13
|
@dataclass
|
|
16
14
|
class R2Repository(FileRepository):
|
|
17
15
|
bucket_name: str = "fal_file_storage"
|
fal/toolkit/file/types.py
CHANGED
|
@@ -5,10 +5,7 @@ from mimetypes import guess_extension, guess_type
|
|
|
5
5
|
from typing import Literal
|
|
6
6
|
from uuid import uuid4
|
|
7
7
|
|
|
8
|
-
from fal.toolkit.mainify import mainify
|
|
9
8
|
|
|
10
|
-
|
|
11
|
-
@mainify
|
|
12
9
|
class FileData:
|
|
13
10
|
data: bytes
|
|
14
11
|
content_type: str
|
|
@@ -34,7 +31,6 @@ class FileData:
|
|
|
34
31
|
RepositoryId = Literal["fal", "in_memory", "gcp_storage", "r2", "cdn"]
|
|
35
32
|
|
|
36
33
|
|
|
37
|
-
@mainify
|
|
38
34
|
@dataclass
|
|
39
35
|
class FileRepository:
|
|
40
36
|
def save(self, data: FileData) -> str:
|
fal/toolkit/image/__init__.py
CHANGED