fal 1.2.1__py3-none-any.whl → 1.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

Files changed (45) hide show
  1. fal/__main__.py +3 -1
  2. fal/_fal_version.py +2 -2
  3. fal/api.py +88 -20
  4. fal/app.py +221 -27
  5. fal/apps.py +147 -3
  6. fal/auth/__init__.py +50 -2
  7. fal/cli/_utils.py +40 -0
  8. fal/cli/apps.py +5 -3
  9. fal/cli/create.py +26 -0
  10. fal/cli/deploy.py +97 -16
  11. fal/cli/main.py +2 -2
  12. fal/cli/parser.py +11 -7
  13. fal/cli/run.py +12 -1
  14. fal/cli/runners.py +44 -0
  15. fal/config.py +23 -0
  16. fal/container.py +1 -1
  17. fal/exceptions/__init__.py +7 -1
  18. fal/exceptions/_base.py +51 -0
  19. fal/exceptions/_cuda.py +44 -0
  20. fal/files.py +81 -0
  21. fal/sdk.py +67 -6
  22. fal/toolkit/file/file.py +103 -13
  23. fal/toolkit/file/providers/fal.py +572 -24
  24. fal/toolkit/file/providers/gcp.py +8 -1
  25. fal/toolkit/file/providers/r2.py +8 -1
  26. fal/toolkit/file/providers/s3.py +80 -0
  27. fal/toolkit/file/types.py +28 -3
  28. fal/toolkit/image/__init__.py +71 -0
  29. fal/toolkit/image/image.py +25 -2
  30. fal/toolkit/image/nsfw_filter/__init__.py +11 -0
  31. fal/toolkit/image/nsfw_filter/env.py +9 -0
  32. fal/toolkit/image/nsfw_filter/inference.py +77 -0
  33. fal/toolkit/image/nsfw_filter/model.py +18 -0
  34. fal/toolkit/image/nsfw_filter/requirements.txt +4 -0
  35. fal/toolkit/image/safety_checker.py +107 -0
  36. fal/toolkit/types.py +140 -0
  37. fal/toolkit/utils/download_utils.py +4 -0
  38. fal/toolkit/utils/retry.py +45 -0
  39. fal/utils.py +20 -4
  40. fal/workflows.py +10 -4
  41. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/METADATA +47 -40
  42. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/RECORD +45 -30
  43. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/WHEEL +1 -1
  44. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/entry_points.txt +0 -0
  45. {fal-1.2.1.dist-info → fal-1.7.2.dist-info}/top_level.txt +0 -0
@@ -1,32 +1,130 @@
1
1
  from __future__ import annotations
2
2
 
3
- import dataclasses
4
3
  import json
4
+ import math
5
5
  import os
6
+ import threading
6
7
  from base64 import b64encode
7
8
  from dataclasses import dataclass
9
+ from datetime import datetime, timezone
10
+ from pathlib import Path
11
+ from typing import Generic, TypeVar
8
12
  from urllib.error import HTTPError
13
+ from urllib.parse import urlparse, urlunparse
9
14
  from urllib.request import Request, urlopen
10
15
 
11
16
  from fal.auth import key_credentials
12
17
  from fal.toolkit.exceptions import FileUploadException
13
18
  from fal.toolkit.file.types import FileData, FileRepository
19
+ from fal.toolkit.utils.retry import retry
14
20
 
15
21
  _FAL_CDN = "https://fal.media"
22
+ _FAL_CDN_V3 = "https://v3.fal.media"
16
23
 
17
24
 
18
25
  @dataclass
19
- class ObjectLifecyclePreference:
20
- expriation_duration_seconds: int
26
+ class FalV2Token:
27
+ token: str
28
+ token_type: str
29
+ base_upload_url: str
30
+ expires_at: datetime
21
31
 
32
+ def is_expired(self) -> bool:
33
+ return datetime.now(timezone.utc) >= self.expires_at
22
34
 
