fal 1.45.2__py3-none-any.whl → 1.46.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

@@ -0,0 +1,420 @@
1
+ import base64
2
+ import datetime
3
+ import json
4
+ import os
5
+ import pickle
6
+ import threading
7
+ import warnings
8
+ from collections.abc import Callable
9
+ from io import BytesIO
10
+ from pathlib import Path
11
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
12
+
13
+ if TYPE_CHECKING:
14
+ import torch.multiprocessing as mp
15
+
16
+
17
+ def has_type_name(maybe_type: Any, type_name: str) -> bool:
18
+ """
19
+ Checks if the given object has a type name that matches the provided type name.
20
+ This is used to avoid importing torch or other libraries unnecessarily.
21
+ :param maybe_type: The object to check.
22
+ :param type_name: The type name to match against.
23
+ :return: True if the object's type name matches, False otherwise.
24
+ """
25
+ if not isinstance(maybe_type, type):
26
+ maybe_type = type(maybe_type)
27
+
28
+ mro_types = [t.__name__ for t in maybe_type.mro()]
29
+ return type_name in mro_types
30
+
31
+
32
+ def is_torch_tensor(obj: Any) -> bool:
33
+ """
34
+ Checks if the given object is a PyTorch tensor without importing torch.
35
+ """
36
+ return has_type_name(obj, "Tensor")
37
+
38
+
39
+ def is_numpy_array(obj: Any) -> bool:
40
+ """
41
+ Checks if the given object is a NumPy array without importing numpy.
42
+ """
43
+ return has_type_name(obj, "ndarray")
44
+
45
+
46
+ def is_pil_image(obj: Any) -> bool:
47
+ """
48
+ Checks if the given object is a PIL Image without importing PIL.
49
+ """
50
+ return has_type_name(obj, "Image")
51
+
52
+
53
+ def format_for_serialization(
54
+ response: Any,
55
+ image_format: str = "jpeg",
56
+ is_final: bool = False,
57
+ as_data_urls: bool = False,
58
+ ) -> Any:
59
+ """
60
+ Formats the response for serialization.
61
+ Most importantly, it encodes images to base64 and returns the image format and size.
62
+ :param response: The response to format.
63
+ :param is_final: Whether this is the final response.
64
+ :return: The formatted response.
65
+ """
66
+ if is_torch_tensor(response):
67
+ import torch
68
+
69
+ with BytesIO() as buffer:
70
+ torch.save(response.detach().cpu(), buffer)
71
+ tensor_bytes = buffer.getvalue()
72
+
73
+ if as_data_urls:
74
+ base64_tensor = base64.b64encode(tensor_bytes).decode("utf-8")
75
+ data = None
76
+ url = f"data:application/tensor;base64,{base64_tensor}"
77
+ else:
78
+ data = tensor_bytes
79
+ url = None
80
+
81
+ return {
82
+ "content_type": "application/tensor",
83
+ "shape": response.shape,
84
+ "dtype": str(response.dtype),
85
+ "data": data,
86
+ "url": url,
87
+ }
88
+ if is_numpy_array(response):
89
+ import numpy as np
90
+
91
+ with BytesIO() as buffer:
92
+ np.save(buffer, response)
93
+ array_bytes = buffer.getvalue()
94
+
95
+ if as_data_urls:
96
+ base64_array = base64.b64encode(array_bytes).decode("utf-8")
97
+ data = None
98
+ url = f"data:application/ndarray;base64,{base64_array}"
99
+ else:
100
+ data = array_bytes
101
+ url = None
102
+
103
+ return {
104
+ "content_type": "application/ndarray",
105
+ "shape": response.shape,
106
+ "dtype": str(response.dtype),
107
+ "data": data,
108
+ "url": url,
109
+ }
110
+ elif is_pil_image(response):
111
+ width, height = response.size
112
+
113
+ with BytesIO() as buffer:
114
+ if is_final:
115
+ if image_format == "jpeg":
116
+ response.save(buffer, format="jpeg", quality=95)
117
+ else:
118
+ response.save(buffer, format=image_format)
119
+ else:
120
+ image_format = "jpeg"
121
+ response.save(buffer, format="jpeg", quality=60)
122
+
123
+ image_bytes = buffer.getvalue()
124
+
125
+ if as_data_urls:
126
+ base64_image = base64.b64encode(image_bytes).decode("utf-8")
127
+ url = f"data:image/{image_format};base64,{base64_image}"
128
+ data = None
129
+ else:
130
+ url = None
131
+ data = image_bytes
132
+
133
+ return {
134
+ "content_type": f"image/{image_format}",
135
+ "width": width,
136
+ "height": height,
137
+ "data": data,
138
+ "url": url,
139
+ }
140
+ elif isinstance(response, list):
141
+ return [
142
+ format_for_serialization(
143
+ item,
144
+ image_format=image_format,
145
+ is_final=is_final,
146
+ as_data_urls=as_data_urls,
147
+ )
148
+ for item in response
149
+ ]
150
+ elif isinstance(response, dict):
151
+ return {
152
+ key: format_for_serialization(
153
+ value,
154
+ image_format=image_format,
155
+ is_final=is_final,
156
+ as_data_urls=as_data_urls,
157
+ )
158
+ for key, value in response.items()
159
+ }
160
+
161
+ return response
162
+
163
+
164
+ def format_deserialized_data(data: Any) -> Any:
165
+ """
166
+ Formats the deserialized data for further processing.
167
+ :param data: The data to format.
168
+ :return: The formatted data.
169
+ """
170
+ if isinstance(data, dict):
171
+ if data.get("content_type", "").startswith("image/"):
172
+ from PIL import Image
173
+
174
+ # Deserialize image data
175
+ if data.get("url"):
176
+ # If the data is a data URL, decode it
177
+ base64_image = data["url"].split(",")[1]
178
+ image_bytes = base64.b64decode(base64_image)
179
+ elif data.get("data"):
180
+ image_bytes = data["data"]
181
+ else:
182
+ raise ValueError("Image data must contain either 'url' or 'data'.")
183
+
184
+ fp = BytesIO(image_bytes) # Don't close the BytesIO object
185
+ return Image.open(fp)
186
+ elif data.get("content_type", "") == "application/tensor":
187
+ import torch
188
+
189
+ # Deserialize tensor data
190
+ if data.get("url"):
191
+ # If the data is a data URL, decode it
192
+ base64_tensor = data["url"].split(",")[1]
193
+ tensor_bytes = base64.b64decode(base64_tensor)
194
+ elif data.get("data"):
195
+ tensor_bytes = data["data"]
196
+ else:
197
+ raise ValueError("Tensor data must contain either 'url' or 'data'.")
198
+
199
+ with BytesIO(tensor_bytes) as buffer:
200
+ return torch.load(buffer)
201
+
202
+ elif data.get("content_type", "") == "application/ndarray":
203
+ import numpy as np
204
+
205
+ # Deserialize numpy array data
206
+ if data.get("url"):
207
+ # If the data is a data URL, decode it
208
+ base64_array = data["url"].split(",")[1]
209
+ array_bytes = base64.b64decode(base64_array)
210
+ elif data.get("data"):
211
+ array_bytes = data["data"]
212
+ else:
213
+ raise ValueError(
214
+ "Numpy array data must contain either 'url' or 'data'."
215
+ )
216
+
217
+ with BytesIO(array_bytes) as buffer:
218
+ return np.load(buffer, allow_pickle=True)
219
+
220
+ return {key: format_deserialized_data(value) for key, value in data.items()}
221
+ elif isinstance(data, list):
222
+ return [format_deserialized_data(item) for item in data]
223
+ return data
224
+
225
+
226
+ def distributed_serialize(
227
+ obj: Any, is_final: bool = False, image_format: str = "jpeg"
228
+ ) -> bytes:
229
+ """
230
+ Serializes an object to a JSON string.
231
+ :param obj: The object to serialize.
232
+ :return: The serialized JSON string.
233
+ """
234
+ data = format_for_serialization(obj, is_final=is_final, image_format=image_format)
235
+ return pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
236
+
237
+
238
+ def encode_text_event(
239
+ obj: Any, is_final: bool = False, image_format: str = "jpeg"
240
+ ) -> bytes:
241
+ """
242
+ Encodes a text response as a JSON string.
243
+ :param response: The text response to encode.
244
+ :param is_final: Whether this is the final response.
245
+ :return: The encoded JSON string.
246
+ """
247
+ formatted = format_for_serialization(
248
+ obj, image_format=image_format, is_final=is_final, as_data_urls=True
249
+ )
250
+ return f"data: {json.dumps(formatted)}\n\n".encode()
251
+
252
+
253
+ def distributed_deserialize(serialized: Union[bytes, str]) -> Any:
254
+ """
255
+ Deserializes a JSON string to an object.
256
+ :param serialized: The serialized JSON string.
257
+ :return: The deserialized object.
258
+ """
259
+ if isinstance(serialized, str):
260
+ data = json.loads(serialized)
261
+ else:
262
+ data = pickle.loads(serialized)
263
+ return format_deserialized_data(data)
264
+
265
+
266
+ def wrap_distributed_worker(
267
+ rank: int,
268
+ func: Callable,
269
+ world_size: int,
270
+ master_addr: str,
271
+ master_port: int,
272
+ timeout: int,
273
+ cwd: Optional[Union[str, Path]],
274
+ args: Tuple[Any],
275
+ kwargs: Dict[str, Any],
276
+ ) -> None:
277
+ """
278
+ Worker function for distributed training or inference.
279
+
280
+ This function is called by each worker process spawned by
281
+ `torch.multiprocessing.spawn`.
282
+
283
+ :param func: The function to run in each worker process.
284
+ :param world_size: The total number of processes.
285
+ :param rank: The rank of the current process.
286
+ :param master_addr: The address of the master node.
287
+ :param master_port: The port on which the master node will listen.
288
+ """
289
+ import torch
290
+ import torch.distributed as dist
291
+
292
+ if cwd:
293
+ os.chdir(str(cwd))
294
+
295
+ os.environ["RANK"] = str(rank)
296
+ os.environ["LOCAL_RANK"] = str(rank)
297
+ os.environ["WORLD_SIZE"] = str(world_size)
298
+ os.environ["MASTER_ADDR"] = master_addr
299
+ os.environ["MASTER_PORT"] = str(master_port)
300
+
301
+ print(f"[debug] Worker {rank} started with PID {os.getpid()}.")
302
+ dist.init_process_group(
303
+ backend="nccl",
304
+ init_method="env://",
305
+ world_size=world_size,
306
+ rank=rank,
307
+ timeout=datetime.timedelta(seconds=timeout),
308
+ device_id=torch.device(f"cuda:{rank}"),
309
+ )
310
+
311
+ try:
312
+ func(*args, **kwargs)
313
+ finally:
314
+ dist.destroy_process_group()
315
+
316
+
317
+ def launch_distributed_processes(
318
+ func: Callable,
319
+ world_size: int = 1,
320
+ master_addr: str = "127.0.0.1",
321
+ master_port: int = 29500,
322
+ timeout: int = 1800,
323
+ cwd: Optional[Union[str, Path]] = None,
324
+ *args: Any,
325
+ **kwargs: Any,
326
+ ) -> "mp.ProcessContext":
327
+ """
328
+ Launches a distributed process group using torch.multiprocessing.spawn.
329
+ This function is designed to be called from the main process and will
330
+ spawn multiple worker processes for distributed training or inference.
331
+ :param func: The function to run in each worker process.
332
+ :param world_size: The total number of processes to spawn.
333
+ :param master_addr: The address of the master node.
334
+ :param master_port: The port on which the master node will listen.
335
+ :return: The process context for the spawned processes.
336
+ """
337
+
338
+ import torch.distributed as dist
339
+
340
+ if dist.is_initialized():
341
+ raise RuntimeError(
342
+ "Distributed process group is already initialized. "
343
+ "Please ensure that you are not trying to initialize it multiple times."
344
+ )
345
+
346
+ import torch.multiprocessing as mp
347
+
348
+ try:
349
+ import cloudpickle # type: ignore[import-untyped]
350
+
351
+ mp.reducer.ForkingPickler = cloudpickle.Pickler # type: ignore[misc]
352
+ except ImportError:
353
+ warnings.warn(
354
+ "Could not import cloudpickle, using default pickler. "
355
+ "If you encounter pickling errors, please install cloudpickle.",
356
+ )
357
+
358
+ print(f"[debug] Launching distributed processes with world size {world_size}.")
359
+ return mp.spawn( # type: ignore[no-untyped-call]
360
+ wrap_distributed_worker,
361
+ args=(
362
+ func,
363
+ world_size,
364
+ master_addr,
365
+ master_port,
366
+ timeout,
367
+ cwd,
368
+ args,
369
+ kwargs,
370
+ ),
371
+ nprocs=world_size,
372
+ join=False,
373
+ )
374
+
375
+
376
+ class KeepAliveTimer:
377
+ """
378
+ Call a function after a certain amount of time to keep the worker alive.
379
+ """
380
+
381
+ timer: Optional[threading.Timer]
382
+
383
+ def __init__(
384
+ self,
385
+ func: Callable,
386
+ timeout: Union[int, float],
387
+ start: bool = False,
388
+ *args: Any,
389
+ **kwargs: Any,
390
+ ) -> None:
391
+ self.func = func
392
+ self.timeout = timeout
393
+ self.timer = None
394
+ self.args = args
395
+ self.kwargs = kwargs
396
+ self.timer = None
397
+ if start:
398
+ self.start()
399
+
400
+ def start(self) -> None:
401
+ """
402
+ Start the timer.
403
+ """
404
+ self.timer = threading.Timer(self.timeout, self.func, self.args, self.kwargs)
405
+ self.timer.start()
406
+
407
+ def cancel(self) -> None:
408
+ """
409
+ Cancel the timer.
410
+ """
411
+ if self.timer:
412
+ self.timer.cancel()
413
+ self.timer = None
414
+
415
+ def reset(self) -> None:
416
+ """
417
+ Reset the timer.
418
+ """
419
+ self.cancel()
420
+ self.start()