fal 1.3.3__py3-none-any.whl → 1.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

@@ -1,34 +1,130 @@
1
1
  from __future__ import annotations
2
2
 
3
- import dataclasses
4
3
  import json
5
4
  import math
6
5
  import os
6
+ import threading
7
7
  from base64 import b64encode
8
8
  from dataclasses import dataclass
9
+ from datetime import datetime, timezone
9
10
  from pathlib import Path
11
+ from typing import Generic, TypeVar
10
12
  from urllib.error import HTTPError
13
+ from urllib.parse import urlparse, urlunparse
11
14
  from urllib.request import Request, urlopen
12
15
 
13
16
  from fal.auth import key_credentials
14
17
  from fal.toolkit.exceptions import FileUploadException
15
18
  from fal.toolkit.file.types import FileData, FileRepository
19
+ from fal.toolkit.utils.retry import retry
16
20
 
17
21
  _FAL_CDN = "https://fal.media"
22
+ _FAL_CDN_V3 = "https://v3.fal.media"
18
23
 
19
24
 
20
25
  @dataclass
21
- class ObjectLifecyclePreference:
22
- expriation_duration_seconds: int
26
+ class FalV2Token:
27
+ token: str
28
+ token_type: str
29
+ base_upload_url: str
30
+ expires_at: datetime
23
31
 
32
+ def is_expired(self) -> bool:
33
+ return datetime.now(timezone.utc) >= self.expires_at
24
34
 
25
- GLOBAL_LIFECYCLE_PREFERENCE = ObjectLifecyclePreference(
26
- expriation_duration_seconds=86400
27
- )
35
+
36
+ class FalV3Token(FalV2Token):
37
+ pass
38
+
39
+
40
+ class FalV2TokenManager:
41
+ token_cls: type[FalV2Token] = FalV2Token
42
+ storage_type: str = "fal-cdn"
43
+ upload_prefix = "upload."
44
+
45
+ def __init__(self):
46
+ self._token: FalV2Token = self.token_cls(
47
+ token="",
48
+ token_type="",
49
+ base_upload_url="",
50
+ expires_at=datetime.min.replace(tzinfo=timezone.utc),
51
+ )
52
+ self._lock: threading.Lock = threading.Lock()
53
+
54
+ def get_token(self) -> FalV2Token:
55
+ with self._lock:
56
+ if self._token.is_expired():
57
+ self._refresh_token()
58
+ return self._token
59
+
60
+ def _refresh_token(self) -> None:
61
+ key_creds = key_credentials()
62
+ if not key_creds:
63
+ raise FileUploadException("FAL_KEY must be set")
64
+
65
+ key_id, key_secret = key_creds
66
+ headers = {
67
+ "Authorization": f"Key {key_id}:{key_secret}",
68
+ "Accept": "application/json",
69
+ "Content-Type": "application/json",
70
+ }
71
+
72
+ grpc_host = os.environ.get("FAL_HOST", "api.alpha.fal.ai")
73
+ rest_host = grpc_host.replace("api", "rest", 1)
74
+ url = f"https://{rest_host}/storage/auth/token?storage_type={self.storage_type}"
75
+
76
+ req = Request(
77
+ url,
78
+ headers=headers,
79
+ data=b"{}",
80
+ method="POST",
81
+ )
82
+ with urlopen(req) as response:
83
+ result = json.load(response)
84
+
85
+ parsed_base_url = urlparse(result["base_url"])
86
+ base_upload_url = urlunparse(
87
+ parsed_base_url._replace(netloc=self.upload_prefix + parsed_base_url.netloc)
88
+ )
89
+
90
+ self._token = self.token_cls(
91
+ token=result["token"],
92
+ token_type=result["token_type"],
93
+ base_upload_url=base_upload_url,
94
+ expires_at=datetime.fromisoformat(result["expires_at"]),
95
+ )
96
+
97
+
98
+ class FalV3TokenManager(FalV2TokenManager):
99
+ token_cls: type[FalV2Token] = FalV3Token
100
+ storage_type: str = "fal-cdn-v3"
101
+ upload_prefix = ""
102
+
103
+
104
+ fal_v2_token_manager = FalV2TokenManager()
105
+ fal_v3_token_manager = FalV3TokenManager()
106
+
107
+
108
+ VariableType = TypeVar("VariableType")
109
+
110
+
111
+ class VariableReference(Generic[VariableType]):
112
+ def __init__(self, value: VariableType) -> None:
113
+ self.set(value)
114
+
115
+ def get(self) -> VariableType:
116
+ return self.value
117
+
118
+ def set(self, value: VariableType) -> None:
119
+ self.value = value
120
+
121
+
122
+ LIFECYCLE_PREFERENCE: VariableReference[dict[str, str] | None] = VariableReference(None)
28
123
 
