fal 1.3.3__py3-none-any.whl → 1.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

@@ -1,34 +1,130 @@
1
1
  from __future__ import annotations
2
2
 
3
- import dataclasses
4
3
  import json
5
4
  import math
6
5
  import os
6
+ import threading
7
7
  from base64 import b64encode
8
8
  from dataclasses import dataclass
9
+ from datetime import datetime, timezone
9
10
  from pathlib import Path
11
+ from typing import Generic, TypeVar
10
12
  from urllib.error import HTTPError
13
+ from urllib.parse import urlparse, urlunparse
11
14
  from urllib.request import Request, urlopen
12
15
 
13
16
  from fal.auth import key_credentials
14
17
  from fal.toolkit.exceptions import FileUploadException
15
18
  from fal.toolkit.file.types import FileData, FileRepository
19
+ from fal.toolkit.utils.retry import retry
16
20
 
17
21
  _FAL_CDN = "https://fal.media"
22
+ _FAL_CDN_V3 = "https://v3.fal.media"
18
23
 
19
24
 
20
25
  @dataclass
21
- class ObjectLifecyclePreference:
22
- expriation_duration_seconds: int
26
+ class FalV2Token:
27
+ token: str
28
+ token_type: str
29
+ base_upload_url: str
30
+ expires_at: datetime
23
31
 
32
+ def is_expired(self) -> bool:
33
+ return datetime.now(timezone.utc) >= self.expires_at
24
34
 
25
- GLOBAL_LIFECYCLE_PREFERENCE = ObjectLifecyclePreference(
26
- expriation_duration_seconds=86400
27
- )
35
+
36
+ class FalV3Token(FalV2Token):
37
+ pass
38
+
39
+
40
+ class FalV2TokenManager:
41
+ token_cls: type[FalV2Token] = FalV2Token
42
+ storage_type: str = "fal-cdn"
43
+ upload_prefix = "upload."
44
+
45
+ def __init__(self):
46
+ self._token: FalV2Token = self.token_cls(
47
+ token="",
48
+ token_type="",
49
+ base_upload_url="",
50
+ expires_at=datetime.min.replace(tzinfo=timezone.utc),
51
+ )
52
+ self._lock: threading.Lock = threading.Lock()
53
+
54
+ def get_token(self) -> FalV2Token:
55
+ with self._lock:
56
+ if self._token.is_expired():
57
+ self._refresh_token()
58
+ return self._token
59
+
60
+ def _refresh_token(self) -> None:
61
+ key_creds = key_credentials()
62
+ if not key_creds:
63
+ raise FileUploadException("FAL_KEY must be set")
64
+
65
+ key_id, key_secret = key_creds
66
+ headers = {
67
+ "Authorization": f"Key {key_id}:{key_secret}",
68
+ "Accept": "application/json",
69
+ "Content-Type": "application/json",
70
+ }
71
+
72
+ grpc_host = os.environ.get("FAL_HOST", "api.alpha.fal.ai")
73
+ rest_host = grpc_host.replace("api", "rest", 1)
74
+ url = f"https://{rest_host}/storage/auth/token?storage_type={self.storage_type}"
75
+
76
+ req = Request(
77
+ url,
78
+ headers=headers,
79
+ data=b"{}",
80
+ method="POST",
81
+ )
82
+ with urlopen(req) as response:
83
+ result = json.load(response)
84
+
85
+ parsed_base_url = urlparse(result["base_url"])
86
+ base_upload_url = urlunparse(
87
+ parsed_base_url._replace(netloc=self.upload_prefix + parsed_base_url.netloc)
88
+ )
89
+
90
+ self._token = self.token_cls(
91
+ token=result["token"],
92
+ token_type=result["token_type"],
93
+ base_upload_url=base_upload_url,
94
+ expires_at=datetime.fromisoformat(result["expires_at"]),
95
+ )
96
+
97
+
98
+ class FalV3TokenManager(FalV2TokenManager):
99
+ token_cls: type[FalV2Token] = FalV3Token
100
+ storage_type: str = "fal-cdn-v3"
101
+ upload_prefix = ""
102
+
103
+
104
+ fal_v2_token_manager = FalV2TokenManager()
105
+ fal_v3_token_manager = FalV3TokenManager()
106
+
107
+
108
+ VariableType = TypeVar("VariableType")
109
+
110
+
111
+ class VariableReference(Generic[VariableType]):
112
+ def __init__(self, value: VariableType) -> None:
113
+ self.set(value)
114
+
115
+ def get(self) -> VariableType:
116
+ return self.value
117
+
118
+ def set(self, value: VariableType) -> None:
119
+ self.value = value
120
+
121
+
122
+ LIFECYCLE_PREFERENCE: VariableReference[dict[str, str] | None] = VariableReference(None)
28
123
 