23
- GLOBAL_LIFECYCLE_PREFERENCE = ObjectLifecyclePreference(
24
- expriation_duration_seconds=86400
25
- )
35
+
36
+ class FalV3Token(FalV2Token):
37
+ pass
38
+
39
+
40
+ class FalV2TokenManager:
41
+ token_cls: type[FalV2Token] = FalV2Token
42
+ storage_type: str = "fal-cdn"
43
+ upload_prefix = "upload."
44
+
45
+ def __init__(self):
46
+ self._token: FalV2Token = self.token_cls(
47
+ token="",
48
+ token_type="",
49
+ base_upload_url="",
50
+ expires_at=datetime.min.replace(tzinfo=timezone.utc),
51
+ )
52
+ self._lock: threading.Lock = threading.Lock()
53
+
54
+ def get_token(self) -> FalV2Token:
55
+ with self._lock:
56
+ if self._token.is_expired():
57
+ self._refresh_token()
58
+ return self._token
59
+
60
+ def _refresh_token(self) -> None:
61
+ key_creds = key_credentials()
62
+ if not key_creds:
63
+ raise FileUploadException("FAL_KEY must be set")
64
+
65
+ key_id, key_secret = key_creds
66
+ headers = {
67
+ "Authorization": f"Key {key_id}:{key_secret}",
68
+ "Accept": "application/json",
69
+ "Content-Type": "application/json",
70
+ }
71
+
72
+ grpc_host = os.environ.get("FAL_HOST", "api.alpha.fal.ai")
73
+ rest_host = grpc_host.replace("api", "rest", 1)
74
+ url = f"https://{rest_host}/storage/auth/token?storage_type={self.storage_type}"
75
+
76
+ req = Request(
77
+ url,
78
+ headers=headers,
79
+ data=b"{}",
80
+ method="POST",
81
+ )
82
+ with urlopen(req) as response:
83
+ result = json.load(response)
84
+
85
+ parsed_base_url = urlparse(result["base_url"])
86
+ base_upload_url = urlunparse(
87
+ parsed_base_url._replace(netloc=self.upload_prefix + parsed_base_url.netloc)
88
+ )
89
+
90
+ self._token = self.token_cls(
91
+ token=result["token"],
92
+ token_type=result["token_type"],
93
+ base_upload_url=base_upload_url,
94
+ expires_at=datetime.fromisoformat(result["expires_at"]),
95
+ )
96
+
97
+
98
+ class FalV3TokenManager(FalV2TokenManager):
99
+ token_cls: type[FalV2Token] = FalV3Token
100
+ storage_type: str = "fal-cdn-v3"
101
+ upload_prefix = ""
102
+
103
+
104
+ fal_v2_token_manager = FalV2TokenManager()
105
+ fal_v3_token_manager = FalV3TokenManager()
106
+
107
+
108
+ VariableType = TypeVar("VariableType")
109
+
110
+
111
+ class VariableReference(Generic[VariableType]):
112
+ def __init__(self, value: VariableType) -> None:
113
+ self.set(value)
114
+
115
+ def get(self) -> VariableType:
116
+ return self.value
117
+
118
+ def set(self, value: VariableType) -> None:
119
+ self.value = value
120
+
121
+
122
+ LIFECYCLE_PREFERENCE: VariableReference[dict[str, str] | None] = VariableReference(None)
26
123
 
27
124
 
28
125
  @dataclass
29
126
  class FalFileRepositoryBase(FileRepository):
127
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
30
128
  def _save(self, file: FileData, storage_type: str) -> str:
31
129
  key_creds = key_credentials()
32
130
  if not key_creds:
@@ -61,36 +159,371 @@ class FalFileRepositoryBase(FileRepository):
61
159
  result = json.load(response)
62
160
 
63
161
  upload_url = result["upload_url"]
64
- self._upload_file(upload_url, file)
65
-
66
- return result["file_url"]
67
162
  except HTTPError as e:
68
163
  raise FileUploadException(
69
164
  f"Error initiating upload. Status {e.status}: {e.reason}"
70
165
  )
71
166
 
