fal 1.2.2__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

fal/_fal_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '1.2.2'
16
- __version_tuple__ = version_tuple = (1, 2, 2)
15
+ __version__ = version = '1.2.4'
16
+ __version_tuple__ = version_tuple = (1, 2, 4)
fal/app.py CHANGED
@@ -3,7 +3,9 @@ from __future__ import annotations
3
3
  import inspect
4
4
  import json
5
5
  import os
6
+ import queue
6
7
  import re
8
+ import threading
7
9
  import time
8
10
  import typing
9
11
  from contextlib import asynccontextmanager, contextmanager
@@ -72,17 +74,22 @@ def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
72
74
 
73
75
 
74
76
  class EndpointClient:
75
- def __init__(self, url, endpoint, signature):
77
+ def __init__(self, url, endpoint, signature, timeout: int | None = None):
76
78
  self.url = url
77
79
  self.endpoint = endpoint
78
80
  self.signature = signature
81
+ self.timeout = timeout
79
82
 
80
83
  annotations = endpoint.__annotations__ or {}
81
84
  self.return_type = annotations.get("return") or None
82
85
 
83
86
  def __call__(self, data):
84
87
  with httpx.Client() as client:
85
- resp = client.post(self.url + self.signature.path, json=dict(data))
88
+ resp = client.post(
89
+ self.url + self.signature.path,
90
+ json=data.dict() if hasattr(data, "dict") else dict(data),
91
+ timeout=self.timeout,
92
+ )
86
93
  resp.raise_for_status()
87
94
  resp_dict = resp.json()
88
95
 
@@ -93,7 +100,12 @@ class EndpointClient:
93
100
 
94
101
 
95
102
  class AppClient:
96
- def __init__(self, cls, url):
103
+ def __init__(
104
+ self,
105
+ cls,
106
+ url,
107
+ timeout: int | None = None,
108
+ ):
97
109
  self.url = url
98
110
  self.cls = cls
99
111
 
@@ -101,19 +113,38 @@ class AppClient:
101
113
  signature = getattr(endpoint, "route_signature", None)
102
114
  if signature is None:
103
115
  continue
104
-
105
- setattr(self, name, EndpointClient(self.url, endpoint, signature))
116
+ endpoint_client = EndpointClient(
117
+ self.url,
118
+ endpoint,
119
+ signature,
120
+ timeout=timeout,
121
+ )
122
+ setattr(self, name, endpoint_client)
106
123
 
107
124
  @classmethod
108
125
  @contextmanager
109
126
  def connect(cls, app_cls):
110
127
  app = wrap_app(app_cls)
111
128
  info = app.spawn()
129
+ _shutdown_event = threading.Event()
130
+
131
+ def _print_logs():
132
+ while not _shutdown_event.is_set():
133
+ try:
134
+ log = info.logs.get(timeout=0.1)
135
+ except queue.Empty:
136
+ continue
137
+ print(log)
138
+
139
+ _log_printer = threading.Thread(target=_print_logs, daemon=True)
140
+ _log_printer.start()
141
+
112
142
  try:
113
143
  with httpx.Client() as client:
114
144
  retries = 100
115
145
  while retries:
116
146
  resp = client.get(info.url + "/health")
147
+
117
148
  if resp.is_success:
118
149
  break
119
150
  elif resp.status_code != 500:
@@ -121,9 +152,12 @@ class AppClient:
121
152
  time.sleep(0.1)
122
153
  retries -= 1
123
154
 
124
- yield cls(app_cls, info.url)
155
+ client = cls(app_cls, info.url)
156
+ yield client
125
157
  finally:
126
158
  info.stream.cancel()
159
+ _shutdown_event.set()
160
+ _log_printer.join()
127
161
 
128
162
  def health(self):
129
163
  with httpx.Client() as client:
@@ -158,7 +192,7 @@ class App(fal.api.BaseServable):
158
192
  app_name = kwargs.pop("name", None) or _to_fal_app_name(cls.__name__)
159
193
  parent_settings = getattr(cls, "host_kwargs", {})
160
194
  cls.host_kwargs = {**parent_settings, **kwargs}
161
- cls.app_name = app_name
195
+ cls.app_name = getattr(cls, "app_name", app_name)
162
196
 
163
197
  if cls.__init__ is not App.__init__:
