fal 1.45.2__py3-none-any.whl → 1.46.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fal might be problematic. Click here for more details.
- fal/_fal_version.py +2 -2
- fal/api/__init__.py +1 -0
- fal/api/apps.py +69 -0
- fal/api/client.py +116 -0
- fal/api/deploy.py +211 -0
- fal/api/runners.py +16 -0
- fal/cli/apps.py +51 -60
- fal/cli/deploy.py +29 -181
- fal/cli/queue.py +2 -2
- fal/cli/runners.py +45 -47
- fal/distributed/__init__.py +3 -0
- fal/distributed/utils.py +420 -0
- fal/distributed/worker.py +791 -0
- {fal-1.45.2.dist-info → fal-1.46.1.dist-info}/METADATA +2 -1
- {fal-1.45.2.dist-info → fal-1.46.1.dist-info}/RECORD +19 -11
- /fal/{api.py → api/api.py} +0 -0
- {fal-1.45.2.dist-info → fal-1.46.1.dist-info}/WHEEL +0 -0
- {fal-1.45.2.dist-info → fal-1.46.1.dist-info}/entry_points.txt +0 -0
- {fal-1.45.2.dist-info → fal-1.46.1.dist-info}/top_level.txt +0 -0
fal/distributed/utils.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import datetime
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import pickle
|
|
6
|
+
import threading
|
|
7
|
+
import warnings
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from io import BytesIO
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
import torch.multiprocessing as mp
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def has_type_name(maybe_type: Any, type_name: str) -> bool:
|
|
18
|
+
"""
|
|
19
|
+
Checks if the given object has a type name that matches the provided type name.
|
|
20
|
+
This is used to avoid importing torch or other libraries unnecessarily.
|
|
21
|
+
:param maybe_type: The object to check.
|
|
22
|
+
:param type_name: The type name to match against.
|
|
23
|
+
:return: True if the object's type name matches, False otherwise.
|
|
24
|
+
"""
|
|
25
|
+
if not isinstance(maybe_type, type):
|
|
26
|
+
maybe_type = type(maybe_type)
|
|
27
|
+
|
|
28
|
+
mro_types = [t.__name__ for t in maybe_type.mro()]
|
|
29
|
+
return type_name in mro_types
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def is_torch_tensor(obj: Any) -> bool:
|
|
33
|
+
"""
|
|
34
|
+
Checks if the given object is a PyTorch tensor without importing torch.
|
|
35
|
+
"""
|
|
36
|
+
return has_type_name(obj, "Tensor")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def is_numpy_array(obj: Any) -> bool:
|
|
40
|
+
"""
|
|
41
|
+
Checks if the given object is a NumPy array without importing numpy.
|
|
42
|
+
"""
|
|
43
|
+
return has_type_name(obj, "ndarray")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def is_pil_image(obj: Any) -> bool:
|
|
47
|
+
"""
|
|
48
|
+
Checks if the given object is a PIL Image without importing PIL.
|
|
49
|
+
"""
|
|
50
|
+
return has_type_name(obj, "Image")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def format_for_serialization(
|
|
54
|
+
response: Any,
|
|
55
|
+
image_format: str = "jpeg",
|
|
56
|
+
is_final: bool = False,
|
|
57
|
+
as_data_urls: bool = False,
|
|
58
|
+
) -> Any:
|
|
59
|
+
"""
|
|
60
|
+
Formats the response for serialization.
|
|
61
|
+
Most importantly, it encodes images to base64 and returns the image format and size.
|
|
62
|
+
:param response: The response to format.
|
|
63
|
+
:param is_final: Whether this is the final response.
|
|
64
|
+
:return: The formatted response.
|
|
65
|
+
"""
|
|
66
|
+
if is_torch_tensor(response):
|
|
67
|
+
import torch
|
|
68
|
+
|
|
69
|
+
with BytesIO() as buffer:
|
|
70
|
+
torch.save(response.detach().cpu(), buffer)
|
|
71
|
+
tensor_bytes = buffer.getvalue()
|
|
72
|
+
|
|
73
|
+
if as_data_urls:
|
|
74
|
+
base64_tensor = base64.b64encode(tensor_bytes).decode("utf-8")
|
|
75
|
+
data = None
|
|
76
|
+
url = f"data:application/tensor;base64,{base64_tensor}"
|
|
77
|
+
else:
|
|
78
|
+
data = tensor_bytes
|
|
79
|
+
url = None
|
|
80
|
+
|
|
81
|
+
return {
|
|
82
|
+
"content_type": "application/tensor",
|
|
83
|
+
"shape": response.shape,
|
|
84
|
+
"dtype": str(response.dtype),
|
|
85
|
+
"data": data,
|
|
86
|
+
"url": url,
|
|
87
|
+
}
|
|
88
|
+
if is_numpy_array(response):
|
|
89
|
+
import numpy as np
|
|
90
|
+
|
|
91
|
+
with BytesIO() as buffer:
|
|
92
|
+
np.save(buffer, response)
|
|
93
|
+
array_bytes = buffer.getvalue()
|
|
94
|
+
|
|
95
|
+
if as_data_urls:
|
|
96
|
+
base64_array = base64.b64encode(array_bytes).decode("utf-8")
|
|
97
|
+
data = None
|
|
98
|
+
url = f"data:application/ndarray;base64,{base64_array}"
|
|
99
|
+
else:
|
|
100
|
+
data = array_bytes
|
|
101
|
+
url = None
|
|
102
|
+
|
|
103
|
+
return {
|
|
104
|
+
"content_type": "application/ndarray",
|
|
105
|
+
"shape": response.shape,
|
|
106
|
+
"dtype": str(response.dtype),
|
|
107
|
+
"data": data,
|
|
108
|
+
"url": url,
|
|
109
|
+
}
|
|
110
|
+
elif is_pil_image(response):
|
|
111
|
+
width, height = response.size
|
|
112
|
+
|
|
113
|
+
with BytesIO() as buffer:
|
|
114
|
+
if is_final:
|
|
115
|
+
if image_format == "jpeg":
|
|
116
|
+
response.save(buffer, format="jpeg", quality=95)
|
|
117
|
+
else:
|
|
118
|
+
response.save(buffer, format=image_format)
|
|
119
|
+
else:
|
|
120
|
+
image_format = "jpeg"
|
|
121
|
+
response.save(buffer, format="jpeg", quality=60)
|
|
122
|
+
|
|
123
|
+
image_bytes = buffer.getvalue()
|
|
124
|
+
|
|
125
|
+
if as_data_urls:
|
|
126
|
+
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
|
127
|
+
url = f"data:image/{image_format};base64,{base64_image}"
|
|
128
|
+
data = None
|
|
129
|
+
else:
|
|
130
|
+
url = None
|
|
131
|
+
data = image_bytes
|
|
132
|
+
|
|
133
|
+
return {
|
|
134
|
+
"content_type": f"image/{image_format}",
|
|
135
|
+
"width": width,
|
|
136
|
+
"height": height,
|
|
137
|
+
"data": data,
|
|
138
|
+
"url": url,
|
|
139
|
+
}
|
|
140
|
+
elif isinstance(response, list):
|
|
141
|
+
return [
|
|
142
|
+
format_for_serialization(
|
|
143
|
+
item,
|
|
144
|
+
image_format=image_format,
|
|
145
|
+
is_final=is_final,
|
|
146
|
+
as_data_urls=as_data_urls,
|
|
147
|
+
)
|
|
148
|
+
for item in response
|
|
149
|
+
]
|
|
150
|
+
elif isinstance(response, dict):
|
|
151
|
+
return {
|
|
152
|
+
key: format_for_serialization(
|
|
153
|
+
value,
|
|
154
|
+
image_format=image_format,
|
|
155
|
+
is_final=is_final,
|
|
156
|
+
as_data_urls=as_data_urls,
|
|
157
|
+
)
|
|
158
|
+
for key, value in response.items()
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
return response
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def format_deserialized_data(data: Any) -> Any:
|
|
165
|
+
"""
|
|
166
|
+
Formats the deserialized data for further processing.
|
|
167
|
+
:param data: The data to format.
|
|
168
|
+
:return: The formatted data.
|
|
169
|
+
"""
|
|
170
|
+
if isinstance(data, dict):
|
|
171
|
+
if data.get("content_type", "").startswith("image/"):
|
|
172
|
+
from PIL import Image
|
|
173
|
+
|
|
174
|
+
# Deserialize image data
|
|
175
|
+
if data.get("url"):
|
|
176
|
+
# If the data is a data URL, decode it
|
|
177
|
+
base64_image = data["url"].split(",")[1]
|
|
178
|
+
image_bytes = base64.b64decode(base64_image)
|
|
179
|
+
elif data.get("data"):
|
|
180
|
+
image_bytes = data["data"]
|
|
181
|
+
else:
|
|
182
|
+
raise ValueError("Image data must contain either 'url' or 'data'.")
|
|
183
|
+
|
|
184
|
+
fp = BytesIO(image_bytes) # Don't close the BytesIO object
|
|
185
|
+
return Image.open(fp)
|
|
186
|
+
elif data.get("content_type", "") == "application/tensor":
|
|
187
|
+
import torch
|
|
188
|
+
|
|
189
|
+
# Deserialize tensor data
|
|
190
|
+
if data.get("url"):
|
|
191
|
+
# If the data is a data URL, decode it
|
|
192
|
+
base64_tensor = data["url"].split(",")[1]
|
|
193
|
+
tensor_bytes = base64.b64decode(base64_tensor)
|
|
194
|
+
elif data.get("data"):
|
|
195
|
+
tensor_bytes = data["data"]
|
|
196
|
+
else:
|
|
197
|
+
raise ValueError("Tensor data must contain either 'url' or 'data'.")
|
|
198
|
+
|
|
199
|
+
with BytesIO(tensor_bytes) as buffer:
|
|
200
|
+
return torch.load(buffer)
|
|
201
|
+
|
|
202
|
+
elif data.get("content_type", "") == "application/ndarray":
|
|
203
|
+
import numpy as np
|
|
204
|
+
|
|
205
|
+
# Deserialize numpy array data
|
|
206
|
+
if data.get("url"):
|
|
207
|
+
# If the data is a data URL, decode it
|
|
208
|
+
base64_array = data["url"].split(",")[1]
|
|
209
|
+
array_bytes = base64.b64decode(base64_array)
|
|
210
|
+
elif data.get("data"):
|
|
211
|
+
array_bytes = data["data"]
|
|
212
|
+
else:
|
|
213
|
+
raise ValueError(
|
|
214
|
+
"Numpy array data must contain either 'url' or 'data'."
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
with BytesIO(array_bytes) as buffer:
|
|
218
|
+
return np.load(buffer, allow_pickle=True)
|
|
219
|
+
|
|
220
|
+
return {key: format_deserialized_data(value) for key, value in data.items()}
|
|
221
|
+
elif isinstance(data, list):
|
|
222
|
+
return [format_deserialized_data(item) for item in data]
|
|
223
|
+
return data
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def distributed_serialize(
|
|
227
|
+
obj: Any, is_final: bool = False, image_format: str = "jpeg"
|
|
228
|
+
) -> bytes:
|
|
229
|
+
"""
|
|
230
|
+
Serializes an object to a JSON string.
|
|
231
|
+
:param obj: The object to serialize.
|
|
232
|
+
:return: The serialized JSON string.
|
|
233
|
+
"""
|
|
234
|
+
data = format_for_serialization(obj, is_final=is_final, image_format=image_format)
|
|
235
|
+
return pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def encode_text_event(
|
|
239
|
+
obj: Any, is_final: bool = False, image_format: str = "jpeg"
|
|
240
|
+
) -> bytes:
|
|
241
|
+
"""
|
|
242
|
+
Encodes a text response as a JSON string.
|
|
243
|
+
:param response: The text response to encode.
|
|
244
|
+
:param is_final: Whether this is the final response.
|
|
245
|
+
:return: The encoded JSON string.
|
|
246
|
+
"""
|
|
247
|
+
formatted = format_for_serialization(
|
|
248
|
+
obj, image_format=image_format, is_final=is_final, as_data_urls=True
|
|
249
|
+
)
|
|
250
|
+
return f"data: {json.dumps(formatted)}\n\n".encode()
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def distributed_deserialize(serialized: Union[bytes, str]) -> Any:
|
|
254
|
+
"""
|
|
255
|
+
Deserializes a JSON string to an object.
|
|
256
|
+
:param serialized: The serialized JSON string.
|
|
257
|
+
:return: The deserialized object.
|
|
258
|
+
"""
|
|
259
|
+
if isinstance(serialized, str):
|
|
260
|
+
data = json.loads(serialized)
|
|
261
|
+
else:
|
|
262
|
+
data = pickle.loads(serialized)
|
|
263
|
+
return format_deserialized_data(data)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def wrap_distributed_worker(
|
|
267
|
+
rank: int,
|
|
268
|
+
func: Callable,
|
|
269
|
+
world_size: int,
|
|
270
|
+
master_addr: str,
|
|
271
|
+
master_port: int,
|
|
272
|
+
timeout: int,
|
|
273
|
+
cwd: Optional[Union[str, Path]],
|
|
274
|
+
args: Tuple[Any],
|
|
275
|
+
kwargs: Dict[str, Any],
|
|
276
|
+
) -> None:
|
|
277
|
+
"""
|
|
278
|
+
Worker function for distributed training or inference.
|
|
279
|
+
|
|
280
|
+
This function is called by each worker process spawned by
|
|
281
|
+
`torch.multiprocessing.spawn`.
|
|
282
|
+
|
|
283
|
+
:param func: The function to run in each worker process.
|
|
284
|
+
:param world_size: The total number of processes.
|
|
285
|
+
:param rank: The rank of the current process.
|
|
286
|
+
:param master_addr: The address of the master node.
|
|
287
|
+
:param master_port: The port on which the master node will listen.
|
|
288
|
+
"""
|
|
289
|
+
import torch
|
|
290
|
+
import torch.distributed as dist
|
|
291
|
+
|
|
292
|
+
if cwd:
|
|
293
|
+
os.chdir(str(cwd))
|
|
294
|
+
|
|
295
|
+
os.environ["RANK"] = str(rank)
|
|
296
|
+
os.environ["LOCAL_RANK"] = str(rank)
|
|
297
|
+
os.environ["WORLD_SIZE"] = str(world_size)
|
|
298
|
+
os.environ["MASTER_ADDR"] = master_addr
|
|
299
|
+
os.environ["MASTER_PORT"] = str(master_port)
|
|
300
|
+
|
|
301
|
+
print(f"[debug] Worker {rank} started with PID {os.getpid()}.")
|
|
302
|
+
dist.init_process_group(
|
|
303
|
+
backend="nccl",
|
|
304
|
+
init_method="env://",
|
|
305
|
+
world_size=world_size,
|
|
306
|
+
rank=rank,
|
|
307
|
+
timeout=datetime.timedelta(seconds=timeout),
|
|
308
|
+
device_id=torch.device(f"cuda:{rank}"),
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
func(*args, **kwargs)
|
|
313
|
+
finally:
|
|
314
|
+
dist.destroy_process_group()
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def launch_distributed_processes(
|
|
318
|
+
func: Callable,
|
|
319
|
+
world_size: int = 1,
|
|
320
|
+
master_addr: str = "127.0.0.1",
|
|
321
|
+
master_port: int = 29500,
|
|
322
|
+
timeout: int = 1800,
|
|
323
|
+
cwd: Optional[Union[str, Path]] = None,
|
|
324
|
+
*args: Any,
|
|
325
|
+
**kwargs: Any,
|
|
326
|
+
) -> "mp.ProcessContext":
|
|
327
|
+
"""
|
|
328
|
+
Launches a distributed process group using torch.multiprocessing.spawn.
|
|
329
|
+
This function is designed to be called from the main process and will
|
|
330
|
+
spawn multiple worker processes for distributed training or inference.
|
|
331
|
+
:param func: The function to run in each worker process.
|
|
332
|
+
:param world_size: The total number of processes to spawn.
|
|
333
|
+
:param master_addr: The address of the master node.
|
|
334
|
+
:param master_port: The port on which the master node will listen.
|
|
335
|
+
:return: The process context for the spawned processes.
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
import torch.distributed as dist
|
|
339
|
+
|
|
340
|
+
if dist.is_initialized():
|
|
341
|
+
raise RuntimeError(
|
|
342
|
+
"Distributed process group is already initialized. "
|
|
343
|
+
"Please ensure that you are not trying to initialize it multiple times."
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
import torch.multiprocessing as mp
|
|
347
|
+
|
|
348
|
+
try:
|
|
349
|
+
import cloudpickle # type: ignore[import-untyped]
|
|
350
|
+
|
|
351
|
+
mp.reducer.ForkingPickler = cloudpickle.Pickler # type: ignore[misc]
|
|
352
|
+
except ImportError:
|
|
353
|
+
warnings.warn(
|
|
354
|
+
"Could not import cloudpickle, using default pickler. "
|
|
355
|
+
"If you encounter pickling errors, please install cloudpickle.",
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
print(f"[debug] Launching distributed processes with world size {world_size}.")
|
|
359
|
+
return mp.spawn( # type: ignore[no-untyped-call]
|
|
360
|
+
wrap_distributed_worker,
|
|
361
|
+
args=(
|
|
362
|
+
func,
|
|
363
|
+
world_size,
|
|
364
|
+
master_addr,
|
|
365
|
+
master_port,
|
|
366
|
+
timeout,
|
|
367
|
+
cwd,
|
|
368
|
+
args,
|
|
369
|
+
kwargs,
|
|
370
|
+
),
|
|
371
|
+
nprocs=world_size,
|
|
372
|
+
join=False,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
class KeepAliveTimer:
|
|
377
|
+
"""
|
|
378
|
+
Call a function after a certain amount of time to keep the worker alive.
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
timer: Optional[threading.Timer]
|
|
382
|
+
|
|
383
|
+
def __init__(
|
|
384
|
+
self,
|
|
385
|
+
func: Callable,
|
|
386
|
+
timeout: Union[int, float],
|
|
387
|
+
start: bool = False,
|
|
388
|
+
*args: Any,
|
|
389
|
+
**kwargs: Any,
|
|
390
|
+
) -> None:
|
|
391
|
+
self.func = func
|
|
392
|
+
self.timeout = timeout
|
|
393
|
+
self.timer = None
|
|
394
|
+
self.args = args
|
|
395
|
+
self.kwargs = kwargs
|
|
396
|
+
self.timer = None
|
|
397
|
+
if start:
|
|
398
|
+
self.start()
|
|
399
|
+
|
|
400
|
+
def start(self) -> None:
|
|
401
|
+
"""
|
|
402
|
+
Start the timer.
|
|
403
|
+
"""
|
|
404
|
+
self.timer = threading.Timer(self.timeout, self.func, self.args, self.kwargs)
|
|
405
|
+
self.timer.start()
|
|
406
|
+
|
|
407
|
+
def cancel(self) -> None:
|
|
408
|
+
"""
|
|
409
|
+
Cancel the timer.
|
|
410
|
+
"""
|
|
411
|
+
if self.timer:
|
|
412
|
+
self.timer.cancel()
|
|
413
|
+
self.timer = None
|
|
414
|
+
|
|
415
|
+
def reset(self) -> None:
|
|
416
|
+
"""
|
|
417
|
+
Reset the timer.
|
|
418
|
+
"""
|
|
419
|
+
self.cancel()
|
|
420
|
+
self.start()
|