72
- def _upload_file(self, upload_url: str, file: FileData):
73
- req = Request(
74
- upload_url,
75
- method="PUT",
76
- data=file.data,
77
- headers={"Content-Type": file.content_type},
78
- )
167
+ try:
168
+ req = Request(
169
+ upload_url,
170
+ method="PUT",
171
+ data=file.data,
172
+ headers={"Content-Type": file.content_type},
173
+ )
79
174
 
80
- with urlopen(req):
81
- return
175
+ with urlopen(req):
176
+ pass
177
+
178
+ return result["file_url"]
179
+ except HTTPError as e:
180
+ raise FileUploadException(
181
+ f"Error uploading file. Status {e.status}: {e.reason}"
182
+ )
82
183
 
83
184
 
84
185
  @dataclass
85
186
  class FalFileRepository(FalFileRepositoryBase):
86
- def save(self, file: FileData) -> str:
187
+ def save(
188
+ self, file: FileData, object_lifecycle_preference: dict[str, str] | None = None
189
+ ) -> str:
87
190
  return self._save(file, "gcs")
88
191
 
89
192
 
193
+ @dataclass
194
+ class FalFileRepositoryV3(FalFileRepositoryBase):
195
+ def save(
196
+ self, file: FileData, object_lifecycle_preference: dict[str, str] | None = None
197
+ ) -> str:
198
+ return self._save(file, "fal-cdn-v3")
199
+
200
+
201
+ class MultipartUpload:
202
+ MULTIPART_THRESHOLD = 100 * 1024 * 1024
203
+ MULTIPART_CHUNK_SIZE = 100 * 1024 * 1024
204
+ MULTIPART_MAX_CONCURRENCY = 10
205
+
206
+ def __init__(
207
+ self,
208
+ file_path: str | Path,
209
+ chunk_size: int | None = None,
210
+ content_type: str | None = None,
211
+ max_concurrency: int | None = None,
212
+ ) -> None:
213
+ self.file_path = file_path
214
+ self.chunk_size = chunk_size or self.MULTIPART_CHUNK_SIZE
215
+ self.content_type = content_type or "application/octet-stream"
216
+ self.max_concurrency = max_concurrency or self.MULTIPART_MAX_CONCURRENCY
217
+
218
+ self._parts: list[dict] = []
219
+
220
+ def create(self):
221
+ token = fal_v2_token_manager.get_token()
222
+ try:
223
+ req = Request(
224
+ f"{token.base_upload_url}/upload/initiate-multipart",
225
+ method="POST",
226
+ headers={
227
+ "Authorization": f"{token.token_type} {token.token}",
228
+ "Accept": "application/json",
229
+ "Content-Type": "application/json",
230
+ },
231
+ data=json.dumps(
232
+ {
233
+ "file_name": os.path.basename(self.file_path),
234
+ "content_type": self.content_type,
235
+ }
236
+ ).encode(),
237
+ )
238
+ with urlopen(req) as response:
239
+ result = json.load(response)
240
+ self._upload_url = result["upload_url"]
241
+ self._file_url = result["file_url"]
242
+ except HTTPError as exc:
243
+ raise FileUploadException(
244
+ f"Error initiating upload. Status {exc.status}: {exc.reason}"
245
+ )
246
+
247
+ def _upload_part(self, url: str, part_number: int) -> dict:
248
+ with open(self.file_path, "rb") as f:
249
+ start = (part_number - 1) * self.chunk_size
250
+ f.seek(start)
251
+ data = f.read(self.chunk_size)
252
+ req = Request(
253
+ url,
254
+ method="PUT",
255
+ headers={"Content-Type": self.content_type},
256
+ data=data,
257
+ )
258
+
259
+ try:
260
+ with urlopen(req) as resp:
261
+ return {
262
+ "part_number": part_number,
263
+ "etag": resp.headers["ETag"],
264
+ }
265
+ except HTTPError as exc:
266
+ raise FileUploadException(
267
+ f"Error uploading part {part_number} to {url}. "
268
+ f"Status {exc.status}: {exc.reason}"
269
+ )
270
+
271
+ def upload(self) -> None:
272
+ import concurrent.futures
273
+
274
+ parts = math.ceil(os.path.getsize(self.file_path) / self.chunk_size)
275
+ with concurrent.futures.ThreadPoolExecutor(
276
+ max_workers=self.max_concurrency
277
+ ) as executor:
278
+ futures = []
279
+ for part_number in range(1, parts + 1):
280
+ upload_url = f"{self._upload_url}&part_number={part_number}"
281
+ futures.append(
282
+ executor.submit(self._upload_part, upload_url, part_number)
283
+ )
284
+
285
+ for future in concurrent.futures.as_completed(futures):
286
+ entry = future.result()
287
+ self._parts.append(entry)
288
+
289
+ def complete(self):
290
+ url = self._upload_url
291
+ try:
292
+ req = Request(
293
+ url,
294
+ method="POST",
295
+ headers={
296
+ "Accept": "application/json",
297
+ "Content-Type": "application/json",
298
+ },
299
+ data=json.dumps({"parts": self._parts}).encode(),
300
+ )
301
+ with urlopen(req):
302
+ pass
303
+ except HTTPError as e:
304
+ raise FileUploadException(
305
+ f"Error completing upload {url}. Status {e.status}: {e.reason}"
306
+ )
307
+
308
+ return self._file_url
309
+
310
+
311
+ class InternalMultipartUploadV3:
312
+ MULTIPART_THRESHOLD = 100 * 1024 * 1024
313
+ MULTIPART_CHUNK_SIZE = 10 * 1024 * 1024
314
+ MULTIPART_MAX_CONCURRENCY = 10
315
+
316
+ def __init__(
317
+ self,
318
+ file_path: str | Path,
319
+ chunk_size: int | None = None,
320
+ content_type: str | None = None,
321
+ max_concurrency: int | None = None,
322
+ ) -> None:
323
+ self.file_path = file_path
324
+ self.chunk_size = chunk_size or self.MULTIPART_CHUNK_SIZE
325
+ self.content_type = content_type or "application/octet-stream"
326
+ self.max_concurrency = max_concurrency or self.MULTIPART_MAX_CONCURRENCY
327
+ self._access_url: str | None = None
328
+ self._upload_id: str | None = None
329
+
330
+ self._parts: list[dict] = []
331
+
332
+ @property
333
+ def access_url(self) -> str:
334
+ if not self._access_url:
335
+ raise FileUploadException("Upload not initiated")
336
+ return self._access_url
337
+
338
+ @property
339
+ def upload_id(self) -> str:
340
+ if not self._upload_id:
341
+ raise FileUploadException("Upload not initiated")
342
+ return self._upload_id
343
+
344
+ @property
345
+ def auth_headers(self) -> dict[str, str]:
346
+ token = fal_v3_token_manager.get_token()
347
+ return {
348
+ "Authorization": f"{token.token_type} {token.token}",
349
+ "User-Agent": "fal/0.1.0",
350
+ }
351
+
352
+ def create(self):
353
+ token = fal_v3_token_manager.get_token()
354
+ try:
355
+ req = Request(
356
+ f"{token.base_upload_url}/files/upload/multipart",
357
+ method="POST",
358
+ headers={
359
+ **self.auth_headers,
360
+ "Accept": "application/json",
361
+ "Content-Type": self.content_type,
362
+ "X-Fal-File-Name": os.path.basename(self.file_path),
363
+ },
364
+ )
365
+ with urlopen(req) as response:
366
+ result = json.load(response)
367
+ self._access_url = result["access_url"]
368
+ self._upload_id = result["uploadId"]
369
+
370
+ except HTTPError as exc:
371
+ raise FileUploadException(
372
+ f"Error initiating upload. Status {exc.status}: {exc.reason}"
373
+ )
374
+
375
+ @retry(max_retries=5, base_delay=1, backoff_type="exponential", jitter=True)
376
+ def _upload_part(self, url: str, part_number: int) -> dict:
377
+ with open(self.file_path, "rb") as f:
378
+ start = (part_number - 1) * self.chunk_size
379
+ f.seek(start)
380
+ data = f.read(self.chunk_size)
381
+ req = Request(
382
+ url,
383
+ method="PUT",
384
+ headers={
385
+ **self.auth_headers,
386
+ "Content-Type": self.content_type,
387
+ },
388
+ data=data,
389
+ )
390
+
391
+ try:
392
+ with urlopen(req) as resp:
393
+ return {
394
+ "partNumber": part_number,
395
+ "etag": resp.headers["ETag"],
396
+ }
397
+ except HTTPError as exc:
398
+ raise FileUploadException(
399
+ f"Error uploading part {part_number} to {url}. "
400
+ f"Status {exc.status}: {exc.reason}"
401
+ )
402
+
403
+ def upload(self) -> None:
404
+ import concurrent.futures
405
+
406
+ parts = math.ceil(os.path.getsize(self.file_path) / self.chunk_size)
407
+ with concurrent.futures.ThreadPoolExecutor(
408
+ max_workers=self.max_concurrency
409
+ ) as executor:
410
+ futures = []
411
+ for part_number in range(1, parts + 1):
412
+ upload_url = (
413
+ f"{self.access_url}/multipart/{self.upload_id}/{part_number}"
414
+ )
415
+ futures.append(
416
+ executor.submit(self._upload_part, upload_url, part_number)
417
+ )
418
+
419
+ for future in concurrent.futures.as_completed(futures):
420
+ entry = future.result()
421
+ self._parts.append(entry)
422
+
423
+ def complete(self) -> str:
424
+ url = f"{self.access_url}/multipart/{self.upload_id}/complete"
425
+ try:
426
+ req = Request(
427
+ url,
428
+ method="POST",
429
+ headers={
430
+ **self.auth_headers,
431
+ "Accept": "application/json",
432
+ "Content-Type": "application/json",
433
+ },
434
+ data=json.dumps({"parts": self._parts}).encode(),
435
+ )
436
+ with urlopen(req):
437
+ pass
438
+ except HTTPError as e:
439
+ raise FileUploadException(
440
+ f"Error completing upload {url}. Status {e.status}: {e.reason}"
441
+ )
442
+
443
+ return self.access_url
444
+
445
+
90
446
  @dataclass