164
198
  raise ValueError(
@@ -1,3 +1,74 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from functools import lru_cache
4
+ from typing import TYPE_CHECKING
5
+ from urllib.request import Request, urlopen
6
+
3
7
  from .image import * # noqa: F403
8
+
9
+ if TYPE_CHECKING:
10
+ from PIL.Image import Image as PILImage
11
+
12
+
13
+ def filter_by(
14
+ has_nsfw_concepts: list[bool],
15
+ images: list[PILImage],
16
+ ) -> list[PILImage]:
17
+ from PIL import Image as PILImage
18
+
19
+ return [
20
+ (
21
+ PILImage.new("RGB", (image.width, image.height), (0, 0, 0))
22
+ if has_nsfw
23
+ else image
24
+ )
25
+ for image, has_nsfw in zip(images, has_nsfw_concepts)
26
+ ]
27
+
28
+
29
+ def preprocess_image(image_pil, convert_to_rgb=True, fix_orientation=True):
30
+ from PIL import ImageOps, ImageSequence
31
+
32
+ # For MPO (multi picture object) format images, we only need the first image
33
+ images = []
34
+ for image in ImageSequence.Iterator(image_pil):
35
+ img = image
36
+
37
+ if convert_to_rgb:
38
+ img = img.convert("RGB")
39
+
40
+ if fix_orientation:
41
+ img = ImageOps.exif_transpose(img)
42
+
43
+ images.append(img)
44
+
45
+ break
46
+
47
+ return images[0]
48
+
49
+
50
+ @lru_cache(maxsize=64)
51
+ def read_image_from_url(
52
+ url: str, convert_to_rgb: bool = True, fix_orientation: bool = True
53
+ ):
54
+ from fastapi import HTTPException
55
+ from PIL import Image
56
+
57
+ TEMP_HEADERS = {
58
+ "User-Agent": (
59
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.8; rv:21.0) "
60
+ "Gecko/20100101 Firefox/21.0"
61
+ ),
62
+ }
63
+
64
+ try:
65
+ request = Request(url, headers=TEMP_HEADERS)
66
+ response = urlopen(request)
67
+ image_pil = Image.open(response)
68
+ except Exception:
69
+ import traceback
70
+
71
+ traceback.print_exc()
72
+ raise HTTPException(422, f"Could not load image from url: {url}")
73
+
74
+ return preprocess_image(image_pil, convert_to_rgb, fix_orientation)
@@ -0,0 +1,11 @@
1
+ from .inference import (
2
+ NSFWImageDetectionInput,
3
+ NSFWImageDetectionOutput,
4
+ run_nsfw_estimation,
5
+ )
6
+
7
+ __all__ = [
8
+ "NSFWImageDetectionInput",
9
+ "NSFWImageDetectionOutput",
10
+ "run_nsfw_estimation",
11
+ ]
@@ -0,0 +1,9 @@
1
+ from pathlib import Path
2
+
3
+ CURR_DIR = Path(__file__).resolve().parent
4
+
5
+
6
+ def get_requirements():
7
+ with open(CURR_DIR / "requirements.txt") as fp:
8
+ requirements = fp.read().split()
9
+ return requirements
@@ -0,0 +1,77 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+ import fal
4
+ from fal.toolkit.image import read_image_from_url
5
+
6
+ from .env import get_requirements
7
+ from .model import get_model
8
+
9
+
10
+ class NSFWImageDetectionInput(BaseModel):
11
+ image_url: str = Field(
12
+ description="Input image url.",
13
+ examples=[
14
+ "https://storage.googleapis.com/falserverless/model_tests/remove_background/elephant.jpg",
15
+ ],
16
+ )
17
+
18
+
19
+ class NSFWImageDetectionOutput(BaseModel):
20
+ nsfw_probability: float = Field(
21
+ description="The probability of the image being NSFW.",
22
+ )
23
+
24
+
25
+ def check_nsfw_content(pil_image: object):
26
+ import torch
27
+
28
+ model, processor = get_model()
29
+
30
+ with torch.no_grad():
31
+ inputs = processor(images=pil_image, return_tensors="pt")
32
+ outputs = model(**inputs)
33
+ logits = outputs.logits.squeeze() # Remove batch dimension to simplify indexing
34
+
35
+ # Apply softmax to convert logits to probabilities
36
+ probabilities = torch.softmax(logits, dim=0)
37
+
38
+ nsfw_class_index = model.config.label2id.get(
39
+ "nsfw", None
40
+ ) # Replace "NSFW" with the exact class name if different
41
+
42
+ # Validate that NSFW class index is found
43
+ if nsfw_class_index is not None:
44
+ nsfw_probability = probabilities[int(nsfw_class_index)].item()
45
+ return nsfw_probability
46
+ else:
47
+ raise ValueError("NSFW class not found in model output.")
48
+
49
+
50
+ def run_nsfw_estimation(
51
+ input: NSFWImageDetectionInput,
52
+ ) -> NSFWImageDetectionOutput:
53
+ img = read_image_from_url(input.image_url)
54
+ nsfw_probability = check_nsfw_content(img)
55
+
56
+ return NSFWImageDetectionOutput(nsfw_probability=nsfw_probability)
57
+
58
+
59
+ @fal.function(
60
+ requirements=get_requirements(),
61
+ machine_type="GPU-A6000",
62
+ serve=True,
63
+ )
64
+ def run_nsfw_estimation_on_fal(
65
+ input: NSFWImageDetectionInput,
66
+ ) -> NSFWImageDetectionOutput:
67
+ return run_nsfw_estimation(input)
68
+
69
+
70
+ if __name__ == "__main__":
71
+ local = run_nsfw_estimation_on_fal.on(serve=False)
72
+ result = local(
73
+ NSFWImageDetectionInput(
74
+ image_url="https://storage.googleapis.com/falserverless/model_tests/remove_background/elephant.jpg",
75
+ )
76
+ )
77
+ print(result)
@@ -0,0 +1,18 @@
1
+ import fal
2
+
3
+
4
+ @fal.cached
5
+ def get_model():
6
+ import os
7
+
8
+ from transformers import AutoModelForImageClassification, ViTImageProcessor
9
+
10
+ os.environ["TRANSFORMERS_CACHE"] = "/data/models"
11
+ os.environ["HF_HOME"] = "/data/models"
12
+
13
+ model = AutoModelForImageClassification.from_pretrained(
14
+ "Falconsai/nsfw_image_detection"
15
+ )
16
+ processor = ViTImageProcessor.from_pretrained("Falconsai/nsfw_image_detection")
17
+
18
+ return model, processor
@@ -0,0 +1,4 @@
1
+ accelerate
2
+ Pillow
3
+ torch
4
+ transformers
@@ -0,0 +1,107 @@
1
+ from typing import Any
2
+
3
+ import fal
4
+
5
+ from . import filter_by
6
+ from .nsfw_filter.model import get_model
7
+
8
+
9
+ @fal.cached
10
+ def load_safety_checker():
11
+ import torch
12
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
13
+ StableDiffusionSafetyChecker,
14
+ )
15
+ from transformers import AutoFeatureExtractor
16
+
17
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
18
+ "CompVis/stable-diffusion-safety-checker",
19
+ torch_dtype="float16",
20
+ )
21
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
22
+ "CompVis/stable-diffusion-safety-checker",
23
+ torch_dtype=torch.float16,
24
+ ).to("cuda")
25
+
26
+ return feature_extractor, safety_checker
27
+
28
+
29
+ def run_safety_checker(
30
+ pil_images: list[object],
31
+ ) -> list[bool]:
32
+ import numpy as np
33
+ import torch
34
+
35
+ feature_extractor, safety_checker = load_safety_checker()
36
+
37
+ safety_checker_input = feature_extractor(pil_images, return_tensors="pt").to("cuda")
38
+
39
+ np_image = [np.array(val) for val in pil_images]
40
+
41
+ _, has_nsfw_concept = safety_checker(
42
+ images=np_image,
43
+ clip_input=safety_checker_input.pixel_values.to(torch.float16),
44
+ )
45
+
46
+ return has_nsfw_concept
47
+
48
+
49
+ def run_safety_checker_v2(pil_images: list, nsfw_threshold: float = 0.5) -> list[bool]:
50
+ import torch
51
+
52
+ model, processor = get_model()
53
+
54
+ has_nsfw_concept = []
55
+
56
+ with torch.no_grad():
57
+ for pil_image in pil_images:
58
+ inputs = processor(
59
+ images=pil_image.convert("RGB"),
60
+ return_tensors="pt",
61
+ )
62
+ outputs = model(**inputs)
63
+ logits = (
64
+ outputs.logits.squeeze()
65
+ ) # Remove batch dimension to simplify indexing
66
+
67
+ # Apply softmax to convert logits to probabilities
68
+ probabilities = torch.softmax(logits, dim=0)
69
+
70
+ nsfw_class_index = model.config.label2id.get(
71
+ "nsfw", None
72
+ ) # Replace "NSFW" with the exact class name if different
73
+
74
+ # Validate that NSFW class index is found
75
+ if nsfw_class_index is not None:
76
+ nsfw_probability = probabilities[int(nsfw_class_index)].item()
77
+ print("NSFW probability:", nsfw_probability)
78
+ has_nsfw_concept.append(nsfw_probability > nsfw_threshold)
79
+ else:
80
+ raise ValueError("NSFW class not found in model output.")
81
+
82
+ return has_nsfw_concept
83
+
84
+
85
+ def postprocess_images(
86
+ pil_images: list[object],
87
+ enable_safety_checker: bool = True,
88
+ safety_checker_version: int = 2,
89
+ ) -> dict[str, Any]:
90
+ outputs: dict[str, list[Any]] = {
91
+ "images": pil_images,
92
+ }
93
+
94
+ if enable_safety_checker:
95
+ safety_checker_fn = (
96
+ run_safety_checker_v2 if safety_checker_version == 2 else run_safety_checker
97
+ )
98
+ outputs["has_nsfw_concepts"] = safety_checker_fn(pil_images) # type: ignore
99
+ else:
100
+ outputs["has_nsfw_concepts"] = [False] * len(pil_images)
101
+
102
+ outputs["images"] = filter_by(
103
+ outputs["has_nsfw_concepts"],
104
+ outputs["images"],
105
+ )
106
+
107
+ return outputs
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: fal
3
- Version: 1.2.2
3
+ Version: 1.2.4
4
4
  Summary: fal is an easy-to-use Serverless Python Framework
