strands-diffusers 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strands_diffusers/__init__.py +41 -0
- strands_diffusers/_version.py +24 -0
- strands_diffusers/core/__init__.py +4 -0
- strands_diffusers/core/engine.py +163 -0
- strands_diffusers/core/io.py +552 -0
- strands_diffusers/core/registry.py +349 -0
- strands_diffusers/core/viz.py +256 -0
- strands_diffusers/tools/__init__.py +4 -0
- strands_diffusers/tools/use_diffusers.py +420 -0
- strands_diffusers-0.1.0.dist-info/METADATA +199 -0
- strands_diffusers-0.1.0.dist-info/RECORD +13 -0
- strands_diffusers-0.1.0.dist-info/WHEEL +5 -0
- strands_diffusers-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,552 @@
|
|
|
1
|
+
"""Native multimodal I/O for diffusion — images / video / audio / actions out.
|
|
2
|
+
|
|
3
|
+
Inputs: file paths, URLs, base64 data URIs, PIL images, numpy arrays, video paths.
|
|
4
|
+
Outputs: diffusers pipeline results serialized to JSON-safe form, with binary
|
|
5
|
+
artifacts (generated images, videos, audio, and robot ACTION chunks) written to
|
|
6
|
+
disk and referenced by path.
|
|
7
|
+
|
|
8
|
+
The headline feature for Physical-AI / world-foundation models: a Cosmos-style
|
|
9
|
+
pipeline returns a `Cosmos3OmniPipelineOutput(video=..., sound=..., action=...)`.
|
|
10
|
+
We serialize the video to .mp4, the sound to .wav, and the **action** chunk to a
|
|
11
|
+
.json (the model-normalized action-space tensor) — so an agent gets back a path
|
|
12
|
+
to a playable world AND a usable robot action vector in one call.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import base64
|
|
18
|
+
import io as _io
|
|
19
|
+
import os
|
|
20
|
+
import tempfile
|
|
21
|
+
import time
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Any, Dict, List, Optional
|
|
24
|
+
|
|
25
|
+
ARTIFACT_DIR = Path(
|
|
26
|
+
os.getenv("STRANDS_DIFFUSERS_ARTIFACTS", tempfile.gettempdir())
|
|
27
|
+
) / "strands_diffusers"
|
|
28
|
+
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
import itertools as _itertools
|
|
32
|
+
import threading as _threading
|
|
33
|
+
|
|
34
|
+
_STAMP_LOCK = _threading.Lock()
|
|
35
|
+
_STAMP_COUNTER = _itertools.count()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _stamp() -> str:
|
|
39
|
+
"""Monotonic, collision-free artifact stamp: <ms>_<counter>. Millisecond
|
|
40
|
+
timestamps alone collide in tight loops / batched generation (num_images>1),
|
|
41
|
+
silently overwriting artifacts — so append an atomic process-wide counter."""
|
|
42
|
+
with _STAMP_LOCK:
|
|
43
|
+
n = next(_STAMP_COUNTER)
|
|
44
|
+
return f"{int(time.time() * 1000)}_{n}"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# ───────────────────────── INPUT COERCION ─────────────────────────
|
|
48
|
+
|
|
49
|
+
def coerce_input(value: Any) -> Any:
|
|
50
|
+
"""Coerce an input spec into something a diffusers pipeline accepts.
|
|
51
|
+
|
|
52
|
+
- "data:..." → PIL Image / bytes
|
|
53
|
+
- "*.png|jpg|..." → PIL Image (loaded via diffusers.utils.load_image)
|
|
54
|
+
- "*.mp4|mov|..." → list[PIL] frames (via diffusers.utils.load_video)
|
|
55
|
+
- "http(s)://..." → load_image / load_video by extension
|
|
56
|
+
- lists / dicts → coerced recursively
|
|
57
|
+
- everything else → passed through (text prompt, ints, etc.)
|
|
58
|
+
"""
|
|
59
|
+
if isinstance(value, str):
|
|
60
|
+
if value.startswith("data:"):
|
|
61
|
+
return _decode_data_uri(value)
|
|
62
|
+
low = value.lower()
|
|
63
|
+
if _looks_like(low, (".png", ".jpg", ".jpeg", ".bmp", ".webp", ".gif")):
|
|
64
|
+
img = _load_image(value)
|
|
65
|
+
if img is not None:
|
|
66
|
+
return img
|
|
67
|
+
if _looks_like(low, (".mp4", ".mov", ".avi", ".mkv", ".webm")):
|
|
68
|
+
vid = _load_video(value)
|
|
69
|
+
if vid is not None:
|
|
70
|
+
return vid
|
|
71
|
+
return value # text prompt / repo id / path we don't special-case
|
|
72
|
+
if isinstance(value, list):
|
|
73
|
+
return [coerce_input(v) for v in value]
|
|
74
|
+
if isinstance(value, dict):
|
|
75
|
+
return {k: coerce_input(v) for k, v in value.items()}
|
|
76
|
+
return value
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _looks_like(path: str, exts) -> bool:
|
|
80
|
+
is_url = path.startswith("http://") or path.startswith("https://")
|
|
81
|
+
return (is_url or os.path.exists(path)) and path.split("?")[0].endswith(tuple(exts))
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _load_image(spec: str):
|
|
85
|
+
try:
|
|
86
|
+
from diffusers.utils import load_image
|
|
87
|
+
return load_image(spec)
|
|
88
|
+
except Exception:
|
|
89
|
+
try:
|
|
90
|
+
from PIL import Image
|
|
91
|
+
return Image.open(spec).convert("RGB")
|
|
92
|
+
except Exception:
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _load_video(spec: str):
|
|
97
|
+
try:
|
|
98
|
+
from diffusers.utils import load_video
|
|
99
|
+
return load_video(spec)
|
|
100
|
+
except Exception:
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _decode_data_uri(uri: str) -> Any:
|
|
105
|
+
header, _, b64 = uri.partition(",")
|
|
106
|
+
raw = base64.b64decode(b64)
|
|
107
|
+
mime = header.split(";")[0].removeprefix("data:")
|
|
108
|
+
if mime.startswith("image/"):
|
|
109
|
+
from PIL import Image
|
|
110
|
+
return Image.open(_io.BytesIO(raw)).convert("RGB")
|
|
111
|
+
return raw
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def load_array(spec: Any):
|
|
115
|
+
"""Load a numpy array from list, .npy path, or pass through ndarray.
|
|
116
|
+
|
|
117
|
+
Useful for raw robot action vectors fed to forward-dynamics WFM runs.
|
|
118
|
+
"""
|
|
119
|
+
import numpy as np
|
|
120
|
+
|
|
121
|
+
if isinstance(spec, np.ndarray):
|
|
122
|
+
return spec
|
|
123
|
+
if isinstance(spec, (list, tuple)):
|
|
124
|
+
return np.asarray(spec)
|
|
125
|
+
if isinstance(spec, str) and spec.endswith(".npy"):
|
|
126
|
+
return np.load(spec)
|
|
127
|
+
raise ValueError(f"Cannot load array from {type(spec).__name__}")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# ───────────────────────── OUTPUT SERIALIZATION ─────────────────────────
|
|
131
|
+
|
|
132
|
+
def serialize_output(result: Any, save_artifacts: bool = True,
|
|
133
|
+
fps: float = 24.0, audio_sample_rate: int = 16000) -> Dict[str, Any]:
|
|
134
|
+
"""Convert any diffusers pipeline/model output into a JSON-safe dict.
|
|
135
|
+
|
|
136
|
+
Video → .mp4, images → .png, audio → .wav, action chunks → .json — all under
|
|
137
|
+
ARTIFACT_DIR and referenced by path so the agent can hand them downstream.
|
|
138
|
+
"""
|
|
139
|
+
artifacts: List[str] = []
|
|
140
|
+
ctx = {"fps": fps, "audio_sample_rate": audio_sample_rate}
|
|
141
|
+
payload = _serialize(result, artifacts, save_artifacts, ctx)
|
|
142
|
+
payload = _ensure_json_safe(payload)
|
|
143
|
+
out: Dict[str, Any] = {"result": payload}
|
|
144
|
+
if artifacts:
|
|
145
|
+
out["artifacts"] = artifacts
|
|
146
|
+
return out
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _serialize(obj: Any, artifacts: List[str], save: bool, ctx: Dict[str, Any],
|
|
150
|
+
depth: int = 0, field: str = "") -> Any:
|
|
151
|
+
if depth > 6:
|
|
152
|
+
return str(obj)[:200]
|
|
153
|
+
|
|
154
|
+
if obj is None or isinstance(obj, (str, int, float, bool)):
|
|
155
|
+
return obj
|
|
156
|
+
|
|
157
|
+
# diffusers pipeline outputs are dataclass-like (ImagePipelineOutput,
|
|
158
|
+
# Cosmos3OmniPipelineOutput, ...) — handle their known fields explicitly so
|
|
159
|
+
# video/sound/action each go to the right serializer.
|
|
160
|
+
handled = _maybe_pipeline_output(obj, artifacts, save, ctx, depth)
|
|
161
|
+
if handled is not None:
|
|
162
|
+
return handled
|
|
163
|
+
|
|
164
|
+
# 3D mesh (ShapE / mesh-output pipelines) → .ply/.obj
|
|
165
|
+
mesh = _maybe_mesh(obj, artifacts, save)
|
|
166
|
+
if mesh is not None:
|
|
167
|
+
return mesh
|
|
168
|
+
|
|
169
|
+
# PIL image
|
|
170
|
+
pil = _maybe_pil(obj, artifacts, save)
|
|
171
|
+
if pil is not None:
|
|
172
|
+
return pil
|
|
173
|
+
|
|
174
|
+
# numpy / torch arrays
|
|
175
|
+
arr = _maybe_array(obj, artifacts, save, ctx, field)
|
|
176
|
+
if arr is not None:
|
|
177
|
+
return arr
|
|
178
|
+
|
|
179
|
+
if isinstance(obj, dict):
|
|
180
|
+
return {str(k): _serialize(v, artifacts, save, ctx, depth + 1, str(k))
|
|
181
|
+
for k, v in obj.items()}
|
|
182
|
+
if isinstance(obj, (list, tuple)):
|
|
183
|
+
# A list of PIL images is a video/frame-set → save as mp4 if many.
|
|
184
|
+
vid = _maybe_video_framelist(obj, artifacts, save, ctx, field)
|
|
185
|
+
if vid is not None:
|
|
186
|
+
return vid
|
|
187
|
+
return [_serialize(v, artifacts, save, ctx, depth + 1, field) for v in obj[:200]]
|
|
188
|
+
if isinstance(obj, (set, frozenset)):
|
|
189
|
+
return [_serialize(v, artifacts, save, ctx, depth + 1) for v in list(obj)[:200]]
|
|
190
|
+
|
|
191
|
+
if hasattr(obj, "to_dict"):
|
|
192
|
+
try:
|
|
193
|
+
return _serialize(obj.to_dict(), artifacts, save, ctx, depth + 1)
|
|
194
|
+
except Exception:
|
|
195
|
+
pass
|
|
196
|
+
if hasattr(obj, "keys"):
|
|
197
|
+
try:
|
|
198
|
+
return {str(k): _serialize(obj[k], artifacts, save, ctx, depth + 1, str(k))
|
|
199
|
+
for k in obj.keys()}
|
|
200
|
+
except Exception:
|
|
201
|
+
pass
|
|
202
|
+
|
|
203
|
+
return str(obj)[:50000]
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _maybe_pipeline_output(obj, artifacts, save, ctx, depth):
|
|
207
|
+
"""Serialize a diffusers *PipelineOutput dataclass field-by-field."""
|
|
208
|
+
cls = type(obj).__name__
|
|
209
|
+
if not cls.endswith("PipelineOutput") and not cls.endswith("Output"):
|
|
210
|
+
return None
|
|
211
|
+
# A mesh decoder output (verts/faces) is an *Output too — route to mesh export
|
|
212
|
+
# before generic field-by-field serialization (else 3D data is lost).
|
|
213
|
+
if hasattr(obj, "verts") and hasattr(obj, "faces"):
|
|
214
|
+
return _maybe_mesh(obj, artifacts, save)
|
|
215
|
+
# Gather public fields (dataclass-style or attrs).
|
|
216
|
+
fields = {}
|
|
217
|
+
if hasattr(obj, "__dataclass_fields__"):
|
|
218
|
+
names = list(obj.__dataclass_fields__.keys())
|
|
219
|
+
else:
|
|
220
|
+
names = [n for n in dir(obj) if not n.startswith("_")
|
|
221
|
+
and not callable(getattr(obj, n, None))]
|
|
222
|
+
for name in names:
|
|
223
|
+
try:
|
|
224
|
+
val = getattr(obj, name)
|
|
225
|
+
except Exception:
|
|
226
|
+
continue
|
|
227
|
+
if val is None:
|
|
228
|
+
fields[name] = None
|
|
229
|
+
continue
|
|
230
|
+
if name in ("video", "videos", "frames"):
|
|
231
|
+
fields[name] = _serialize_video(val, artifacts, save, ctx)
|
|
232
|
+
elif name in ("sound", "audio", "audios"):
|
|
233
|
+
fields[name] = _serialize_audio(val, artifacts, save, ctx)
|
|
234
|
+
elif name in ("action", "actions"):
|
|
235
|
+
fields[name] = _serialize_action(val, artifacts, save)
|
|
236
|
+
elif name in ("images", "image"):
|
|
237
|
+
fields[name] = _serialize(val, artifacts, save, ctx, depth + 1, name)
|
|
238
|
+
else:
|
|
239
|
+
fields[name] = _serialize(val, artifacts, save, ctx, depth + 1, name)
|
|
240
|
+
fields["__type__"] = cls
|
|
241
|
+
return fields
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _serialize_video(val, artifacts, save, ctx):
|
|
245
|
+
"""Save video output (list[PIL] | ndarray[T,H,W,C] | tensor) → .mp4."""
|
|
246
|
+
frames = _to_frame_list(val)
|
|
247
|
+
if frames is None:
|
|
248
|
+
return _serialize(val, artifacts, save, ctx, 5, "video")
|
|
249
|
+
if not save:
|
|
250
|
+
return {"type": "video", "num_frames": len(frames)}
|
|
251
|
+
path = _save_video(frames, ctx.get("fps", 24.0))
|
|
252
|
+
if path:
|
|
253
|
+
artifacts.append(path)
|
|
254
|
+
return {"type": "video", "path": path, "num_frames": len(frames)}
|
|
255
|
+
return {"type": "video", "num_frames": len(frames)}
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _serialize_audio(val, artifacts, save, ctx):
|
|
259
|
+
import numpy as np
|
|
260
|
+
a = _tensor_to_numpy(val)
|
|
261
|
+
if a is None:
|
|
262
|
+
return _serialize(val, artifacts, save, ctx, 5, "audio")
|
|
263
|
+
a = np.asarray(a)
|
|
264
|
+
if a.ndim > 1:
|
|
265
|
+
a = a.squeeze()
|
|
266
|
+
if a.ndim > 1:
|
|
267
|
+
# Down-mix multi-channel → mono. Audio time-axis is long, channel-axis
|
|
268
|
+
# is small (1/2/~8), so average over the SHORTER axis regardless of
|
|
269
|
+
# whether the layout is channels-first [C,N] or channels-last [N,C].
|
|
270
|
+
ch_axis = int(np.argmin(a.shape))
|
|
271
|
+
a = a.mean(axis=ch_axis)
|
|
272
|
+
if not save:
|
|
273
|
+
return {"type": "audio", "samples": int(a.size)}
|
|
274
|
+
path = _save_wav(a, int(ctx.get("audio_sample_rate", 16000)))
|
|
275
|
+
artifacts.append(path)
|
|
276
|
+
return {"type": "audio", "path": path, "samples": int(a.size),
|
|
277
|
+
"sampling_rate": int(ctx.get("audio_sample_rate", 16000))}
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _serialize_action(val, artifacts, save):
|
|
281
|
+
"""Serialize a robot ACTION chunk — the WFM payload agents care about.
|
|
282
|
+
|
|
283
|
+
Cosmos returns action as list[torch.Tensor]; each tensor is a normalized
|
|
284
|
+
action chunk [T, action_dim]. We emit full nested lists (small) + write a
|
|
285
|
+
.json artifact so the agent can feed it straight to a robot controller.
|
|
286
|
+
"""
|
|
287
|
+
import numpy as np
|
|
288
|
+
|
|
289
|
+
def _one(t):
|
|
290
|
+
a = _tensor_to_numpy(t)
|
|
291
|
+
return np.asarray(a).tolist() if a is not None else None
|
|
292
|
+
|
|
293
|
+
if isinstance(val, (list, tuple)):
|
|
294
|
+
data = [_one(t) for t in val]
|
|
295
|
+
else:
|
|
296
|
+
data = _one(val)
|
|
297
|
+
|
|
298
|
+
result = {"type": "action", "data": data}
|
|
299
|
+
if isinstance(val, (list, tuple)) and val:
|
|
300
|
+
first = _tensor_to_numpy(val[0])
|
|
301
|
+
if first is not None:
|
|
302
|
+
result["chunk_shape"] = list(np.asarray(first).shape)
|
|
303
|
+
result["num_chunks"] = len(val)
|
|
304
|
+
if save and data is not None:
|
|
305
|
+
import json
|
|
306
|
+
path = ARTIFACT_DIR / f"action_{_stamp()}.json"
|
|
307
|
+
with open(path, "w") as f:
|
|
308
|
+
json.dump(data, f)
|
|
309
|
+
artifacts.append(str(path))
|
|
310
|
+
result["path"] = str(path)
|
|
311
|
+
return result
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _maybe_mesh(obj, artifacts, save):
|
|
315
|
+
"""A 3D mesh (ShapE MeshDecoderOutput / trimesh-like) → .ply (+ .obj).
|
|
316
|
+
|
|
317
|
+
diffusers mesh outputs expose `verts` and `faces`; diffusers.utils ships
|
|
318
|
+
export_to_ply / export_to_obj. Without this path a mesh would serialize to an
|
|
319
|
+
opaque repr string (silent 3D data loss).
|
|
320
|
+
"""
|
|
321
|
+
# NB: diffusers BaseOutput subclasses OrderedDict, so DON'T exclude dict here —
|
|
322
|
+
# detect by the verts/faces attributes instead.
|
|
323
|
+
if obj is None or isinstance(obj, (str, int, float, bool, list, tuple)):
|
|
324
|
+
return None
|
|
325
|
+
if not (hasattr(obj, "verts") and hasattr(obj, "faces")):
|
|
326
|
+
return None
|
|
327
|
+
nv = _safe_len(getattr(obj, "verts", None))
|
|
328
|
+
nf = _safe_len(getattr(obj, "faces", None))
|
|
329
|
+
info = {"type": "mesh", "num_verts": nv, "num_faces": nf}
|
|
330
|
+
if save:
|
|
331
|
+
path = _save_mesh(obj)
|
|
332
|
+
if path:
|
|
333
|
+
artifacts.append(path)
|
|
334
|
+
info["path"] = path
|
|
335
|
+
return info
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _safe_len(x):
|
|
339
|
+
try:
|
|
340
|
+
if x is None:
|
|
341
|
+
return None
|
|
342
|
+
if hasattr(x, "shape"):
|
|
343
|
+
return int(x.shape[0])
|
|
344
|
+
return len(x)
|
|
345
|
+
except Exception:
|
|
346
|
+
return None
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def _save_mesh(mesh) -> Optional[str]:
|
|
350
|
+
"""Write a mesh → .ply via diffusers.utils.export_to_ply (obj fallback)."""
|
|
351
|
+
ts = _stamp()
|
|
352
|
+
try:
|
|
353
|
+
from diffusers.utils import export_to_ply
|
|
354
|
+
path = ARTIFACT_DIR / f"mesh_{ts}.ply"
|
|
355
|
+
export_to_ply(mesh, str(path))
|
|
356
|
+
return str(path)
|
|
357
|
+
except Exception:
|
|
358
|
+
pass
|
|
359
|
+
try:
|
|
360
|
+
from diffusers.utils import export_to_obj
|
|
361
|
+
path = ARTIFACT_DIR / f"mesh_{ts}.obj"
|
|
362
|
+
export_to_obj(mesh, str(path))
|
|
363
|
+
return str(path)
|
|
364
|
+
except Exception:
|
|
365
|
+
pass
|
|
366
|
+
# last resort: dump raw verts/faces as npz so data isn't lost
|
|
367
|
+
try:
|
|
368
|
+
import numpy as np
|
|
369
|
+
path = ARTIFACT_DIR / f"mesh_{ts}.npz"
|
|
370
|
+
v = _tensor_to_numpy(mesh.verts)
|
|
371
|
+
fa = _tensor_to_numpy(mesh.faces)
|
|
372
|
+
np.savez(str(path), verts=v, faces=fa)
|
|
373
|
+
return str(path)
|
|
374
|
+
except Exception:
|
|
375
|
+
return None
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def _maybe_pil(obj, artifacts, save):
|
|
379
|
+
try:
|
|
380
|
+
from PIL import Image
|
|
381
|
+
if isinstance(obj, Image.Image):
|
|
382
|
+
if save:
|
|
383
|
+
path = _save_image(obj)
|
|
384
|
+
artifacts.append(path)
|
|
385
|
+
return {"type": "image", "path": path, "size": list(obj.size)}
|
|
386
|
+
return {"type": "image", "size": list(obj.size)}
|
|
387
|
+
except ImportError:
|
|
388
|
+
pass
|
|
389
|
+
return None
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def _maybe_video_framelist(obj, artifacts, save, ctx, field):
|
|
393
|
+
"""A long list of PIL frames (not in a known field) → treat as video."""
|
|
394
|
+
try:
|
|
395
|
+
from PIL import Image
|
|
396
|
+
except ImportError:
|
|
397
|
+
return None
|
|
398
|
+
if (isinstance(obj, list) and len(obj) >= 8
|
|
399
|
+
and all(isinstance(x, Image.Image) for x in obj[:8])):
|
|
400
|
+
return _serialize_video(obj, artifacts, save, ctx)
|
|
401
|
+
return None
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def _maybe_array(obj, artifacts, save, ctx, field):
|
|
405
|
+
a = _tensor_to_numpy(obj)
|
|
406
|
+
if a is None:
|
|
407
|
+
return None
|
|
408
|
+
import numpy as np
|
|
409
|
+
arr = np.asarray(a)
|
|
410
|
+
# A 4D/5D array that looks like video frames → mp4
|
|
411
|
+
if save and arr.ndim in (4, 5) and field in ("video", "videos", "frames"):
|
|
412
|
+
return _serialize_video(arr, artifacts, save, ctx)
|
|
413
|
+
if arr.size <= 256:
|
|
414
|
+
return arr.tolist()
|
|
415
|
+
return {
|
|
416
|
+
"type": "ndarray",
|
|
417
|
+
"shape": list(arr.shape),
|
|
418
|
+
"dtype": str(arr.dtype),
|
|
419
|
+
"preview": arr.flatten()[:16].tolist(),
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def _tensor_to_numpy(obj):
|
|
424
|
+
try:
|
|
425
|
+
import torch
|
|
426
|
+
if isinstance(obj, torch.Tensor):
|
|
427
|
+
obj = obj.detach().cpu()
|
|
428
|
+
if obj.dtype in (torch.bfloat16, torch.float16):
|
|
429
|
+
obj = obj.to(torch.float32)
|
|
430
|
+
return obj.numpy()
|
|
431
|
+
except ImportError:
|
|
432
|
+
pass
|
|
433
|
+
try:
|
|
434
|
+
import numpy as np
|
|
435
|
+
if isinstance(obj, (np.ndarray, np.generic)):
|
|
436
|
+
return np.asarray(obj)
|
|
437
|
+
except ImportError:
|
|
438
|
+
pass
|
|
439
|
+
return None
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def _to_frame_list(val):
|
|
443
|
+
"""Normalize a video output to a list of HxWxC uint8 numpy frames."""
|
|
444
|
+
import numpy as np
|
|
445
|
+
|
|
446
|
+
try:
|
|
447
|
+
from PIL import Image
|
|
448
|
+
except ImportError:
|
|
449
|
+
Image = None
|
|
450
|
+
|
|
451
|
+
if isinstance(val, (list, tuple)):
|
|
452
|
+
if Image and val and all(isinstance(x, Image.Image) for x in val):
|
|
453
|
+
return [np.asarray(x.convert("RGB")) for x in val]
|
|
454
|
+
# list of per-frame numpy arrays (a common pipeline return shape) → stack
|
|
455
|
+
if val and all(isinstance(x, np.ndarray) for x in val):
|
|
456
|
+
return _to_frame_list(np.stack(val))
|
|
457
|
+
# list of per-frame torch tensors → stack via numpy
|
|
458
|
+
a0 = _tensor_to_numpy(val[0]) if val else None
|
|
459
|
+
if a0 is not None and not isinstance(val[0], np.ndarray):
|
|
460
|
+
return _to_frame_list(np.stack([_tensor_to_numpy(x) for x in val]))
|
|
461
|
+
# nested (batched) list of frames → take first sample
|
|
462
|
+
if val and isinstance(val[0], (list, tuple)):
|
|
463
|
+
return _to_frame_list(val[0])
|
|
464
|
+
a = _tensor_to_numpy(val)
|
|
465
|
+
if a is None:
|
|
466
|
+
return None
|
|
467
|
+
a = np.asarray(a)
|
|
468
|
+
if a.ndim == 5: # [B,T,H,W,C] or [B,T,C,H,W]
|
|
469
|
+
a = a[0]
|
|
470
|
+
if a.ndim == 4:
|
|
471
|
+
# detect channel position
|
|
472
|
+
if a.shape[1] in (1, 3, 4) and a.shape[-1] not in (1, 3, 4):
|
|
473
|
+
a = np.transpose(a, (0, 2, 3, 1)) # [T,C,H,W] → [T,H,W,C]
|
|
474
|
+
if a.dtype != np.uint8:
|
|
475
|
+
a = np.clip(a, 0, 1) if a.max() <= 1.0 else np.clip(a, 0, 255) / 255.0
|
|
476
|
+
a = (a * 255).astype(np.uint8) if a.max() <= 1.0 + 1e-6 else a.astype(np.uint8)
|
|
477
|
+
return [a[i] for i in range(a.shape[0])]
|
|
478
|
+
return None
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
# ───────────────────────── artifact writers ─────────────────────────
|
|
482
|
+
|
|
483
|
+
def _save_image(image) -> str:
|
|
484
|
+
path = ARTIFACT_DIR / f"image_{_stamp()}.png"
|
|
485
|
+
image.save(str(path))
|
|
486
|
+
return str(path)
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def _save_video(frames, fps: float) -> Optional[str]:
|
|
490
|
+
"""Write frames → mp4. Prefer diffusers.export_to_video, fall back to imageio."""
|
|
491
|
+
path = ARTIFACT_DIR / f"video_{_stamp()}.mp4"
|
|
492
|
+
try:
|
|
493
|
+
from PIL import Image
|
|
494
|
+
pil_frames = [Image.fromarray(f) for f in frames]
|
|
495
|
+
from diffusers.utils import export_to_video
|
|
496
|
+
export_to_video(pil_frames, str(path), fps=int(fps))
|
|
497
|
+
return str(path)
|
|
498
|
+
except Exception:
|
|
499
|
+
pass
|
|
500
|
+
try:
|
|
501
|
+
import imageio.v3 as iio
|
|
502
|
+
iio.imwrite(str(path), frames, fps=int(fps), codec="libx264")
|
|
503
|
+
return str(path)
|
|
504
|
+
except Exception:
|
|
505
|
+
pass
|
|
506
|
+
# last resort: dump frames as a gif (always available via PIL)
|
|
507
|
+
try:
|
|
508
|
+
from PIL import Image
|
|
509
|
+
gif = ARTIFACT_DIR / f"video_{_stamp()}.gif"
|
|
510
|
+
imgs = [Image.fromarray(f) for f in frames]
|
|
511
|
+
imgs[0].save(str(gif), save_all=True, append_images=imgs[1:],
|
|
512
|
+
duration=int(1000 / max(fps, 1)), loop=0)
|
|
513
|
+
return str(gif)
|
|
514
|
+
except Exception:
|
|
515
|
+
return None
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def _save_wav(audio, sampling_rate: int) -> str:
|
|
519
|
+
import numpy as np
|
|
520
|
+
|
|
521
|
+
a = np.asarray(audio, dtype=np.float32)
|
|
522
|
+
if a.ndim > 1:
|
|
523
|
+
a = a.squeeze()
|
|
524
|
+
path = ARTIFACT_DIR / f"audio_{_stamp()}.wav"
|
|
525
|
+
try:
|
|
526
|
+
import soundfile as sf
|
|
527
|
+
sf.write(str(path), a, int(sampling_rate))
|
|
528
|
+
return str(path)
|
|
529
|
+
except ImportError:
|
|
530
|
+
pass
|
|
531
|
+
import wave
|
|
532
|
+
a = np.clip(a, -1.0, 1.0)
|
|
533
|
+
pcm = (a * 32767).astype(np.int16)
|
|
534
|
+
with wave.open(str(path), "wb") as w:
|
|
535
|
+
w.setnchannels(1)
|
|
536
|
+
w.setsampwidth(2)
|
|
537
|
+
w.setframerate(int(sampling_rate))
|
|
538
|
+
w.writeframes(pcm.tobytes())
|
|
539
|
+
return str(path)
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def _ensure_json_safe(obj: Any) -> Any:
|
|
543
|
+
import json as _json
|
|
544
|
+
try:
|
|
545
|
+
_json.dumps(obj)
|
|
546
|
+
return obj
|
|
547
|
+
except (TypeError, ValueError):
|
|
548
|
+
if isinstance(obj, dict):
|
|
549
|
+
return {str(k): _ensure_json_safe(v) for k, v in obj.items()}
|
|
550
|
+
if isinstance(obj, (list, tuple)):
|
|
551
|
+
return [_ensure_json_safe(v) for v in obj]
|
|
552
|
+
return str(obj)[:500]
|