91
447
  class FalFileRepositoryV2(FalFileRepositoryBase):
92
- def save(self, file: FileData) -> str:
93
- return self._save(file, "fal-cdn")
448
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
449
+ def save(
450
+ self, file: FileData, object_lifecycle_preference: dict[str, str] | None = None
451
+ ) -> str:
452
+ token = fal_v2_token_manager.get_token()
453
+ headers = {
454
+ "Authorization": f"{token.token_type} {token.token}",
455
+ "Accept": "application/json",
456
+ "X-Fal-File-Name": file.file_name,
457
+ "Content-Type": file.content_type,
458
+ }
459
+
460
+ storage_url = f"{token.base_upload_url}/upload"
461
+
462
+ try:
463
+ req = Request(
464
+ storage_url,
465
+ data=file.data,
466
+ headers=headers,
467
+ method="PUT",
468
+ )
469
+ with urlopen(req) as response:
470
+ result = json.load(response)
471
+
472
+ return result["file_url"]
473
+ except HTTPError as e:
474
+ raise FileUploadException(
475
+ f"Error initiating upload. Status {e.status}: {e.reason}"
476
+ )
477
+
478
+ def _save_multipart(
479
+ self,
480
+ file_path: str | Path,
481
+ chunk_size: int | None = None,
482
+ content_type: str | None = None,
483
+ max_concurrency: int | None = None,
484
+ ) -> str:
485
+ multipart = MultipartUpload(
486
+ file_path,
487
+ chunk_size=chunk_size,
488
+ content_type=content_type,
489
+ max_concurrency=max_concurrency,
490
+ )
491
+ multipart.create()
492
+ multipart.upload()
493
+ return multipart.complete()
494
+
495
+ def save_file(
496
+ self,
497
+ file_path: str | Path,
498
+ content_type: str,
499
+ multipart: bool | None = None,
500
+ multipart_threshold: int | None = None,
501
+ multipart_chunk_size: int | None = None,
502
+ multipart_max_concurrency: int | None = None,
503
+ object_lifecycle_preference: dict[str, str] | None = None,
504
+ ) -> tuple[str, FileData | None]:
505
+ if multipart is None:
506
+ threshold = multipart_threshold or MultipartUpload.MULTIPART_THRESHOLD
507
+ multipart = os.path.getsize(file_path) > threshold
508
+
509
+ if multipart:
510
+ url = self._save_multipart(
511
+ file_path,
512
+ chunk_size=multipart_chunk_size,
513
+ content_type=content_type,
514
+ max_concurrency=multipart_max_concurrency,
515
+ )
516
+ data = None
517
+ else:
518
+ with open(file_path, "rb") as f:
519
+ data = FileData(
520
+ f.read(),
521
+ content_type=content_type,
522
+ file_name=os.path.basename(file_path),
523
+ )
524
+ url = self.save(data, object_lifecycle_preference)
525
+
526
+ return url, data
94
527
 