5
5
  Author: Features & Labels <support@fal.ai>
6
6
  Requires-Python: >=3.8
@@ -1,10 +1,10 @@
1
1
  fal/__init__.py,sha256=wXs1G0gSc7ZK60-bHe-B2m0l_sA6TrFk4BxY0tMoLe8,784
2
2
  fal/__main__.py,sha256=MSmt_5Xg84uHqzTN38JwgseJK8rsJn_11A8WD99VtEo,61
3
- fal/_fal_version.py,sha256=XEVwqOPlIChKtEnSO5v_SvghWXnn9WeQSoJ436w3v9Y,411
3
+ fal/_fal_version.py,sha256=DFpsAdSrahcTSWRccxC8nEJpgcmby0LdmRoAddZy2zA,411
4
4
  fal/_serialization.py,sha256=rD2YiSa8iuzCaZohZwN_MPEB-PpSKbWRDeaIDpTEjyY,7653
5
5
  fal/_version.py,sha256=EBGqrknaf1WygENX-H4fBefLvHryvJBBGtVJetaB0NY,266
6
6
  fal/api.py,sha256=LAPl5Hf6ZWzEjv4lFUtsisWgrnXH_qNUHdJrEHT_A5Y,40602
7
- fal/app.py,sha256=oyN4PNULFJtjOwHYrR5lh4Ks_zBi-dEPzvFYRUXe0sI,15877
7
+ fal/app.py,sha256=BM4lk6741Z0DKr3bYLLhEvRBuDkqZqo883LvPXjgOPQ,16812
8
8
  fal/apps.py,sha256=FrKmaAUo8U9vE_fcva0GQvk4sCrzaTEr62lGtu3Ld5M,6825
