fal 1.2.1__py3-none-any.whl → 1.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

Files changed (45) hide show
  1. fal/__main__.py +3 -1
  2. fal/_fal_version.py +2 -2
  3. fal/api.py +88 -20
  4. fal/app.py +221 -27
  5. fal/apps.py +147 -3
  6. fal/auth/__init__.py +50 -2
  7. fal/cli/_utils.py +40 -0
  8. fal/cli/apps.py +5 -3
  9. fal/cli/create.py +26 -0
  10. fal/cli/deploy.py +97 -16
  11. fal/cli/main.py +2 -2
  12. fal/cli/parser.py +11 -7
  13. fal/cli/run.py +12 -1
  14. fal/cli/runners.py +44 -0
  15. fal/config.py +23 -0
  16. fal/container.py +1 -1
  17. fal/exceptions/__init__.py +7 -1
  18. fal/exceptions/_base.py +51 -0
  19. fal/exceptions/_cuda.py +44 -0
  20. fal/files.py +81 -0
  21. fal/sdk.py +67 -6
  22. fal/toolkit/file/file.py +103 -13
  23. fal/toolkit/file/providers/fal.py +572 -24
  24. fal/toolkit/file/providers/gcp.py +8 -1
  25. fal/toolkit/file/providers/r2.py +8 -1
  26. fal/toolkit/file/providers/s3.py +80 -0
  27. fal/toolkit/file/types.py +28 -3
  28. fal/toolkit/image/__init__.py +71 -0
  29. fal/toolkit/image/image.py +25 -2
  30. fal/toolkit/image/nsfw_filter/__init__.py +11 -0
  31. fal/toolkit/image/nsfw_filter/env.py +9 -0
  32. fal/toolkit/image/nsfw_filter/inference.py +77 -0
  33. fal/toolkit/image/nsfw_filter/model.py +18 -0
  34. fal/toolkit/image/nsfw_filter/requirements.txt +4 -0
  35. fal/toolkit/image/safety_checker.py +107 -0
  36. fal/toolkit/types.py +140 -0
  37. fal/toolkit/utils/download_utils.py +4 -0
  38. fal/toolkit/utils/retry.py +45 -0
  39. fal/utils.py +20 -4
  40. fal/workflows.py +10 -4
  41. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/METADATA +47 -40
  42. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/RECORD +45 -30
  43. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/WHEEL +1 -1
  44. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/entry_points.txt +0 -0
  45. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/top_level.txt +0 -0
@@ -6,8 +6,10 @@ import posixpath
6
6
  import uuid
7
7
  from dataclasses import dataclass
8
8
  from io import BytesIO
9
+ from typing import Optional
9
10
 
10
11
  from fal.toolkit.file.types import FileData, FileRepository
12
+ from fal.toolkit.utils.retry import retry
11
13
 
12
14
  DEFAULT_URL_TIMEOUT = 60 * 15 # 15 minutes
13
15
 
@@ -67,7 +69,12 @@ class R2Repository(FileRepository):
67
69
 
68
70
  return self._bucket
69
71
 