95
528
 
96
529
  @dataclass
@@ -98,25 +531,38 @@ class InMemoryRepository(FileRepository):
98
531
  def save(
99
532
  self,
100
533
  file: FileData,
534
+ object_lifecycle_preference: dict[str, str] | None = None,
101
535
  ) -> str:
102
536
  return f'data:{file.content_type};base64,{b64encode(file.data).decode("utf-8")}'
103
537
 
104
538
 
105
539
  @dataclass
106
540
  class FalCDNFileRepository(FileRepository):
541
+ def _object_lifecycle_headers(
542
+ self,
543
+ headers: dict[str, str],
544
+ object_lifecycle_preference: dict[str, str] | None,
545
+ ):
546
+ if object_lifecycle_preference:
547
+ headers["X-Fal-Object-Lifecycle-Preference"] = json.dumps(
548
+ object_lifecycle_preference
549
+ )
550
+
551
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
107
552
  def save(
108
553
  self,
109
554
  file: FileData,
555
+ object_lifecycle_preference: dict[str, str] | None = None,
110
556
  ) -> str:
111
557
  headers = {
112
558
  **self.auth_headers,
113
559
  "Accept": "application/json",
114
560
  "Content-Type": file.content_type,
115
561
  "X-Fal-File-Name": file.file_name,
116
- "X-Fal-Object-Lifecycle-Preference": json.dumps(
117
- dataclasses.asdict(GLOBAL_LIFECYCLE_PREFERENCE)
118
- ),
119
562
  }