9
9
  fal/container.py,sha256=V7riyyq8AZGwEX9QaqRQDZyDN_bUKeRKV1OOZArXjL0,622
10
10
  fal/flags.py,sha256=oWN_eidSUOcE9wdPK_77si3A1fpgOC0UEERPsvNLIMc,842
@@ -49,8 +49,14 @@ fal/toolkit/file/types.py,sha256=bJCeV5NPcpJYJoglailiRgFsuNAfcextYA8Et5-XUag,106
49
49
  fal/toolkit/file/providers/fal.py,sha256=65-BkK9jhGBwYI_OjhHJsL2DthyKxBBRrqXPI_ZN4-k,4115
50
50
  fal/toolkit/file/providers/gcp.py,sha256=pUVH2qNcnO_VrDQQU8MmfYOQZMGaKQIqE4yGnYdQhAc,2003
51
51
  fal/toolkit/file/providers/r2.py,sha256=WxmOHF5WxHt6tKMcFjWj7ZWO8a1EXysO9lfYv_tB3MI,2627
52
- fal/toolkit/image/__init__.py,sha256=qNLyXsBWysionUjbeWbohLqWlw3G_UpzunamkZd_JLQ,71
52
+ fal/toolkit/image/__init__.py,sha256=aLcU8HzD7HyOxx-C-Bbx9kYCMHdBhy9tR98FSVJ6gSA,1830
53
53
  fal/toolkit/image/image.py,sha256=UDIHgkxae8LzmCvWBM9GayMnK8c0JMMfsrVlLnW5rto,4234