70
- def save(self, data: FileData) -> str:
72
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
73
+ def save(
74
+ self,
75
+ data: FileData,
76
+ object_lifecycle_preference: Optional[dict[str, str]] = None,
77
+ ) -> str:
71
78
  destination_path = posixpath.join(
72
79
  self.key,
73
80
  f"{uuid.uuid4().hex}_{data.file_name}",
@@ -0,0 +1,80 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import posixpath
5
+ import uuid
6
+ from dataclasses import dataclass
7
+ from io import BytesIO
8
+ from typing import Optional
9
+
10
+ from fal.toolkit.file.types import FileData, FileRepository
11
+ from fal.toolkit.utils.retry import retry
12
+
13
+ DEFAULT_URL_TIMEOUT = 60 * 15 # 15 minutes
14
+
15
+
16
+ @dataclass
17
+ class S3Repository(FileRepository):
18
+ bucket_name: str = "fal_file_storage"
19
+ url_expiration: int = DEFAULT_URL_TIMEOUT
20
+ aws_access_key_id: str | None = None
21
+ aws_secret_access_key: str | None = None
22
+
23
+ _s3_client = None
24
+
25
+ def __post_init__(self):
26
+ try:
27
+ import boto3
28
+ from botocore.client import Config
29
+ except ImportError:
30
+ raise Exception("boto3 is not installed")
31
+
32
+ if self.aws_access_key_id is None:
33
+ self.aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID")
34
+ if self.aws_access_key_id is None:
35
+ raise Exception("AWS_ACCESS_KEY_ID environment variable is not set")
36
+
37
+ if self.aws_secret_access_key is None:
38
+ self.aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
39
+ if self.aws_secret_access_key is None:
40
+ raise Exception("AWS_SECRET_ACCESS_KEY environment variable is not set")
41
+
42
+ self._s3_client = boto3.client(
43
+ "s3",
44
+ aws_access_key_id=self.aws_access_key_id,
45
+ aws_secret_access_key=self.aws_secret_access_key,
46
+ config=Config(signature_version="s3v4"),
47
+ )
48
+
49
+ @property
50
+ def storage_client(self):
51
+ if self._s3_client is None:
52
+ raise Exception("S3 client is not initialized")
53
+
54
+ return self._s3_client
55
+
56
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
57
+ def save(
58
+ self,
59
+ data: FileData,
60
+ object_lifecycle_preference: Optional[dict[str, str]] = None,
61
+ key: Optional[str] = None,
62
+ ) -> str:
63
+ destination_path = posixpath.join(
64
+ key or "",
65
+ f"{uuid.uuid4().hex}_{data.file_name}",
66
+ )
67
+
68
+ self.storage_client.upload_fileobj(
69
+ BytesIO(data.data),
70
+ self.bucket_name,
71
+ destination_path,
72
+ ExtraArgs={"ContentType": data.content_type},
73
+ )
74
+
75
+ public_url = self.storage_client.generate_presigned_url(
76
+ ClientMethod="get_object",
77
+ Params={"Bucket": self.bucket_name, "Key": destination_path},
78
+ ExpiresIn=self.url_expiration,
79
+ )
80
+ return public_url
fal/toolkit/file/types.py CHANGED
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
4
  from mimetypes import guess_extension, guess_type
5
- from typing import Literal
5
+ from pathlib import Path
6
+ from typing import Literal, Optional
6
7
  from uuid import uuid4
7
8
 
8
9
 
@@ -28,10 +29,34 @@ class FileData:
28
29
  self.file_name = file_name
29
30
 
30
31
 
31
- RepositoryId = Literal["fal", "fal_v2", "in_memory", "gcp_storage", "r2", "cdn"]
32
+ RepositoryId = Literal[
33
+ "fal", "fal_v2", "fal_v3", "in_memory", "gcp_storage", "r2", "cdn"
34
+ ]
32
35
 
33
36
 
34
37
  @dataclass
35
38
  class FileRepository:
36
- def save(self, data: FileData) -> str:
39
+ def save(
40
+ self,
41
+ data: FileData,
42
+ object_lifecycle_preference: Optional[dict[str, str]] = None,
43
+ ) -> str:
37
44
  raise NotImplementedError()
45
+
46
+ def save_file(
47
+ self,
48
+ file_path: str | Path,
49
+ content_type: str,
50
+ multipart: bool | None = None,
51
+ multipart_threshold: int | None = None,
52
+ multipart_chunk_size: int | None = None,
53
+ multipart_max_concurrency: int | None = None,
54
+ object_lifecycle_preference: Optional[dict[str, str]] = None,
55
+ ) -> tuple[str, FileData | None]:
56
+ if multipart:
57
+ raise NotImplementedError()
58
+
59
+ with open(file_path, "rb") as fobj:
60
+ data = FileData(fobj.read(), content_type, Path(file_path).name)
61
+
62
+ return self.save(data, object_lifecycle_preference), data
@@ -1,3 +1,74 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import urllib.request
4
+ from functools import lru_cache
5
+ from typing import TYPE_CHECKING
6
+
3
7
  from .image import * # noqa: F403
8
+
9
+ if TYPE_CHECKING:
10
+ from PIL.Image import Image as PILImage
11
+
12
+
13
+ def filter_by(
14
+ has_nsfw_concepts: list[bool],
15
+ images: list[PILImage],
16
+ ) -> list[PILImage]:
17
+ from PIL import Image as PILImage
18
+
19
+ return [
20
+ (
21
+ PILImage.new("RGB", (image.width, image.height), (0, 0, 0))
22
+ if has_nsfw
23
+ else image
24
+ )
25
+ for image, has_nsfw in zip(images, has_nsfw_concepts)
26
+ ]
27
+
28
+
29
+ def preprocess_image(image_pil, convert_to_rgb=True, fix_orientation=True):
30
+ from PIL import ImageOps, ImageSequence
31
+
32
+ # For MPO (multi picture object) format images, we only need the first image
33
+ images = []
34
+ for image in ImageSequence.Iterator(image_pil):
35
+ img = image
36
+
37
+ if convert_to_rgb:
38
+ img = img.convert("RGB")
39
+
40
+ if fix_orientation:
41
+ img = ImageOps.exif_transpose(img)
42
+
43
+ images.append(img)
44
+
45
+ break
46
+
47
+ return images[0]
48
+
49
+
50
+ @lru_cache(maxsize=64)
51
+ def read_image_from_url(
52
+ url: str, convert_to_rgb: bool = True, fix_orientation: bool = True
53
+ ):
54
+ from fastapi import HTTPException
55
+ from PIL import Image
56
+
57
+ TEMP_HEADERS = {
58
+ "User-Agent": (
59
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.8; rv:21.0) "
60
+ "Gecko/20100101 Firefox/21.0"
61
+ ),
62
+ }
63
+
64
+ try:
65
+ request = urllib.request.Request(url, headers=TEMP_HEADERS)
66
+ response = urllib.request.urlopen(request)
67
+ image_pil = Image.open(response)
68
+ except Exception:
69
+ import traceback
70
+
71
+ traceback.print_exc()
72
+ raise HTTPException(422, f"Could not load image from url: {url}")
73
+
74
+ return preprocess_image(image_pil, convert_to_rgb, fix_orientation)
@@ -4,9 +4,10 @@ import io
4
4
  from tempfile import NamedTemporaryFile
