strands-diffusers 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,349 @@
1
+ """Dynamic pipeline/model/scheduler registry — 100% diffusers coverage, zero hardcoding.
2
+
3
+ The single source of truth is diffusers' own `_import_structure` — the lazy map of
4
+ every public symbol it exposes (307 pipelines, 87 models, 54 schedulers as of
5
+ 0.38). We read it at runtime, so when diffusers adds a new pipeline (e.g. a fresh
6
+ Cosmos world-foundation model), strands-diffusers supports it automatically — no
7
+ code change required.
8
+
9
+ Same philosophy as `use_aws` (wraps boto3 dynamically), `use_lerobot` (wraps
10
+ lerobot) and `use_transformers` (wraps the transformers task taxonomy): discover,
11
+ don't hardcode.
12
+
13
+ Diffusers has no single "task taxonomy" like transformers' SUPPORTED_TASKS, so we
14
+ derive structure from three places, all dynamic:
15
+
16
+ 1. `diffusers._import_structure` → every public class, grouped by submodule.
17
+ 2. The `AutoPipelineFor*` mappings → the canonical task → pipeline-class maps
18
+ (text2image / image2image / inpainting), diffusers' closest thing to tasks.
19
+ 3. Class-name heuristics → group the long tail of pipelines by the
20
+ modality their name implies (TextToImage / ImageToVideo / VideoToWorld / ...).
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import importlib
26
+ import inspect
27
+ import re
28
+ from functools import lru_cache
29
+ from typing import Any, Dict, List, Optional
30
+
31
+
32
+ @lru_cache(maxsize=1)
33
+ def _import_structure() -> Dict[str, List[str]]:
34
+ import diffusers
35
+
36
+ ist = getattr(diffusers, "_import_structure", None) or {}
37
+ # Flatten dotted keys but keep the group name (submodule) for context.
38
+ return {k: list(v) for k, v in ist.items()}
39
+
40
+
41
+ @lru_cache(maxsize=1)
42
+ def all_symbols() -> Dict[str, str]:
43
+ """Every public diffusers symbol → its kind (pipeline/model/scheduler/other)."""
44
+ out: Dict[str, str] = {}
45
+ for syms in _import_structure().values():
46
+ for s in syms:
47
+ out[s] = _classify(s)
48
+ return out
49
+
50
+
51
+ def _classify(name: str) -> str:
52
+ if name.endswith("Pipeline"):
53
+ return "pipeline"
54
+ if name.endswith("Scheduler") or name.endswith("SchedulerOutput"):
55
+ return "scheduler"
56
+ if name.endswith("Output"):
57
+ return "output"
58
+ if "Model" in name or name.startswith("Autoencoder") or name.endswith("Transformer"):
59
+ return "model"
60
+ return "other"
61
+
62
+
63
+ @lru_cache(maxsize=1)
64
+ def pipelines() -> List[str]:
65
+ return sorted(n for n, k in all_symbols().items() if k == "pipeline")
66
+
67
+
68
+ @lru_cache(maxsize=1)
69
+ def models() -> List[str]:
70
+ return sorted(n for n, k in all_symbols().items() if k == "model")
71
+
72
+
73
+ @lru_cache(maxsize=1)
74
+ def schedulers() -> List[str]:
75
+ return sorted(n for n, k in all_symbols().items() if k == "scheduler"
76
+ and not n.endswith("Output"))
77
+
78
+
79
+ # ───────────────────────── AutoPipeline task maps ─────────────────────────
80
+
81
+ @lru_cache(maxsize=1)
82
+ def auto_pipeline_tasks() -> Dict[str, Dict[str, str]]:
83
+ """diffusers' canonical task → {model-family: pipeline-class} maps.
84
+
85
+ These are the only first-class "tasks" diffusers ships: text2image,
86
+ image2image, inpainting. Returned as plain class-name strings.
87
+ """
88
+ from diffusers.pipelines.auto_pipeline import (
89
+ AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
90
+ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
91
+ AUTO_INPAINT_PIPELINES_MAPPING,
92
+ )
93
+
94
+ def _names(m):
95
+ return {family: cls.__name__ for family, cls in m.items()}
96
+
97
+ return {
98
+ "text-to-image": _names(AUTO_TEXT2IMAGE_PIPELINES_MAPPING),
99
+ "image-to-image": _names(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING),
100
+ "inpainting": _names(AUTO_INPAINT_PIPELINES_MAPPING),
101
+ }
102
+
103
+
104
+ # ───────────────────────── modality grouping (derived) ─────────────────────────
105
+
106
+ # Ordered (pattern → modality) rules applied to a pipeline class name. First match
107
+ # wins. Purely name-derived so new pipelines slot in without code changes.
108
+ _MODALITY_RULES = (
109
+ # Explicit transition names first (most specific).
110
+ (r"TextToImage|Text2Image", "text-to-image"),
111
+ (r"TextToVideo|Text2Video|TextToWorld", "text-to-video"),
112
+ (r"ImageToVideo|Image2Video", "image-to-video"),
113
+ (r"VideoToVideo|Video2Video", "video-to-video"),
114
+ (r"VideoToWorld|World", "video-to-world"),
115
+ (r"ImageToImage|Image2Image", "image-to-image"),
116
+ (r"Inpaint", "inpainting"),
117
+ (r"Upscale|SuperResolution", "upscaling"),
118
+ (r"TextToAudio|Text2Audio|Audio|Music|Speech|TTS|Bark|MusicGen|Stable[Aa]udio", "audio"),
119
+ # Image families that share a name-stem with a video family (e.g. HunyuanDiT,
120
+ # HunyuanImage) — classify as image BEFORE the broad video catch-all below.
121
+ (r"DiT|HunyuanImage|HunyuanDiT", "image"),
122
+ # Broad video/world model families (named after the architecture, not a task).
123
+ (r"Video|Animate|Cosmos|Wan|Hunyuan|Mochi|Allegro|LTX|SkyReels|CogVideo|Latte|Genie", "video"),
124
+ (r"ControlNet|Adapter|IP", "controlled-image"),
125
+ # Abbreviated transition tokens — MUST precede architecture-family rules so a
126
+ # family-named video pipeline (Kandinsky5I2V, *T2V) isn't grabbed as image.
127
+ (r"I2V", "image-to-video"),
128
+ (r"T2V", "text-to-video"),
129
+ (r"V2V", "video-to-video"),
130
+ # Transition variants on family-named pipelines (task encoded as a suffix).
131
+ (r"Img2Img|Image2Image", "image-to-image"),
132
+ (r"Fill", "inpainting"),
133
+ # NOTE: no bare "Edit" rule — editing can be image OR video (e.g. ChronoEdit
134
+ # is image-to-video). Let such names fall through to their family/Video rule.
135
+ # Architecture-named image-gen families (task implicit = text-to-image-class).
136
+ (r"StableDiffusion|Flux|CogView|Bria|Chroma|AuraFlow|Amused|Kandinsky|"
137
+ r"PixArt|Sana|Lumina|DeepFloyd|Wuerstchen|Kolors|HiDream|Janus|OmniGen|"
138
+ r"Marigold|VisualCloze", "image"),
139
+ # Architecture-named audio families.
140
+ (r"AudioLDM|StableAudio|MusicGen|Musicgen|AceStep|Dance|Spectrogram", "audio"),
141
+ # 3D / mesh generation (ShapE, etc.) — emits meshes, not images.
142
+ (r"ShapE|Shap[_-]?E|Mesh|TripoSR|Hunyuan3D|TRELLIS", "3d"),
143
+ (r"Image", "image"),
144
+ )
145
+
146
+
147
+ # Docstring phrasings → modality. Used ONLY as a fallback when name-rules yield
148
+ # "other". Rules stay authoritative (they encode WFM/inpaint precedence the docs
149
+ # lack); the doc parser just rescues the long tail. Ordered: most-specific first.
150
+ _DOC_MODALITY = (
151
+ ("text-to-3d", "3d"), ("image-to-3d", "3d"), ("text-to-mesh", "3d"),
152
+ ("image-to-image", "image-to-image"),
153
+ ("image-to-video", "image-to-video"),
154
+ ("video-to-video", "video-to-video"),
155
+ ("text-to-video", "text-to-video"),
156
+ ("text-to-image", "text-to-image"),
157
+ ("text-to-audio", "audio"), ("text-to-speech", "audio"),
158
+ ("super-resolution", "upscaling"),
159
+ ("inpainting", "inpainting"),
160
+ ("unconditional image", "image"), ("class-conditional image", "image"),
161
+ ("image generation", "image"), ("image synthesis", "image"),
162
+ ("video generation", "video"), ("audio generation", "audio"),
163
+ )
164
+
165
+ _DOC_PAT = re.compile(
166
+ r"[Pp]ipeline for ([\w\- ]+?) (?:generation|synthesis|using|based|with)", re.S)
167
+
168
+
169
+ @lru_cache(maxsize=512)
170
+ def _modality_from_doc(name: str) -> str:
171
+ """Best-effort modality from a pipeline class docstring. Cached; never raises."""
172
+ try:
173
+ cls = resolve_attr(name)
174
+ doc = (cls.__doc__ or "").strip().lower()
175
+ except Exception:
176
+ return "other"
177
+ if not doc:
178
+ return "other"
179
+ m = _DOC_PAT.search(doc)
180
+ phrase = m.group(1).strip() if m else doc[:80]
181
+ for key, mod in _DOC_MODALITY:
182
+ if key in phrase:
183
+ return mod
184
+ return "other"
185
+
186
+
187
+ def modality_of(pipeline_name: str, use_doc: bool = False) -> str:
188
+ """Derive a pipeline's modality. Name-rules are authoritative; if they yield
189
+ "other" and use_doc=True, fall back to docstring parsing (slower: imports the
190
+ class). Keep use_doc=False for hot paths / deterministic tests."""
191
+ for pat, mod in _MODALITY_RULES:
192
+ if re.search(pat, pipeline_name):
193
+ return mod
194
+ if use_doc:
195
+ return _modality_from_doc(pipeline_name)
196
+ return "other"
197
+
198
+
199
+ def tasks_by_modality(use_doc: bool = True) -> Dict[str, List[str]]:
200
+ groups: Dict[str, List[str]] = {}
201
+ for p in pipelines():
202
+ groups.setdefault(modality_of(p, use_doc=use_doc), []).append(p)
203
+ for v in groups.values():
204
+ v.sort()
205
+ return groups
206
+
207
+
208
+ # World-foundation-model / action-capable pipelines (the ones the agent cares
209
+ # about for robotics). Detected by name — Cosmos*, *World*, action-conditioned.
210
+ def world_foundation_models() -> List[str]:
211
+ return sorted(p for p in pipelines()
212
+ if re.search(r"Cosmos|World|Wan|Hunyuan|Genie", p))
213
+
214
+
215
+ # ───────────────────────── resolution & introspection ─────────────────────────
216
+
217
+ def resolve_attr(dotted: str, root_module: str = "diffusers") -> Any:
218
+ """Resolve a dotted path against diffusers (or a submodule).
219
+
220
+ Examples:
221
+ resolve_attr("StableDiffusionPipeline")
222
+ resolve_attr("DiffusionPipeline.from_pretrained")
223
+ resolve_attr("Cosmos3OmniPipeline") # from-source builds
224
+ resolve_attr("utils.export_to_video")
225
+ resolve_attr("schedulers.UniPCMultistepScheduler.from_config")
226
+ """
227
+ full = dotted if dotted.startswith(root_module + ".") else f"{root_module}.{dotted}"
228
+
229
+ try:
230
+ return importlib.import_module(full)
231
+ except ImportError:
232
+ pass
233
+
234
+ # Fast path: attribute(s) on the root module (diffusers uses lazy __getattr__
235
+ # that raises AttributeError, not ImportError, for non-module attrs).
236
+ try:
237
+ root = importlib.import_module(root_module)
238
+ obj = root
239
+ for attr in dotted.split("."):
240
+ obj = getattr(obj, attr)
241
+ return obj
242
+ except AttributeError:
243
+ pass
244
+
245
+ segments = full.split(".")
246
+ for i in range(len(segments), 0, -1):
247
+ try:
248
+ mod = importlib.import_module(".".join(segments[:i]))
249
+ except Exception:
250
+ continue
251
+ obj = mod
252
+ try:
253
+ for attr in segments[i:]:
254
+ obj = getattr(obj, attr)
255
+ return obj
256
+ except AttributeError:
257
+ break
258
+
259
+ root = importlib.import_module(root_module)
260
+ obj = root
261
+ for attr in dotted.split("."):
262
+ obj = getattr(obj, attr)
263
+ return obj
264
+
265
+
266
+ def pipeline_info(name: str) -> Optional[Dict[str, Any]]:
267
+ """Modality + __call__ signature for one pipeline class (lazily resolved)."""
268
+ if name not in all_symbols():
269
+ # Known from-source WFM/pipeline classes (e.g. Cosmos3OmniPipeline ships in
270
+ # diffusers>0.38 from source). Degrade gracefully instead of erroring like a
271
+ # typo would — the tool resolves these dynamically once the install has them.
272
+ if re.search(r"Pipeline$", name) and re.search(
273
+ r"Cosmos|World|Wan|Hunyuan|Genie|Omni", name):
274
+ return {
275
+ "name": name,
276
+ "kind": "pipeline",
277
+ "modality": modality_of(name),
278
+ "available": False,
279
+ "note": (f"'{name}' is not in this diffusers build "
280
+ "(likely a from-source >0.38 class). Install with: "
281
+ "pip install 'git+https://github.com/huggingface/diffusers' "
282
+ "— use_diffusers resolves it dynamically once present."),
283
+ }
284
+ return None
285
+ info: Dict[str, Any] = {
286
+ "name": name,
287
+ "kind": all_symbols()[name],
288
+ "modality": modality_of(name),
289
+ }
290
+ try:
291
+ cls = resolve_attr(name)
292
+ call = getattr(cls, "__call__", None)
293
+ if call is not None:
294
+ info["call_params"] = _sig_params(call)
295
+ fp = getattr(cls, "from_pretrained", None)
296
+ if fp is not None and getattr(fp, "__doc__", None):
297
+ info["from_pretrained_doc"] = fp.__doc__[:400]
298
+ if cls.__doc__:
299
+ info["doc"] = cls.__doc__[:600]
300
+ except Exception as e: # resolution may fail on from-source-only classes
301
+ info["note"] = f"class not resolvable in this diffusers build: {e}"
302
+ return info
303
+
304
+
305
+ def describe(obj: Any, max_doc: int = 600) -> Dict[str, Any]:
306
+ info: Dict[str, Any] = {
307
+ "kind": type(obj).__name__,
308
+ "name": getattr(obj, "__name__", str(obj)[:80]),
309
+ }
310
+ if inspect.isclass(obj):
311
+ info["methods"] = [
312
+ m for m in dir(obj)
313
+ if not m.startswith("_") and callable(getattr(obj, m, None))
314
+ ][:40]
315
+ for ctor in ("from_pretrained", "__call__", "__init__"):
316
+ fn = getattr(obj, ctor, None)
317
+ if fn is not None:
318
+ try:
319
+ info[f"{ctor}_params"] = _sig_params(fn)
320
+ if fn.__doc__:
321
+ info[f"{ctor}_doc"] = fn.__doc__[:max_doc]
322
+ except (ValueError, TypeError):
323
+ continue
324
+ elif callable(obj):
325
+ try:
326
+ info["params"] = _sig_params(obj)
327
+ except (ValueError, TypeError):
328
+ pass
329
+ if obj.__doc__:
330
+ info["doc"] = obj.__doc__[:max_doc]
331
+ elif inspect.ismodule(obj):
332
+ info["public"] = [n for n in dir(obj) if not n.startswith("_")][:50]
333
+ else:
334
+ info["value"] = str(obj)[:200]
335
+ return info
336
+
337
+
338
+ def _sig_params(fn: Any) -> Dict[str, Dict[str, Any]]:
339
+ sig = inspect.signature(fn)
340
+ return {
341
+ name: {
342
+ "default": ("REQUIRED" if p.default is inspect.Parameter.empty
343
+ else str(p.default)),
344
+ "annotation": (None if p.annotation is inspect.Parameter.empty
345
+ else str(p.annotation)),
346
+ }
347
+ for name, p in sig.parameters.items()
348
+ if name not in ("self", "args", "kwargs")
349
+ }
@@ -0,0 +1,256 @@
1
+ """Visualize robot ACTION chunks — turn raw action tensors into something you can SEE.
2
+
3
+ A world-foundation model (Cosmos) emits an action chunk of shape
4
+ `[num_chunks, T, action_dim]` in normalized action space. Numbers alone are
5
+ opaque, so this renders them three ways:
6
+
7
+ 1. time-series — every action dimension plotted over the chunk's timesteps
8
+ (joint / delta curves), with the gripper channel highlighted.
9
+ 2. trajectory — if the first 3 dims look like an end-effector position/delta,
10
+ the cumulative 3D path of the gripper through space.
11
+ 3. animation — an mp4/gif that sweeps a playhead across the time-series so you
12
+ can watch the action unfold, optionally side-by-side with the
13
+ generated world video frames.
14
+
15
+ All outputs are written to ARTIFACT_DIR and returned as paths.
16
+
17
+ Design notes:
18
+ - We DON'T hardcode an embodiment. action_dim varies (7-DoF arm, 10-DoF, etc.).
19
+ We label dims generically and treat the LAST dim as the gripper by convention
20
+ (override with gripper_index=None to disable). The first 3 dims are *optionally*
21
+ interpreted as an end-effector position for the 3D path (interpret_xyz).
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import time
27
+ from pathlib import Path
28
+ from typing import Any, List, Optional
29
+
30
+ from strands_diffusers.core.io import ARTIFACT_DIR
31
+
32
+
33
+ def _as_chunks(action: Any):
34
+ """Normalize an action payload to a numpy array [num_chunks, T, action_dim]."""
35
+ import numpy as np
36
+
37
+ # Accept: serialized dict {"data": [...]}, raw nested list, tensor, ndarray,
38
+ # OR a JSON string (LLM tool-calls serialize list inputs to a string).
39
+ if isinstance(action, str):
40
+ import json
41
+ try:
42
+ action = json.loads(action)
43
+ except (ValueError, TypeError):
44
+ raise ValueError(
45
+ "action string is not valid JSON; pass a nested list "
46
+ "[num_chunks, T, action_dim], a .json path, or cached:key")
47
+ if isinstance(action, dict):
48
+ action = action.get("data", action)
49
+ a = np.asarray(action, dtype=float)
50
+ if a.ndim == 1: # [action_dim] → one chunk, one step
51
+ a = a[None, None, :]
52
+ elif a.ndim == 2: # [T, action_dim] → one chunk
53
+ a = a[None, :, :]
54
+ return a
55
+
56
+
57
+ def visualize_action(
58
+ action: Any,
59
+ save_prefix: str = "action",
60
+ interpret_xyz: bool = True,
61
+ gripper_index: Optional[int] = -1,
62
+ cumulative_xyz: bool = True,
63
+ world_video: Optional[str] = None,
64
+ fps: int = 5,
65
+ dim_labels: Optional[List[str]] = None,
66
+ ) -> dict:
67
+ """Render an action chunk to plots + an animation. Returns artifact paths.
68
+
69
+ Args:
70
+ action: action payload (serialized dict, nested list, ndarray, or tensor),
71
+ shape [num_chunks, T, action_dim] (lower ranks are promoted).
72
+ save_prefix: filename prefix for artifacts.
73
+ interpret_xyz: if action_dim >= 3, draw a 3D path from dims 0-2.
74
+ gripper_index: dim treated as gripper (highlighted); None to disable.
75
+ cumulative_xyz: treat xyz dims as deltas and integrate into a path; if
76
+ False, plot them as absolute positions.
77
+ world_video: optional path to the generated world .mp4 to show beside the
78
+ action animation (frame-synced as best as frame counts allow).
79
+ fps: animation frames-per-second.
80
+ dim_labels: optional human labels per dimension.
81
+
82
+ Returns:
83
+ {"artifacts": [...], "summary": {...}}
84
+ """
85
+ import matplotlib
86
+ matplotlib.use("Agg")
87
+ import matplotlib.pyplot as plt
88
+ import numpy as np
89
+
90
+ chunks = _as_chunks(action)
91
+ num_chunks, T, dim = chunks.shape
92
+ # Flatten chunks end-to-end into a single timeline for plotting.
93
+ flat = chunks.reshape(num_chunks * T, dim)
94
+ steps = np.arange(flat.shape[0])
95
+
96
+ labels = dim_labels or [f"dim{i}" for i in range(dim)]
97
+ g_idx = (gripper_index % dim) if (gripper_index is not None) else None
98
+
99
+ artifacts: List[str] = []
100
+ from strands_diffusers.core.io import _stamp
101
+ ts = _stamp()
102
+
103
+ # ── 1. time-series ──────────────────────────────────────────────
104
+ fig, ax = plt.subplots(figsize=(10, 5))
105
+ for i in range(dim):
106
+ is_grip = (i == g_idx)
107
+ ax.plot(steps, flat[:, i],
108
+ label=("gripper" if is_grip else labels[i]),
109
+ lw=2.4 if is_grip else 1.3,
110
+ color="black" if is_grip else None,
111
+ ls="--" if is_grip else "-",
112
+ alpha=1.0 if is_grip else 0.85)
113
+ for c in range(1, num_chunks):
114
+ ax.axvline(c * T - 0.5, color="gray", ls=":", alpha=0.4)
115
+ ax.set_xlabel("timestep")
116
+ ax.set_ylabel("normalized action value")
117
+ ax.set_title(f"Action chunk — {num_chunks}×{T} steps × {dim} dims")
118
+ ax.legend(loc="upper right", fontsize=8, ncol=2)
119
+ ax.grid(alpha=0.25)
120
+ p_ts = ARTIFACT_DIR / f"{save_prefix}_timeseries_{ts}.png"
121
+ fig.tight_layout()
122
+ fig.savefig(p_ts, dpi=110)
123
+ plt.close(fig)
124
+ artifacts.append(str(p_ts))
125
+
126
+ # ── 2. 3D end-effector path ─────────────────────────────────────
127
+ path_xyz = None
128
+ if interpret_xyz and dim >= 3:
129
+ xyz = flat[:, :3].copy()
130
+ if cumulative_xyz:
131
+ xyz = np.cumsum(xyz, axis=0) # treat as deltas → integrated trajectory
132
+ path_xyz = xyz
133
+ fig = plt.figure(figsize=(6, 5.5))
134
+ ax = fig.add_subplot(111, projection="3d")
135
+ ax.plot(xyz[:, 0], xyz[:, 1], xyz[:, 2], "-o", ms=3, lw=1.5)
136
+ ax.scatter(*xyz[0], c="green", s=60, label="start")
137
+ ax.scatter(*xyz[-1], c="red", s=60, label="end")
138
+ # mark gripper close/open events along the path if we have a gripper channel
139
+ if g_idx is not None:
140
+ g = flat[:, g_idx]
141
+ thr = (g.max() + g.min()) / 2.0
142
+ closed = g > thr
143
+ if closed.any():
144
+ ax.scatter(xyz[closed, 0], xyz[closed, 1], xyz[closed, 2],
145
+ c="orange", s=18, alpha=0.7, label="gripper engaged")
146
+ ax.set_xlabel("x"); ax.set_ylabel("y"); ax.set_zlabel("z")
147
+ ax.set_title("End-effector path"
148
+ + (" (∫ deltas)" if cumulative_xyz else " (absolute)"))
149
+ ax.legend(fontsize=8)
150
+ p_traj = ARTIFACT_DIR / f"{save_prefix}_trajectory_{ts}.png"
151
+ fig.tight_layout()
152
+ fig.savefig(p_traj, dpi=110)
153
+ plt.close(fig)
154
+ artifacts.append(str(p_traj))
155
+
156
+ # ── 3. animation (playhead sweep, optional world video beside it) ──
157
+ anim_path = _animate(flat, labels, g_idx, path_xyz, world_video, fps,
158
+ save_prefix, ts)
159
+ if anim_path:
160
+ artifacts.append(anim_path)
161
+
162
+ summary = {
163
+ "num_chunks": int(num_chunks),
164
+ "timesteps_per_chunk": int(T),
165
+ "action_dim": int(dim),
166
+ "gripper_index": g_idx,
167
+ "value_range": [float(flat.min()), float(flat.max())],
168
+ "has_3d_path": path_xyz is not None,
169
+ }
170
+ return {"artifacts": artifacts, "summary": summary}
171
+
172
+
173
+ def _animate(flat, labels, g_idx, path_xyz, world_video, fps, prefix, ts):
174
+ """Build an mp4/gif sweeping a playhead across the action curves."""
175
+ import matplotlib
176
+ matplotlib.use("Agg")
177
+ import matplotlib.pyplot as plt
178
+ import numpy as np
179
+
180
+ n = flat.shape[0]
181
+ dim = flat.shape[1]
182
+
183
+ # Optionally load world video frames to show alongside.
184
+ world_frames = None
185
+ if world_video:
186
+ try:
187
+ import imageio.v3 as iio
188
+ world_frames = iio.imread(world_video) # [F,H,W,C]
189
+ except Exception:
190
+ world_frames = None
191
+
192
+ have_world = world_frames is not None and len(world_frames) > 0
193
+ have_3d = path_xyz is not None
194
+
195
+ ncols = 1 + int(have_world) + int(have_3d)
196
+ frames_out = []
197
+
198
+ for k in range(n):
199
+ fig = plt.figure(figsize=(5 * ncols, 4.2))
200
+ col = 1
201
+
202
+ # panel A: action curves with playhead
203
+ axc = fig.add_subplot(1, ncols, col); col += 1
204
+ for i in range(dim):
205
+ is_grip = (i == g_idx)
206
+ axc.plot(flat[:, i], lw=2.2 if is_grip else 1.0,
207
+ color="black" if is_grip else None,
208
+ ls="--" if is_grip else "-", alpha=0.9 if is_grip else 0.7)
209
+ axc.axvline(k, color="red", lw=2)
210
+ axc.set_title("action")
211
+ axc.set_xlabel("timestep"); axc.grid(alpha=0.2)
212
+
213
+ # panel B: world frame (synced by proportion)
214
+ if have_world:
215
+ axw = fig.add_subplot(1, ncols, col); col += 1
216
+ fi = min(int(k / max(n - 1, 1) * (len(world_frames) - 1)),
217
+ len(world_frames) - 1)
218
+ axw.imshow(world_frames[fi])
219
+ axw.set_title(f"world frame {fi}")
220
+ axw.axis("off")
221
+
222
+ # panel C: 3D path with current point
223
+ if have_3d:
224
+ ax3 = fig.add_subplot(1, ncols, col, projection="3d"); col += 1
225
+ ax3.plot(path_xyz[:, 0], path_xyz[:, 1], path_xyz[:, 2],
226
+ "-", lw=1, alpha=0.5)
227
+ ax3.scatter(*path_xyz[k], c="red", s=50)
228
+ ax3.set_title("end-effector"); ax3.set_xticks([]); ax3.set_yticks([])
229
+ ax3.set_zticks([])
230
+
231
+ fig.tight_layout()
232
+ fig.canvas.draw()
233
+ buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
234
+ w, h = fig.canvas.get_width_height()
235
+ frames_out.append(buf.reshape(h, w, 4)[..., :3].copy())
236
+ plt.close(fig)
237
+
238
+ if not frames_out:
239
+ return None
240
+
241
+ out_mp4 = ARTIFACT_DIR / f"{prefix}_animation_{ts}.mp4"
242
+ try:
243
+ import imageio.v3 as iio
244
+ iio.imwrite(str(out_mp4), np.stack(frames_out), fps=fps, codec="libx264")
245
+ return str(out_mp4)
246
+ except Exception:
247
+ pass
248
+ try:
249
+ from PIL import Image
250
+ out_gif = ARTIFACT_DIR / f"{prefix}_animation_{ts}.gif"
251
+ imgs = [Image.fromarray(f) for f in frames_out]
252
+ imgs[0].save(str(out_gif), save_all=True, append_images=imgs[1:],
253
+ duration=int(1000 / max(fps, 1)), loop=0)
254
+ return str(out_gif)
255
+ except Exception:
256
+ return None
@@ -0,0 +1,4 @@
1
+ """Tools for strands-diffusers."""
2
+ from strands_diffusers.tools.use_diffusers import use_diffusers
3
+
4
+ __all__ = ["use_diffusers"]