54
+ fal/toolkit/image/safety_checker.py,sha256=S7ow-HuoVxC6ixHWWcBrAUm2dIlgq3sTAIull6xIbAg,3105
55
+ fal/toolkit/image/nsfw_filter/__init__.py,sha256=0d9D51EhcnJg8cZLYJjgvQJDZT74CfQu6mpvinRYRpA,216
56
+ fal/toolkit/image/nsfw_filter/env.py,sha256=iAP2Q3vzIl--DD8nr8o3o0goAwhExN2v0feYE0nIQjs,212
57
+ fal/toolkit/image/nsfw_filter/inference.py,sha256=BhIPF_zxRLetThQYxDDF0sdx9VRwvu74M5ye6Povi40,2167
58
+ fal/toolkit/image/nsfw_filter/model.py,sha256=63mu8D15z_IosoRUagRLGHy6VbLqFmrG-yZqnu2vVm4,457
59
+ fal/toolkit/image/nsfw_filter/requirements.txt,sha256=3Pmrd0Ny6QAeBqUNHCgffRyfaCARAPJcfSCX5cRYpbM,37
54
60
  fal/toolkit/utils/__init__.py,sha256=CrmM9DyCz5-SmcTzRSm5RaLgxy3kf0ZsSEN9uhnX2Xo,97
55
61
  fal/toolkit/utils/download_utils.py,sha256=9WMpn0mFIhkFelQpPj5KG-pC7RMyyOzGHbNRDSyz07o,17664
56
62
  openapi_fal_rest/__init__.py,sha256=ziculmF_i6trw63LzZGFX-6W3Lwq9mCR8_UpkpvpaHI,152
@@ -116,8 +122,8 @@ openapi_fal_rest/models/workflow_node_type.py,sha256=-FzyeY2bxcNmizKbJI8joG7byRi
116
122
  openapi_fal_rest/models/workflow_schema.py,sha256=4K5gsv9u9pxx2ItkffoyHeNjBBYf6ur5bN4m_zePZNY,2019
117
123
  openapi_fal_rest/models/workflow_schema_input.py,sha256=2OkOXWHTNsCXHWS6EGDFzcJKkW5FIap-2gfO233EvZQ,1191
118
124
  openapi_fal_rest/models/workflow_schema_output.py,sha256=EblwSPAGfWfYVWw_WSSaBzQVju296is9o28rMBAd0mc,1196
119
- fal-1.2.2.dist-info/METADATA,sha256=ip9ta9NozUhujfW9CmFlgVrVmrvZztXkFFOCDCboQi0,3805
120
- fal-1.2.2.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
121
- fal-1.2.2.dist-info/entry_points.txt,sha256=32zwTUC1U1E7nSTIGCoANQOQ3I7-qHG5wI6gsVz5pNU,37
122
- fal-1.2.2.dist-info/top_level.txt,sha256=r257X1L57oJL8_lM0tRrfGuXFwm66i1huwQygbpLmHw,21
123
- fal-1.2.2.dist-info/RECORD,,
125
+ fal-1.2.4.dist-info/METADATA,sha256=yCcxwr4BEx0DlAY2xi7FQzd_Qi-IChNr-aQ9QstwzUg,3805
126
+ fal-1.2.4.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
127
+ fal-1.2.4.dist-info/entry_points.txt,sha256=32zwTUC1U1E7nSTIGCoANQOQ3I7-qHG5wI6gsVz5pNU,37
128
+ fal-1.2.4.dist-info/top_level.txt,sha256=r257X1L57oJL8_lM0tRrfGuXFwm66i1huwQygbpLmHw,21
129
+ fal-1.2.4.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (71.1.0)
2
+ Generator: setuptools (72.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5