vlalab 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlalab/__init__.py +8 -1
- vlalab/apps/streamlit/app.py +310 -37
- vlalab/apps/streamlit/pages/eval_viewer.py +374 -0
- vlalab/cli.py +1 -1
- vlalab/eval/__init__.py +15 -0
- vlalab/eval/adapters/__init__.py +14 -0
- vlalab/eval/adapters/dp_adapter.py +279 -0
- vlalab/eval/adapters/groot_adapter.py +253 -0
- vlalab/eval/open_loop_eval.py +542 -0
- vlalab/eval/policy_interface.py +155 -0
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/METADATA +12 -70
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/RECORD +16 -9
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/WHEEL +0 -0
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/entry_points.txt +0 -0
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,542 @@
|
|
|
1
|
+
"""
|
|
2
|
+
VLA-Lab Open-Loop Evaluation
|
|
3
|
+
|
|
4
|
+
Model-agnostic open-loop evaluation for VLA policies.
|
|
5
|
+
Compares predicted actions against ground-truth actions from a dataset.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
11
|
+
import logging
|
|
12
|
+
import json
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import matplotlib.pyplot as plt
|
|
16
|
+
|
|
17
|
+
from vlalab.eval.policy_interface import EvalPolicy, ModalityConfig
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class EvalResult:
|
|
22
|
+
"""Result of evaluating a single trajectory."""
|
|
23
|
+
trajectory_id: int
|
|
24
|
+
mse: float
|
|
25
|
+
mae: float
|
|
26
|
+
num_steps: int
|
|
27
|
+
gt_actions: np.ndarray # (T, action_dim)
|
|
28
|
+
pred_actions: np.ndarray # (T, action_dim)
|
|
29
|
+
states: Optional[np.ndarray] = None # (T, state_dim)
|
|
30
|
+
|
|
31
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
32
|
+
return {
|
|
33
|
+
"trajectory_id": self.trajectory_id,
|
|
34
|
+
"mse": float(self.mse),
|
|
35
|
+
"mae": float(self.mae),
|
|
36
|
+
"num_steps": self.num_steps,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class EvalConfig:
|
|
42
|
+
"""Configuration for open-loop evaluation."""
|
|
43
|
+
max_steps: int = 300
|
|
44
|
+
action_horizon: int = 16
|
|
45
|
+
task_description: Optional[str] = None
|
|
46
|
+
save_plot_path: Optional[str] = None
|
|
47
|
+
|
|
48
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
49
|
+
return {
|
|
50
|
+
"max_steps": self.max_steps,
|
|
51
|
+
"action_horizon": self.action_horizon,
|
|
52
|
+
"task_description": self.task_description,
|
|
53
|
+
"save_plot_path": self.save_plot_path,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class DatasetLoader:
|
|
58
|
+
"""
|
|
59
|
+
Abstract interface for loading trajectory data from datasets.
|
|
60
|
+
|
|
61
|
+
Subclass this to support different dataset formats (Zarr, LeRobot, HDF5, etc.)
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __len__(self) -> int:
|
|
65
|
+
"""Return number of trajectories."""
|
|
66
|
+
raise NotImplementedError
|
|
67
|
+
|
|
68
|
+
def get_trajectory_length(self, traj_id: int) -> int:
|
|
69
|
+
"""Return length of a specific trajectory."""
|
|
70
|
+
raise NotImplementedError
|
|
71
|
+
|
|
72
|
+
def get_step_data(
|
|
73
|
+
self,
|
|
74
|
+
traj_id: int,
|
|
75
|
+
step_idx: int,
|
|
76
|
+
state_keys: List[str],
|
|
77
|
+
image_keys: List[str],
|
|
78
|
+
) -> Dict[str, Any]:
|
|
79
|
+
"""
|
|
80
|
+
Get observation data for a specific step.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Dict with:
|
|
84
|
+
- "state": Dict[str, np.ndarray] - state vectors by key
|
|
85
|
+
- "images": Dict[str, np.ndarray] - images by camera name
|
|
86
|
+
"""
|
|
87
|
+
raise NotImplementedError
|
|
88
|
+
|
|
89
|
+
def get_action(
|
|
90
|
+
self,
|
|
91
|
+
traj_id: int,
|
|
92
|
+
step_idx: int,
|
|
93
|
+
action_keys: List[str],
|
|
94
|
+
) -> np.ndarray:
|
|
95
|
+
"""
|
|
96
|
+
Get ground-truth action for a specific step.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Action array of shape (action_dim,)
|
|
100
|
+
"""
|
|
101
|
+
raise NotImplementedError
|
|
102
|
+
|
|
103
|
+
def get_trajectory_actions(
|
|
104
|
+
self,
|
|
105
|
+
traj_id: int,
|
|
106
|
+
action_keys: List[str],
|
|
107
|
+
max_steps: Optional[int] = None,
|
|
108
|
+
) -> np.ndarray:
|
|
109
|
+
"""
|
|
110
|
+
Get all ground-truth actions for a trajectory.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Action array of shape (T, action_dim)
|
|
114
|
+
"""
|
|
115
|
+
raise NotImplementedError
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class ZarrDatasetLoader(DatasetLoader):
|
|
119
|
+
"""
|
|
120
|
+
Loader for Zarr format datasets (used by Diffusion Policy, etc.)
|
|
121
|
+
|
|
122
|
+
Expected structure:
|
|
123
|
+
dataset.zarr/
|
|
124
|
+
├── data/
|
|
125
|
+
│ ├── action (T, action_dim)
|
|
126
|
+
│ ├── state (T, state_dim) # or robot_state, etc.
|
|
127
|
+
│ └── image_* (T, H, W, C)
|
|
128
|
+
└── meta/
|
|
129
|
+
└── episode_ends (num_episodes,)
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
def __init__(self, zarr_path: str):
|
|
133
|
+
import zarr
|
|
134
|
+
self.zarr_path = Path(zarr_path)
|
|
135
|
+
self.root = zarr.open(str(zarr_path), mode='r')
|
|
136
|
+
self.data = self.root['data']
|
|
137
|
+
self.meta = self.root['meta']
|
|
138
|
+
self.episode_ends = self.meta['episode_ends'][:]
|
|
139
|
+
|
|
140
|
+
def __len__(self) -> int:
|
|
141
|
+
return len(self.episode_ends)
|
|
142
|
+
|
|
143
|
+
def _get_episode_slice(self, traj_id: int) -> Tuple[int, int]:
|
|
144
|
+
start = 0 if traj_id == 0 else int(self.episode_ends[traj_id - 1])
|
|
145
|
+
end = int(self.episode_ends[traj_id])
|
|
146
|
+
return start, end
|
|
147
|
+
|
|
148
|
+
def get_trajectory_length(self, traj_id: int) -> int:
|
|
149
|
+
start, end = self._get_episode_slice(traj_id)
|
|
150
|
+
return end - start
|
|
151
|
+
|
|
152
|
+
def get_step_data(
|
|
153
|
+
self,
|
|
154
|
+
traj_id: int,
|
|
155
|
+
step_idx: int,
|
|
156
|
+
state_keys: List[str],
|
|
157
|
+
image_keys: List[str],
|
|
158
|
+
) -> Dict[str, Any]:
|
|
159
|
+
start, end = self._get_episode_slice(traj_id)
|
|
160
|
+
global_idx = start + step_idx
|
|
161
|
+
|
|
162
|
+
obs = {"state": {}, "images": {}}
|
|
163
|
+
|
|
164
|
+
# Load state data
|
|
165
|
+
for key in state_keys:
|
|
166
|
+
if key in self.data:
|
|
167
|
+
obs["state"][key] = self.data[key][global_idx]
|
|
168
|
+
|
|
169
|
+
# Load image data
|
|
170
|
+
for key in image_keys:
|
|
171
|
+
# Try different naming conventions
|
|
172
|
+
for img_key in [key, f"image_{key}", f"img_{key}"]:
|
|
173
|
+
if img_key in self.data:
|
|
174
|
+
img = self.data[img_key][global_idx]
|
|
175
|
+
# Handle (C, H, W) -> (H, W, C)
|
|
176
|
+
if img.ndim == 3 and img.shape[0] in [1, 3]:
|
|
177
|
+
img = np.transpose(img, (1, 2, 0))
|
|
178
|
+
obs["images"][key] = img
|
|
179
|
+
break
|
|
180
|
+
|
|
181
|
+
return obs
|
|
182
|
+
|
|
183
|
+
def get_action(
|
|
184
|
+
self,
|
|
185
|
+
traj_id: int,
|
|
186
|
+
step_idx: int,
|
|
187
|
+
action_keys: List[str],
|
|
188
|
+
) -> np.ndarray:
|
|
189
|
+
start, _ = self._get_episode_slice(traj_id)
|
|
190
|
+
global_idx = start + step_idx
|
|
191
|
+
|
|
192
|
+
# For Zarr, action is typically stored as single array
|
|
193
|
+
return self.data['action'][global_idx]
|
|
194
|
+
|
|
195
|
+
def get_trajectory_actions(
|
|
196
|
+
self,
|
|
197
|
+
traj_id: int,
|
|
198
|
+
action_keys: List[str],
|
|
199
|
+
max_steps: Optional[int] = None,
|
|
200
|
+
) -> np.ndarray:
|
|
201
|
+
start, end = self._get_episode_slice(traj_id)
|
|
202
|
+
if max_steps:
|
|
203
|
+
end = min(end, start + max_steps)
|
|
204
|
+
return self.data['action'][start:end]
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def evaluate_trajectory(
|
|
208
|
+
policy: EvalPolicy,
|
|
209
|
+
dataset: DatasetLoader,
|
|
210
|
+
traj_id: int,
|
|
211
|
+
config: EvalConfig,
|
|
212
|
+
) -> EvalResult:
|
|
213
|
+
"""
|
|
214
|
+
Evaluate a policy on a single trajectory.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
policy: Policy adapter implementing EvalPolicy
|
|
218
|
+
dataset: Dataset loader
|
|
219
|
+
traj_id: Trajectory ID to evaluate
|
|
220
|
+
config: Evaluation configuration
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
EvalResult with metrics and action arrays
|
|
224
|
+
"""
|
|
225
|
+
modality = policy.get_modality_config()
|
|
226
|
+
action_horizon = config.action_horizon or modality.action_horizon
|
|
227
|
+
|
|
228
|
+
# Get trajectory length
|
|
229
|
+
traj_length = dataset.get_trajectory_length(traj_id)
|
|
230
|
+
actual_steps = min(config.max_steps, traj_length)
|
|
231
|
+
|
|
232
|
+
logging.info(f"Evaluating trajectory {traj_id}: {actual_steps} steps")
|
|
233
|
+
|
|
234
|
+
# Collect predicted actions
|
|
235
|
+
pred_actions_list = []
|
|
236
|
+
|
|
237
|
+
for step_idx in range(0, actual_steps, action_horizon):
|
|
238
|
+
logging.debug(f"Inferencing at step {step_idx}")
|
|
239
|
+
|
|
240
|
+
# Get observation
|
|
241
|
+
obs = dataset.get_step_data(
|
|
242
|
+
traj_id,
|
|
243
|
+
step_idx,
|
|
244
|
+
modality.state_keys,
|
|
245
|
+
modality.image_keys,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Get action from policy
|
|
249
|
+
action_chunk = policy.get_action(obs, config.task_description)
|
|
250
|
+
|
|
251
|
+
# Collect actions from chunk
|
|
252
|
+
for j in range(action_horizon):
|
|
253
|
+
if step_idx + j >= actual_steps:
|
|
254
|
+
break
|
|
255
|
+
if j < len(action_chunk):
|
|
256
|
+
pred_actions_list.append(action_chunk[j])
|
|
257
|
+
|
|
258
|
+
# Get ground truth actions
|
|
259
|
+
gt_actions = dataset.get_trajectory_actions(
|
|
260
|
+
traj_id,
|
|
261
|
+
modality.action_keys,
|
|
262
|
+
max_steps=actual_steps,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
pred_actions = np.array(pred_actions_list)[:actual_steps]
|
|
266
|
+
gt_actions = gt_actions[:actual_steps]
|
|
267
|
+
|
|
268
|
+
# Ensure shapes match
|
|
269
|
+
min_len = min(len(pred_actions), len(gt_actions))
|
|
270
|
+
pred_actions = pred_actions[:min_len]
|
|
271
|
+
gt_actions = gt_actions[:min_len]
|
|
272
|
+
|
|
273
|
+
# Handle dimension mismatch
|
|
274
|
+
if pred_actions.shape[-1] != gt_actions.shape[-1]:
|
|
275
|
+
# Take minimum dimension (common case: pred has extra dims)
|
|
276
|
+
min_dim = min(pred_actions.shape[-1], gt_actions.shape[-1])
|
|
277
|
+
pred_actions = pred_actions[..., :min_dim]
|
|
278
|
+
gt_actions = gt_actions[..., :min_dim]
|
|
279
|
+
logging.warning(
|
|
280
|
+
f"Dimension mismatch: pred {pred_actions.shape[-1]}, gt {gt_actions.shape[-1]}. "
|
|
281
|
+
f"Using first {min_dim} dims."
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# Calculate metrics
|
|
285
|
+
mse = float(np.mean((gt_actions - pred_actions) ** 2))
|
|
286
|
+
mae = float(np.mean(np.abs(gt_actions - pred_actions)))
|
|
287
|
+
|
|
288
|
+
logging.info(f"Trajectory {traj_id}: MSE={mse:.6f}, MAE={mae:.6f}")
|
|
289
|
+
|
|
290
|
+
return EvalResult(
|
|
291
|
+
trajectory_id=traj_id,
|
|
292
|
+
mse=mse,
|
|
293
|
+
mae=mae,
|
|
294
|
+
num_steps=min_len,
|
|
295
|
+
gt_actions=gt_actions,
|
|
296
|
+
pred_actions=pred_actions,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def plot_trajectory_results(
|
|
301
|
+
result: EvalResult,
|
|
302
|
+
action_keys: Optional[List[str]] = None,
|
|
303
|
+
action_horizon: int = 16,
|
|
304
|
+
save_path: Optional[str] = None,
|
|
305
|
+
show: bool = False,
|
|
306
|
+
) -> plt.Figure:
|
|
307
|
+
"""
|
|
308
|
+
Plot evaluation results comparing GT vs predicted actions.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
result: EvalResult from evaluate_trajectory
|
|
312
|
+
action_keys: Optional labels for action dimensions
|
|
313
|
+
action_horizon: Action horizon for marking inference points
|
|
314
|
+
save_path: Path to save plot (optional)
|
|
315
|
+
show: Whether to display plot
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
Matplotlib figure
|
|
319
|
+
"""
|
|
320
|
+
gt_actions = result.gt_actions
|
|
321
|
+
pred_actions = result.pred_actions
|
|
322
|
+
|
|
323
|
+
num_dims = gt_actions.shape[1]
|
|
324
|
+
actual_steps = len(gt_actions)
|
|
325
|
+
|
|
326
|
+
# Create figure
|
|
327
|
+
fig, axes = plt.subplots(
|
|
328
|
+
nrows=num_dims,
|
|
329
|
+
ncols=1,
|
|
330
|
+
figsize=(10, 3 * num_dims),
|
|
331
|
+
squeeze=False,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
fig.suptitle(
|
|
335
|
+
f"Trajectory {result.trajectory_id} | MSE: {result.mse:.4f} | MAE: {result.mae:.4f}",
|
|
336
|
+
fontsize=14,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
for i in range(num_dims):
|
|
340
|
+
ax = axes[i, 0]
|
|
341
|
+
|
|
342
|
+
# Plot GT and predicted
|
|
343
|
+
ax.plot(gt_actions[:, i], label="GT Action", alpha=0.8)
|
|
344
|
+
ax.plot(pred_actions[:, i], label="Pred Action", alpha=0.8, linestyle="--")
|
|
345
|
+
|
|
346
|
+
# Mark inference points
|
|
347
|
+
for j in range(0, actual_steps, action_horizon):
|
|
348
|
+
ax.axvline(x=j, color='gray', linestyle=':', alpha=0.3)
|
|
349
|
+
if j == 0:
|
|
350
|
+
ax.plot(j, gt_actions[j, i], "ro", markersize=4, label="Inference Point")
|
|
351
|
+
else:
|
|
352
|
+
ax.plot(j, gt_actions[j, i], "ro", markersize=4)
|
|
353
|
+
|
|
354
|
+
# Labels
|
|
355
|
+
dim_label = action_keys[i] if action_keys and i < len(action_keys) else f"Dim {i}"
|
|
356
|
+
ax.set_title(f"Action {dim_label}")
|
|
357
|
+
ax.set_xlabel("Step")
|
|
358
|
+
ax.set_ylabel("Value")
|
|
359
|
+
ax.legend(loc="upper right")
|
|
360
|
+
ax.grid(True, alpha=0.3)
|
|
361
|
+
|
|
362
|
+
plt.tight_layout()
|
|
363
|
+
|
|
364
|
+
if save_path:
|
|
365
|
+
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
|
366
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
367
|
+
logging.info(f"Plot saved to {save_path}")
|
|
368
|
+
|
|
369
|
+
if show:
|
|
370
|
+
plt.show()
|
|
371
|
+
|
|
372
|
+
return fig
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
class OpenLoopEvaluator:
|
|
376
|
+
"""
|
|
377
|
+
High-level evaluator for running open-loop evaluation.
|
|
378
|
+
|
|
379
|
+
Usage:
|
|
380
|
+
from vlalab.eval import OpenLoopEvaluator
|
|
381
|
+
from vlalab.eval.adapters import GR00TAdapter
|
|
382
|
+
|
|
383
|
+
# Create adapter
|
|
384
|
+
adapter = GR00TAdapter(policy)
|
|
385
|
+
|
|
386
|
+
# Create evaluator
|
|
387
|
+
evaluator = OpenLoopEvaluator(
|
|
388
|
+
policy=adapter,
|
|
389
|
+
dataset_path="/path/to/dataset.zarr",
|
|
390
|
+
dataset_format="zarr",
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# Evaluate
|
|
394
|
+
results = evaluator.evaluate(
|
|
395
|
+
traj_ids=[0, 1, 2],
|
|
396
|
+
max_steps=200,
|
|
397
|
+
save_plots_dir="outputs/",
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
print(f"Average MSE: {results['avg_mse']:.4f}")
|
|
401
|
+
"""
|
|
402
|
+
|
|
403
|
+
def __init__(
|
|
404
|
+
self,
|
|
405
|
+
policy: EvalPolicy,
|
|
406
|
+
dataset_path: str,
|
|
407
|
+
dataset_format: str = "zarr",
|
|
408
|
+
task_description: Optional[str] = None,
|
|
409
|
+
):
|
|
410
|
+
"""
|
|
411
|
+
Initialize evaluator.
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
policy: Policy adapter implementing EvalPolicy
|
|
415
|
+
dataset_path: Path to dataset
|
|
416
|
+
dataset_format: Dataset format ("zarr", "lerobot", etc.)
|
|
417
|
+
task_description: Default task description for language-conditioned models
|
|
418
|
+
"""
|
|
419
|
+
self.policy = policy
|
|
420
|
+
self.dataset_path = dataset_path
|
|
421
|
+
self.task_description = task_description
|
|
422
|
+
|
|
423
|
+
# Load dataset
|
|
424
|
+
if dataset_format == "zarr":
|
|
425
|
+
self.dataset = ZarrDatasetLoader(dataset_path)
|
|
426
|
+
else:
|
|
427
|
+
raise ValueError(f"Unsupported dataset format: {dataset_format}")
|
|
428
|
+
|
|
429
|
+
logging.info(f"Loaded dataset with {len(self.dataset)} trajectories")
|
|
430
|
+
|
|
431
|
+
def evaluate(
|
|
432
|
+
self,
|
|
433
|
+
traj_ids: Optional[List[int]] = None,
|
|
434
|
+
max_steps: int = 300,
|
|
435
|
+
action_horizon: Optional[int] = None,
|
|
436
|
+
save_plots_dir: Optional[str] = None,
|
|
437
|
+
task_description: Optional[str] = None,
|
|
438
|
+
) -> Dict[str, Any]:
|
|
439
|
+
"""
|
|
440
|
+
Run evaluation on specified trajectories.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
traj_ids: List of trajectory IDs to evaluate (default: [0])
|
|
444
|
+
max_steps: Maximum steps per trajectory
|
|
445
|
+
action_horizon: Action horizon (default: from policy config)
|
|
446
|
+
save_plots_dir: Directory to save plots (optional)
|
|
447
|
+
task_description: Override default task description
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
Dict with results:
|
|
451
|
+
- "results": List of EvalResult dicts
|
|
452
|
+
- "avg_mse": Average MSE across trajectories
|
|
453
|
+
- "avg_mae": Average MAE across trajectories
|
|
454
|
+
- "num_trajectories": Number of trajectories evaluated
|
|
455
|
+
"""
|
|
456
|
+
if traj_ids is None:
|
|
457
|
+
traj_ids = [0]
|
|
458
|
+
|
|
459
|
+
modality = self.policy.get_modality_config()
|
|
460
|
+
config = EvalConfig(
|
|
461
|
+
max_steps=max_steps,
|
|
462
|
+
action_horizon=action_horizon or modality.action_horizon,
|
|
463
|
+
task_description=task_description or self.task_description,
|
|
464
|
+
save_plot_path=None,
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
results = []
|
|
468
|
+
all_mse = []
|
|
469
|
+
all_mae = []
|
|
470
|
+
|
|
471
|
+
for traj_id in traj_ids:
|
|
472
|
+
if traj_id >= len(self.dataset):
|
|
473
|
+
logging.warning(f"Trajectory {traj_id} out of range, skipping")
|
|
474
|
+
continue
|
|
475
|
+
|
|
476
|
+
# Reset policy state
|
|
477
|
+
self.policy.reset()
|
|
478
|
+
|
|
479
|
+
# Evaluate
|
|
480
|
+
result = evaluate_trajectory(
|
|
481
|
+
self.policy,
|
|
482
|
+
self.dataset,
|
|
483
|
+
traj_id,
|
|
484
|
+
config,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
results.append(result)
|
|
488
|
+
all_mse.append(result.mse)
|
|
489
|
+
all_mae.append(result.mae)
|
|
490
|
+
|
|
491
|
+
# Save plot
|
|
492
|
+
if save_plots_dir:
|
|
493
|
+
plot_path = Path(save_plots_dir) / f"traj_{traj_id}.png"
|
|
494
|
+
plot_trajectory_results(
|
|
495
|
+
result,
|
|
496
|
+
action_keys=modality.action_keys,
|
|
497
|
+
action_horizon=config.action_horizon,
|
|
498
|
+
save_path=str(plot_path),
|
|
499
|
+
)
|
|
500
|
+
plt.close()
|
|
501
|
+
|
|
502
|
+
# Aggregate results
|
|
503
|
+
output = {
|
|
504
|
+
"results": [r.to_dict() for r in results],
|
|
505
|
+
"num_trajectories": len(results),
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
if all_mse:
|
|
509
|
+
output["avg_mse"] = float(np.mean(all_mse))
|
|
510
|
+
output["avg_mae"] = float(np.mean(all_mae))
|
|
511
|
+
logging.info(f"Average MSE: {output['avg_mse']:.6f}")
|
|
512
|
+
logging.info(f"Average MAE: {output['avg_mae']:.6f}")
|
|
513
|
+
|
|
514
|
+
return output
|
|
515
|
+
|
|
516
|
+
def evaluate_and_save(
|
|
517
|
+
self,
|
|
518
|
+
output_path: str,
|
|
519
|
+
**kwargs,
|
|
520
|
+
) -> Dict[str, Any]:
|
|
521
|
+
"""
|
|
522
|
+
Run evaluation and save results to JSON.
|
|
523
|
+
|
|
524
|
+
Args:
|
|
525
|
+
output_path: Path to save results JSON
|
|
526
|
+
**kwargs: Arguments passed to evaluate()
|
|
527
|
+
|
|
528
|
+
Returns:
|
|
529
|
+
Evaluation results dict
|
|
530
|
+
"""
|
|
531
|
+
results = self.evaluate(**kwargs)
|
|
532
|
+
|
|
533
|
+
# Save to JSON
|
|
534
|
+
output_path = Path(output_path)
|
|
535
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
536
|
+
|
|
537
|
+
with open(output_path, "w") as f:
|
|
538
|
+
json.dump(results, f, indent=2)
|
|
539
|
+
|
|
540
|
+
logging.info(f"Results saved to {output_path}")
|
|
541
|
+
|
|
542
|
+
return results
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""
|
|
2
|
+
VLA-Lab Unified Policy Interface
|
|
3
|
+
|
|
4
|
+
Defines a standard interface for VLA policies to enable model-agnostic evaluation.
|
|
5
|
+
Each model (GR00T, Diffusion Policy, OpenVLA, etc.) implements an adapter that
|
|
6
|
+
conforms to this interface.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class ModalityConfig:
|
|
17
|
+
"""
|
|
18
|
+
Configuration describing the modalities a policy expects/produces.
|
|
19
|
+
|
|
20
|
+
This enables the evaluator to understand what data to extract from
|
|
21
|
+
datasets and how to format inputs/outputs.
|
|
22
|
+
"""
|
|
23
|
+
# State modality keys (e.g., ["joint_position", "gripper_position"])
|
|
24
|
+
state_keys: List[str] = field(default_factory=list)
|
|
25
|
+
|
|
26
|
+
# Action modality keys (e.g., ["arm_action", "gripper_action"])
|
|
27
|
+
action_keys: List[str] = field(default_factory=list)
|
|
28
|
+
|
|
29
|
+
# Image/video modality keys (e.g., ["ego_view", "front_view"])
|
|
30
|
+
image_keys: List[str] = field(default_factory=list)
|
|
31
|
+
|
|
32
|
+
# Language modality keys (e.g., ["annotation.human.action.task_description"])
|
|
33
|
+
language_keys: List[str] = field(default_factory=list)
|
|
34
|
+
|
|
35
|
+
# Action horizon (number of future actions predicted)
|
|
36
|
+
action_horizon: int = 16
|
|
37
|
+
|
|
38
|
+
# Action dimension (total dim after concatenating all action keys)
|
|
39
|
+
action_dim: Optional[int] = None
|
|
40
|
+
|
|
41
|
+
# State dimension (total dim after concatenating all state keys)
|
|
42
|
+
state_dim: Optional[int] = None
|
|
43
|
+
|
|
44
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
45
|
+
return {
|
|
46
|
+
"state_keys": self.state_keys,
|
|
47
|
+
"action_keys": self.action_keys,
|
|
48
|
+
"image_keys": self.image_keys,
|
|
49
|
+
"language_keys": self.language_keys,
|
|
50
|
+
"action_horizon": self.action_horizon,
|
|
51
|
+
"action_dim": self.action_dim,
|
|
52
|
+
"state_dim": self.state_dim,
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def from_dict(cls, d: Dict[str, Any]) -> "ModalityConfig":
|
|
57
|
+
return cls(
|
|
58
|
+
state_keys=d.get("state_keys", []),
|
|
59
|
+
action_keys=d.get("action_keys", []),
|
|
60
|
+
image_keys=d.get("image_keys", []),
|
|
61
|
+
language_keys=d.get("language_keys", []),
|
|
62
|
+
action_horizon=d.get("action_horizon", 16),
|
|
63
|
+
action_dim=d.get("action_dim"),
|
|
64
|
+
state_dim=d.get("state_dim"),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class EvalPolicy(ABC):
|
|
69
|
+
"""
|
|
70
|
+
Abstract base class for VLA policies in evaluation mode.
|
|
71
|
+
|
|
72
|
+
Each model implementation (GR00T, DP, OpenVLA) provides an adapter
|
|
73
|
+
that wraps the actual policy and implements this interface.
|
|
74
|
+
|
|
75
|
+
The interface is designed to be minimal and model-agnostic:
|
|
76
|
+
- get_action(): Takes a standardized observation dict, returns action array
|
|
77
|
+
- get_modality_config(): Returns what modalities the policy expects
|
|
78
|
+
- reset(): Resets any internal state (for stateful policies)
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
@abstractmethod
|
|
82
|
+
def get_action(
|
|
83
|
+
self,
|
|
84
|
+
obs: Dict[str, Any],
|
|
85
|
+
task_description: Optional[str] = None,
|
|
86
|
+
) -> np.ndarray:
|
|
87
|
+
"""
|
|
88
|
+
Get action from the policy given an observation.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
obs: Standardized observation dictionary with keys:
|
|
92
|
+
- "state": Dict[str, np.ndarray] - state vectors by key
|
|
93
|
+
- "images": Dict[str, np.ndarray] - images by camera name (H, W, C)
|
|
94
|
+
task_description: Optional language instruction
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Action array of shape (action_horizon, action_dim)
|
|
98
|
+
The action_dim is the concatenation of all action modality keys.
|
|
99
|
+
"""
|
|
100
|
+
pass
|
|
101
|
+
|
|
102
|
+
@abstractmethod
|
|
103
|
+
def get_modality_config(self) -> ModalityConfig:
|
|
104
|
+
"""
|
|
105
|
+
Get the modality configuration for this policy.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
ModalityConfig describing expected inputs and outputs
|
|
109
|
+
"""
|
|
110
|
+
pass
|
|
111
|
+
|
|
112
|
+
def reset(self) -> None:
|
|
113
|
+
"""
|
|
114
|
+
Reset any internal state.
|
|
115
|
+
|
|
116
|
+
Override this for stateful policies (e.g., those with history buffers).
|
|
117
|
+
"""
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def action_horizon(self) -> int:
|
|
122
|
+
"""Convenience property for action horizon."""
|
|
123
|
+
return self.get_modality_config().action_horizon
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class DummyPolicy(EvalPolicy):
|
|
127
|
+
"""
|
|
128
|
+
A dummy policy for testing that returns random actions.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
action_dim: int = 8,
|
|
134
|
+
action_horizon: int = 16,
|
|
135
|
+
state_keys: Optional[List[str]] = None,
|
|
136
|
+
action_keys: Optional[List[str]] = None,
|
|
137
|
+
image_keys: Optional[List[str]] = None,
|
|
138
|
+
):
|
|
139
|
+
self._config = ModalityConfig(
|
|
140
|
+
state_keys=state_keys or ["joint_position"],
|
|
141
|
+
action_keys=action_keys or ["action"],
|
|
142
|
+
image_keys=image_keys or ["front"],
|
|
143
|
+
action_horizon=action_horizon,
|
|
144
|
+
action_dim=action_dim,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def get_action(
|
|
148
|
+
self,
|
|
149
|
+
obs: Dict[str, Any],
|
|
150
|
+
task_description: Optional[str] = None,
|
|
151
|
+
) -> np.ndarray:
|
|
152
|
+
return np.random.randn(self._config.action_horizon, self._config.action_dim)
|
|
153
|
+
|
|
154
|
+
def get_modality_config(self) -> ModalityConfig:
|
|
155
|
+
return self._config
|