29
124
 
30
125
  @dataclass
31
126
  class FalFileRepositoryBase(FileRepository):
127
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
32
128
  def _save(self, file: FileData, storage_type: str) -> str:
33
129
  key_creds = key_credentials()
34
130
  if not key_creds:
@@ -63,32 +159,45 @@ class FalFileRepositoryBase(FileRepository):
63
159
  result = json.load(response)
64
160
 
65
161
  upload_url = result["upload_url"]
66
- self._upload_file(upload_url, file)
67
-
68
- return result["file_url"]
69
162
  except HTTPError as e:
70
163
  raise FileUploadException(
71
164
  f"Error initiating upload. Status {e.status}: {e.reason}"
72
165
  )
73
166
 
74
- def _upload_file(self, upload_url: str, file: FileData):
75
- req = Request(
76
- upload_url,
77
- method="PUT",
78
- data=file.data,
79
- headers={"Content-Type": file.content_type},
80
- )
167
+ try:
168
+ req = Request(
169
+ upload_url,
170
+ method="PUT",
171
+ data=file.data,
172
+ headers={"Content-Type": file.content_type},
173
+ )
174
+
175
+ with urlopen(req):
176
+ pass
81
177
 
82
- with urlopen(req):
83
- return
178
+ return result["file_url"]
179
+ except HTTPError as e:
180
+ raise FileUploadException(
181
+ f"Error uploading file. Status {e.status}: {e.reason}"
182
+ )
84
183
 
85
184
 
86
185
  @dataclass
87
186
  class FalFileRepository(FalFileRepositoryBase):
88
- def save(self, file: FileData) -> str:
187
+ def save(
188
+ self, file: FileData, object_lifecycle_preference: dict[str, str] | None = None
189
+ ) -> str:
89
190
  return self._save(file, "gcs")
90
191
 
91
192
 
193
+ @dataclass
194
+ class FalFileRepositoryV3(FalFileRepositoryBase):
195
+ def save(
196
+ self, file: FileData, object_lifecycle_preference: dict[str, str] | None = None
197
+ ) -> str:
198
+ return self._save(file, "fal-cdn-v3")
199
+
200
+
92
201
  class MultipartUpload:
93
202
  MULTIPART_THRESHOLD = 100 * 1024 * 1024
94
203
  MULTIPART_CHUNK_SIZE = 100 * 1024 * 1024
@@ -96,109 +205,267 @@ class MultipartUpload:
96
205
 
97
206
  def __init__(
98
207
  self,
99
- file_path: str | Path,
208
+ file_name: str,
100
209
  chunk_size: int | None = None,
101
210
  content_type: str | None = None,
102
211
  max_concurrency: int | None = None,
103
212
  ) -> None:
104
- self.file_path = file_path
213
+ self.file_name = file_name
105
214
  self.chunk_size = chunk_size or self.MULTIPART_CHUNK_SIZE
106
215
  self.content_type = content_type or "application/octet-stream"
107
216
  self.max_concurrency = max_concurrency or self.MULTIPART_MAX_CONCURRENCY
108
217
 
109
218
  self._parts: list[dict] = []
110
219
 
111
- key_creds = key_credentials()
112
- if not key_creds:
113
- raise FileUploadException("FAL_KEY must be set")
114
-
115
- key_id, key_secret = key_creds
116
-
117
- self._auth_headers = {
118
- "Authorization": f"Key {key_id}:{key_secret}",
119
- }
120
- grpc_host = os.environ.get("FAL_HOST", "api.alpha.fal.ai")
121
- rest_host = grpc_host.replace("api", "rest", 1)
122
- self._storage_upload_url = f"https://{rest_host}/storage/upload"
123
-
124
220
  def create(self):
221
+ token = fal_v2_token_manager.get_token()
125
222
  try:
126
223
  req = Request(
127
- f"{self._storage_upload_url}/initiate-multipart",
224
+ f"{token.base_upload_url}/upload/initiate-multipart",
128
225
  method="POST",
129
226
  headers={
130
- **self._auth_headers,
227
+ "Authorization": f"{token.token_type} {token.token}",
131
228
  "Accept": "application/json",
132
229
  "Content-Type": "application/json",
133
230
  },
134
231
  data=json.dumps(
135
232
  {
136
- "file_name": os.path.basename(self.file_path),
233
+ "file_name": self.file_name,
137
234
  "content_type": self.content_type,
138
235
  }
139
236
  ).encode(),
140
237
  )
141
238
  with urlopen(req) as response:
142
239
  result = json.load(response)
143
- self._upload_id = result["upload_id"]
240
+ self._upload_url = result["upload_url"]
144
241
  self._file_url = result["file_url"]
145
242
  except HTTPError as exc:
146
243
  raise FileUploadException(
147
244
  f"Error initiating upload. Status {exc.status}: {exc.reason}"
148
245
  )
149
246
 
150
- def _upload_part(self, url: str, part_number: int) -> dict:
151
- with open(self.file_path, "rb") as f:
152
- start = (part_number - 1) * self.chunk_size
153
- f.seek(start)
154
- data = f.read(self.chunk_size)
155
- req = Request(
156
- url,
157
- method="PUT",
158
- headers={"Content-Type": self.content_type},
159
- data=data,
160
- )
247
+ def upload_part(self, part_number: int, data: bytes) -> None:
248
+ url = f"{self._upload_url}&part_number={part_number}"
249
+
250
+ req = Request(
251
+ url,
252
+ method="PUT",
253
+ headers={"Content-Type": self.content_type},
254
+ data=data,
255
+ )
161
256
 
