fal 1.2.1__py3-none-any.whl → 1.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fal might be problematic. Click here for more details.
- fal/__main__.py +3 -1
- fal/_fal_version.py +2 -2
- fal/api.py +88 -20
- fal/app.py +221 -27
- fal/apps.py +147 -3
- fal/auth/__init__.py +50 -2
- fal/cli/_utils.py +40 -0
- fal/cli/apps.py +5 -3
- fal/cli/create.py +26 -0
- fal/cli/deploy.py +97 -16
- fal/cli/main.py +2 -2
- fal/cli/parser.py +11 -7
- fal/cli/run.py +12 -1
- fal/cli/runners.py +44 -0
- fal/config.py +23 -0
- fal/container.py +1 -1
- fal/exceptions/__init__.py +7 -1
- fal/exceptions/_base.py +51 -0
- fal/exceptions/_cuda.py +44 -0
- fal/files.py +81 -0
- fal/sdk.py +67 -6
- fal/toolkit/file/file.py +103 -13
- fal/toolkit/file/providers/fal.py +572 -24
- fal/toolkit/file/providers/gcp.py +8 -1
- fal/toolkit/file/providers/r2.py +8 -1
- fal/toolkit/file/providers/s3.py +80 -0
- fal/toolkit/file/types.py +28 -3
- fal/toolkit/image/__init__.py +71 -0
- fal/toolkit/image/image.py +25 -2
- fal/toolkit/image/nsfw_filter/__init__.py +11 -0
- fal/toolkit/image/nsfw_filter/env.py +9 -0
- fal/toolkit/image/nsfw_filter/inference.py +77 -0
- fal/toolkit/image/nsfw_filter/model.py +18 -0
- fal/toolkit/image/nsfw_filter/requirements.txt +4 -0
- fal/toolkit/image/safety_checker.py +107 -0
- fal/toolkit/types.py +140 -0
- fal/toolkit/utils/download_utils.py +4 -0
- fal/toolkit/utils/retry.py +45 -0
- fal/utils.py +20 -4
- fal/workflows.py +10 -4
- {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/METADATA +47 -40
- {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/RECORD +45 -30
- {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/WHEEL +1 -1
- {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/entry_points.txt +0 -0
- {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/top_level.txt +0 -0
fal/toolkit/file/providers/r2.py
CHANGED
|
@@ -6,8 +6,10 @@ import posixpath
|
|
|
6
6
|
import uuid
|
|
7
7
|
from dataclasses import dataclass
|
|
8
8
|
from io import BytesIO
|
|
9
|
+
from typing import Optional
|
|
9
10
|
|
|
10
11
|
from fal.toolkit.file.types import FileData, FileRepository
|
|
12
|
+
from fal.toolkit.utils.retry import retry
|
|
11
13
|
|
|
12
14
|
DEFAULT_URL_TIMEOUT = 60 * 15 # 15 minutes
|
|
13
15
|
|
|
@@ -67,7 +69,12 @@ class R2Repository(FileRepository):
|
|
|
67
69
|
|
|
68
70
|
return self._bucket
|
|
69
71
|
|
|
70
|
-
|
|
72
|
+
@retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
|
|
73
|
+
def save(
|
|
74
|
+
self,
|
|
75
|
+
data: FileData,
|
|
76
|
+
object_lifecycle_preference: Optional[dict[str, str]] = None,
|
|
77
|
+
) -> str:
|
|
71
78
|
destination_path = posixpath.join(
|
|
72
79
|
self.key,
|
|
73
80
|
f"{uuid.uuid4().hex}_{data.file_name}",
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import posixpath
|
|
5
|
+
import uuid
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from io import BytesIO
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
from fal.toolkit.file.types import FileData, FileRepository
|
|
11
|
+
from fal.toolkit.utils.retry import retry
|
|
12
|
+
|
|
13
|
+
DEFAULT_URL_TIMEOUT = 60 * 15 # 15 minutes
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class S3Repository(FileRepository):
|
|
18
|
+
bucket_name: str = "fal_file_storage"
|
|
19
|
+
url_expiration: int = DEFAULT_URL_TIMEOUT
|
|
20
|
+
aws_access_key_id: str | None = None
|
|
21
|
+
aws_secret_access_key: str | None = None
|
|
22
|
+
|
|
23
|
+
_s3_client = None
|
|
24
|
+
|
|
25
|
+
def __post_init__(self):
|
|
26
|
+
try:
|
|
27
|
+
import boto3
|
|
28
|
+
from botocore.client import Config
|
|
29
|
+
except ImportError:
|
|
30
|
+
raise Exception("boto3 is not installed")
|
|
31
|
+
|
|
32
|
+
if self.aws_access_key_id is None:
|
|
33
|
+
self.aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID")
|
|
34
|
+
if self.aws_access_key_id is None:
|
|
35
|
+
raise Exception("AWS_ACCESS_KEY_ID environment variable is not set")
|
|
36
|
+
|
|
37
|
+
if self.aws_secret_access_key is None:
|
|
38
|
+
self.aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
|
|
39
|
+
if self.aws_secret_access_key is None:
|
|
40
|
+
raise Exception("AWS_SECRET_ACCESS_KEY environment variable is not set")
|
|
41
|
+
|
|
42
|
+
self._s3_client = boto3.client(
|
|
43
|
+
"s3",
|
|
44
|
+
aws_access_key_id=self.aws_access_key_id,
|
|
45
|
+
aws_secret_access_key=self.aws_secret_access_key,
|
|
46
|
+
config=Config(signature_version="s3v4"),
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def storage_client(self):
|
|
51
|
+
if self._s3_client is None:
|
|
52
|
+
raise Exception("S3 client is not initialized")
|
|
53
|
+
|
|
54
|
+
return self._s3_client
|
|
55
|
+
|
|
56
|
+
@retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
|
|
57
|
+
def save(
|
|
58
|
+
self,
|
|
59
|
+
data: FileData,
|
|
60
|
+
object_lifecycle_preference: Optional[dict[str, str]] = None,
|
|
61
|
+
key: Optional[str] = None,
|
|
62
|
+
) -> str:
|
|
63
|
+
destination_path = posixpath.join(
|
|
64
|
+
key or "",
|
|
65
|
+
f"{uuid.uuid4().hex}_{data.file_name}",
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
self.storage_client.upload_fileobj(
|
|
69
|
+
BytesIO(data.data),
|
|
70
|
+
self.bucket_name,
|
|
71
|
+
destination_path,
|
|
72
|
+
ExtraArgs={"ContentType": data.content_type},
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
public_url = self.storage_client.generate_presigned_url(
|
|
76
|
+
ClientMethod="get_object",
|
|
77
|
+
Params={"Bucket": self.bucket_name, "Key": destination_path},
|
|
78
|
+
ExpiresIn=self.url_expiration,
|
|
79
|
+
)
|
|
80
|
+
return public_url
|
fal/toolkit/file/types.py
CHANGED
|
@@ -2,7 +2,8 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from mimetypes import guess_extension, guess_type
|
|
5
|
-
from
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Literal, Optional
|
|
6
7
|
from uuid import uuid4
|
|
7
8
|
|
|
8
9
|
|
|
@@ -28,10 +29,34 @@ class FileData:
|
|
|
28
29
|
self.file_name = file_name
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
RepositoryId = Literal[
|
|
32
|
+
RepositoryId = Literal[
|
|
33
|
+
"fal", "fal_v2", "fal_v3", "in_memory", "gcp_storage", "r2", "cdn"
|
|
34
|
+
]
|
|
32
35
|
|
|
33
36
|
|
|
34
37
|
@dataclass
|
|
35
38
|
class FileRepository:
|
|
36
|
-
def save(
|
|
39
|
+
def save(
|
|
40
|
+
self,
|
|
41
|
+
data: FileData,
|
|
42
|
+
object_lifecycle_preference: Optional[dict[str, str]] = None,
|
|
43
|
+
) -> str:
|
|
37
44
|
raise NotImplementedError()
|
|
45
|
+
|
|
46
|
+
def save_file(
|
|
47
|
+
self,
|
|
48
|
+
file_path: str | Path,
|
|
49
|
+
content_type: str,
|
|
50
|
+
multipart: bool | None = None,
|
|
51
|
+
multipart_threshold: int | None = None,
|
|
52
|
+
multipart_chunk_size: int | None = None,
|
|
53
|
+
multipart_max_concurrency: int | None = None,
|
|
54
|
+
object_lifecycle_preference: Optional[dict[str, str]] = None,
|
|
55
|
+
) -> tuple[str, FileData | None]:
|
|
56
|
+
if multipart:
|
|
57
|
+
raise NotImplementedError()
|
|
58
|
+
|
|
59
|
+
with open(file_path, "rb") as fobj:
|
|
60
|
+
data = FileData(fobj.read(), content_type, Path(file_path).name)
|
|
61
|
+
|
|
62
|
+
return self.save(data, object_lifecycle_preference), data
|
fal/toolkit/image/__init__.py
CHANGED
|
@@ -1,3 +1,74 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import urllib.request
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
3
7
|
from .image import * # noqa: F403
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from PIL.Image import Image as PILImage
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def filter_by(
|
|
14
|
+
has_nsfw_concepts: list[bool],
|
|
15
|
+
images: list[PILImage],
|
|
16
|
+
) -> list[PILImage]:
|
|
17
|
+
from PIL import Image as PILImage
|
|
18
|
+
|
|
19
|
+
return [
|
|
20
|
+
(
|
|
21
|
+
PILImage.new("RGB", (image.width, image.height), (0, 0, 0))
|
|
22
|
+
if has_nsfw
|
|
23
|
+
else image
|
|
24
|
+
)
|
|
25
|
+
for image, has_nsfw in zip(images, has_nsfw_concepts)
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def preprocess_image(image_pil, convert_to_rgb=True, fix_orientation=True):
|
|
30
|
+
from PIL import ImageOps, ImageSequence
|
|
31
|
+
|
|
32
|
+
# For MPO (multi picture object) format images, we only need the first image
|
|
33
|
+
images = []
|
|
34
|
+
for image in ImageSequence.Iterator(image_pil):
|
|
35
|
+
img = image
|
|
36
|
+
|
|
37
|
+
if convert_to_rgb:
|
|
38
|
+
img = img.convert("RGB")
|
|
39
|
+
|
|
40
|
+
if fix_orientation:
|
|
41
|
+
img = ImageOps.exif_transpose(img)
|
|
42
|
+
|
|
43
|
+
images.append(img)
|
|
44
|
+
|
|
45
|
+
break
|
|
46
|
+
|
|
47
|
+
return images[0]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@lru_cache(maxsize=64)
|
|
51
|
+
def read_image_from_url(
|
|
52
|
+
url: str, convert_to_rgb: bool = True, fix_orientation: bool = True
|
|
53
|
+
):
|
|
54
|
+
from fastapi import HTTPException
|
|
55
|
+
from PIL import Image
|
|
56
|
+
|
|
57
|
+
TEMP_HEADERS = {
|
|
58
|
+
"User-Agent": (
|
|
59
|
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10.8; rv:21.0) "
|
|
60
|
+
"Gecko/20100101 Firefox/21.0"
|
|
61
|
+
),
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
request = urllib.request.Request(url, headers=TEMP_HEADERS)
|
|
66
|
+
response = urllib.request.urlopen(request)
|
|
67
|
+
image_pil = Image.open(response)
|
|
68
|
+
except Exception:
|
|
69
|
+
import traceback
|
|
70
|
+
|
|
71
|
+
traceback.print_exc()
|
|
72
|
+
raise HTTPException(422, f"Could not load image from url: {url}")
|
|
73
|
+
|
|
74
|
+
return preprocess_image(image_pil, convert_to_rgb, fix_orientation)
|
fal/toolkit/image/image.py
CHANGED
|
@@ -4,9 +4,10 @@ import io
|
|
|
4
4
|
from tempfile import NamedTemporaryFile
|
|
5
5
|
from typing import TYPE_CHECKING, Literal, Optional, Union
|
|
6
6
|
|
|
7
|
+
from fastapi import Request
|
|
7
8
|
from pydantic import BaseModel, Field
|
|
8
9
|
|
|
9
|
-
from fal.toolkit.file.file import DEFAULT_REPOSITORY, File
|
|
10
|
+
from fal.toolkit.file.file import DEFAULT_REPOSITORY, FALLBACK_REPOSITORY, File
|
|
10
11
|
from fal.toolkit.file.types import FileRepository, RepositoryId
|
|
11
12
|
from fal.toolkit.utils.download_utils import _download_file_python
|
|
12
13
|
|
|
@@ -79,12 +80,18 @@ class Image(File):
|
|
|
79
80
|
size: ImageSize | None = None,
|
|
80
81
|
file_name: str | None = None,
|
|
81
82
|
repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
|
|
83
|
+
fallback_repository: Optional[
|
|
84
|
+
FileRepository | RepositoryId
|
|
85
|
+
] = FALLBACK_REPOSITORY,
|
|
86
|
+
request: Optional[Request] = None,
|
|
82
87
|
) -> Image:
|
|
83
88
|
obj = super().from_bytes(
|
|
84
89
|
data,
|
|
85
90
|
content_type=f"image/{format}",
|
|
86
91
|
file_name=file_name,
|
|
87
92
|
repository=repository,
|
|
93
|
+
fallback_repository=fallback_repository,
|
|
94
|
+
request=request,
|
|
88
95
|
)
|
|
89
96
|
obj.width = size.width if size else None
|
|
90
97
|
obj.height = size.height if size else None
|
|
@@ -97,6 +104,10 @@ class Image(File):
|
|
|
97
104
|
format: ImageFormat | None = None,
|
|
98
105
|
file_name: str | None = None,
|
|
99
106
|
repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
|
|
107
|
+
fallback_repository: Optional[
|
|
108
|
+
FileRepository | RepositoryId
|
|
109
|
+
] = FALLBACK_REPOSITORY,
|
|
110
|
+
request: Optional[Request] = None,
|
|
100
111
|
) -> Image:
|
|
101
112
|
size = ImageSize(width=pil_image.width, height=pil_image.height)
|
|
102
113
|
if format is None:
|
|
@@ -110,12 +121,24 @@ class Image(File):
|
|
|
110
121
|
# enough result quickly to utilize the underlying resources
|
|
111
122
|
# efficiently.
|
|
112
123
|
saving_options["compress_level"] = 1
|
|
124
|
+
elif format == "jpeg":
|
|
125
|
+
# JPEG quality is set to 95 by default, which is a good balance
|
|
126
|
+
# between file size and image quality.
|
|
127
|
+
saving_options["quality"] = 95
|
|
113
128
|
|
|
114
129
|
with io.BytesIO() as f:
|
|
115
130
|
pil_image.save(f, format=format, **saving_options)
|
|
116
131
|
raw_image = f.getvalue()
|
|
117
132
|
|
|
118
|
-
return cls.from_bytes(
|
|
133
|
+
return cls.from_bytes(
|
|
134
|
+
raw_image,
|
|
135
|
+
format,
|
|
136
|
+
size,
|
|
137
|
+
file_name,
|
|
138
|
+
repository,
|
|
139
|
+
fallback_repository=fallback_repository,
|
|
140
|
+
request=request,
|
|
141
|
+
)
|
|
119
142
|
|
|
120
143
|
def to_pil(self, mode: str = "RGB") -> PILImage.Image:
|
|
121
144
|
try:
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
|
|
3
|
+
import fal
|
|
4
|
+
from fal.toolkit.image import read_image_from_url
|
|
5
|
+
|
|
6
|
+
from .env import get_requirements
|
|
7
|
+
from .model import get_model
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class NSFWImageDetectionInput(BaseModel):
|
|
11
|
+
image_url: str = Field(
|
|
12
|
+
description="Input image url.",
|
|
13
|
+
examples=[
|
|
14
|
+
"https://storage.googleapis.com/falserverless/model_tests/remove_background/elephant.jpg",
|
|
15
|
+
],
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class NSFWImageDetectionOutput(BaseModel):
|
|
20
|
+
nsfw_probability: float = Field(
|
|
21
|
+
description="The probability of the image being NSFW.",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def check_nsfw_content(pil_image: object):
|
|
26
|
+
import torch
|
|
27
|
+
|
|
28
|
+
model, processor = get_model()
|
|
29
|
+
|
|
30
|
+
with torch.no_grad():
|
|
31
|
+
inputs = processor(images=pil_image, return_tensors="pt")
|
|
32
|
+
outputs = model(**inputs)
|
|
33
|
+
logits = outputs.logits.squeeze() # Remove batch dimension to simplify indexing
|
|
34
|
+
|
|
35
|
+
# Apply softmax to convert logits to probabilities
|
|
36
|
+
probabilities = torch.softmax(logits, dim=0)
|
|
37
|
+
|
|
38
|
+
nsfw_class_index = model.config.label2id.get(
|
|
39
|
+
"nsfw", None
|
|
40
|
+
) # Replace "NSFW" with the exact class name if different
|
|
41
|
+
|
|
42
|
+
# Validate that NSFW class index is found
|
|
43
|
+
if nsfw_class_index is not None:
|
|
44
|
+
nsfw_probability = probabilities[int(nsfw_class_index)].item()
|
|
45
|
+
return nsfw_probability
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError("NSFW class not found in model output.")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def run_nsfw_estimation(
|
|
51
|
+
input: NSFWImageDetectionInput,
|
|
52
|
+
) -> NSFWImageDetectionOutput:
|
|
53
|
+
img = read_image_from_url(input.image_url)
|
|
54
|
+
nsfw_probability = check_nsfw_content(img)
|
|
55
|
+
|
|
56
|
+
return NSFWImageDetectionOutput(nsfw_probability=nsfw_probability)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@fal.function(
|
|
60
|
+
requirements=get_requirements(),
|
|
61
|
+
machine_type="GPU-A6000",
|
|
62
|
+
serve=True,
|
|
63
|
+
)
|
|
64
|
+
def run_nsfw_estimation_on_fal(
|
|
65
|
+
input: NSFWImageDetectionInput,
|
|
66
|
+
) -> NSFWImageDetectionOutput:
|
|
67
|
+
return run_nsfw_estimation(input)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
if __name__ == "__main__":
|
|
71
|
+
local = run_nsfw_estimation_on_fal.on(serve=False)
|
|
72
|
+
result = local(
|
|
73
|
+
NSFWImageDetectionInput(
|
|
74
|
+
image_url="https://storage.googleapis.com/falserverless/model_tests/remove_background/elephant.jpg",
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
print(result)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import fal
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@fal.cached
|
|
5
|
+
def get_model():
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
from transformers import AutoModelForImageClassification, ViTImageProcessor
|
|
9
|
+
|
|
10
|
+
os.environ["TRANSFORMERS_CACHE"] = "/data/models"
|
|
11
|
+
os.environ["HF_HOME"] = "/data/models"
|
|
12
|
+
|
|
13
|
+
model = AutoModelForImageClassification.from_pretrained(
|
|
14
|
+
"Falconsai/nsfw_image_detection"
|
|
15
|
+
)
|
|
16
|
+
processor = ViTImageProcessor.from_pretrained("Falconsai/nsfw_image_detection")
|
|
17
|
+
|
|
18
|
+
return model, processor
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import fal
|
|
4
|
+
|
|
5
|
+
from . import filter_by
|
|
6
|
+
from .nsfw_filter.model import get_model
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@fal.cached
|
|
10
|
+
def load_safety_checker():
|
|
11
|
+
import torch
|
|
12
|
+
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|
13
|
+
StableDiffusionSafetyChecker,
|
|
14
|
+
)
|
|
15
|
+
from transformers import AutoFeatureExtractor
|
|
16
|
+
|
|
17
|
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
|
18
|
+
"CompVis/stable-diffusion-safety-checker",
|
|
19
|
+
torch_dtype="float16",
|
|
20
|
+
)
|
|
21
|
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
|
22
|
+
"CompVis/stable-diffusion-safety-checker",
|
|
23
|
+
torch_dtype=torch.float16,
|
|
24
|
+
).to("cuda")
|
|
25
|
+
|
|
26
|
+
return feature_extractor, safety_checker
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def run_safety_checker(
|
|
30
|
+
pil_images: list[object],
|
|
31
|
+
) -> list[bool]:
|
|
32
|
+
import numpy as np
|
|
33
|
+
import torch
|
|
34
|
+
|
|
35
|
+
feature_extractor, safety_checker = load_safety_checker()
|
|
36
|
+
|
|
37
|
+
safety_checker_input = feature_extractor(pil_images, return_tensors="pt").to("cuda")
|
|
38
|
+
|
|
39
|
+
np_image = [np.array(val) for val in pil_images]
|
|
40
|
+
|
|
41
|
+
_, has_nsfw_concept = safety_checker(
|
|
42
|
+
images=np_image,
|
|
43
|
+
clip_input=safety_checker_input.pixel_values.to(torch.float16),
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
return has_nsfw_concept
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def run_safety_checker_v2(pil_images: list, nsfw_threshold: float = 0.5) -> list[bool]:
|
|
50
|
+
import torch
|
|
51
|
+
|
|
52
|
+
model, processor = get_model()
|
|
53
|
+
|
|
54
|
+
has_nsfw_concept = []
|
|
55
|
+
|
|
56
|
+
with torch.no_grad():
|
|
57
|
+
for pil_image in pil_images:
|
|
58
|
+
inputs = processor(
|
|
59
|
+
images=pil_image.convert("RGB"),
|
|
60
|
+
return_tensors="pt",
|
|
61
|
+
)
|
|
62
|
+
outputs = model(**inputs)
|
|
63
|
+
logits = (
|
|
64
|
+
outputs.logits.squeeze()
|
|
65
|
+
) # Remove batch dimension to simplify indexing
|
|
66
|
+
|
|
67
|
+
# Apply softmax to convert logits to probabilities
|
|
68
|
+
probabilities = torch.softmax(logits, dim=0)
|
|
69
|
+
|
|
70
|
+
nsfw_class_index = model.config.label2id.get(
|
|
71
|
+
"nsfw", None
|
|
72
|
+
) # Replace "NSFW" with the exact class name if different
|
|
73
|
+
|
|
74
|
+
# Validate that NSFW class index is found
|
|
75
|
+
if nsfw_class_index is not None:
|
|
76
|
+
nsfw_probability = probabilities[int(nsfw_class_index)].item()
|
|
77
|
+
print("NSFW probability:", nsfw_probability)
|
|
78
|
+
has_nsfw_concept.append(nsfw_probability > nsfw_threshold)
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError("NSFW class not found in model output.")
|
|
81
|
+
|
|
82
|
+
return has_nsfw_concept
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def postprocess_images(
|
|
86
|
+
pil_images: list[object],
|
|
87
|
+
enable_safety_checker: bool = True,
|
|
88
|
+
safety_checker_version: int = 2,
|
|
89
|
+
) -> dict[str, Any]:
|
|
90
|
+
outputs: dict[str, list[Any]] = {
|
|
91
|
+
"images": pil_images,
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
if enable_safety_checker:
|
|
95
|
+
safety_checker_fn = (
|
|
96
|
+
run_safety_checker_v2 if safety_checker_version == 2 else run_safety_checker
|
|
97
|
+
)
|
|
98
|
+
outputs["has_nsfw_concepts"] = safety_checker_fn(pil_images) # type: ignore
|
|
99
|
+
else:
|
|
100
|
+
outputs["has_nsfw_concepts"] = [False] * len(pil_images)
|
|
101
|
+
|
|
102
|
+
outputs["images"] = filter_by(
|
|
103
|
+
outputs["has_nsfw_concepts"],
|
|
104
|
+
outputs["images"],
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
return outputs
|
fal/toolkit/types.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import tempfile
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, Generator, Union
|
|
6
|
+
|
|
7
|
+
import pydantic
|
|
8
|
+
from pydantic.utils import update_not_none
|
|
9
|
+
|
|
10
|
+
from fal.toolkit.image import read_image_from_url
|
|
11
|
+
from fal.toolkit.utils.download_utils import download_file
|
|
12
|
+
|
|
13
|
+
# https://github.com/pydantic/pydantic/pull/2573
|
|
14
|
+
if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."):
|
|
15
|
+
IS_PYDANTIC_V2 = False
|
|
16
|
+
else:
|
|
17
|
+
IS_PYDANTIC_V2 = True
|
|
18
|
+
|
|
19
|
+
MAX_DATA_URI_LENGTH = 10 * 1024 * 1024
|
|
20
|
+
MAX_HTTPS_URL_LENGTH = 2048
|
|
21
|
+
|
|
22
|
+
HTTP_URL_REGEX = (
|
|
23
|
+
r"^https:\/\/(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(?::\d{1,5})?(?:\/[^\s]*)?$"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DownloadFileMixin:
|
|
28
|
+
@contextmanager
|
|
29
|
+
def as_temp_file(self) -> Generator[Path, None, None]:
|
|
30
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
31
|
+
yield download_file(str(self), temp_dir)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class DownloadImageMixin:
|
|
35
|
+
def to_pil(self):
|
|
36
|
+
return read_image_from_url(str(self))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class DataUri(DownloadFileMixin, str):
|
|
40
|
+
if IS_PYDANTIC_V2:
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def __get_pydantic_core_schema__(cls, source_type: Any, handler) -> Any:
|
|
44
|
+
return {
|
|
45
|
+
"type": "str",
|
|
46
|
+
"pattern": "^data:",
|
|
47
|
+
"max_length": MAX_DATA_URI_LENGTH,
|
|
48
|
+
"strip_whitespace": True,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
def __get_pydantic_json_schema__(cls, core_schema, handler) -> Dict[str, Any]:
|
|
52
|
+
json_schema = handler(core_schema)
|
|
53
|
+
json_schema.update(format="data-uri")
|
|
54
|
+
return json_schema
|
|
55
|
+
else:
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def __get_validators__(cls):
|
|
59
|
+
yield cls.validate
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def validate(cls, value: Any) -> "DataUri":
|
|
63
|
+
from pydantic.validators import str_validator
|
|
64
|
+
|
|
65
|
+
value = str_validator(value)
|
|
66
|
+
value = value.strip()
|
|
67
|
+
|
|
68
|
+
if not value.startswith("data:"):
|
|
69
|
+
raise ValueError("Data URI must start with 'data:'")
|
|
70
|
+
|
|
71
|
+
if len(value) > MAX_DATA_URI_LENGTH:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"Data URI is too long. Max length is {MAX_DATA_URI_LENGTH} bytes."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return cls(value)
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
|
80
|
+
update_not_none(field_schema, format="data-uri")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class HttpsUrl(DownloadFileMixin, str):
|
|
84
|
+
if IS_PYDANTIC_V2:
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def __get_pydantic_core_schema__(cls, source_type: Any, handler) -> Any:
|
|
88
|
+
return {
|
|
89
|
+
"type": "str",
|
|
90
|
+
"pattern": HTTP_URL_REGEX,
|
|
91
|
+
"max_length": MAX_HTTPS_URL_LENGTH,
|
|
92
|
+
"strip_whitespace": True,
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
def __get_pydantic_json_schema__(cls, core_schema, handler) -> Dict[str, Any]:
|
|
96
|
+
json_schema = handler(core_schema)
|
|
97
|
+
json_schema.update(format="https-url")
|
|
98
|
+
return json_schema
|
|
99
|
+
|
|
100
|
+
else:
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
def __get_validators__(cls):
|
|
104
|
+
yield cls.validate
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def validate(cls, value: Any) -> "HttpsUrl":
|
|
108
|
+
from pydantic.validators import str_validator
|
|
109
|
+
|
|
110
|
+
value = str_validator(value)
|
|
111
|
+
value = value.strip()
|
|
112
|
+
|
|
113
|
+
if not re.match(HTTP_URL_REGEX, value):
|
|
114
|
+
raise ValueError(
|
|
115
|
+
"URL must start with 'https://' and follow the correct format."
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
if len(value) > MAX_HTTPS_URL_LENGTH:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"HTTPS URL is too long. Max length is "
|
|
121
|
+
f"{MAX_HTTPS_URL_LENGTH} characters."
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return cls(value)
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
|
128
|
+
update_not_none(field_schema, format="https-url")
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class ImageHttpsUrl(DownloadImageMixin, HttpsUrl):
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class ImageDataUri(DownloadImageMixin, DataUri):
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
FileInput = Union[HttpsUrl, DataUri]
|
|
140
|
+
ImageInput = Union[ImageHttpsUrl, ImageDataUri]
|
|
@@ -84,6 +84,9 @@ def _get_remote_file_properties(
|
|
|
84
84
|
url_path = parsed_url.path
|
|
85
85
|
file_name = Path(url_path).name or _hash_url(url)
|
|
86
86
|
|
|
87
|
+
# file name can still contain a forward slash if the server returns a relative path
|
|
88
|
+
file_name = Path(file_name).name
|
|
89
|
+
|
|
87
90
|
return file_name, content_length
|
|
88
91
|
|
|
89
92
|
|
|
@@ -159,6 +162,7 @@ def download_file(
|
|
|
159
162
|
try:
|
|
160
163
|
file_name = _get_remote_file_properties(url, request_headers)[0]
|
|
161
164
|
except Exception as e:
|
|
165
|
+
print(f"GOt error: {e}")
|
|
162
166
|
raise DownloadError(f"Failed to get remote file properties for {url}") from e
|
|
163
167
|
|
|
164
168
|
if "/" in file_name:
|