strands-diffusers 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,552 @@
1
+ """Native multimodal I/O for diffusion — images / video / audio / actions out.
2
+
3
+ Inputs: file paths, URLs, base64 data URIs, PIL images, numpy arrays, video paths.
4
+ Outputs: diffusers pipeline results serialized to JSON-safe form, with binary
5
+ artifacts (generated images, videos, audio, and robot ACTION chunks) written to
6
+ disk and referenced by path.
7
+
8
+ The headline feature for Physical-AI / world-foundation models: a Cosmos-style
9
+ pipeline returns a `Cosmos3OmniPipelineOutput(video=..., sound=..., action=...)`.
10
+ We serialize the video to .mp4, the sound to .wav, and the **action** chunk to a
11
+ .json (the model-normalized action-space tensor) — so an agent gets back a path
12
+ to a playable world AND a usable robot action vector in one call.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import base64
18
+ import io as _io
19
+ import os
20
+ import tempfile
21
+ import time
22
+ from pathlib import Path
23
+ from typing import Any, Dict, List, Optional
24
+
25
+ ARTIFACT_DIR = Path(
26
+ os.getenv("STRANDS_DIFFUSERS_ARTIFACTS", tempfile.gettempdir())
27
+ ) / "strands_diffusers"
28
+ ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
29
+
30
+
31
+ import itertools as _itertools
32
+ import threading as _threading
33
+
34
+ _STAMP_LOCK = _threading.Lock()
35
+ _STAMP_COUNTER = _itertools.count()
36
+
37
+
38
+ def _stamp() -> str:
39
+ """Monotonic, collision-free artifact stamp: <ms>_<counter>. Millisecond
40
+ timestamps alone collide in tight loops / batched generation (num_images>1),
41
+ silently overwriting artifacts — so append an atomic process-wide counter."""
42
+ with _STAMP_LOCK:
43
+ n = next(_STAMP_COUNTER)
44
+ return f"{int(time.time() * 1000)}_{n}"
45
+
46
+
47
+ # ───────────────────────── INPUT COERCION ─────────────────────────
48
+
49
+ def coerce_input(value: Any) -> Any:
50
+ """Coerce an input spec into something a diffusers pipeline accepts.
51
+
52
+ - "data:..." → PIL Image / bytes
53
+ - "*.png|jpg|..." → PIL Image (loaded via diffusers.utils.load_image)
54
+ - "*.mp4|mov|..." → list[PIL] frames (via diffusers.utils.load_video)
55
+ - "http(s)://..." → load_image / load_video by extension
56
+ - lists / dicts → coerced recursively
57
+ - everything else → passed through (text prompt, ints, etc.)
58
+ """
59
+ if isinstance(value, str):
60
+ if value.startswith("data:"):
61
+ return _decode_data_uri(value)
62
+ low = value.lower()
63
+ if _looks_like(low, (".png", ".jpg", ".jpeg", ".bmp", ".webp", ".gif")):
64
+ img = _load_image(value)
65
+ if img is not None:
66
+ return img
67
+ if _looks_like(low, (".mp4", ".mov", ".avi", ".mkv", ".webm")):
68
+ vid = _load_video(value)
69
+ if vid is not None:
70
+ return vid
71
+ return value # text prompt / repo id / path we don't special-case
72
+ if isinstance(value, list):
73
+ return [coerce_input(v) for v in value]
74
+ if isinstance(value, dict):
75
+ return {k: coerce_input(v) for k, v in value.items()}
76
+ return value
77
+
78
+
79
+ def _looks_like(path: str, exts) -> bool:
80
+ is_url = path.startswith("http://") or path.startswith("https://")
81
+ return (is_url or os.path.exists(path)) and path.split("?")[0].endswith(tuple(exts))
82
+
83
+
84
+ def _load_image(spec: str):
85
+ try:
86
+ from diffusers.utils import load_image
87
+ return load_image(spec)
88
+ except Exception:
89
+ try:
90
+ from PIL import Image
91
+ return Image.open(spec).convert("RGB")
92
+ except Exception:
93
+ return None
94
+
95
+
96
+ def _load_video(spec: str):
97
+ try:
98
+ from diffusers.utils import load_video
99
+ return load_video(spec)
100
+ except Exception:
101
+ return None
102
+
103
+
104
+ def _decode_data_uri(uri: str) -> Any:
105
+ header, _, b64 = uri.partition(",")
106
+ raw = base64.b64decode(b64)
107
+ mime = header.split(";")[0].removeprefix("data:")
108
+ if mime.startswith("image/"):
109
+ from PIL import Image
110
+ return Image.open(_io.BytesIO(raw)).convert("RGB")
111
+ return raw
112
+
113
+
114
+ def load_array(spec: Any):
115
+ """Load a numpy array from list, .npy path, or pass through ndarray.
116
+
117
+ Useful for raw robot action vectors fed to forward-dynamics WFM runs.
118
+ """
119
+ import numpy as np
120
+
121
+ if isinstance(spec, np.ndarray):
122
+ return spec
123
+ if isinstance(spec, (list, tuple)):
124
+ return np.asarray(spec)
125
+ if isinstance(spec, str) and spec.endswith(".npy"):
126
+ return np.load(spec)
127
+ raise ValueError(f"Cannot load array from {type(spec).__name__}")
128
+
129
+
130
+ # ───────────────────────── OUTPUT SERIALIZATION ─────────────────────────
131
+
132
+ def serialize_output(result: Any, save_artifacts: bool = True,
133
+ fps: float = 24.0, audio_sample_rate: int = 16000) -> Dict[str, Any]:
134
+ """Convert any diffusers pipeline/model output into a JSON-safe dict.
135
+
136
+ Video → .mp4, images → .png, audio → .wav, action chunks → .json — all under
137
+ ARTIFACT_DIR and referenced by path so the agent can hand them downstream.
138
+ """
139
+ artifacts: List[str] = []
140
+ ctx = {"fps": fps, "audio_sample_rate": audio_sample_rate}
141
+ payload = _serialize(result, artifacts, save_artifacts, ctx)
142
+ payload = _ensure_json_safe(payload)
143
+ out: Dict[str, Any] = {"result": payload}
144
+ if artifacts:
145
+ out["artifacts"] = artifacts
146
+ return out
147
+
148
+
149
+ def _serialize(obj: Any, artifacts: List[str], save: bool, ctx: Dict[str, Any],
150
+ depth: int = 0, field: str = "") -> Any:
151
+ if depth > 6:
152
+ return str(obj)[:200]
153
+
154
+ if obj is None or isinstance(obj, (str, int, float, bool)):
155
+ return obj
156
+
157
+ # diffusers pipeline outputs are dataclass-like (ImagePipelineOutput,
158
+ # Cosmos3OmniPipelineOutput, ...) — handle their known fields explicitly so
159
+ # video/sound/action each go to the right serializer.
160
+ handled = _maybe_pipeline_output(obj, artifacts, save, ctx, depth)
161
+ if handled is not None:
162
+ return handled
163
+
164
+ # 3D mesh (ShapE / mesh-output pipelines) → .ply/.obj
165
+ mesh = _maybe_mesh(obj, artifacts, save)
166
+ if mesh is not None:
167
+ return mesh
168
+
169
+ # PIL image
170
+ pil = _maybe_pil(obj, artifacts, save)
171
+ if pil is not None:
172
+ return pil
173
+
174
+ # numpy / torch arrays
175
+ arr = _maybe_array(obj, artifacts, save, ctx, field)
176
+ if arr is not None:
177
+ return arr
178
+
179
+ if isinstance(obj, dict):
180
+ return {str(k): _serialize(v, artifacts, save, ctx, depth + 1, str(k))
181
+ for k, v in obj.items()}
182
+ if isinstance(obj, (list, tuple)):
183
+ # A list of PIL images is a video/frame-set → save as mp4 if many.
184
+ vid = _maybe_video_framelist(obj, artifacts, save, ctx, field)
185
+ if vid is not None:
186
+ return vid
187
+ return [_serialize(v, artifacts, save, ctx, depth + 1, field) for v in obj[:200]]
188
+ if isinstance(obj, (set, frozenset)):
189
+ return [_serialize(v, artifacts, save, ctx, depth + 1) for v in list(obj)[:200]]
190
+
191
+ if hasattr(obj, "to_dict"):
192
+ try:
193
+ return _serialize(obj.to_dict(), artifacts, save, ctx, depth + 1)
194
+ except Exception:
195
+ pass
196
+ if hasattr(obj, "keys"):
197
+ try:
198
+ return {str(k): _serialize(obj[k], artifacts, save, ctx, depth + 1, str(k))
199
+ for k in obj.keys()}
200
+ except Exception:
201
+ pass
202
+
203
+ return str(obj)[:50000]
204
+
205
+
206
+ def _maybe_pipeline_output(obj, artifacts, save, ctx, depth):
207
+ """Serialize a diffusers *PipelineOutput dataclass field-by-field."""
208
+ cls = type(obj).__name__
209
+ if not cls.endswith("PipelineOutput") and not cls.endswith("Output"):
210
+ return None
211
+ # A mesh decoder output (verts/faces) is an *Output too — route to mesh export
212
+ # before generic field-by-field serialization (else 3D data is lost).
213
+ if hasattr(obj, "verts") and hasattr(obj, "faces"):
214
+ return _maybe_mesh(obj, artifacts, save)
215
+ # Gather public fields (dataclass-style or attrs).
216
+ fields = {}
217
+ if hasattr(obj, "__dataclass_fields__"):
218
+ names = list(obj.__dataclass_fields__.keys())
219
+ else:
220
+ names = [n for n in dir(obj) if not n.startswith("_")
221
+ and not callable(getattr(obj, n, None))]
222
+ for name in names:
223
+ try:
224
+ val = getattr(obj, name)
225
+ except Exception:
226
+ continue
227
+ if val is None:
228
+ fields[name] = None
229
+ continue
230
+ if name in ("video", "videos", "frames"):
231
+ fields[name] = _serialize_video(val, artifacts, save, ctx)
232
+ elif name in ("sound", "audio", "audios"):
233
+ fields[name] = _serialize_audio(val, artifacts, save, ctx)
234
+ elif name in ("action", "actions"):
235
+ fields[name] = _serialize_action(val, artifacts, save)
236
+ elif name in ("images", "image"):
237
+ fields[name] = _serialize(val, artifacts, save, ctx, depth + 1, name)
238
+ else:
239
+ fields[name] = _serialize(val, artifacts, save, ctx, depth + 1, name)
240
+ fields["__type__"] = cls
241
+ return fields
242
+
243
+
244
+ def _serialize_video(val, artifacts, save, ctx):
245
+ """Save video output (list[PIL] | ndarray[T,H,W,C] | tensor) → .mp4."""
246
+ frames = _to_frame_list(val)
247
+ if frames is None:
248
+ return _serialize(val, artifacts, save, ctx, 5, "video")
249
+ if not save:
250
+ return {"type": "video", "num_frames": len(frames)}
251
+ path = _save_video(frames, ctx.get("fps", 24.0))
252
+ if path:
253
+ artifacts.append(path)
254
+ return {"type": "video", "path": path, "num_frames": len(frames)}
255
+ return {"type": "video", "num_frames": len(frames)}
256
+
257
+
258
+ def _serialize_audio(val, artifacts, save, ctx):
259
+ import numpy as np
260
+ a = _tensor_to_numpy(val)
261
+ if a is None:
262
+ return _serialize(val, artifacts, save, ctx, 5, "audio")
263
+ a = np.asarray(a)
264
+ if a.ndim > 1:
265
+ a = a.squeeze()
266
+ if a.ndim > 1:
267
+ # Down-mix multi-channel → mono. Audio time-axis is long, channel-axis
268
+ # is small (1/2/~8), so average over the SHORTER axis regardless of
269
+ # whether the layout is channels-first [C,N] or channels-last [N,C].
270
+ ch_axis = int(np.argmin(a.shape))
271
+ a = a.mean(axis=ch_axis)
272
+ if not save:
273
+ return {"type": "audio", "samples": int(a.size)}
274
+ path = _save_wav(a, int(ctx.get("audio_sample_rate", 16000)))
275
+ artifacts.append(path)
276
+ return {"type": "audio", "path": path, "samples": int(a.size),
277
+ "sampling_rate": int(ctx.get("audio_sample_rate", 16000))}
278
+
279
+
280
+ def _serialize_action(val, artifacts, save):
281
+ """Serialize a robot ACTION chunk — the WFM payload agents care about.
282
+
283
+ Cosmos returns action as list[torch.Tensor]; each tensor is a normalized
284
+ action chunk [T, action_dim]. We emit full nested lists (small) + write a
285
+ .json artifact so the agent can feed it straight to a robot controller.
286
+ """
287
+ import numpy as np
288
+
289
+ def _one(t):
290
+ a = _tensor_to_numpy(t)
291
+ return np.asarray(a).tolist() if a is not None else None
292
+
293
+ if isinstance(val, (list, tuple)):
294
+ data = [_one(t) for t in val]
295
+ else:
296
+ data = _one(val)
297
+
298
+ result = {"type": "action", "data": data}
299
+ if isinstance(val, (list, tuple)) and val:
300
+ first = _tensor_to_numpy(val[0])
301
+ if first is not None:
302
+ result["chunk_shape"] = list(np.asarray(first).shape)
303
+ result["num_chunks"] = len(val)
304
+ if save and data is not None:
305
+ import json
306
+ path = ARTIFACT_DIR / f"action_{_stamp()}.json"
307
+ with open(path, "w") as f:
308
+ json.dump(data, f)
309
+ artifacts.append(str(path))
310
+ result["path"] = str(path)
311
+ return result
312
+
313
+
314
+ def _maybe_mesh(obj, artifacts, save):
315
+ """A 3D mesh (ShapE MeshDecoderOutput / trimesh-like) → .ply (+ .obj).
316
+
317
+ diffusers mesh outputs expose `verts` and `faces`; diffusers.utils ships
318
+ export_to_ply / export_to_obj. Without this path a mesh would serialize to an
319
+ opaque repr string (silent 3D data loss).
320
+ """
321
+ # NB: diffusers BaseOutput subclasses OrderedDict, so DON'T exclude dict here —
322
+ # detect by the verts/faces attributes instead.
323
+ if obj is None or isinstance(obj, (str, int, float, bool, list, tuple)):
324
+ return None
325
+ if not (hasattr(obj, "verts") and hasattr(obj, "faces")):
326
+ return None
327
+ nv = _safe_len(getattr(obj, "verts", None))
328
+ nf = _safe_len(getattr(obj, "faces", None))
329
+ info = {"type": "mesh", "num_verts": nv, "num_faces": nf}
330
+ if save:
331
+ path = _save_mesh(obj)
332
+ if path:
333
+ artifacts.append(path)
334
+ info["path"] = path
335
+ return info
336
+
337
+
338
+ def _safe_len(x):
339
+ try:
340
+ if x is None:
341
+ return None
342
+ if hasattr(x, "shape"):
343
+ return int(x.shape[0])
344
+ return len(x)
345
+ except Exception:
346
+ return None
347
+
348
+
349
+ def _save_mesh(mesh) -> Optional[str]:
350
+ """Write a mesh → .ply via diffusers.utils.export_to_ply (obj fallback)."""
351
+ ts = _stamp()
352
+ try:
353
+ from diffusers.utils import export_to_ply
354
+ path = ARTIFACT_DIR / f"mesh_{ts}.ply"
355
+ export_to_ply(mesh, str(path))
356
+ return str(path)
357
+ except Exception:
358
+ pass
359
+ try:
360
+ from diffusers.utils import export_to_obj
361
+ path = ARTIFACT_DIR / f"mesh_{ts}.obj"
362
+ export_to_obj(mesh, str(path))
363
+ return str(path)
364
+ except Exception:
365
+ pass
366
+ # last resort: dump raw verts/faces as npz so data isn't lost
367
+ try:
368
+ import numpy as np
369
+ path = ARTIFACT_DIR / f"mesh_{ts}.npz"
370
+ v = _tensor_to_numpy(mesh.verts)
371
+ fa = _tensor_to_numpy(mesh.faces)
372
+ np.savez(str(path), verts=v, faces=fa)
373
+ return str(path)
374
+ except Exception:
375
+ return None
376
+
377
+
378
+ def _maybe_pil(obj, artifacts, save):
379
+ try:
380
+ from PIL import Image
381
+ if isinstance(obj, Image.Image):
382
+ if save:
383
+ path = _save_image(obj)
384
+ artifacts.append(path)
385
+ return {"type": "image", "path": path, "size": list(obj.size)}
386
+ return {"type": "image", "size": list(obj.size)}
387
+ except ImportError:
388
+ pass
389
+ return None
390
+
391
+
392
+ def _maybe_video_framelist(obj, artifacts, save, ctx, field):
393
+ """A long list of PIL frames (not in a known field) → treat as video."""
394
+ try:
395
+ from PIL import Image
396
+ except ImportError:
397
+ return None
398
+ if (isinstance(obj, list) and len(obj) >= 8
399
+ and all(isinstance(x, Image.Image) for x in obj[:8])):
400
+ return _serialize_video(obj, artifacts, save, ctx)
401
+ return None
402
+
403
+
404
+ def _maybe_array(obj, artifacts, save, ctx, field):
405
+ a = _tensor_to_numpy(obj)
406
+ if a is None:
407
+ return None
408
+ import numpy as np
409
+ arr = np.asarray(a)
410
+ # A 4D/5D array that looks like video frames → mp4
411
+ if save and arr.ndim in (4, 5) and field in ("video", "videos", "frames"):
412
+ return _serialize_video(arr, artifacts, save, ctx)
413
+ if arr.size <= 256:
414
+ return arr.tolist()
415
+ return {
416
+ "type": "ndarray",
417
+ "shape": list(arr.shape),
418
+ "dtype": str(arr.dtype),
419
+ "preview": arr.flatten()[:16].tolist(),
420
+ }
421
+
422
+
423
+ def _tensor_to_numpy(obj):
424
+ try:
425
+ import torch
426
+ if isinstance(obj, torch.Tensor):
427
+ obj = obj.detach().cpu()
428
+ if obj.dtype in (torch.bfloat16, torch.float16):
429
+ obj = obj.to(torch.float32)
430
+ return obj.numpy()
431
+ except ImportError:
432
+ pass
433
+ try:
434
+ import numpy as np
435
+ if isinstance(obj, (np.ndarray, np.generic)):
436
+ return np.asarray(obj)
437
+ except ImportError:
438
+ pass
439
+ return None
440
+
441
+
442
+ def _to_frame_list(val):
443
+ """Normalize a video output to a list of HxWxC uint8 numpy frames."""
444
+ import numpy as np
445
+
446
+ try:
447
+ from PIL import Image
448
+ except ImportError:
449
+ Image = None
450
+
451
+ if isinstance(val, (list, tuple)):
452
+ if Image and val and all(isinstance(x, Image.Image) for x in val):
453
+ return [np.asarray(x.convert("RGB")) for x in val]
454
+ # list of per-frame numpy arrays (a common pipeline return shape) → stack
455
+ if val and all(isinstance(x, np.ndarray) for x in val):
456
+ return _to_frame_list(np.stack(val))
457
+ # list of per-frame torch tensors → stack via numpy
458
+ a0 = _tensor_to_numpy(val[0]) if val else None
459
+ if a0 is not None and not isinstance(val[0], np.ndarray):
460
+ return _to_frame_list(np.stack([_tensor_to_numpy(x) for x in val]))
461
+ # nested (batched) list of frames → take first sample
462
+ if val and isinstance(val[0], (list, tuple)):
463
+ return _to_frame_list(val[0])
464
+ a = _tensor_to_numpy(val)
465
+ if a is None:
466
+ return None
467
+ a = np.asarray(a)
468
+ if a.ndim == 5: # [B,T,H,W,C] or [B,T,C,H,W]
469
+ a = a[0]
470
+ if a.ndim == 4:
471
+ # detect channel position
472
+ if a.shape[1] in (1, 3, 4) and a.shape[-1] not in (1, 3, 4):
473
+ a = np.transpose(a, (0, 2, 3, 1)) # [T,C,H,W] → [T,H,W,C]
474
+ if a.dtype != np.uint8:
475
+ a = np.clip(a, 0, 1) if a.max() <= 1.0 else np.clip(a, 0, 255) / 255.0
476
+ a = (a * 255).astype(np.uint8) if a.max() <= 1.0 + 1e-6 else a.astype(np.uint8)
477
+ return [a[i] for i in range(a.shape[0])]
478
+ return None
479
+
480
+
481
+ # ───────────────────────── artifact writers ─────────────────────────
482
+
483
+ def _save_image(image) -> str:
484
+ path = ARTIFACT_DIR / f"image_{_stamp()}.png"
485
+ image.save(str(path))
486
+ return str(path)
487
+
488
+
489
+ def _save_video(frames, fps: float) -> Optional[str]:
490
+ """Write frames → mp4. Prefer diffusers.export_to_video, fall back to imageio."""
491
+ path = ARTIFACT_DIR / f"video_{_stamp()}.mp4"
492
+ try:
493
+ from PIL import Image
494
+ pil_frames = [Image.fromarray(f) for f in frames]
495
+ from diffusers.utils import export_to_video
496
+ export_to_video(pil_frames, str(path), fps=int(fps))
497
+ return str(path)
498
+ except Exception:
499
+ pass
500
+ try:
501
+ import imageio.v3 as iio
502
+ iio.imwrite(str(path), frames, fps=int(fps), codec="libx264")
503
+ return str(path)
504
+ except Exception:
505
+ pass
506
+ # last resort: dump frames as a gif (always available via PIL)
507
+ try:
508
+ from PIL import Image
509
+ gif = ARTIFACT_DIR / f"video_{_stamp()}.gif"
510
+ imgs = [Image.fromarray(f) for f in frames]
511
+ imgs[0].save(str(gif), save_all=True, append_images=imgs[1:],
512
+ duration=int(1000 / max(fps, 1)), loop=0)
513
+ return str(gif)
514
+ except Exception:
515
+ return None
516
+
517
+
518
+ def _save_wav(audio, sampling_rate: int) -> str:
519
+ import numpy as np
520
+
521
+ a = np.asarray(audio, dtype=np.float32)
522
+ if a.ndim > 1:
523
+ a = a.squeeze()
524
+ path = ARTIFACT_DIR / f"audio_{_stamp()}.wav"
525
+ try:
526
+ import soundfile as sf
527
+ sf.write(str(path), a, int(sampling_rate))
528
+ return str(path)
529
+ except ImportError:
530
+ pass
531
+ import wave
532
+ a = np.clip(a, -1.0, 1.0)
533
+ pcm = (a * 32767).astype(np.int16)
534
+ with wave.open(str(path), "wb") as w:
535
+ w.setnchannels(1)
536
+ w.setsampwidth(2)
537
+ w.setframerate(int(sampling_rate))
538
+ w.writeframes(pcm.tobytes())
539
+ return str(path)
540
+
541
+
542
+ def _ensure_json_safe(obj: Any) -> Any:
543
+ import json as _json
544
+ try:
545
+ _json.dumps(obj)
546
+ return obj
547
+ except (TypeError, ValueError):
548
+ if isinstance(obj, dict):
549
+ return {str(k): _ensure_json_safe(v) for k, v in obj.items()}
550
+ if isinstance(obj, (list, tuple)):
551
+ return [_ensure_json_safe(v) for v in obj]
552
+ return str(obj)[:500]