563
+
564
+ self._object_lifecycle_headers(headers, object_lifecycle_preference)
565
+
120
566
  url = os.getenv("FAL_CDN_HOST", _FAL_CDN) + "/files/upload"
121
567
  request = Request(url, headers=headers, method="POST", data=file.data)
122
568
  try:
@@ -141,3 +587,105 @@ class FalCDNFileRepository(FileRepository):
141
587
  "Authorization": f"Bearer {key_id}:{key_secret}",
142
588
  "User-Agent": "fal/0.1.0",
143
589
  }
590
+
591
+
592
+ # This is only available for internal users to have long-lived access tokens
593
+ @dataclass
594
+ class InternalFalFileRepositoryV3(FileRepository):
595
+ """
596
+ InternalFalFileRepositoryV3 is a file repository that uses the FAL CDN V3.
597
+ But generates and uses long-lived access tokens.
598
+ That way it can avoid the need to refresh the token for every upload.
599
+ """
600
+
601
+ def _object_lifecycle_headers(
602
+ self,
603
+ headers: dict[str, str],
604
+ object_lifecycle_preference: dict[str, str] | None,
605
+ ):
606
+ if object_lifecycle_preference:
607
+ headers["X-Fal-Object-Lifecycle"] = json.dumps(object_lifecycle_preference)
608
+
609
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
610
+ def save(
611
+ self, file: FileData, object_lifecycle_preference: dict[str, str] | None
612
+ ) -> str:
613
+ headers = {
614
+ **self.auth_headers,
615
+ "Accept": "application/json",
616
+ "Content-Type": file.content_type,
617
+ "X-Fal-File-Name": file.file_name,
618
+ }
619
+
620
+ self._object_lifecycle_headers(headers, object_lifecycle_preference)
621
+
622
+ url = os.getenv("FAL_CDN_V3_HOST", _FAL_CDN_V3) + "/files/upload"
623
+ request = Request(url, headers=headers, method="POST", data=file.data)
624
+ try:
625
+ with urlopen(request) as response:
626
+ result = json.load(response)
627
+ except HTTPError as e:
628
+ raise FileUploadException(
629
+ f"Error initiating upload. Status {e.status}: {e.reason}"
630
+ )
631
+
632
+ access_url = result["access_url"]
633
+ return access_url
634
+
635
+ @property
636
+ def auth_headers(self) -> dict[str, str]:
637
+ token = fal_v3_token_manager.get_token()
638
+ return {
639
+ "Authorization": f"{token.token_type} {token.token}",
640
+ "User-Agent": "fal/0.1.0",
641
+ }
642
+
643
+ def _save_multipart(
644
+ self,
645
+ file_path: str | Path,
646
+ chunk_size: int | None = None,
647
+ content_type: str | None = None,
648
+ max_concurrency: int | None = None,
649
+ ) -> str:
650
+ multipart = InternalMultipartUploadV3(
651
+ file_path,
652
+ chunk_size=chunk_size,
653
+ content_type=content_type,
654
+ max_concurrency=max_concurrency,
655
+ )
656
+ multipart.create()
657
+ multipart.upload()
658
+ return multipart.complete()
659
+
660
+ def save_file(
661
+ self,
662
+ file_path: str | Path,
663
+ content_type: str,
664
+ multipart: bool | None = None,
665
+ multipart_threshold: int | None = None,
666
+ multipart_chunk_size: int | None = None,
667
+ multipart_max_concurrency: int | None = None,
668
+ object_lifecycle_preference: dict[str, str] | None = None,
669
+ ) -> tuple[str, FileData | None]:
670
+ if multipart is None:
671
+ threshold = multipart_threshold or MultipartUpload.MULTIPART_THRESHOLD
672
+ multipart = os.path.getsize(file_path) > threshold
673
+
674
+ if multipart:
675
+ url = self._save_multipart(
676
+ file_path,
677
+ chunk_size=multipart_chunk_size,
678
+ content_type=content_type,
679
+ max_concurrency=multipart_max_concurrency,
680
+ )
681
+ data = None
682
+ else:
683
+ with open(file_path, "rb") as f:
684
+ data = FileData(
685
+ f.read(),
686
+ content_type=content_type,
687
+ file_name=os.path.basename(file_path),
688
+ )
689
+ url = self.save(data, object_lifecycle_preference)
690
+
691
+ return url, data
@@ -6,8 +6,10 @@ import os
6
6
  import posixpath
7
7
  import uuid
8
8
  from dataclasses import dataclass
9
+ from typing import Optional
9
10
 
10
11
  from fal.toolkit.file.types import FileData, FileRepository
12
+ from fal.toolkit.utils.retry import retry
11
13
 
12
14
  DEFAULT_URL_TIMEOUT = 60 * 15 # 15 minutes
13
15
 
@@ -50,7 +52,12 @@ class GoogleStorageRepository(FileRepository):
50
52
 
51
53
  return self._bucket
52
54
 
53
- def save(self, data: FileData) -> str:
55
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
56
+ def save(
57
+ self,
58
+ data: FileData,
59
+ object_lifecycle_preference: Optional[dict[str, str]] = None,
60
+ ) -> str:
54
61
  destination_path = posixpath.join(
55
62
  self.folder,
56
63
  f"{uuid.uuid4().hex}_{data.file_name}",