5
5
  from typing import TYPE_CHECKING, Literal, Optional, Union
6
6
 
7
+ from fastapi import Request
7
8
  from pydantic import BaseModel, Field
8
9
 
9
- from fal.toolkit.file.file import DEFAULT_REPOSITORY, File
10
+ from fal.toolkit.file.file import DEFAULT_REPOSITORY, FALLBACK_REPOSITORY, File
10
11
  from fal.toolkit.file.types import FileRepository, RepositoryId
11
12
  from fal.toolkit.utils.download_utils import _download_file_python
12
13
 
@@ -79,12 +80,18 @@ class Image(File):
79
80
  size: ImageSize | None = None,
80
81
  file_name: str | None = None,
81
82
  repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
83
+ fallback_repository: Optional[
84
+ FileRepository | RepositoryId
85
+ ] = FALLBACK_REPOSITORY,
86
+ request: Optional[Request] = None,
82
87
  ) -> Image:
83
88
  obj = super().from_bytes(
84
89
  data,
85
90
  content_type=f"image/{format}",
86
91
  file_name=file_name,
87
92
  repository=repository,
93
+ fallback_repository=fallback_repository,
94
+ request=request,
88
95
  )
89
96
  obj.width = size.width if size else None
90
97
  obj.height = size.height if size else None
