strands-diffusers 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strands_diffusers/__init__.py +41 -0
- strands_diffusers/_version.py +24 -0
- strands_diffusers/core/__init__.py +4 -0
- strands_diffusers/core/engine.py +163 -0
- strands_diffusers/core/io.py +552 -0
- strands_diffusers/core/registry.py +349 -0
- strands_diffusers/core/viz.py +256 -0
- strands_diffusers/tools/__init__.py +4 -0
- strands_diffusers/tools/use_diffusers.py +420 -0
- strands_diffusers-0.1.0.dist-info/METADATA +199 -0
- strands_diffusers-0.1.0.dist-info/RECORD +13 -0
- strands_diffusers-0.1.0.dist-info/WHEEL +5 -0
- strands_diffusers-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
"""Dynamic pipeline/model/scheduler registry — 100% diffusers coverage, zero hardcoding.
|
|
2
|
+
|
|
3
|
+
The single source of truth is diffusers' own `_import_structure` — the lazy map of
|
|
4
|
+
every public symbol it exposes (307 pipelines, 87 models, 54 schedulers as of
|
|
5
|
+
0.38). We read it at runtime, so when diffusers adds a new pipeline (e.g. a fresh
|
|
6
|
+
Cosmos world-foundation model), strands-diffusers supports it automatically — no
|
|
7
|
+
code change required.
|
|
8
|
+
|
|
9
|
+
Same philosophy as `use_aws` (wraps boto3 dynamically), `use_lerobot` (wraps
|
|
10
|
+
lerobot) and `use_transformers` (wraps the transformers task taxonomy): discover,
|
|
11
|
+
don't hardcode.
|
|
12
|
+
|
|
13
|
+
Diffusers has no single "task taxonomy" like transformers' SUPPORTED_TASKS, so we
|
|
14
|
+
derive structure from three places, all dynamic:
|
|
15
|
+
|
|
16
|
+
1. `diffusers._import_structure` → every public class, grouped by submodule.
|
|
17
|
+
2. The `AutoPipelineFor*` mappings → the canonical task → pipeline-class maps
|
|
18
|
+
(text2image / image2image / inpainting), diffusers' closest thing to tasks.
|
|
19
|
+
3. Class-name heuristics → group the long tail of pipelines by the
|
|
20
|
+
modality their name implies (TextToImage / ImageToVideo / VideoToWorld / ...).
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import importlib
|
|
26
|
+
import inspect
|
|
27
|
+
import re
|
|
28
|
+
from functools import lru_cache
|
|
29
|
+
from typing import Any, Dict, List, Optional
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@lru_cache(maxsize=1)
|
|
33
|
+
def _import_structure() -> Dict[str, List[str]]:
|
|
34
|
+
import diffusers
|
|
35
|
+
|
|
36
|
+
ist = getattr(diffusers, "_import_structure", None) or {}
|
|
37
|
+
# Flatten dotted keys but keep the group name (submodule) for context.
|
|
38
|
+
return {k: list(v) for k, v in ist.items()}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@lru_cache(maxsize=1)
|
|
42
|
+
def all_symbols() -> Dict[str, str]:
|
|
43
|
+
"""Every public diffusers symbol → its kind (pipeline/model/scheduler/other)."""
|
|
44
|
+
out: Dict[str, str] = {}
|
|
45
|
+
for syms in _import_structure().values():
|
|
46
|
+
for s in syms:
|
|
47
|
+
out[s] = _classify(s)
|
|
48
|
+
return out
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _classify(name: str) -> str:
|
|
52
|
+
if name.endswith("Pipeline"):
|
|
53
|
+
return "pipeline"
|
|
54
|
+
if name.endswith("Scheduler") or name.endswith("SchedulerOutput"):
|
|
55
|
+
return "scheduler"
|
|
56
|
+
if name.endswith("Output"):
|
|
57
|
+
return "output"
|
|
58
|
+
if "Model" in name or name.startswith("Autoencoder") or name.endswith("Transformer"):
|
|
59
|
+
return "model"
|
|
60
|
+
return "other"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@lru_cache(maxsize=1)
|
|
64
|
+
def pipelines() -> List[str]:
|
|
65
|
+
return sorted(n for n, k in all_symbols().items() if k == "pipeline")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@lru_cache(maxsize=1)
|
|
69
|
+
def models() -> List[str]:
|
|
70
|
+
return sorted(n for n, k in all_symbols().items() if k == "model")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@lru_cache(maxsize=1)
|
|
74
|
+
def schedulers() -> List[str]:
|
|
75
|
+
return sorted(n for n, k in all_symbols().items() if k == "scheduler"
|
|
76
|
+
and not n.endswith("Output"))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# ───────────────────────── AutoPipeline task maps ─────────────────────────
|
|
80
|
+
|
|
81
|
+
@lru_cache(maxsize=1)
|
|
82
|
+
def auto_pipeline_tasks() -> Dict[str, Dict[str, str]]:
|
|
83
|
+
"""diffusers' canonical task → {model-family: pipeline-class} maps.
|
|
84
|
+
|
|
85
|
+
These are the only first-class "tasks" diffusers ships: text2image,
|
|
86
|
+
image2image, inpainting. Returned as plain class-name strings.
|
|
87
|
+
"""
|
|
88
|
+
from diffusers.pipelines.auto_pipeline import (
|
|
89
|
+
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
|
|
90
|
+
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
|
|
91
|
+
AUTO_INPAINT_PIPELINES_MAPPING,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def _names(m):
|
|
95
|
+
return {family: cls.__name__ for family, cls in m.items()}
|
|
96
|
+
|
|
97
|
+
return {
|
|
98
|
+
"text-to-image": _names(AUTO_TEXT2IMAGE_PIPELINES_MAPPING),
|
|
99
|
+
"image-to-image": _names(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING),
|
|
100
|
+
"inpainting": _names(AUTO_INPAINT_PIPELINES_MAPPING),
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# ───────────────────────── modality grouping (derived) ─────────────────────────
|
|
105
|
+
|
|
106
|
+
# Ordered (pattern → modality) rules applied to a pipeline class name. First match
|
|
107
|
+
# wins. Purely name-derived so new pipelines slot in without code changes.
|
|
108
|
+
_MODALITY_RULES = (
|
|
109
|
+
# Explicit transition names first (most specific).
|
|
110
|
+
(r"TextToImage|Text2Image", "text-to-image"),
|
|
111
|
+
(r"TextToVideo|Text2Video|TextToWorld", "text-to-video"),
|
|
112
|
+
(r"ImageToVideo|Image2Video", "image-to-video"),
|
|
113
|
+
(r"VideoToVideo|Video2Video", "video-to-video"),
|
|
114
|
+
(r"VideoToWorld|World", "video-to-world"),
|
|
115
|
+
(r"ImageToImage|Image2Image", "image-to-image"),
|
|
116
|
+
(r"Inpaint", "inpainting"),
|
|
117
|
+
(r"Upscale|SuperResolution", "upscaling"),
|
|
118
|
+
(r"TextToAudio|Text2Audio|Audio|Music|Speech|TTS|Bark|MusicGen|Stable[Aa]udio", "audio"),
|
|
119
|
+
# Image families that share a name-stem with a video family (e.g. HunyuanDiT,
|
|
120
|
+
# HunyuanImage) — classify as image BEFORE the broad video catch-all below.
|
|
121
|
+
(r"DiT|HunyuanImage|HunyuanDiT", "image"),
|
|
122
|
+
# Broad video/world model families (named after the architecture, not a task).
|
|
123
|
+
(r"Video|Animate|Cosmos|Wan|Hunyuan|Mochi|Allegro|LTX|SkyReels|CogVideo|Latte|Genie", "video"),
|
|
124
|
+
(r"ControlNet|Adapter|IP", "controlled-image"),
|
|
125
|
+
# Abbreviated transition tokens — MUST precede architecture-family rules so a
|
|
126
|
+
# family-named video pipeline (Kandinsky5I2V, *T2V) isn't grabbed as image.
|
|
127
|
+
(r"I2V", "image-to-video"),
|
|
128
|
+
(r"T2V", "text-to-video"),
|
|
129
|
+
(r"V2V", "video-to-video"),
|
|
130
|
+
# Transition variants on family-named pipelines (task encoded as a suffix).
|
|
131
|
+
(r"Img2Img|Image2Image", "image-to-image"),
|
|
132
|
+
(r"Fill", "inpainting"),
|
|
133
|
+
# NOTE: no bare "Edit" rule — editing can be image OR video (e.g. ChronoEdit
|
|
134
|
+
# is image-to-video). Let such names fall through to their family/Video rule.
|
|
135
|
+
# Architecture-named image-gen families (task implicit = text-to-image-class).
|
|
136
|
+
(r"StableDiffusion|Flux|CogView|Bria|Chroma|AuraFlow|Amused|Kandinsky|"
|
|
137
|
+
r"PixArt|Sana|Lumina|DeepFloyd|Wuerstchen|Kolors|HiDream|Janus|OmniGen|"
|
|
138
|
+
r"Marigold|VisualCloze", "image"),
|
|
139
|
+
# Architecture-named audio families.
|
|
140
|
+
(r"AudioLDM|StableAudio|MusicGen|Musicgen|AceStep|Dance|Spectrogram", "audio"),
|
|
141
|
+
# 3D / mesh generation (ShapE, etc.) — emits meshes, not images.
|
|
142
|
+
(r"ShapE|Shap[_-]?E|Mesh|TripoSR|Hunyuan3D|TRELLIS", "3d"),
|
|
143
|
+
(r"Image", "image"),
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
# Docstring phrasings → modality. Used ONLY as a fallback when name-rules yield
|
|
148
|
+
# "other". Rules stay authoritative (they encode WFM/inpaint precedence the docs
|
|
149
|
+
# lack); the doc parser just rescues the long tail. Ordered: most-specific first.
|
|
150
|
+
_DOC_MODALITY = (
|
|
151
|
+
("text-to-3d", "3d"), ("image-to-3d", "3d"), ("text-to-mesh", "3d"),
|
|
152
|
+
("image-to-image", "image-to-image"),
|
|
153
|
+
("image-to-video", "image-to-video"),
|
|
154
|
+
("video-to-video", "video-to-video"),
|
|
155
|
+
("text-to-video", "text-to-video"),
|
|
156
|
+
("text-to-image", "text-to-image"),
|
|
157
|
+
("text-to-audio", "audio"), ("text-to-speech", "audio"),
|
|
158
|
+
("super-resolution", "upscaling"),
|
|
159
|
+
("inpainting", "inpainting"),
|
|
160
|
+
("unconditional image", "image"), ("class-conditional image", "image"),
|
|
161
|
+
("image generation", "image"), ("image synthesis", "image"),
|
|
162
|
+
("video generation", "video"), ("audio generation", "audio"),
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
_DOC_PAT = re.compile(
|
|
166
|
+
r"[Pp]ipeline for ([\w\- ]+?) (?:generation|synthesis|using|based|with)", re.S)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@lru_cache(maxsize=512)
|
|
170
|
+
def _modality_from_doc(name: str) -> str:
|
|
171
|
+
"""Best-effort modality from a pipeline class docstring. Cached; never raises."""
|
|
172
|
+
try:
|
|
173
|
+
cls = resolve_attr(name)
|
|
174
|
+
doc = (cls.__doc__ or "").strip().lower()
|
|
175
|
+
except Exception:
|
|
176
|
+
return "other"
|
|
177
|
+
if not doc:
|
|
178
|
+
return "other"
|
|
179
|
+
m = _DOC_PAT.search(doc)
|
|
180
|
+
phrase = m.group(1).strip() if m else doc[:80]
|
|
181
|
+
for key, mod in _DOC_MODALITY:
|
|
182
|
+
if key in phrase:
|
|
183
|
+
return mod
|
|
184
|
+
return "other"
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def modality_of(pipeline_name: str, use_doc: bool = False) -> str:
|
|
188
|
+
"""Derive a pipeline's modality. Name-rules are authoritative; if they yield
|
|
189
|
+
"other" and use_doc=True, fall back to docstring parsing (slower: imports the
|
|
190
|
+
class). Keep use_doc=False for hot paths / deterministic tests."""
|
|
191
|
+
for pat, mod in _MODALITY_RULES:
|
|
192
|
+
if re.search(pat, pipeline_name):
|
|
193
|
+
return mod
|
|
194
|
+
if use_doc:
|
|
195
|
+
return _modality_from_doc(pipeline_name)
|
|
196
|
+
return "other"
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def tasks_by_modality(use_doc: bool = True) -> Dict[str, List[str]]:
|
|
200
|
+
groups: Dict[str, List[str]] = {}
|
|
201
|
+
for p in pipelines():
|
|
202
|
+
groups.setdefault(modality_of(p, use_doc=use_doc), []).append(p)
|
|
203
|
+
for v in groups.values():
|
|
204
|
+
v.sort()
|
|
205
|
+
return groups
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
# World-foundation-model / action-capable pipelines (the ones the agent cares
|
|
209
|
+
# about for robotics). Detected by name — Cosmos*, *World*, action-conditioned.
|
|
210
|
+
def world_foundation_models() -> List[str]:
|
|
211
|
+
return sorted(p for p in pipelines()
|
|
212
|
+
if re.search(r"Cosmos|World|Wan|Hunyuan|Genie", p))
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# ───────────────────────── resolution & introspection ─────────────────────────
|
|
216
|
+
|
|
217
|
+
def resolve_attr(dotted: str, root_module: str = "diffusers") -> Any:
|
|
218
|
+
"""Resolve a dotted path against diffusers (or a submodule).
|
|
219
|
+
|
|
220
|
+
Examples:
|
|
221
|
+
resolve_attr("StableDiffusionPipeline")
|
|
222
|
+
resolve_attr("DiffusionPipeline.from_pretrained")
|
|
223
|
+
resolve_attr("Cosmos3OmniPipeline") # from-source builds
|
|
224
|
+
resolve_attr("utils.export_to_video")
|
|
225
|
+
resolve_attr("schedulers.UniPCMultistepScheduler.from_config")
|
|
226
|
+
"""
|
|
227
|
+
full = dotted if dotted.startswith(root_module + ".") else f"{root_module}.{dotted}"
|
|
228
|
+
|
|
229
|
+
try:
|
|
230
|
+
return importlib.import_module(full)
|
|
231
|
+
except ImportError:
|
|
232
|
+
pass
|
|
233
|
+
|
|
234
|
+
# Fast path: attribute(s) on the root module (diffusers uses lazy __getattr__
|
|
235
|
+
# that raises AttributeError, not ImportError, for non-module attrs).
|
|
236
|
+
try:
|
|
237
|
+
root = importlib.import_module(root_module)
|
|
238
|
+
obj = root
|
|
239
|
+
for attr in dotted.split("."):
|
|
240
|
+
obj = getattr(obj, attr)
|
|
241
|
+
return obj
|
|
242
|
+
except AttributeError:
|
|
243
|
+
pass
|
|
244
|
+
|
|
245
|
+
segments = full.split(".")
|
|
246
|
+
for i in range(len(segments), 0, -1):
|
|
247
|
+
try:
|
|
248
|
+
mod = importlib.import_module(".".join(segments[:i]))
|
|
249
|
+
except Exception:
|
|
250
|
+
continue
|
|
251
|
+
obj = mod
|
|
252
|
+
try:
|
|
253
|
+
for attr in segments[i:]:
|
|
254
|
+
obj = getattr(obj, attr)
|
|
255
|
+
return obj
|
|
256
|
+
except AttributeError:
|
|
257
|
+
break
|
|
258
|
+
|
|
259
|
+
root = importlib.import_module(root_module)
|
|
260
|
+
obj = root
|
|
261
|
+
for attr in dotted.split("."):
|
|
262
|
+
obj = getattr(obj, attr)
|
|
263
|
+
return obj
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def pipeline_info(name: str) -> Optional[Dict[str, Any]]:
|
|
267
|
+
"""Modality + __call__ signature for one pipeline class (lazily resolved)."""
|
|
268
|
+
if name not in all_symbols():
|
|
269
|
+
# Known from-source WFM/pipeline classes (e.g. Cosmos3OmniPipeline ships in
|
|
270
|
+
# diffusers>0.38 from source). Degrade gracefully instead of erroring like a
|
|
271
|
+
# typo would — the tool resolves these dynamically once the install has them.
|
|
272
|
+
if re.search(r"Pipeline$", name) and re.search(
|
|
273
|
+
r"Cosmos|World|Wan|Hunyuan|Genie|Omni", name):
|
|
274
|
+
return {
|
|
275
|
+
"name": name,
|
|
276
|
+
"kind": "pipeline",
|
|
277
|
+
"modality": modality_of(name),
|
|
278
|
+
"available": False,
|
|
279
|
+
"note": (f"'{name}' is not in this diffusers build "
|
|
280
|
+
"(likely a from-source >0.38 class). Install with: "
|
|
281
|
+
"pip install 'git+https://github.com/huggingface/diffusers' "
|
|
282
|
+
"— use_diffusers resolves it dynamically once present."),
|
|
283
|
+
}
|
|
284
|
+
return None
|
|
285
|
+
info: Dict[str, Any] = {
|
|
286
|
+
"name": name,
|
|
287
|
+
"kind": all_symbols()[name],
|
|
288
|
+
"modality": modality_of(name),
|
|
289
|
+
}
|
|
290
|
+
try:
|
|
291
|
+
cls = resolve_attr(name)
|
|
292
|
+
call = getattr(cls, "__call__", None)
|
|
293
|
+
if call is not None:
|
|
294
|
+
info["call_params"] = _sig_params(call)
|
|
295
|
+
fp = getattr(cls, "from_pretrained", None)
|
|
296
|
+
if fp is not None and getattr(fp, "__doc__", None):
|
|
297
|
+
info["from_pretrained_doc"] = fp.__doc__[:400]
|
|
298
|
+
if cls.__doc__:
|
|
299
|
+
info["doc"] = cls.__doc__[:600]
|
|
300
|
+
except Exception as e: # resolution may fail on from-source-only classes
|
|
301
|
+
info["note"] = f"class not resolvable in this diffusers build: {e}"
|
|
302
|
+
return info
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def describe(obj: Any, max_doc: int = 600) -> Dict[str, Any]:
|
|
306
|
+
info: Dict[str, Any] = {
|
|
307
|
+
"kind": type(obj).__name__,
|
|
308
|
+
"name": getattr(obj, "__name__", str(obj)[:80]),
|
|
309
|
+
}
|
|
310
|
+
if inspect.isclass(obj):
|
|
311
|
+
info["methods"] = [
|
|
312
|
+
m for m in dir(obj)
|
|
313
|
+
if not m.startswith("_") and callable(getattr(obj, m, None))
|
|
314
|
+
][:40]
|
|
315
|
+
for ctor in ("from_pretrained", "__call__", "__init__"):
|
|
316
|
+
fn = getattr(obj, ctor, None)
|
|
317
|
+
if fn is not None:
|
|
318
|
+
try:
|
|
319
|
+
info[f"{ctor}_params"] = _sig_params(fn)
|
|
320
|
+
if fn.__doc__:
|
|
321
|
+
info[f"{ctor}_doc"] = fn.__doc__[:max_doc]
|
|
322
|
+
except (ValueError, TypeError):
|
|
323
|
+
continue
|
|
324
|
+
elif callable(obj):
|
|
325
|
+
try:
|
|
326
|
+
info["params"] = _sig_params(obj)
|
|
327
|
+
except (ValueError, TypeError):
|
|
328
|
+
pass
|
|
329
|
+
if obj.__doc__:
|
|
330
|
+
info["doc"] = obj.__doc__[:max_doc]
|
|
331
|
+
elif inspect.ismodule(obj):
|
|
332
|
+
info["public"] = [n for n in dir(obj) if not n.startswith("_")][:50]
|
|
333
|
+
else:
|
|
334
|
+
info["value"] = str(obj)[:200]
|
|
335
|
+
return info
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _sig_params(fn: Any) -> Dict[str, Dict[str, Any]]:
|
|
339
|
+
sig = inspect.signature(fn)
|
|
340
|
+
return {
|
|
341
|
+
name: {
|
|
342
|
+
"default": ("REQUIRED" if p.default is inspect.Parameter.empty
|
|
343
|
+
else str(p.default)),
|
|
344
|
+
"annotation": (None if p.annotation is inspect.Parameter.empty
|
|
345
|
+
else str(p.annotation)),
|
|
346
|
+
}
|
|
347
|
+
for name, p in sig.parameters.items()
|
|
348
|
+
if name not in ("self", "args", "kwargs")
|
|
349
|
+
}
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""Visualize robot ACTION chunks — turn raw action tensors into something you can SEE.
|
|
2
|
+
|
|
3
|
+
A world-foundation model (Cosmos) emits an action chunk of shape
|
|
4
|
+
`[num_chunks, T, action_dim]` in normalized action space. Numbers alone are
|
|
5
|
+
opaque, so this renders them three ways:
|
|
6
|
+
|
|
7
|
+
1. time-series — every action dimension plotted over the chunk's timesteps
|
|
8
|
+
(joint / delta curves), with the gripper channel highlighted.
|
|
9
|
+
2. trajectory — if the first 3 dims look like an end-effector position/delta,
|
|
10
|
+
the cumulative 3D path of the gripper through space.
|
|
11
|
+
3. animation — an mp4/gif that sweeps a playhead across the time-series so you
|
|
12
|
+
can watch the action unfold, optionally side-by-side with the
|
|
13
|
+
generated world video frames.
|
|
14
|
+
|
|
15
|
+
All outputs are written to ARTIFACT_DIR and returned as paths.
|
|
16
|
+
|
|
17
|
+
Design notes:
|
|
18
|
+
- We DON'T hardcode an embodiment. action_dim varies (7-DoF arm, 10-DoF, etc.).
|
|
19
|
+
We label dims generically and treat the LAST dim as the gripper by convention
|
|
20
|
+
(override with gripper_index=None to disable). The first 3 dims are *optionally*
|
|
21
|
+
interpreted as an end-effector position for the 3D path (interpret_xyz).
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
import time
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
from typing import Any, List, Optional
|
|
29
|
+
|
|
30
|
+
from strands_diffusers.core.io import ARTIFACT_DIR
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _as_chunks(action: Any):
|
|
34
|
+
"""Normalize an action payload to a numpy array [num_chunks, T, action_dim]."""
|
|
35
|
+
import numpy as np
|
|
36
|
+
|
|
37
|
+
# Accept: serialized dict {"data": [...]}, raw nested list, tensor, ndarray,
|
|
38
|
+
# OR a JSON string (LLM tool-calls serialize list inputs to a string).
|
|
39
|
+
if isinstance(action, str):
|
|
40
|
+
import json
|
|
41
|
+
try:
|
|
42
|
+
action = json.loads(action)
|
|
43
|
+
except (ValueError, TypeError):
|
|
44
|
+
raise ValueError(
|
|
45
|
+
"action string is not valid JSON; pass a nested list "
|
|
46
|
+
"[num_chunks, T, action_dim], a .json path, or cached:key")
|
|
47
|
+
if isinstance(action, dict):
|
|
48
|
+
action = action.get("data", action)
|
|
49
|
+
a = np.asarray(action, dtype=float)
|
|
50
|
+
if a.ndim == 1: # [action_dim] → one chunk, one step
|
|
51
|
+
a = a[None, None, :]
|
|
52
|
+
elif a.ndim == 2: # [T, action_dim] → one chunk
|
|
53
|
+
a = a[None, :, :]
|
|
54
|
+
return a
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def visualize_action(
|
|
58
|
+
action: Any,
|
|
59
|
+
save_prefix: str = "action",
|
|
60
|
+
interpret_xyz: bool = True,
|
|
61
|
+
gripper_index: Optional[int] = -1,
|
|
62
|
+
cumulative_xyz: bool = True,
|
|
63
|
+
world_video: Optional[str] = None,
|
|
64
|
+
fps: int = 5,
|
|
65
|
+
dim_labels: Optional[List[str]] = None,
|
|
66
|
+
) -> dict:
|
|
67
|
+
"""Render an action chunk to plots + an animation. Returns artifact paths.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
action: action payload (serialized dict, nested list, ndarray, or tensor),
|
|
71
|
+
shape [num_chunks, T, action_dim] (lower ranks are promoted).
|
|
72
|
+
save_prefix: filename prefix for artifacts.
|
|
73
|
+
interpret_xyz: if action_dim >= 3, draw a 3D path from dims 0-2.
|
|
74
|
+
gripper_index: dim treated as gripper (highlighted); None to disable.
|
|
75
|
+
cumulative_xyz: treat xyz dims as deltas and integrate into a path; if
|
|
76
|
+
False, plot them as absolute positions.
|
|
77
|
+
world_video: optional path to the generated world .mp4 to show beside the
|
|
78
|
+
action animation (frame-synced as best as frame counts allow).
|
|
79
|
+
fps: animation frames-per-second.
|
|
80
|
+
dim_labels: optional human labels per dimension.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
{"artifacts": [...], "summary": {...}}
|
|
84
|
+
"""
|
|
85
|
+
import matplotlib
|
|
86
|
+
matplotlib.use("Agg")
|
|
87
|
+
import matplotlib.pyplot as plt
|
|
88
|
+
import numpy as np
|
|
89
|
+
|
|
90
|
+
chunks = _as_chunks(action)
|
|
91
|
+
num_chunks, T, dim = chunks.shape
|
|
92
|
+
# Flatten chunks end-to-end into a single timeline for plotting.
|
|
93
|
+
flat = chunks.reshape(num_chunks * T, dim)
|
|
94
|
+
steps = np.arange(flat.shape[0])
|
|
95
|
+
|
|
96
|
+
labels = dim_labels or [f"dim{i}" for i in range(dim)]
|
|
97
|
+
g_idx = (gripper_index % dim) if (gripper_index is not None) else None
|
|
98
|
+
|
|
99
|
+
artifacts: List[str] = []
|
|
100
|
+
from strands_diffusers.core.io import _stamp
|
|
101
|
+
ts = _stamp()
|
|
102
|
+
|
|
103
|
+
# ── 1. time-series ──────────────────────────────────────────────
|
|
104
|
+
fig, ax = plt.subplots(figsize=(10, 5))
|
|
105
|
+
for i in range(dim):
|
|
106
|
+
is_grip = (i == g_idx)
|
|
107
|
+
ax.plot(steps, flat[:, i],
|
|
108
|
+
label=("gripper" if is_grip else labels[i]),
|
|
109
|
+
lw=2.4 if is_grip else 1.3,
|
|
110
|
+
color="black" if is_grip else None,
|
|
111
|
+
ls="--" if is_grip else "-",
|
|
112
|
+
alpha=1.0 if is_grip else 0.85)
|
|
113
|
+
for c in range(1, num_chunks):
|
|
114
|
+
ax.axvline(c * T - 0.5, color="gray", ls=":", alpha=0.4)
|
|
115
|
+
ax.set_xlabel("timestep")
|
|
116
|
+
ax.set_ylabel("normalized action value")
|
|
117
|
+
ax.set_title(f"Action chunk — {num_chunks}×{T} steps × {dim} dims")
|
|
118
|
+
ax.legend(loc="upper right", fontsize=8, ncol=2)
|
|
119
|
+
ax.grid(alpha=0.25)
|
|
120
|
+
p_ts = ARTIFACT_DIR / f"{save_prefix}_timeseries_{ts}.png"
|
|
121
|
+
fig.tight_layout()
|
|
122
|
+
fig.savefig(p_ts, dpi=110)
|
|
123
|
+
plt.close(fig)
|
|
124
|
+
artifacts.append(str(p_ts))
|
|
125
|
+
|
|
126
|
+
# ── 2. 3D end-effector path ─────────────────────────────────────
|
|
127
|
+
path_xyz = None
|
|
128
|
+
if interpret_xyz and dim >= 3:
|
|
129
|
+
xyz = flat[:, :3].copy()
|
|
130
|
+
if cumulative_xyz:
|
|
131
|
+
xyz = np.cumsum(xyz, axis=0) # treat as deltas → integrated trajectory
|
|
132
|
+
path_xyz = xyz
|
|
133
|
+
fig = plt.figure(figsize=(6, 5.5))
|
|
134
|
+
ax = fig.add_subplot(111, projection="3d")
|
|
135
|
+
ax.plot(xyz[:, 0], xyz[:, 1], xyz[:, 2], "-o", ms=3, lw=1.5)
|
|
136
|
+
ax.scatter(*xyz[0], c="green", s=60, label="start")
|
|
137
|
+
ax.scatter(*xyz[-1], c="red", s=60, label="end")
|
|
138
|
+
# mark gripper close/open events along the path if we have a gripper channel
|
|
139
|
+
if g_idx is not None:
|
|
140
|
+
g = flat[:, g_idx]
|
|
141
|
+
thr = (g.max() + g.min()) / 2.0
|
|
142
|
+
closed = g > thr
|
|
143
|
+
if closed.any():
|
|
144
|
+
ax.scatter(xyz[closed, 0], xyz[closed, 1], xyz[closed, 2],
|
|
145
|
+
c="orange", s=18, alpha=0.7, label="gripper engaged")
|
|
146
|
+
ax.set_xlabel("x"); ax.set_ylabel("y"); ax.set_zlabel("z")
|
|
147
|
+
ax.set_title("End-effector path"
|
|
148
|
+
+ (" (∫ deltas)" if cumulative_xyz else " (absolute)"))
|
|
149
|
+
ax.legend(fontsize=8)
|
|
150
|
+
p_traj = ARTIFACT_DIR / f"{save_prefix}_trajectory_{ts}.png"
|
|
151
|
+
fig.tight_layout()
|
|
152
|
+
fig.savefig(p_traj, dpi=110)
|
|
153
|
+
plt.close(fig)
|
|
154
|
+
artifacts.append(str(p_traj))
|
|
155
|
+
|
|
156
|
+
# ── 3. animation (playhead sweep, optional world video beside it) ──
|
|
157
|
+
anim_path = _animate(flat, labels, g_idx, path_xyz, world_video, fps,
|
|
158
|
+
save_prefix, ts)
|
|
159
|
+
if anim_path:
|
|
160
|
+
artifacts.append(anim_path)
|
|
161
|
+
|
|
162
|
+
summary = {
|
|
163
|
+
"num_chunks": int(num_chunks),
|
|
164
|
+
"timesteps_per_chunk": int(T),
|
|
165
|
+
"action_dim": int(dim),
|
|
166
|
+
"gripper_index": g_idx,
|
|
167
|
+
"value_range": [float(flat.min()), float(flat.max())],
|
|
168
|
+
"has_3d_path": path_xyz is not None,
|
|
169
|
+
}
|
|
170
|
+
return {"artifacts": artifacts, "summary": summary}
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _animate(flat, labels, g_idx, path_xyz, world_video, fps, prefix, ts):
|
|
174
|
+
"""Build an mp4/gif sweeping a playhead across the action curves."""
|
|
175
|
+
import matplotlib
|
|
176
|
+
matplotlib.use("Agg")
|
|
177
|
+
import matplotlib.pyplot as plt
|
|
178
|
+
import numpy as np
|
|
179
|
+
|
|
180
|
+
n = flat.shape[0]
|
|
181
|
+
dim = flat.shape[1]
|
|
182
|
+
|
|
183
|
+
# Optionally load world video frames to show alongside.
|
|
184
|
+
world_frames = None
|
|
185
|
+
if world_video:
|
|
186
|
+
try:
|
|
187
|
+
import imageio.v3 as iio
|
|
188
|
+
world_frames = iio.imread(world_video) # [F,H,W,C]
|
|
189
|
+
except Exception:
|
|
190
|
+
world_frames = None
|
|
191
|
+
|
|
192
|
+
have_world = world_frames is not None and len(world_frames) > 0
|
|
193
|
+
have_3d = path_xyz is not None
|
|
194
|
+
|
|
195
|
+
ncols = 1 + int(have_world) + int(have_3d)
|
|
196
|
+
frames_out = []
|
|
197
|
+
|
|
198
|
+
for k in range(n):
|
|
199
|
+
fig = plt.figure(figsize=(5 * ncols, 4.2))
|
|
200
|
+
col = 1
|
|
201
|
+
|
|
202
|
+
# panel A: action curves with playhead
|
|
203
|
+
axc = fig.add_subplot(1, ncols, col); col += 1
|
|
204
|
+
for i in range(dim):
|
|
205
|
+
is_grip = (i == g_idx)
|
|
206
|
+
axc.plot(flat[:, i], lw=2.2 if is_grip else 1.0,
|
|
207
|
+
color="black" if is_grip else None,
|
|
208
|
+
ls="--" if is_grip else "-", alpha=0.9 if is_grip else 0.7)
|
|
209
|
+
axc.axvline(k, color="red", lw=2)
|
|
210
|
+
axc.set_title("action")
|
|
211
|
+
axc.set_xlabel("timestep"); axc.grid(alpha=0.2)
|
|
212
|
+
|
|
213
|
+
# panel B: world frame (synced by proportion)
|
|
214
|
+
if have_world:
|
|
215
|
+
axw = fig.add_subplot(1, ncols, col); col += 1
|
|
216
|
+
fi = min(int(k / max(n - 1, 1) * (len(world_frames) - 1)),
|
|
217
|
+
len(world_frames) - 1)
|
|
218
|
+
axw.imshow(world_frames[fi])
|
|
219
|
+
axw.set_title(f"world frame {fi}")
|
|
220
|
+
axw.axis("off")
|
|
221
|
+
|
|
222
|
+
# panel C: 3D path with current point
|
|
223
|
+
if have_3d:
|
|
224
|
+
ax3 = fig.add_subplot(1, ncols, col, projection="3d"); col += 1
|
|
225
|
+
ax3.plot(path_xyz[:, 0], path_xyz[:, 1], path_xyz[:, 2],
|
|
226
|
+
"-", lw=1, alpha=0.5)
|
|
227
|
+
ax3.scatter(*path_xyz[k], c="red", s=50)
|
|
228
|
+
ax3.set_title("end-effector"); ax3.set_xticks([]); ax3.set_yticks([])
|
|
229
|
+
ax3.set_zticks([])
|
|
230
|
+
|
|
231
|
+
fig.tight_layout()
|
|
232
|
+
fig.canvas.draw()
|
|
233
|
+
buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
|
|
234
|
+
w, h = fig.canvas.get_width_height()
|
|
235
|
+
frames_out.append(buf.reshape(h, w, 4)[..., :3].copy())
|
|
236
|
+
plt.close(fig)
|
|
237
|
+
|
|
238
|
+
if not frames_out:
|
|
239
|
+
return None
|
|
240
|
+
|
|
241
|
+
out_mp4 = ARTIFACT_DIR / f"{prefix}_animation_{ts}.mp4"
|
|
242
|
+
try:
|
|
243
|
+
import imageio.v3 as iio
|
|
244
|
+
iio.imwrite(str(out_mp4), np.stack(frames_out), fps=fps, codec="libx264")
|
|
245
|
+
return str(out_mp4)
|
|
246
|
+
except Exception:
|
|
247
|
+
pass
|
|
248
|
+
try:
|
|
249
|
+
from PIL import Image
|
|
250
|
+
out_gif = ARTIFACT_DIR / f"{prefix}_animation_{ts}.gif"
|
|
251
|
+
imgs = [Image.fromarray(f) for f in frames_out]
|
|
252
|
+
imgs[0].save(str(out_gif), save_all=True, append_images=imgs[1:],
|
|
253
|
+
duration=int(1000 / max(fps, 1)), loop=0)
|
|
254
|
+
return str(out_gif)
|
|
255
|
+
except Exception:
|
|
256
|
+
return None
|