162
- try:
163
- with urlopen(req) as resp:
164
- return {
257
+ try:
258
+ with urlopen(req) as resp:
259
+ self._parts.append(
260
+ {
165
261
  "part_number": part_number,
166
262
  "etag": resp.headers["ETag"],
167
263
  }
168
- except HTTPError as exc:
169
- raise FileUploadException(
170
- f"Error uploading part {part_number} to {url}. "
171
- f"Status {exc.status}: {exc.reason}"
172
264
  )
265
+ except HTTPError as exc:
266
+ raise FileUploadException(
267
+ f"Error uploading part {part_number} to {url}. "
268
+ f"Status {exc.status}: {exc.reason}"
269
+ )
173
270
 
174
- def upload(self) -> None:
271
+ def complete(self):
272
+ url = self._upload_url
273
+ try:
274
+ req = Request(
275
+ url,
276
+ method="POST",
277
+ headers={
278
+ "Accept": "application/json",
279
+ "Content-Type": "application/json",
280
+ },
281
+ data=json.dumps({"parts": self._parts}).encode(),
282
+ )
283
+ with urlopen(req):
284
+ pass
285
+ except HTTPError as e:
286
+ raise FileUploadException(
287
+ f"Error completing upload {url}. Status {e.status}: {e.reason}"
288
+ )
289
+
290
+ return self._file_url
291
+
292
+ @classmethod
293
+ def save(
294
+ cls,
295
+ file: FileData,
296
+ chunk_size: int | None = None,
297
+ max_concurrency: int | None = None,
298
+ ):
175
299
  import concurrent.futures
176
300
 
177
- parts = math.ceil(os.path.getsize(self.file_path) / self.chunk_size)
301
+ multipart = cls(
302
+ file.file_name,
303
+ chunk_size=chunk_size,
304
+ content_type=file.content_type,
305
+ max_concurrency=max_concurrency,
306
+ )
307
+ multipart.create()
308
+
309
+ parts = math.ceil(len(file.data) / multipart.chunk_size)
178
310
  with concurrent.futures.ThreadPoolExecutor(
179
- max_workers=self.max_concurrency
311
+ max_workers=multipart.max_concurrency
180
312
  ) as executor:
181
313
  futures = []
182
314
  for part_number in range(1, parts + 1):
183
- upload_url = (
184
- f"{self._file_url}?upload_id={self._upload_id}"
185
- f"&part_number={part_number}"
186
- )
315
+ start = (part_number - 1) * multipart.chunk_size
316
+ data = file.data[start : start + multipart.chunk_size]
187
317
  futures.append(
188
- executor.submit(self._upload_part, upload_url, part_number)
318
+ executor.submit(multipart.upload_part, part_number, data)
189
319
  )
190
320
 
191
321
  for future in concurrent.futures.as_completed(futures):
192
- entry = future.result()
193
- self._parts.append(entry)
322
+ future.result()
323
+
324
+ return multipart.complete()
325
+
326
+ @classmethod
327
+ def save_file(
328
+ cls,
329
+ file_path: str | Path,
330
+ chunk_size: int | None = None,
331
+ content_type: str | None = None,
332
+ max_concurrency: int | None = None,
333
+ ) -> str:
334
+ import concurrent.futures
335
+
336
+ file_name = os.path.basename(file_path)
337
+ size = os.path.getsize(file_path)
338
+
339
+ multipart = cls(
340
+ file_name,
341
+ chunk_size=chunk_size,
342
+ content_type=content_type,
343
+ max_concurrency=max_concurrency,
344
+ )
345
+ multipart.create()
346
+
347
+ parts = math.ceil(size / multipart.chunk_size)
348
+ with concurrent.futures.ThreadPoolExecutor(
349
+ max_workers=multipart.max_concurrency
350
+ ) as executor:
351
+ futures = []
352
+ for part_number in range(1, parts + 1):
353
+
354
+ def _upload_part(pn: int) -> None:
355
+ with open(file_path, "rb") as f:
356
+ start = (pn - 1) * multipart.chunk_size
357
+ f.seek(start)
358
+ data = f.read(multipart.chunk_size)
359
+ multipart.upload_part(pn, data)
360
+
361
+ futures.append(executor.submit(_upload_part, part_number))
362
+
363
+ for future in concurrent.futures.as_completed(futures):
364
+ future.result()
365
+
366
+ return multipart.complete()
367
+
368
+
369
+ class InternalMultipartUploadV3:
370
+ MULTIPART_THRESHOLD = 100 * 1024 * 1024
371
+ MULTIPART_CHUNK_SIZE = 10 * 1024 * 1024
372
+ MULTIPART_MAX_CONCURRENCY = 10
373
+
374
+ def __init__(
375
+ self,
376
+ file_name: str,
377
+ chunk_size: int | None = None,
378
+ content_type: str | None = None,
379
+ max_concurrency: int | None = None,
380
+ ) -> None:
381
+ self.file_name = file_name
382
+ self.chunk_size = chunk_size or self.MULTIPART_CHUNK_SIZE
383
+ self.content_type = content_type or "application/octet-stream"
384
+ self.max_concurrency = max_concurrency or self.MULTIPART_MAX_CONCURRENCY
385
+ self._access_url: str | None = None
386
+ self._upload_id: str | None = None
387
+
388
+ self._parts: list[dict] = []
389
+
390
+ @property
391
+ def access_url(self) -> str:
392
+ if not self._access_url:
393
+ raise FileUploadException("Upload not initiated")
394
+ return self._access_url
395
+
396
+ @property
397
+ def upload_id(self) -> str:
398
+ if not self._upload_id:
399
+ raise FileUploadException("Upload not initiated")
400
+ return self._upload_id
401
+
402
+ @property
403
+ def auth_headers(self) -> dict[str, str]:
404
+ token = fal_v3_token_manager.get_token()
405
+ return {
406
+ "Authorization": f"{token.token_type} {token.token}",
407
+ "User-Agent": "fal/0.1.0",
408
+ }
409
+
410
+ def create(self):
411
+ token = fal_v3_token_manager.get_token()
412
+ try:
413
+ req = Request(
414
+ f"{token.base_upload_url}/files/upload/multipart",
415
+ method="POST",
416
+ headers={
417
+ **self.auth_headers,
418
+ "Accept": "application/json",
419
+ "Content-Type": self.content_type,
420
+ "X-Fal-File-Name": self.file_name,
421
+ },
422
+ )
423
+ with urlopen(req) as response:
424
+ result = json.load(response)
425
+ self._access_url = result["access_url"]
426
+ self._upload_id = result["uploadId"]
427
+
428
+ except HTTPError as exc:
429
+ raise FileUploadException(
430
+ f"Error initiating upload. Status {exc.status}: {exc.reason}"
431
+ )
432
+
433
+ @retry(max_retries=5, base_delay=1, backoff_type="exponential", jitter=True)
434
+ def upload_part(self, part_number: int, data: bytes) -> None:
435
+ url = f"{self.access_url}/multipart/{self.upload_id}/{part_number}"
436
+
437
+ req = Request(
438
+ url,
439
+ method="PUT",
440
+ headers={
441
+ **self.auth_headers,
442
+ "Content-Type": self.content_type,
443
+ },
444
+ data=data,
445
+ )
194
446
 
195
- def complete(self):
196
- url = f"{self._file_url}?upload_id={self._upload_id}"
447
+ try:
448
+ with urlopen(req) as resp:
449
+ self._parts.append(
450
+ {
451
+ "partNumber": part_number,
452
+ "etag": resp.headers["ETag"],
453
+ }
454
+ )
455
+ except HTTPError as exc:
456
+ raise FileUploadException(
457
+ f"Error uploading part {part_number} to {url}. "
458
+ f"Status {exc.status}: {exc.reason}"
459
+ )
460
+
461
+ def complete(self) -> str:
462
+ url = f"{self.access_url}/multipart/{self.upload_id}/complete"
197
463
  try:
198
464
  req = Request(
199
465
  url,
200
466
  method="POST",
201
467
  headers={
468
+ **self.auth_headers,
202
469
  "Accept": "application/json",
203
470
  "Content-Type": "application/json",
204
471
  },
@@ -211,31 +478,134 @@ class MultipartUpload:
211
478
  f"Error completing upload {url}. Status {e.status}: {e.reason}"
212
479
  )
213
480
 
214
- return self._file_url
481
+ return self.access_url
482
+
483
+ @classmethod
484
+ def save(
485
+ cls,
486
+ file: FileData,
487
+ chunk_size: int | None = None,
488
+ max_concurrency: int | None = None,
489
+ ):
490
+ import concurrent.futures
215
491
 
492
+ multipart = cls(
493
+ file.file_name,
494
+ chunk_size=chunk_size,
495
+ content_type=file.content_type,
496
+ max_concurrency=max_concurrency,
497
+ )
498
+ multipart.create()
216
499
 
217
- @dataclass
218
- class FalFileRepositoryV2(FalFileRepositoryBase):
219
- def save(self, file: FileData) -> str:
220
- return self._save(file, "fal-cdn")
500
+ parts = math.ceil(len(file.data) / multipart.chunk_size)
501
+ with concurrent.futures.ThreadPoolExecutor(
502
+ max_workers=multipart.max_concurrency
503
+ ) as executor:
504
+ futures = []
505
+ for part_number in range(1, parts + 1):
506
+ start = (part_number - 1) * multipart.chunk_size
507
+ data = file.data[start : start + multipart.chunk_size]
508
+ futures.append(
509
+ executor.submit(multipart.upload_part, part_number, data)
510
+ )
221
511
 
222
- def _save_multipart(
223
- self,
512
+ for future in concurrent.futures.as_completed(futures):
513
+ future.result()
514
+
515
+ return multipart.complete()
516
+
517
+ @classmethod
518
+ def save_file(
519
+ cls,
224
520
  file_path: str | Path,
225
521
  chunk_size: int | None = None,
226
522
  content_type: str | None = None,
227
523
  max_concurrency: int | None = None,
228
524
  ) -> str:
229
- multipart = MultipartUpload(
230
- file_path,
525
+ import concurrent.futures
526
+
527
+ file_name = os.path.basename(file_path)
528
+ size = os.path.getsize(file_path)
529
+
530
+ multipart = cls(
531
+ file_name,
231
532
  chunk_size=chunk_size,
232
533
  content_type=content_type,
233
534
  max_concurrency=max_concurrency,
234
535
  )
235
536
  multipart.create()
236
- multipart.upload()
537
+
538
+ parts = math.ceil(size / multipart.chunk_size)
539
+ with concurrent.futures.ThreadPoolExecutor(
540
+ max_workers=multipart.max_concurrency
541
+ ) as executor:
542
+ futures = []
543
+ for part_number in range(1, parts + 1):
544
+
545
+ def _upload_part(pn: int) -> None:
546
+ with open(file_path, "rb") as f:
547
+ start = (pn - 1) * multipart.chunk_size
548
+ f.seek(start)
549
+ data = f.read(multipart.chunk_size)
550
+ multipart.upload_part(pn, data)
551
+
552
+ futures.append(executor.submit(_upload_part, part_number))
553
+
554
+ for future in concurrent.futures.as_completed(futures):
555
+ future.result()
556
+
237
557
  return multipart.complete()
238
558
 
559
+
560
+ @dataclass
561
+ class FalFileRepositoryV2(FalFileRepositoryBase):
562
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
563
+ def save(
564
+ self,
565
+ file: FileData,
566
+ multipart: bool | None = None,
567
+ multipart_threshold: int | None = None,
568
+ multipart_chunk_size: int | None = None,
569
+ multipart_max_concurrency: int | None = None,
570
+ object_lifecycle_preference: dict[str, str] | None = None,
571
+ ) -> str:
572
+ if multipart is None:
573
+ threshold = multipart_threshold or MultipartUpload.MULTIPART_THRESHOLD
574
+ multipart = len(file.data) > threshold
575
+
576
+ if multipart:
577
+ return MultipartUpload.save(
578
+ file,
579
+ chunk_size=multipart_chunk_size,
580
+ max_concurrency=multipart_max_concurrency,
581
+ )
582
+
583
+ token = fal_v2_token_manager.get_token()
584
+ headers = {
585
+ "Authorization": f"{token.token_type} {token.token}",
586
+ "Accept": "application/json",
587
+ "X-Fal-File-Name": file.file_name,
588
+ "Content-Type": file.content_type,
589
+ }
590
+
591
+ storage_url = f"{token.base_upload_url}/upload"
592
+
593
+ try:
594
+ req = Request(
595
+ storage_url,
596
+ data=file.data,
597
+ headers=headers,
598
+ method="PUT",
599
+ )
600
+ with urlopen(req) as response:
601
+ result = json.load(response)
602
+
603
+ return result["file_url"]
604
+ except HTTPError as e:
605
+ raise FileUploadException(
606
+ f"Error initiating upload. Status {e.status}: {e.reason}"
607
+ )
608
+
239
609
  def save_file(
240
610
  self,
241
611
  file_path: str | Path,
@@ -244,13 +614,14 @@ class FalFileRepositoryV2(FalFileRepositoryBase):
244
614
  multipart_threshold: int | None = None,
245
615
  multipart_chunk_size: int | None = None,
246
616
  multipart_max_concurrency: int | None = None,
617
+ object_lifecycle_preference: dict[str, str] | None = None,
247
618
  ) -> tuple[str, FileData | None]:
248
619
  if multipart is None:
249
620
  threshold = multipart_threshold or MultipartUpload.MULTIPART_THRESHOLD
250
621
  multipart = os.path.getsize(file_path) > threshold
251
622
 
252
623
  if multipart:
253
- url = self._save_multipart(
624
+ url = MultipartUpload.save_file(
254
625
  file_path,
255
626
  chunk_size=multipart_chunk_size,
256
627
  content_type=content_type,
@@ -264,7 +635,7 @@ class FalFileRepositoryV2(FalFileRepositoryBase):
264
635
  content_type=content_type,
265
636
  file_name=os.path.basename(file_path),
266
637
  )
267
- url = self.save(data)
638
+ url = self.save(data, object_lifecycle_preference)
268
639
 
269
640
  return url, data
270
641
 
@@ -274,25 +645,38 @@ class InMemoryRepository(FileRepository):
274
645
  def save(
275
646
  self,
276
647
  file: FileData,
648
+ object_lifecycle_preference: dict[str, str] | None = None,
277
649
  ) -> str:
278
650
  return f'data:{file.content_type};base64,{b64encode(file.data).decode("utf-8")}'
279
651
 
280
652
 
281
653
  @dataclass
282
654
  class FalCDNFileRepository(FileRepository):
655
+ def _object_lifecycle_headers(
656
+ self,
657
+ headers: dict[str, str],
658
+ object_lifecycle_preference: dict[str, str] | None,
659
+ ):
660
+ if object_lifecycle_preference:
661
+ headers["X-Fal-Object-Lifecycle-Preference"] = json.dumps(
662
+ object_lifecycle_preference
663
+ )
664
+
665
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
283
666
  def save(
284
667
  self,
285
668
  file: FileData,
669
+ object_lifecycle_preference: dict[str, str] | None = None,
286
670
  ) -> str:
287
671
  headers = {
288
672
  **self.auth_headers,
289
673
  "Accept": "application/json",
290
674
  "Content-Type": file.content_type,
291
675
  "X-Fal-File-Name": file.file_name,
292
- "X-Fal-Object-Lifecycle-Preference": json.dumps(
293
- dataclasses.asdict(GLOBAL_LIFECYCLE_PREFERENCE)
294
- ),
295
676
  }
677
+
678
+ self._object_lifecycle_headers(headers, object_lifecycle_preference)
679
+
296
680
  url = os.getenv("FAL_CDN_HOST", _FAL_CDN) + "/files/upload"
297
681
  request = Request(url, headers=headers, method="POST", data=file.data)
298
682
  try:
@@ -317,3 +701,107 @@ class FalCDNFileRepository(FileRepository):
317
701
  "Authorization": f"Bearer {key_id}:{key_secret}",
318
702
  "User-Agent": "fal/0.1.0",
319
703
  }
704
+
705
+
706
+ # This is only available for internal users to have long-lived access tokens
707
+ @dataclass
708
+ class InternalFalFileRepositoryV3(FileRepository):
709
+ """
710
+ InternalFalFileRepositoryV3 is a file repository that uses the FAL CDN V3.
711
+ But generates and uses long-lived access tokens.
712
+ That way it can avoid the need to refresh the token for every upload.
713
+ """
714
+
715
+ def _object_lifecycle_headers(
716
+ self,
717
+ headers: dict[str, str],
718
+ object_lifecycle_preference: dict[str, str] | None,
719
+ ):
720
+ if object_lifecycle_preference:
721
+ headers["X-Fal-Object-Lifecycle"] = json.dumps(object_lifecycle_preference)
722
+
723
+ @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
724
+ def save(
725
+ self,
726
+ file: FileData,
727
+ multipart: bool | None = None,
728
+ multipart_threshold: int | None = None,
729
+ multipart_chunk_size: int | None = None,
730
+ multipart_max_concurrency: int | None = None,
731
+ object_lifecycle_preference: dict[str, str] | None = None,
732
+ ) -> str:
733
+ if multipart is None:
734
+ threshold = (
735
+ multipart_threshold or InternalMultipartUploadV3.MULTIPART_THRESHOLD
736
+ )
737
+ multipart = len(file.data) > threshold
738
+
739
+ if multipart:
740
+ return InternalMultipartUploadV3.save(
741
+ file,
742
+ chunk_size=multipart_chunk_size,
743
+ max_concurrency=multipart_max_concurrency,
744
+ )
745
+
746
+ headers = {
747
+ **self.auth_headers,
748
+ "Accept": "application/json",
749
+ "Content-Type": file.content_type,
750
+ "X-Fal-File-Name": file.file_name,
751
+ }
752
+
753
+ self._object_lifecycle_headers(headers, object_lifecycle_preference)
754
+
755
+ url = os.getenv("FAL_CDN_V3_HOST", _FAL_CDN_V3) + "/files/upload"
756
+ request = Request(url, headers=headers, method="POST", data=file.data)
757
+ try:
758
+ with urlopen(request) as response:
759
+ result = json.load(response)
760
+ except HTTPError as e:
761
+ raise FileUploadException(
762
+ f"Error initiating upload. Status {e.status}: {e.reason}"
763
+ )
764
+
765
+ access_url = result["access_url"]
766
+ return access_url
767
+
768
+ @property
769
+ def auth_headers(self) -> dict[str, str]:
770
+ token = fal_v3_token_manager.get_token()
771
+ return {
772
+ "Authorization": f"{token.token_type} {token.token}",
773
+ "User-Agent": "fal/0.1.0",
774
+ }
775
+
776
+ def save_file(
777
+ self,
778
+ file_path: str | Path,
779
+ content_type: str,
780
+ multipart: bool | None = None,
781
+ multipart_threshold: int | None = None,
782
+ multipart_chunk_size: int | None = None,
783
+ multipart_max_concurrency: int | None = None,
784
+ object_lifecycle_preference: dict[str, str] | None = None,
785
+ ) -> tuple[str, FileData | None]:
786
+ if multipart is None:
787
+ threshold = multipart_threshold or MultipartUpload.MULTIPART_THRESHOLD
788
+ multipart = os.path.getsize(file_path) > threshold
789
+
790
+ if multipart:
791
+ url = MultipartUpload.save_file(
792
+ file_path,
793
+ chunk_size=multipart_chunk_size,
794
+ content_type=content_type,
795
+ max_concurrency=multipart_max_concurrency,
796
+ )
797
+ data = None
798
+ else:
799
+ with open(file_path, "rb") as f:
800
+ data = FileData(
801
+ f.read(),
802
+ content_type=content_type,
803
+ file_name=os.path.basename(file_path),
804
+ )
805
+ url = self.save(data, object_lifecycle_preference)
806
+
807
+ return url, data