vlalab 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlalab/__init__.py +8 -1
- vlalab/apps/streamlit/app.py +310 -37
- vlalab/apps/streamlit/pages/eval_viewer.py +374 -0
- vlalab/cli.py +1 -1
- vlalab/eval/__init__.py +15 -0
- vlalab/eval/adapters/__init__.py +14 -0
- vlalab/eval/adapters/dp_adapter.py +279 -0
- vlalab/eval/adapters/groot_adapter.py +253 -0
- vlalab/eval/open_loop_eval.py +542 -0
- vlalab/eval/policy_interface.py +155 -0
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/METADATA +12 -70
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/RECORD +16 -9
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/WHEEL +0 -0
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/entry_points.txt +0 -0
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
"""
|
|
2
|
+
VLA-Lab Open-Loop Evaluation Viewer
|
|
3
|
+
|
|
4
|
+
Visualize and compare predicted vs ground-truth actions from open-loop evaluation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import streamlit as st
|
|
8
|
+
import numpy as np
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
import json
|
|
12
|
+
from typing import Dict, List, Any, Optional
|
|
13
|
+
|
|
14
|
+
# Setup matplotlib fonts
|
|
15
|
+
try:
|
|
16
|
+
from vlalab.viz.mpl_fonts import setup_matplotlib_fonts
|
|
17
|
+
setup_matplotlib_fonts(verbose=False)
|
|
18
|
+
except Exception:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def load_eval_results(results_path: Path) -> Dict[str, Any]:
|
|
23
|
+
"""Load evaluation results from JSON file."""
|
|
24
|
+
with open(results_path, "r") as f:
|
|
25
|
+
return json.load(f)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def load_trajectory_arrays(results_dir: Path, traj_id: int) -> Dict[str, np.ndarray]:
|
|
29
|
+
"""
|
|
30
|
+
Load trajectory arrays (GT and pred actions) if saved as .npy files.
|
|
31
|
+
|
|
32
|
+
This is a placeholder - actual implementation depends on how arrays are saved.
|
|
33
|
+
"""
|
|
34
|
+
arrays = {}
|
|
35
|
+
|
|
36
|
+
gt_path = results_dir / f"traj_{traj_id}_gt.npy"
|
|
37
|
+
pred_path = results_dir / f"traj_{traj_id}_pred.npy"
|
|
38
|
+
|
|
39
|
+
if gt_path.exists():
|
|
40
|
+
arrays["gt_actions"] = np.load(gt_path)
|
|
41
|
+
if pred_path.exists():
|
|
42
|
+
arrays["pred_actions"] = np.load(pred_path)
|
|
43
|
+
|
|
44
|
+
return arrays
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def plot_action_comparison(
|
|
48
|
+
gt_actions: np.ndarray,
|
|
49
|
+
pred_actions: np.ndarray,
|
|
50
|
+
dim_idx: int,
|
|
51
|
+
action_horizon: int = 16,
|
|
52
|
+
dim_label: str = "",
|
|
53
|
+
) -> plt.Figure:
|
|
54
|
+
"""Plot GT vs predicted actions for a single dimension."""
|
|
55
|
+
fig, ax = plt.subplots(figsize=(12, 4))
|
|
56
|
+
|
|
57
|
+
steps = np.arange(len(gt_actions))
|
|
58
|
+
|
|
59
|
+
ax.plot(steps, gt_actions[:, dim_idx], label="Ground Truth", alpha=0.8, linewidth=2)
|
|
60
|
+
ax.plot(steps, pred_actions[:, dim_idx], label="Predicted", alpha=0.8, linewidth=2, linestyle="--")
|
|
61
|
+
|
|
62
|
+
# Mark inference points
|
|
63
|
+
for j in range(0, len(gt_actions), action_horizon):
|
|
64
|
+
ax.axvline(x=j, color='gray', linestyle=':', alpha=0.3)
|
|
65
|
+
|
|
66
|
+
# Error band
|
|
67
|
+
error = np.abs(gt_actions[:, dim_idx] - pred_actions[:, dim_idx])
|
|
68
|
+
ax.fill_between(steps,
|
|
69
|
+
gt_actions[:, dim_idx] - error,
|
|
70
|
+
gt_actions[:, dim_idx] + error,
|
|
71
|
+
alpha=0.2, color='red', label='Error')
|
|
72
|
+
|
|
73
|
+
ax.set_xlabel("Step")
|
|
74
|
+
ax.set_ylabel("Value")
|
|
75
|
+
ax.set_title(f"Action Dimension: {dim_label or dim_idx}")
|
|
76
|
+
ax.legend()
|
|
77
|
+
ax.grid(True, alpha=0.3)
|
|
78
|
+
|
|
79
|
+
return fig
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def plot_all_dimensions(
|
|
83
|
+
gt_actions: np.ndarray,
|
|
84
|
+
pred_actions: np.ndarray,
|
|
85
|
+
action_horizon: int = 16,
|
|
86
|
+
action_labels: Optional[List[str]] = None,
|
|
87
|
+
) -> plt.Figure:
|
|
88
|
+
"""Plot all action dimensions in a grid."""
|
|
89
|
+
num_dims = gt_actions.shape[1]
|
|
90
|
+
n_cols = min(3, num_dims)
|
|
91
|
+
n_rows = (num_dims + n_cols - 1) // n_cols
|
|
92
|
+
|
|
93
|
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 3 * n_rows))
|
|
94
|
+
axes = np.atleast_2d(axes)
|
|
95
|
+
|
|
96
|
+
steps = np.arange(len(gt_actions))
|
|
97
|
+
|
|
98
|
+
for i in range(num_dims):
|
|
99
|
+
row, col = i // n_cols, i % n_cols
|
|
100
|
+
ax = axes[row, col]
|
|
101
|
+
|
|
102
|
+
ax.plot(steps, gt_actions[:, i], label="GT", alpha=0.8)
|
|
103
|
+
ax.plot(steps, pred_actions[:, i], label="Pred", alpha=0.8, linestyle="--")
|
|
104
|
+
|
|
105
|
+
label = action_labels[i] if action_labels and i < len(action_labels) else f"Dim {i}"
|
|
106
|
+
ax.set_title(label, fontsize=10)
|
|
107
|
+
ax.grid(True, alpha=0.3)
|
|
108
|
+
|
|
109
|
+
if i == 0:
|
|
110
|
+
ax.legend(fontsize=8)
|
|
111
|
+
|
|
112
|
+
# Hide empty subplots
|
|
113
|
+
for i in range(num_dims, n_rows * n_cols):
|
|
114
|
+
row, col = i // n_cols, i % n_cols
|
|
115
|
+
axes[row, col].axis('off')
|
|
116
|
+
|
|
117
|
+
plt.tight_layout()
|
|
118
|
+
return fig
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def plot_3d_trajectory(
|
|
122
|
+
gt_actions: np.ndarray,
|
|
123
|
+
pred_actions: np.ndarray,
|
|
124
|
+
xyz_dims: List[int] = [0, 1, 2],
|
|
125
|
+
) -> plt.Figure:
|
|
126
|
+
"""Plot 3D trajectory comparison (if actions contain position)."""
|
|
127
|
+
fig = plt.figure(figsize=(10, 8))
|
|
128
|
+
ax = fig.add_subplot(111, projection='3d')
|
|
129
|
+
|
|
130
|
+
x, y, z = xyz_dims
|
|
131
|
+
|
|
132
|
+
# GT trajectory
|
|
133
|
+
ax.plot(gt_actions[:, x], gt_actions[:, y], gt_actions[:, z],
|
|
134
|
+
'b-', label='Ground Truth', alpha=0.8, linewidth=2)
|
|
135
|
+
ax.scatter(gt_actions[0, x], gt_actions[0, y], gt_actions[0, z],
|
|
136
|
+
c='green', s=100, label='Start')
|
|
137
|
+
ax.scatter(gt_actions[-1, x], gt_actions[-1, y], gt_actions[-1, z],
|
|
138
|
+
c='red', s=100, label='End')
|
|
139
|
+
|
|
140
|
+
# Pred trajectory
|
|
141
|
+
ax.plot(pred_actions[:, x], pred_actions[:, y], pred_actions[:, z],
|
|
142
|
+
'r--', label='Predicted', alpha=0.6, linewidth=2)
|
|
143
|
+
|
|
144
|
+
ax.set_xlabel('X')
|
|
145
|
+
ax.set_ylabel('Y')
|
|
146
|
+
ax.set_zlabel('Z')
|
|
147
|
+
ax.legend()
|
|
148
|
+
ax.set_title('3D Trajectory Comparison')
|
|
149
|
+
|
|
150
|
+
return fig
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def plot_error_histogram(
|
|
154
|
+
gt_actions: np.ndarray,
|
|
155
|
+
pred_actions: np.ndarray,
|
|
156
|
+
) -> plt.Figure:
|
|
157
|
+
"""Plot error distribution histogram."""
|
|
158
|
+
errors = (gt_actions - pred_actions).flatten()
|
|
159
|
+
|
|
160
|
+
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
|
|
161
|
+
|
|
162
|
+
# Error histogram
|
|
163
|
+
axes[0].hist(errors, bins=50, alpha=0.7, edgecolor='black')
|
|
164
|
+
axes[0].axvline(x=0, color='r', linestyle='--', label='Zero')
|
|
165
|
+
axes[0].set_xlabel('Error')
|
|
166
|
+
axes[0].set_ylabel('Frequency')
|
|
167
|
+
axes[0].set_title('Error Distribution')
|
|
168
|
+
axes[0].legend()
|
|
169
|
+
|
|
170
|
+
# Absolute error histogram
|
|
171
|
+
abs_errors = np.abs(errors)
|
|
172
|
+
axes[1].hist(abs_errors, bins=50, alpha=0.7, edgecolor='black', color='orange')
|
|
173
|
+
axes[1].axvline(x=np.mean(abs_errors), color='r', linestyle='--',
|
|
174
|
+
label=f'Mean: {np.mean(abs_errors):.4f}')
|
|
175
|
+
axes[1].set_xlabel('Absolute Error')
|
|
176
|
+
axes[1].set_ylabel('Frequency')
|
|
177
|
+
axes[1].set_title('Absolute Error Distribution')
|
|
178
|
+
axes[1].legend()
|
|
179
|
+
|
|
180
|
+
plt.tight_layout()
|
|
181
|
+
return fig
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def render():
|
|
185
|
+
"""Render the evaluation viewer page."""
|
|
186
|
+
st.title("📊 Open-Loop 评估结果")
|
|
187
|
+
|
|
188
|
+
st.markdown("""
|
|
189
|
+
可视化 VLA 模型的 Open-Loop 评估结果,比较预测动作与真实动作。
|
|
190
|
+
|
|
191
|
+
**使用方法:**
|
|
192
|
+
1. 上传评估结果 JSON 文件
|
|
193
|
+
2. 或指定包含评估图片的目录
|
|
194
|
+
""")
|
|
195
|
+
|
|
196
|
+
# Sidebar options
|
|
197
|
+
st.sidebar.markdown("### 数据来源")
|
|
198
|
+
|
|
199
|
+
source = st.sidebar.radio(
|
|
200
|
+
"选择数据来源",
|
|
201
|
+
["上传 JSON", "浏览目录", "演示数据"],
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
results = None
|
|
205
|
+
gt_actions = None
|
|
206
|
+
pred_actions = None
|
|
207
|
+
|
|
208
|
+
if source == "上传 JSON":
|
|
209
|
+
uploaded_file = st.sidebar.file_uploader(
|
|
210
|
+
"上传评估结果 JSON",
|
|
211
|
+
type=["json"],
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
if uploaded_file:
|
|
215
|
+
results = json.load(uploaded_file)
|
|
216
|
+
st.success(f"已加载评估结果: {len(results.get('results', []))} 条轨迹")
|
|
217
|
+
|
|
218
|
+
elif source == "浏览目录":
|
|
219
|
+
results_dir = st.sidebar.text_input(
|
|
220
|
+
"评估结果目录",
|
|
221
|
+
value="",
|
|
222
|
+
placeholder="/path/to/eval_results/",
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
if results_dir and Path(results_dir).exists():
|
|
226
|
+
# Find JSON files
|
|
227
|
+
json_files = list(Path(results_dir).glob("*.json"))
|
|
228
|
+
if json_files:
|
|
229
|
+
selected_json = st.sidebar.selectbox(
|
|
230
|
+
"选择结果文件",
|
|
231
|
+
json_files,
|
|
232
|
+
format_func=lambda p: p.name,
|
|
233
|
+
)
|
|
234
|
+
if selected_json:
|
|
235
|
+
results = load_eval_results(selected_json)
|
|
236
|
+
|
|
237
|
+
# Find plot images
|
|
238
|
+
png_files = sorted(Path(results_dir).glob("*.png"))
|
|
239
|
+
if png_files:
|
|
240
|
+
st.markdown("### 已生成的评估图")
|
|
241
|
+
for png_file in png_files:
|
|
242
|
+
st.image(str(png_file), caption=png_file.name)
|
|
243
|
+
|
|
244
|
+
elif source == "演示数据":
|
|
245
|
+
st.info("使用随机生成的演示数据")
|
|
246
|
+
|
|
247
|
+
# Generate demo data
|
|
248
|
+
np.random.seed(42)
|
|
249
|
+
num_steps = 200
|
|
250
|
+
action_dim = 8
|
|
251
|
+
|
|
252
|
+
# Simulate GT actions (smooth trajectory)
|
|
253
|
+
t = np.linspace(0, 4 * np.pi, num_steps)
|
|
254
|
+
gt_actions = np.zeros((num_steps, action_dim))
|
|
255
|
+
gt_actions[:, 0] = 0.5 * np.sin(t) + 0.1 * np.random.randn(num_steps)
|
|
256
|
+
gt_actions[:, 1] = 0.3 * np.cos(t) + 0.1 * np.random.randn(num_steps)
|
|
257
|
+
gt_actions[:, 2] = 0.2 * np.sin(0.5 * t) + 0.1 * np.random.randn(num_steps)
|
|
258
|
+
for i in range(3, action_dim):
|
|
259
|
+
gt_actions[:, i] = 0.1 * np.sin(t * (i - 2)) + 0.05 * np.random.randn(num_steps)
|
|
260
|
+
|
|
261
|
+
# Simulate pred actions (GT + noise + slight bias)
|
|
262
|
+
pred_actions = gt_actions + 0.05 * np.random.randn(num_steps, action_dim)
|
|
263
|
+
pred_actions += 0.02 * np.ones_like(pred_actions) # slight bias
|
|
264
|
+
|
|
265
|
+
# Calculate metrics
|
|
266
|
+
mse = float(np.mean((gt_actions - pred_actions) ** 2))
|
|
267
|
+
mae = float(np.mean(np.abs(gt_actions - pred_actions)))
|
|
268
|
+
|
|
269
|
+
results = {
|
|
270
|
+
"results": [{"trajectory_id": 0, "mse": mse, "mae": mae, "num_steps": num_steps}],
|
|
271
|
+
"avg_mse": mse,
|
|
272
|
+
"avg_mae": mae,
|
|
273
|
+
"num_trajectories": 1,
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
# Display results
|
|
277
|
+
if results:
|
|
278
|
+
st.markdown("---")
|
|
279
|
+
|
|
280
|
+
# Summary metrics
|
|
281
|
+
st.markdown("### 📈 评估指标")
|
|
282
|
+
col1, col2, col3 = st.columns(3)
|
|
283
|
+
|
|
284
|
+
col1.metric("平均 MSE", f"{results.get('avg_mse', 0):.6f}")
|
|
285
|
+
col2.metric("平均 MAE", f"{results.get('avg_mae', 0):.6f}")
|
|
286
|
+
col3.metric("评估轨迹数", results.get('num_trajectories', 0))
|
|
287
|
+
|
|
288
|
+
# Per-trajectory results
|
|
289
|
+
if results.get("results"):
|
|
290
|
+
st.markdown("### 📋 轨迹详情")
|
|
291
|
+
|
|
292
|
+
traj_data = []
|
|
293
|
+
for r in results["results"]:
|
|
294
|
+
traj_data.append({
|
|
295
|
+
"轨迹 ID": r["trajectory_id"],
|
|
296
|
+
"MSE": f"{r['mse']:.6f}",
|
|
297
|
+
"MAE": f"{r['mae']:.6f}",
|
|
298
|
+
"步数": r["num_steps"],
|
|
299
|
+
})
|
|
300
|
+
|
|
301
|
+
st.dataframe(traj_data, use_container_width=True)
|
|
302
|
+
|
|
303
|
+
# Visualization tabs
|
|
304
|
+
if gt_actions is not None and pred_actions is not None:
|
|
305
|
+
st.markdown("---")
|
|
306
|
+
st.markdown("### 🎨 可视化")
|
|
307
|
+
|
|
308
|
+
tab1, tab2, tab3, tab4 = st.tabs([
|
|
309
|
+
"📊 时序对比",
|
|
310
|
+
"🗺️ 3D 轨迹",
|
|
311
|
+
"📈 误差分布",
|
|
312
|
+
"🔍 逐维度分析",
|
|
313
|
+
])
|
|
314
|
+
|
|
315
|
+
with tab1:
|
|
316
|
+
action_horizon = st.slider("Action Horizon", 4, 32, 16)
|
|
317
|
+
fig = plot_all_dimensions(gt_actions, pred_actions, action_horizon)
|
|
318
|
+
st.pyplot(fig)
|
|
319
|
+
plt.close(fig)
|
|
320
|
+
|
|
321
|
+
with tab2:
|
|
322
|
+
if gt_actions.shape[1] >= 3:
|
|
323
|
+
st.markdown("假设前三个维度为 XYZ 位置")
|
|
324
|
+
fig = plot_3d_trajectory(gt_actions, pred_actions)
|
|
325
|
+
st.pyplot(fig)
|
|
326
|
+
plt.close(fig)
|
|
327
|
+
else:
|
|
328
|
+
st.warning("动作维度不足,无法绘制 3D 轨迹")
|
|
329
|
+
|
|
330
|
+
with tab3:
|
|
331
|
+
fig = plot_error_histogram(gt_actions, pred_actions)
|
|
332
|
+
st.pyplot(fig)
|
|
333
|
+
plt.close(fig)
|
|
334
|
+
|
|
335
|
+
with tab4:
|
|
336
|
+
dim_idx = st.selectbox(
|
|
337
|
+
"选择维度",
|
|
338
|
+
range(gt_actions.shape[1]),
|
|
339
|
+
format_func=lambda x: f"维度 {x}",
|
|
340
|
+
)
|
|
341
|
+
fig = plot_action_comparison(gt_actions, pred_actions, dim_idx)
|
|
342
|
+
st.pyplot(fig)
|
|
343
|
+
plt.close(fig)
|
|
344
|
+
|
|
345
|
+
# Usage instructions
|
|
346
|
+
with st.expander("💡 如何生成评估结果"):
|
|
347
|
+
st.markdown("""
|
|
348
|
+
```python
|
|
349
|
+
from vlalab.eval import OpenLoopEvaluator
|
|
350
|
+
from vlalab.eval.adapters import GR00TAdapter
|
|
351
|
+
|
|
352
|
+
# 1. 创建 Policy 适配器
|
|
353
|
+
adapter = GR00TAdapter(policy)
|
|
354
|
+
|
|
355
|
+
# 2. 创建评估器
|
|
356
|
+
evaluator = OpenLoopEvaluator(
|
|
357
|
+
policy=adapter,
|
|
358
|
+
dataset_path="/path/to/dataset.zarr",
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# 3. 运行评估
|
|
362
|
+
results = evaluator.evaluate(
|
|
363
|
+
traj_ids=[0, 1, 2],
|
|
364
|
+
max_steps=200,
|
|
365
|
+
save_plots_dir="eval_outputs/",
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
# 4. 保存结果
|
|
369
|
+
evaluator.evaluate_and_save(
|
|
370
|
+
"eval_outputs/results.json",
|
|
371
|
+
traj_ids=[0, 1, 2],
|
|
372
|
+
)
|
|
373
|
+
```
|
|
374
|
+
""")
|
vlalab/cli.py
CHANGED
|
@@ -44,7 +44,7 @@ def view(port: int, run_dir: str):
|
|
|
44
44
|
package_dir = Path(vlalab.__file__).parent
|
|
45
45
|
app_path = package_dir / "apps" / "streamlit" / "app.py"
|
|
46
46
|
|
|
47
|
-
cmd = [sys.executable, "-m", "streamlit", "run", str(app_path), "--server.port", str(port)]
|
|
47
|
+
cmd = [sys.executable, "-m", "streamlit", "run", str(app_path), "--server.port", str(port), "--server.address", "0.0.0.0"]
|
|
48
48
|
|
|
49
49
|
if run_dir:
|
|
50
50
|
cmd.extend(["--", "--run-dir", run_dir])
|
vlalab/eval/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""
|
|
2
|
+
VLA-Lab Evaluation Module
|
|
3
|
+
|
|
4
|
+
Provides open-loop evaluation tools for VLA models.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from vlalab.eval.policy_interface import EvalPolicy, ModalityConfig
|
|
8
|
+
from vlalab.eval.open_loop_eval import OpenLoopEvaluator, evaluate_trajectory
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"EvalPolicy",
|
|
12
|
+
"ModalityConfig",
|
|
13
|
+
"OpenLoopEvaluator",
|
|
14
|
+
"evaluate_trajectory",
|
|
15
|
+
]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""
|
|
2
|
+
VLA-Lab Policy Adapters
|
|
3
|
+
|
|
4
|
+
Adapters that wrap specific VLA model implementations to conform to
|
|
5
|
+
the unified EvalPolicy interface.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from vlalab.eval.adapters.groot_adapter import GR00TAdapter
|
|
9
|
+
from vlalab.eval.adapters.dp_adapter import DiffusionPolicyAdapter
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"GR00TAdapter",
|
|
13
|
+
"DiffusionPolicyAdapter",
|
|
14
|
+
]
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Diffusion Policy Adapter
|
|
3
|
+
|
|
4
|
+
Wraps Diffusion Policy implementations to conform to the unified EvalPolicy interface.
|
|
5
|
+
This adapter handles the conversion between VLA-Lab's standardized observation
|
|
6
|
+
format and Diffusion Policy's specific input/output formats.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any, Dict, List, Optional, Callable
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from vlalab.eval.policy_interface import EvalPolicy, ModalityConfig
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def parse_observation_dp(
|
|
16
|
+
obs: Dict[str, Any],
|
|
17
|
+
state_key: str = "state",
|
|
18
|
+
image_key: str = "image",
|
|
19
|
+
) -> Dict[str, Any]:
|
|
20
|
+
"""
|
|
21
|
+
Convert standardized observation to Diffusion Policy's expected format.
|
|
22
|
+
|
|
23
|
+
Diffusion Policy typically expects:
|
|
24
|
+
{
|
|
25
|
+
"state": np.ndarray (state_dim,) or (T, state_dim),
|
|
26
|
+
"image": np.ndarray (H, W, C) or (T, H, W, C),
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
obs: Standardized observation dict with:
|
|
31
|
+
- "state": Dict[str, np.ndarray] - state vectors
|
|
32
|
+
- "images": Dict[str, np.ndarray] - images (H, W, C)
|
|
33
|
+
state_key: Key to use for concatenated state
|
|
34
|
+
image_key: Key to use for primary image
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
DP-formatted observation dict
|
|
38
|
+
"""
|
|
39
|
+
dp_obs = {}
|
|
40
|
+
|
|
41
|
+
# Concatenate all state keys into single state vector
|
|
42
|
+
if "state" in obs:
|
|
43
|
+
state_parts = []
|
|
44
|
+
for key, arr in obs["state"].items():
|
|
45
|
+
arr = np.atleast_1d(arr).astype(np.float32)
|
|
46
|
+
state_parts.append(arr)
|
|
47
|
+
if state_parts:
|
|
48
|
+
dp_obs[state_key] = np.concatenate(state_parts, axis=-1)
|
|
49
|
+
|
|
50
|
+
# Use first image as primary image
|
|
51
|
+
if "images" in obs:
|
|
52
|
+
for cam_name, img in obs["images"].items():
|
|
53
|
+
dp_obs[image_key] = img
|
|
54
|
+
break # Use first image
|
|
55
|
+
|
|
56
|
+
return dp_obs
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def parse_action_dp(
|
|
60
|
+
action: Any,
|
|
61
|
+
) -> np.ndarray:
|
|
62
|
+
"""
|
|
63
|
+
Convert Diffusion Policy action output to standardized array format.
|
|
64
|
+
|
|
65
|
+
Diffusion Policy outputs action chunks as:
|
|
66
|
+
- np.ndarray of shape (action_horizon, action_dim)
|
|
67
|
+
- or torch.Tensor of same shape
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
action: DP action output
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Action array of shape (action_horizon, action_dim)
|
|
74
|
+
"""
|
|
75
|
+
# Handle torch tensors
|
|
76
|
+
if hasattr(action, "cpu"):
|
|
77
|
+
action = action.cpu().numpy()
|
|
78
|
+
|
|
79
|
+
action = np.asarray(action)
|
|
80
|
+
|
|
81
|
+
# Ensure 2D
|
|
82
|
+
if action.ndim == 1:
|
|
83
|
+
action = action[None, :]
|
|
84
|
+
|
|
85
|
+
return action
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class DiffusionPolicyAdapter(EvalPolicy):
|
|
89
|
+
"""
|
|
90
|
+
Adapter for Diffusion Policy.
|
|
91
|
+
|
|
92
|
+
This adapter is designed to work with various Diffusion Policy implementations.
|
|
93
|
+
You can either:
|
|
94
|
+
1. Pass a policy object with a predict() or get_action() method
|
|
95
|
+
2. Pass a callable inference function directly
|
|
96
|
+
|
|
97
|
+
Usage:
|
|
98
|
+
# Option 1: Wrap policy object
|
|
99
|
+
from dp_server import DPPolicy
|
|
100
|
+
policy = DPPolicy.load(checkpoint_path)
|
|
101
|
+
adapter = DiffusionPolicyAdapter(policy)
|
|
102
|
+
|
|
103
|
+
# Option 2: Wrap inference client
|
|
104
|
+
from dp_client import DPClient
|
|
105
|
+
client = DPClient(host="localhost", port=5000)
|
|
106
|
+
adapter = DiffusionPolicyAdapter(
|
|
107
|
+
client,
|
|
108
|
+
inference_fn=lambda p, obs: p.predict(obs["state"], obs["image"])
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Option 3: Wrap callable
|
|
112
|
+
adapter = DiffusionPolicyAdapter(
|
|
113
|
+
inference_fn=my_inference_function,
|
|
114
|
+
action_horizon=8,
|
|
115
|
+
action_dim=7,
|
|
116
|
+
)
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
policy: Any = None,
|
|
122
|
+
inference_fn: Optional[Callable] = None,
|
|
123
|
+
action_horizon: int = 8,
|
|
124
|
+
action_dim: int = 7,
|
|
125
|
+
state_keys: Optional[List[str]] = None,
|
|
126
|
+
image_keys: Optional[List[str]] = None,
|
|
127
|
+
):
|
|
128
|
+
"""
|
|
129
|
+
Initialize the Diffusion Policy adapter.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
policy: DP policy instance (optional if inference_fn provided)
|
|
133
|
+
inference_fn: Custom inference function (policy, obs) -> action
|
|
134
|
+
action_horizon: Number of future actions predicted
|
|
135
|
+
action_dim: Dimension of action space
|
|
136
|
+
state_keys: List of state modality keys
|
|
137
|
+
image_keys: List of image modality keys
|
|
138
|
+
"""
|
|
139
|
+
self.policy = policy
|
|
140
|
+
self._inference_fn = inference_fn
|
|
141
|
+
|
|
142
|
+
self._modality_config = ModalityConfig(
|
|
143
|
+
state_keys=state_keys or ["robot_state"],
|
|
144
|
+
action_keys=["action"],
|
|
145
|
+
image_keys=image_keys or ["front"],
|
|
146
|
+
language_keys=[], # DP typically doesn't use language
|
|
147
|
+
action_horizon=action_horizon,
|
|
148
|
+
action_dim=action_dim,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
def _default_inference(self, obs: Dict[str, Any]) -> np.ndarray:
|
|
152
|
+
"""Default inference logic when no custom inference_fn provided."""
|
|
153
|
+
dp_obs = parse_observation_dp(obs)
|
|
154
|
+
|
|
155
|
+
# Try different method names
|
|
156
|
+
if hasattr(self.policy, "predict"):
|
|
157
|
+
action = self.policy.predict(dp_obs)
|
|
158
|
+
elif hasattr(self.policy, "get_action"):
|
|
159
|
+
action = self.policy.get_action(dp_obs)
|
|
160
|
+
elif hasattr(self.policy, "__call__"):
|
|
161
|
+
action = self.policy(dp_obs)
|
|
162
|
+
else:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
"Policy must have predict(), get_action(), or __call__() method. "
|
|
165
|
+
"Alternatively, provide inference_fn parameter."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
return parse_action_dp(action)
|
|
169
|
+
|
|
170
|
+
def get_action(
|
|
171
|
+
self,
|
|
172
|
+
obs: Dict[str, Any],
|
|
173
|
+
task_description: Optional[str] = None,
|
|
174
|
+
) -> np.ndarray:
|
|
175
|
+
"""
|
|
176
|
+
Get action from Diffusion Policy.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
obs: Standardized observation dict
|
|
180
|
+
task_description: Ignored (DP doesn't use language)
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Action array of shape (action_horizon, action_dim)
|
|
184
|
+
"""
|
|
185
|
+
if self._inference_fn is not None:
|
|
186
|
+
if self.policy is not None:
|
|
187
|
+
action = self._inference_fn(self.policy, obs)
|
|
188
|
+
else:
|
|
189
|
+
action = self._inference_fn(obs)
|
|
190
|
+
return parse_action_dp(action)
|
|
191
|
+
|
|
192
|
+
return self._default_inference(obs)
|
|
193
|
+
|
|
194
|
+
def get_modality_config(self) -> ModalityConfig:
|
|
195
|
+
"""Get modality configuration."""
|
|
196
|
+
return self._modality_config
|
|
197
|
+
|
|
198
|
+
def reset(self) -> None:
|
|
199
|
+
"""Reset the policy."""
|
|
200
|
+
if self.policy is not None and hasattr(self.policy, "reset"):
|
|
201
|
+
self.policy.reset()
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class DiffusionPolicyClientAdapter(DiffusionPolicyAdapter):
|
|
205
|
+
"""
|
|
206
|
+
Adapter for Diffusion Policy inference client (e.g., ZMQ client).
|
|
207
|
+
|
|
208
|
+
This is a convenience class for wrapping remote DP inference servers.
|
|
209
|
+
|
|
210
|
+
Usage:
|
|
211
|
+
adapter = DiffusionPolicyClientAdapter(
|
|
212
|
+
host="localhost",
|
|
213
|
+
port=5000,
|
|
214
|
+
action_horizon=8,
|
|
215
|
+
action_dim=7,
|
|
216
|
+
)
|
|
217
|
+
action = adapter.get_action(obs)
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
def __init__(
|
|
221
|
+
self,
|
|
222
|
+
host: str = "localhost",
|
|
223
|
+
port: int = 5000,
|
|
224
|
+
action_horizon: int = 8,
|
|
225
|
+
action_dim: int = 7,
|
|
226
|
+
state_keys: Optional[List[str]] = None,
|
|
227
|
+
image_keys: Optional[List[str]] = None,
|
|
228
|
+
):
|
|
229
|
+
"""
|
|
230
|
+
Initialize client adapter.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
host: Server hostname
|
|
234
|
+
port: Server port
|
|
235
|
+
action_horizon: Number of future actions
|
|
236
|
+
action_dim: Action dimension
|
|
237
|
+
state_keys: State modality keys
|
|
238
|
+
image_keys: Image modality keys
|
|
239
|
+
"""
|
|
240
|
+
super().__init__(
|
|
241
|
+
policy=None,
|
|
242
|
+
action_horizon=action_horizon,
|
|
243
|
+
action_dim=action_dim,
|
|
244
|
+
state_keys=state_keys,
|
|
245
|
+
image_keys=image_keys,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
self.host = host
|
|
249
|
+
self.port = port
|
|
250
|
+
self._client = None
|
|
251
|
+
|
|
252
|
+
def _get_client(self):
|
|
253
|
+
"""Lazy initialization of ZMQ client."""
|
|
254
|
+
if self._client is None:
|
|
255
|
+
try:
|
|
256
|
+
import zmq
|
|
257
|
+
context = zmq.Context()
|
|
258
|
+
self._client = context.socket(zmq.REQ)
|
|
259
|
+
self._client.connect(f"tcp://{self.host}:{self.port}")
|
|
260
|
+
except ImportError:
|
|
261
|
+
raise ImportError("pyzmq required for DiffusionPolicyClientAdapter")
|
|
262
|
+
return self._client
|
|
263
|
+
|
|
264
|
+
def get_action(
|
|
265
|
+
self,
|
|
266
|
+
obs: Dict[str, Any],
|
|
267
|
+
task_description: Optional[str] = None,
|
|
268
|
+
) -> np.ndarray:
|
|
269
|
+
"""
|
|
270
|
+
Get action from remote DP server.
|
|
271
|
+
|
|
272
|
+
Note: This is a placeholder. Actual implementation depends on
|
|
273
|
+
the specific DP server protocol used.
|
|
274
|
+
"""
|
|
275
|
+
raise NotImplementedError(
|
|
276
|
+
"DiffusionPolicyClientAdapter.get_action() requires implementation "
|
|
277
|
+
"specific to your DP server protocol. Override this method or use "
|
|
278
|
+
"DiffusionPolicyAdapter with a custom inference_fn."
|
|
279
|
+
)
|