29
124
 
30
125
  @dataclass
31
126
  class FalFileRepositoryBase(FileRepository):
127
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
32
128
  def _save(self, file: FileData, storage_type: str) -> str:
33
129
  key_creds = key_credentials()
34
130
  if not key_creds:
@@ -63,32 +159,45 @@ class FalFileRepositoryBase(FileRepository):
63
159
  result = json.load(response)
64
160
 
65
161
  upload_url = result["upload_url"]
66
- self._upload_file(upload_url, file)
67
-
68
- return result["file_url"]
69
162
  except HTTPError as e:
70
163
  raise FileUploadException(
71
164
  f"Error initiating upload. Status {e.status}: {e.reason}"
72
165
  )
73
166
 
74
- def _upload_file(self, upload_url: str, file: FileData):
75
- req = Request(
76
- upload_url,
77
- method="PUT",
78
- data=file.data,
79
- headers={"Content-Type": file.content_type},
80
- )
167
+ try:
168
+ req = Request(
169
+ upload_url,
170
+ method="PUT",
171
+ data=file.data,
172
+ headers={"Content-Type": file.content_type},
173
+ )
174
+
175
+ with urlopen(req):
176
+ pass
81
177
 
82
- with urlopen(req):
83
- return
178
+ return result["file_url"]
179
+ except HTTPError as e:
180
+ raise FileUploadException(
181
+ f"Error uploading file. Status {e.status}: {e.reason}"
182
+ )
84
183
 
85
184
 
86
185
  @dataclass
87
186
  class FalFileRepository(FalFileRepositoryBase):
88
- def save(self, file: FileData) -> str:
187
+ def save(
188
+ self, file: FileData, object_lifecycle_preference: dict[str, str] | None = None
189
+ ) -> str:
89
190
  return self._save(file, "gcs")
90
191
 
91
192
 
193
+ @dataclass
194
+ class FalFileRepositoryV3(FalFileRepositoryBase):
195
+ def save(
196
+ self, file: FileData, object_lifecycle_preference: dict[str, str] | None = None
197
+ ) -> str:
198
+ return self._save(file, "fal-cdn-v3")
199
+
200
+
92
201
  class MultipartUpload:
93
202
  MULTIPART_THRESHOLD = 100 * 1024 * 1024
94
203
  MULTIPART_CHUNK_SIZE = 100 * 1024 * 1024
@@ -108,26 +217,14 @@ class MultipartUpload:
108
217
 
109
218
  self._parts: list[dict] = []
110
219
 
111
- key_creds = key_credentials()
112
- if not key_creds:
113
- raise FileUploadException("FAL_KEY must be set")
114
-
115
- key_id, key_secret = key_creds
116
-
117
- self._auth_headers = {
118
- "Authorization": f"Key {key_id}:{key_secret}",
119
- }
120
- grpc_host = os.environ.get("FAL_HOST", "api.alpha.fal.ai")
121
- rest_host = grpc_host.replace("api", "rest", 1)
122
- self._storage_upload_url = f"https://{rest_host}/storage/upload"
123
-
124
220
  def create(self):
221
+ token = fal_v2_token_manager.get_token()
125
222
  try:
126
223
  req = Request(
127
- f"{self._storage_upload_url}/initiate-multipart",
224
+ f"{token.base_upload_url}/upload/initiate-multipart",
128
225
  method="POST",
129
226
  headers={
130
- **self._auth_headers,
227
+ "Authorization": f"{token.token_type} {token.token}",
131
228
  "Accept": "application/json",
132
229
  "Content-Type": "application/json",
133
230
  },
@@ -140,7 +237,7 @@ class MultipartUpload:
140
237
  )
141
238
  with urlopen(req) as response:
142
239
  result = json.load(response)
143
- self._upload_id = result["upload_id"]
240
+ self._upload_url = result["upload_url"]
144
241
  self._file_url = result["file_url"]
145
242
  except HTTPError as exc:
146
243
  raise FileUploadException(
@@ -180,10 +277,7 @@ class MultipartUpload:
180
277
  ) as executor:
181
278
  futures = []
182
279
  for part_number in range(1, parts + 1):
183
- upload_url = (
184
- f"{self._file_url}?upload_id={self._upload_id}"
185
- f"&part_number={part_number}"
186
- )
280
+ upload_url = f"{self._upload_url}&part_number={part_number}"
187
281
  futures.append(
188
282
  executor.submit(self._upload_part, upload_url, part_number)
189
283
  )
@@ -193,7 +287,7 @@ class MultipartUpload:
193
287
  self._parts.append(entry)
194
288
 
195
289
  def complete(self):
196
- url = f"{self._file_url}?upload_id={self._upload_id}"
290
+ url = self._upload_url
197
291
  try:
198
292
  req = Request(
199
293
  url,
@@ -214,10 +308,172 @@ class MultipartUpload:
214
308
  return self._file_url
215
309
 
216
310
 
311
+ class InternalMultipartUploadV3:
312
+ MULTIPART_THRESHOLD = 100 * 1024 * 1024
313
+ MULTIPART_CHUNK_SIZE = 10 * 1024 * 1024
314
+ MULTIPART_MAX_CONCURRENCY = 10
315
+
316
+ def __init__(
317
+ self,
318
+ file_path: str | Path,
319
+ chunk_size: int | None = None,
320
+ content_type: str | None = None,
321
+ max_concurrency: int | None = None,
322
+ ) -> None:
323
+ self.file_path = file_path
324
+ self.chunk_size = chunk_size or self.MULTIPART_CHUNK_SIZE
325
+ self.content_type = content_type or "application/octet-stream"
326
+ self.max_concurrency = max_concurrency or self.MULTIPART_MAX_CONCURRENCY
327
+ self._access_url: str | None = None
328
+ self._upload_id: str | None = None
329
+
330
+ self._parts: list[dict] = []
331
+
332
+ @property
333
+ def access_url(self) -> str:
334
+ if not self._access_url:
335
+ raise FileUploadException("Upload not initiated")
336
+ return self._access_url
337
+
338
+ @property
339
+ def upload_id(self) -> str:
340
+ if not self._upload_id:
341
+ raise FileUploadException("Upload not initiated")
342
+ return self._upload_id
343
+
344
+ @property
345
+ def auth_headers(self) -> dict[str, str]:
346
+ token = fal_v3_token_manager.get_token()
347
+ return {
348
+ "Authorization": f"{token.token_type} {token.token}",
349
+ "User-Agent": "fal/0.1.0",
350
+ }
351
+
352
+ def create(self):
353
+ token = fal_v3_token_manager.get_token()
354
+ try:
355
+ req = Request(
356
+ f"{token.base_upload_url}/files/upload/multipart",
357
+ method="POST",
358
+ headers={
359
+ **self.auth_headers,
360
+ "Accept": "application/json",
361
+ "Content-Type": self.content_type,
362
+ "X-Fal-File-Name": os.path.basename(self.file_path),
363
+ },
364
+ )
365
+ with urlopen(req) as response:
366
+ result = json.load(response)
367
+ self._access_url = result["access_url"]
368
+ self._upload_id = result["uploadId"]
369
+
370
+ except HTTPError as exc:
371
+ raise FileUploadException(
372
+ f"Error initiating upload. Status {exc.status}: {exc.reason}"
373
+ )
374
+
375
+ @retry(max_retries=5, base_delay=1, backoff_type="exponential", jitter=True)
376
+ def _upload_part(self, url: str, part_number: int) -> dict:
377
+ with open(self.file_path, "rb") as f:
378
+ start = (part_number - 1) * self.chunk_size
379
+ f.seek(start)
380
+ data = f.read(self.chunk_size)
381
+ req = Request(
382
+ url,
383
+ method="PUT",
384
+ headers={
385
+ **self.auth_headers,
386
+ "Content-Type": self.content_type,
387
+ },
388
+ data=data,
389
+ )
390
+
391
+ try:
392
+ with urlopen(req) as resp:
393
+ return {
394
+ "partNumber": part_number,
395
+ "etag": resp.headers["ETag"],
396
+ }
397
+ except HTTPError as exc:
398
+ raise FileUploadException(
399
+ f"Error uploading part {part_number} to {url}. "
400
+ f"Status {exc.status}: {exc.reason}"
401
+ )
402
+
403
+ def upload(self) -> None:
404
+ import concurrent.futures
405
+
406
+ parts = math.ceil(os.path.getsize(self.file_path) / self.chunk_size)
407
+ with concurrent.futures.ThreadPoolExecutor(
408
+ max_workers=self.max_concurrency
409
+ ) as executor:
410
+ futures = []
411
+ for part_number in range(1, parts + 1):
412
+ upload_url = (
413
+ f"{self.access_url}/multipart/{self.upload_id}/{part_number}"
414
+ )
415
+ futures.append(
416
+ executor.submit(self._upload_part, upload_url, part_number)
417
+ )
418
+
419
+ for future in concurrent.futures.as_completed(futures):
420
+ entry = future.result()
421
+ self._parts.append(entry)
422
+
423
+ def complete(self) -> str:
424
+ url = f"{self.access_url}/multipart/{self.upload_id}/complete"
425
+ try:
426
+ req = Request(
427
+ url,
428
+ method="POST",
429
+ headers={
430
+ **self.auth_headers,
431
+ "Accept": "application/json",
432
+ "Content-Type": "application/json",
433
+ },
434
+ data=json.dumps({"parts": self._parts}).encode(),
435
+ )
436
+ with urlopen(req):
437
+ pass
438
+ except HTTPError as e:
439
+ raise FileUploadException(
440
+ f"Error completing upload {url}. Status {e.status}: {e.reason}"
441
+ )
442
+
443
+ return self.access_url
444
+
445
+
217
446
  @dataclass
218
447
  class FalFileRepositoryV2(FalFileRepositoryBase):
219
- def save(self, file: FileData) -> str:
220
- return self._save(file, "fal-cdn")
448
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
449
+ def save(
450
+ self, file: FileData, object_lifecycle_preference: dict[str, str] | None = None
451
+ ) -> str:
452
+ token = fal_v2_token_manager.get_token()
453
+ headers = {
454
+ "Authorization": f"{token.token_type} {token.token}",
455
+ "Accept": "application/json",
456
+ "X-Fal-File-Name": file.file_name,
457
+ "Content-Type": file.content_type,
458
+ }
459
+
460
+ storage_url = f"{token.base_upload_url}/upload"
461
+
462
+ try:
463
+ req = Request(
464
+ storage_url,
465
+ data=file.data,
466
+ headers=headers,
467
+ method="PUT",
468
+ )
469
+ with urlopen(req) as response:
470
+ result = json.load(response)
471
+
472
+ return result["file_url"]
473
+ except HTTPError as e:
474
+ raise FileUploadException(
475
+ f"Error initiating upload. Status {e.status}: {e.reason}"
476
+ )
221
477
 
222
478
  def _save_multipart(
223
479
  self,
@@ -244,6 +500,7 @@ class FalFileRepositoryV2(FalFileRepositoryBase):
244
500
  multipart_threshold: int | None = None,
245
501
  multipart_chunk_size: int | None = None,
246
502
  multipart_max_concurrency: int | None = None,
503
+ object_lifecycle_preference: dict[str, str] | None = None,
247
504
  ) -> tuple[str, FileData | None]:
248
505
  if multipart is None:
249
506
  threshold = multipart_threshold or MultipartUpload.MULTIPART_THRESHOLD
@@ -264,7 +521,7 @@ class FalFileRepositoryV2(FalFileRepositoryBase):
264
521
  content_type=content_type,
265
522
  file_name=os.path.basename(file_path),
266
523
  )
267
- url = self.save(data)
524
+ url = self.save(data, object_lifecycle_preference)
268
525
 
269
526
  return url, data
270
527
 
@@ -274,25 +531,38 @@ class InMemoryRepository(FileRepository):
274
531
  def save(
275
532
  self,
276
533
  file: FileData,
534
+ object_lifecycle_preference: dict[str, str] | None = None,
277
535
  ) -> str:
278
536
  return f'data:{file.content_type};base64,{b64encode(file.data).decode("utf-8")}'
279
537
 
280
538
 
281
539
  @dataclass
282
540
  class FalCDNFileRepository(FileRepository):
541
+ def _object_lifecycle_headers(
542
+ self,
543
+ headers: dict[str, str],
544
+ object_lifecycle_preference: dict[str, str] | None,
545
+ ):
546
+ if object_lifecycle_preference:
547
+ headers["X-Fal-Object-Lifecycle-Preference"] = json.dumps(
548
+ object_lifecycle_preference
549
+ )
550
+
551
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
283
552
  def save(
284
553
  self,
285
554
  file: FileData,
555
+ object_lifecycle_preference: dict[str, str] | None = None,
286
556
  ) -> str:
287
557
  headers = {
288
558
  **self.auth_headers,
289
559
  "Accept": "application/json",
290
560
  "Content-Type": file.content_type,
291
561
  "X-Fal-File-Name": file.file_name,
292
- "X-Fal-Object-Lifecycle-Preference": json.dumps(
293
- dataclasses.asdict(GLOBAL_LIFECYCLE_PREFERENCE)
294
- ),
295
562
  }
563
+
564
+ self._object_lifecycle_headers(headers, object_lifecycle_preference)
565
+
296
566
  url = os.getenv("FAL_CDN_HOST", _FAL_CDN) + "/files/upload"
297
567
  request = Request(url, headers=headers, method="POST", data=file.data)
298
568
  try:
@@ -317,3 +587,105 @@ class FalCDNFileRepository(FileRepository):
317
587
  "Authorization": f"Bearer {key_id}:{key_secret}",
318
588
  "User-Agent": "fal/0.1.0",
319
589
  }
590
+
591
+
592
+ # This is only available for internal users to have long-lived access tokens
593
+ @dataclass
594
+ class InternalFalFileRepositoryV3(FileRepository):
595
+ """
596
+ InternalFalFileRepositoryV3 is a file repository that uses the FAL CDN V3.
597
+ But generates and uses long-lived access tokens.
598
+ That way it can avoid the need to refresh the token for every upload.
599
+ """
600
+
601
+ def _object_lifecycle_headers(
602
+ self,
603
+ headers: dict[str, str],
604
+ object_lifecycle_preference: dict[str, str] | None,
605
+ ):
606
+ if object_lifecycle_preference:
607
+ headers["X-Fal-Object-Lifecycle"] = json.dumps(object_lifecycle_preference)
608
+
609
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
610
+ def save(
611
+ self, file: FileData, object_lifecycle_preference: dict[str, str] | None
612
+ ) -> str:
613
+ headers = {
614
+ **self.auth_headers,
615
+ "Accept": "application/json",
616
+ "Content-Type": file.content_type,
617
+ "X-Fal-File-Name": file.file_name,
618
+ }
619
+
620
+ self._object_lifecycle_headers(headers, object_lifecycle_preference)
621
+
622
+ url = os.getenv("FAL_CDN_V3_HOST", _FAL_CDN_V3) + "/files/upload"
623
+ request = Request(url, headers=headers, method="POST", data=file.data)
624
+ try:
625
+ with urlopen(request) as response:
626
+ result = json.load(response)
627
+ except HTTPError as e:
628
+ raise FileUploadException(
629
+ f"Error initiating upload. Status {e.status}: {e.reason}"
630
+ )
631
+
632
+ access_url = result["access_url"]
633
+ return access_url
634
+
635
+ @property
636
+ def auth_headers(self) -> dict[str, str]:
637
+ token = fal_v3_token_manager.get_token()
638
+ return {
639
+ "Authorization": f"{token.token_type} {token.token}",
640
+ "User-Agent": "fal/0.1.0",
641
+ }
642
+
643
+ def _save_multipart(
644
+ self,
645
+ file_path: str | Path,
646
+ chunk_size: int | None = None,
647
+ content_type: str | None = None,
648
+ max_concurrency: int | None = None,
649
+ ) -> str:
650
+ multipart = InternalMultipartUploadV3(
651
+ file_path,
652
+ chunk_size=chunk_size,
653
+ content_type=content_type,
654
+ max_concurrency=max_concurrency,
655
+ )
656
+ multipart.create()
657
+ multipart.upload()
658
+ return multipart.complete()
659
+
660
+ def save_file(
661
+ self,
662
+ file_path: str | Path,
663
+ content_type: str,
664
+ multipart: bool | None = None,
665
+ multipart_threshold: int | None = None,
666
+ multipart_chunk_size: int | None = None,
667
+ multipart_max_concurrency: int | None = None,
668
+ object_lifecycle_preference: dict[str, str] | None = None,
669
+ ) -> tuple[str, FileData | None]:
670
+ if multipart is None:
671
+ threshold = multipart_threshold or MultipartUpload.MULTIPART_THRESHOLD
672
+ multipart = os.path.getsize(file_path) > threshold
673
+
674
+ if multipart:
675
+ url = self._save_multipart(
676
+ file_path,
677
+ chunk_size=multipart_chunk_size,
678
+ content_type=content_type,
679
+ max_concurrency=multipart_max_concurrency,
680
+ )
681
+ data = None
682
+ else:
683
+ with open(file_path, "rb") as f:
684
+ data = FileData(
685
+ f.read(),
686
+ content_type=content_type,
687
+ file_name=os.path.basename(file_path),
688
+ )
689
+ url = self.save(data, object_lifecycle_preference)
690
+
691
+ return url, data
@@ -6,8 +6,10 @@ import os
6
6
  import posixpath
7
7
  import uuid
8
8
  from dataclasses import dataclass
9
+ from typing import Optional
9
10
 
10
11
  from fal.toolkit.file.types import FileData, FileRepository
12
+ from fal.toolkit.utils.retry import retry
11
13
 
12
14
  DEFAULT_URL_TIMEOUT = 60 * 15 # 15 minutes
13
15
 
@@ -50,7 +52,12 @@ class GoogleStorageRepository(FileRepository):
50
52
 
51
53
  return self._bucket
52
54
 
53
- def save(self, data: FileData) -> str:
55
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
56
+ def save(
57
+ self,
58
+ data: FileData,
59
+ object_lifecycle_preference: Optional[dict[str, str]] = None,
60
+ ) -> str:
54
61
  destination_path = posixpath.join(
55
62
  self.folder,
56
63
  f"{uuid.uuid4().hex}_{data.file_name}",
@@ -6,8 +6,10 @@ import posixpath
6
6
  import uuid
7
7
  from dataclasses import dataclass
8
8
  from io import BytesIO
9
+ from typing import Optional
9
10
 
10
11
  from fal.toolkit.file.types import FileData, FileRepository
12
+ from fal.toolkit.utils.retry import retry
11
13
 
12
14
  DEFAULT_URL_TIMEOUT = 60 * 15 # 15 minutes
13
15
 
@@ -67,7 +69,12 @@ class R2Repository(FileRepository):
67
69
 
68
70
  return self._bucket
69
71
 
70
- def save(self, data: FileData) -> str:
72
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
73
+ def save(
74
+ self,
75
+ data: FileData,
76
+ object_lifecycle_preference: Optional[dict[str, str]] = None,
77
+ ) -> str:
71
78
  destination_path = posixpath.join(
72
79
  self.key,
73
80
  f"{uuid.uuid4().hex}_{data.file_name}",