@@ -97,6 +104,10 @@ class Image(File):
97
104
  format: ImageFormat | None = None,
98
105
  file_name: str | None = None,
99
106
  repository: FileRepository | RepositoryId = DEFAULT_REPOSITORY,
107
+ fallback_repository: Optional[
108
+ FileRepository | RepositoryId
109
+ ] = FALLBACK_REPOSITORY,
110
+ request: Optional[Request] = None,
100
111
  ) -> Image:
101
112
  size = ImageSize(width=pil_image.width, height=pil_image.height)
102
113
  if format is None:
@@ -110,12 +121,24 @@ class Image(File):
110
121
  # enough result quickly to utilize the underlying resources
111
122
  # efficiently.
112
123
  saving_options["compress_level"] = 1
124
+ elif format == "jpeg":
125
+ # JPEG quality is set to 95 by default, which is a good balance
126
+ # between file size and image quality.
127
+ saving_options["quality"] = 95
113
128
 
114
129
  with io.BytesIO() as f:
115
130
  pil_image.save(f, format=format, **saving_options)
116
131
  raw_image = f.getvalue()
117
132
 
118
- return cls.from_bytes(raw_image, format, size, file_name, repository)
133
+ return cls.from_bytes(
134
+ raw_image,
135
+ format,
136
+ size,
137
+ file_name,
138
+ repository,
139
+ fallback_repository=fallback_repository,
140
+ request=request,
141
+ )
119
142
 
120
143
  def to_pil(self, mode: str = "RGB") -> PILImage.Image:
121
144
  try:
@@ -0,0 +1,11 @@
1
+ from .inference import (
2
+ NSFWImageDetectionInput,
3
+ NSFWImageDetectionOutput,
4
+ run_nsfw_estimation,
5
+ )
6
+
7
+ __all__ = [
8
+ "NSFWImageDetectionInput",
9
+ "NSFWImageDetectionOutput",
10
+ "run_nsfw_estimation",
11
+ ]
@@ -0,0 +1,9 @@
1
+ from pathlib import Path
2
+
3
+ CURR_DIR = Path(__file__).resolve().parent
4
+
5
+
6
+ def get_requirements():
7
+ with open(CURR_DIR / "requirements.txt") as fp:
8
+ requirements = fp.read().split()
9
+ return requirements
@@ -0,0 +1,77 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+ import fal
4
+ from fal.toolkit.image import read_image_from_url
5
+
6
+ from .env import get_requirements
7
+ from .model import get_model
8
+
9
+
10
+ class NSFWImageDetectionInput(BaseModel):
11
+ image_url: str = Field(
12
+ description="Input image url.",
13
+ examples=[
14
+ "https://storage.googleapis.com/falserverless/model_tests/remove_background/elephant.jpg",
15
+ ],
16
+ )
17
+
18
+
19
+ class NSFWImageDetectionOutput(BaseModel):
20
+ nsfw_probability: float = Field(
21
+ description="The probability of the image being NSFW.",
22
+ )
23
+
24
+
25
+ def check_nsfw_content(pil_image: object):
26
+ import torch
27
+
28
+ model, processor = get_model()
29
+
30
+ with torch.no_grad():
31
+ inputs = processor(images=pil_image, return_tensors="pt")
32
+ outputs = model(**inputs)
33
+ logits = outputs.logits.squeeze() # Remove batch dimension to simplify indexing
34
+
35
+ # Apply softmax to convert logits to probabilities
36
+ probabilities = torch.softmax(logits, dim=0)
37
+
38
+ nsfw_class_index = model.config.label2id.get(
39
+ "nsfw", None
40
+ ) # Replace "NSFW" with the exact class name if different
41
+
42
+ # Validate that NSFW class index is found
43
+ if nsfw_class_index is not None:
44
+ nsfw_probability = probabilities[int(nsfw_class_index)].item()
45
+ return nsfw_probability
46
+ else:
47
+ raise ValueError("NSFW class not found in model output.")
48
+
49
+
50
+ def run_nsfw_estimation(
51
+ input: NSFWImageDetectionInput,
52
+ ) -> NSFWImageDetectionOutput:
53
+ img = read_image_from_url(input.image_url)
54
+ nsfw_probability = check_nsfw_content(img)
55
+
56
+ return NSFWImageDetectionOutput(nsfw_probability=nsfw_probability)
57
+
58
+
59
+ @fal.function(
60
+ requirements=get_requirements(),
61
+ machine_type="GPU-A6000",
62
+ serve=True,
63
+ )
64
+ def run_nsfw_estimation_on_fal(
65
+ input: NSFWImageDetectionInput,
66
+ ) -> NSFWImageDetectionOutput:
67
+ return run_nsfw_estimation(input)
68
+
69
+
70
+ if __name__ == "__main__":
71
+ local = run_nsfw_estimation_on_fal.on(serve=False)
72
+ result = local(
73
+ NSFWImageDetectionInput(
74
+ image_url="https://storage.googleapis.com/falserverless/model_tests/remove_background/elephant.jpg",
75
+ )
76
+ )
77
+ print(result)
@@ -0,0 +1,18 @@
1
+ import fal
2
+
3
+
4
+ @fal.cached
5
+ def get_model():
6
+ import os
7
+
8
+ from transformers import AutoModelForImageClassification, ViTImageProcessor
9
+
10
+ os.environ["TRANSFORMERS_CACHE"] = "/data/models"
11
+ os.environ["HF_HOME"] = "/data/models"
12
+
13
+ model = AutoModelForImageClassification.from_pretrained(
14
+ "Falconsai/nsfw_image_detection"
15
+ )
16
+ processor = ViTImageProcessor.from_pretrained("Falconsai/nsfw_image_detection")
17
+
18
+ return model, processor
@@ -0,0 +1,4 @@
1
+ accelerate
2
+ Pillow
3
+ torch
4
+ transformers
@@ -0,0 +1,107 @@
1
+ from typing import Any
2
+
3
+ import fal
4
+
5
+ from . import filter_by
6
+ from .nsfw_filter.model import get_model
7
+
8
+
9
+ @fal.cached
10
+ def load_safety_checker():
11
+ import torch
12
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
13
+ StableDiffusionSafetyChecker,
14
+ )
15
+ from transformers import AutoFeatureExtractor
16
+
17
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
18
+ "CompVis/stable-diffusion-safety-checker",
19
+ torch_dtype="float16",
20
+ )
21
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
22
+ "CompVis/stable-diffusion-safety-checker",
23
+ torch_dtype=torch.float16,
24
+ ).to("cuda")
25
+
26
+ return feature_extractor, safety_checker
27
+
28
+
29
+ def run_safety_checker(
30
+ pil_images: list[object],
31
+ ) -> list[bool]:
32
+ import numpy as np
33
+ import torch
34
+
35
+ feature_extractor, safety_checker = load_safety_checker()
36
+
37
+ safety_checker_input = feature_extractor(pil_images, return_tensors="pt").to("cuda")
38
+
39
+ np_image = [np.array(val) for val in pil_images]
40
+
41
+ _, has_nsfw_concept = safety_checker(
42
+ images=np_image,
43
+ clip_input=safety_checker_input.pixel_values.to(torch.float16),
44
+ )
45
+
46
+ return has_nsfw_concept
47
+
48
+
49
+ def run_safety_checker_v2(pil_images: list, nsfw_threshold: float = 0.5) -> list[bool]:
50
+ import torch
51
+
52
+ model, processor = get_model()
53
+
54
+ has_nsfw_concept = []
55
+
56
+ with torch.no_grad():
57
+ for pil_image in pil_images:
58
+ inputs = processor(
59
+ images=pil_image.convert("RGB"),
60
+ return_tensors="pt",
61
+ )
62
+ outputs = model(**inputs)
63
+ logits = (
64
+ outputs.logits.squeeze()
65
+ ) # Remove batch dimension to simplify indexing
66
+
67
+ # Apply softmax to convert logits to probabilities
68
+ probabilities = torch.softmax(logits, dim=0)
69
+
70
+ nsfw_class_index = model.config.label2id.get(
71
+ "nsfw", None
72
+ ) # Replace "NSFW" with the exact class name if different
73
+
74
+ # Validate that NSFW class index is found
75
+ if nsfw_class_index is not None:
76
+ nsfw_probability = probabilities[int(nsfw_class_index)].item()
77
+ print("NSFW probability:", nsfw_probability)
78
+ has_nsfw_concept.append(nsfw_probability > nsfw_threshold)
79
+ else:
80
+ raise ValueError("NSFW class not found in model output.")
81
+
82
+ return has_nsfw_concept
83
+
84
+
85
+ def postprocess_images(
86
+ pil_images: list[object],
87
+ enable_safety_checker: bool = True,
88
+ safety_checker_version: int = 2,
89
+ ) -> dict[str, Any]:
90
+ outputs: dict[str, list[Any]] = {
91
+ "images": pil_images,
92
+ }
93
+
94
+ if enable_safety_checker:
95
+ safety_checker_fn = (
96
+ run_safety_checker_v2 if safety_checker_version == 2 else run_safety_checker
97
+ )
98
+ outputs["has_nsfw_concepts"] = safety_checker_fn(pil_images) # type: ignore
99
+ else:
100
+ outputs["has_nsfw_concepts"] = [False] * len(pil_images)
101
+
102
+ outputs["images"] = filter_by(
103
+ outputs["has_nsfw_concepts"],
104
+ outputs["images"],
105
+ )
106
+
107
+ return outputs
fal/toolkit/types.py ADDED
@@ -0,0 +1,140 @@
1
+ import re
2
+ import tempfile
3
+ from contextlib import contextmanager
4
+ from pathlib import Path
5
+ from typing import Any, Dict, Generator, Union
6
+
7
+ import pydantic
8
+ from pydantic.utils import update_not_none
9
+
10
+ from fal.toolkit.image import read_image_from_url
11
+ from fal.toolkit.utils.download_utils import download_file
12
+
13
+ # https://github.com/pydantic/pydantic/pull/2573
14
+ if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."):
15
+ IS_PYDANTIC_V2 = False
16
+ else:
17
+ IS_PYDANTIC_V2 = True
18
+
19
+ MAX_DATA_URI_LENGTH = 10 * 1024 * 1024
20
+ MAX_HTTPS_URL_LENGTH = 2048
21
+
22
+ HTTP_URL_REGEX = (
23
+ r"^https:\/\/(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(?::\d{1,5})?(?:\/[^\s]*)?$"
24
+ )
25
+
26
+
27
+ class DownloadFileMixin:
28
+ @contextmanager
29
+ def as_temp_file(self) -> Generator[Path, None, None]:
30
+ with tempfile.TemporaryDirectory() as temp_dir:
31
+ yield download_file(str(self), temp_dir)
32
+
33
+
34
+ class DownloadImageMixin:
35
+ def to_pil(self):
36
+ return read_image_from_url(str(self))
37
+
38
+
39
+ class DataUri(DownloadFileMixin, str):
40
+ if IS_PYDANTIC_V2:
41
+
42
+ @classmethod
43
+ def __get_pydantic_core_schema__(cls, source_type: Any, handler) -> Any:
44
+ return {
45
+ "type": "str",
46
+ "pattern": "^data:",
47
+ "max_length": MAX_DATA_URI_LENGTH,
48
+ "strip_whitespace": True,
49
+ }
50
+
51
+ def __get_pydantic_json_schema__(cls, core_schema, handler) -> Dict[str, Any]:
52
+ json_schema = handler(core_schema)
53
+ json_schema.update(format="data-uri")
54
+ return json_schema
55
+ else:
56
+
57
+ @classmethod
58
+ def __get_validators__(cls):
59
+ yield cls.validate
60
+
61
+ @classmethod
62
+ def validate(cls, value: Any) -> "DataUri":
63
+ from pydantic.validators import str_validator
64
+
65
+ value = str_validator(value)
66
+ value = value.strip()
67
+
68
+ if not value.startswith("data:"):
69
+ raise ValueError("Data URI must start with 'data:'")
70
+
71
+ if len(value) > MAX_DATA_URI_LENGTH:
72
+ raise ValueError(
73
+ f"Data URI is too long. Max length is {MAX_DATA_URI_LENGTH} bytes."
74
+ )
75
+
76
+ return cls(value)
77
+
78
+ @classmethod
79
+ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
80
+ update_not_none(field_schema, format="data-uri")
81
+
82
+
83
+ class HttpsUrl(DownloadFileMixin, str):
84
+ if IS_PYDANTIC_V2:
85
+
86
+ @classmethod
87
+ def __get_pydantic_core_schema__(cls, source_type: Any, handler) -> Any:
88
+ return {
89
+ "type": "str",
90
+ "pattern": HTTP_URL_REGEX,
91
+ "max_length": MAX_HTTPS_URL_LENGTH,
92
+ "strip_whitespace": True,
93
+ }
94
+
95
+ def __get_pydantic_json_schema__(cls, core_schema, handler) -> Dict[str, Any]:
96
+ json_schema = handler(core_schema)
97
+ json_schema.update(format="https-url")
98
+ return json_schema
99
+
100
+ else:
101
+
102
+ @classmethod
103
+ def __get_validators__(cls):
104
+ yield cls.validate
105
+
106
+ @classmethod
107
+ def validate(cls, value: Any) -> "HttpsUrl":
108
+ from pydantic.validators import str_validator
109
+
110
+ value = str_validator(value)
111
+ value = value.strip()
112
+
113
+ if not re.match(HTTP_URL_REGEX, value):
114
+ raise ValueError(
115
+ "URL must start with 'https://' and follow the correct format."
116
+ )
117
+
118
+ if len(value) > MAX_HTTPS_URL_LENGTH:
119
+ raise ValueError(
120
+ f"HTTPS URL is too long. Max length is "
121
+ f"{MAX_HTTPS_URL_LENGTH} characters."
122
+ )
123
+
124
+ return cls(value)
125
+
126
+ @classmethod
127
+ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
128
+ update_not_none(field_schema, format="https-url")
129
+
130
+
131
+ class ImageHttpsUrl(DownloadImageMixin, HttpsUrl):
132
+ pass
133
+
134
+
135
+ class ImageDataUri(DownloadImageMixin, DataUri):
136
+ pass
137
+
138
+
139
+ FileInput = Union[HttpsUrl, DataUri]
140
+ ImageInput = Union[ImageHttpsUrl, ImageDataUri]
@@ -84,6 +84,9 @@ def _get_remote_file_properties(
84
84
  url_path = parsed_url.path
85
85
  file_name = Path(url_path).name or _hash_url(url)
86
86
 
87
+ # file name can still contain a forward slash if the server returns a relative path
88
+ file_name = Path(file_name).name
89
+
87
90
  return file_name, content_length
88
91
 
89
92
 
@@ -159,6 +162,7 @@ def download_file(
159
162
  try:
160
163
  file_name = _get_remote_file_properties(url, request_headers)[0]
161
164
  except Exception as e:
165
+ print(f"GOt error: {e}")
162
166
  raise DownloadError(f"Failed to get remote file properties for {url}") from e
163
167
 
164
168
  if "/" in file_name: