vlalab 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,542 @@
1
+ """
2
+ VLA-Lab Open-Loop Evaluation
3
+
4
+ Model-agnostic open-loop evaluation for VLA policies.
5
+ Compares predicted actions against ground-truth actions from a dataset.
6
+ """
7
+
8
+ from dataclasses import dataclass, field
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Optional, Tuple, Union
11
+ import logging
12
+ import json
13
+
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+
17
+ from vlalab.eval.policy_interface import EvalPolicy, ModalityConfig
18
+
19
+
20
+ @dataclass
21
+ class EvalResult:
22
+ """Result of evaluating a single trajectory."""
23
+ trajectory_id: int
24
+ mse: float
25
+ mae: float
26
+ num_steps: int
27
+ gt_actions: np.ndarray # (T, action_dim)
28
+ pred_actions: np.ndarray # (T, action_dim)
29
+ states: Optional[np.ndarray] = None # (T, state_dim)
30
+
31
+ def to_dict(self) -> Dict[str, Any]:
32
+ return {
33
+ "trajectory_id": self.trajectory_id,
34
+ "mse": float(self.mse),
35
+ "mae": float(self.mae),
36
+ "num_steps": self.num_steps,
37
+ }
38
+
39
+
40
+ @dataclass
41
+ class EvalConfig:
42
+ """Configuration for open-loop evaluation."""
43
+ max_steps: int = 300
44
+ action_horizon: int = 16
45
+ task_description: Optional[str] = None
46
+ save_plot_path: Optional[str] = None
47
+
48
+ def to_dict(self) -> Dict[str, Any]:
49
+ return {
50
+ "max_steps": self.max_steps,
51
+ "action_horizon": self.action_horizon,
52
+ "task_description": self.task_description,
53
+ "save_plot_path": self.save_plot_path,
54
+ }
55
+
56
+
57
+ class DatasetLoader:
58
+ """
59
+ Abstract interface for loading trajectory data from datasets.
60
+
61
+ Subclass this to support different dataset formats (Zarr, LeRobot, HDF5, etc.)
62
+ """
63
+
64
+ def __len__(self) -> int:
65
+ """Return number of trajectories."""
66
+ raise NotImplementedError
67
+
68
+ def get_trajectory_length(self, traj_id: int) -> int:
69
+ """Return length of a specific trajectory."""
70
+ raise NotImplementedError
71
+
72
+ def get_step_data(
73
+ self,
74
+ traj_id: int,
75
+ step_idx: int,
76
+ state_keys: List[str],
77
+ image_keys: List[str],
78
+ ) -> Dict[str, Any]:
79
+ """
80
+ Get observation data for a specific step.
81
+
82
+ Returns:
83
+ Dict with:
84
+ - "state": Dict[str, np.ndarray] - state vectors by key
85
+ - "images": Dict[str, np.ndarray] - images by camera name
86
+ """
87
+ raise NotImplementedError
88
+
89
+ def get_action(
90
+ self,
91
+ traj_id: int,
92
+ step_idx: int,
93
+ action_keys: List[str],
94
+ ) -> np.ndarray:
95
+ """
96
+ Get ground-truth action for a specific step.
97
+
98
+ Returns:
99
+ Action array of shape (action_dim,)
100
+ """
101
+ raise NotImplementedError
102
+
103
+ def get_trajectory_actions(
104
+ self,
105
+ traj_id: int,
106
+ action_keys: List[str],
107
+ max_steps: Optional[int] = None,
108
+ ) -> np.ndarray:
109
+ """
110
+ Get all ground-truth actions for a trajectory.
111
+
112
+ Returns:
113
+ Action array of shape (T, action_dim)
114
+ """
115
+ raise NotImplementedError
116
+
117
+
118
+ class ZarrDatasetLoader(DatasetLoader):
119
+ """
120
+ Loader for Zarr format datasets (used by Diffusion Policy, etc.)
121
+
122
+ Expected structure:
123
+ dataset.zarr/
124
+ ├── data/
125
+ │ ├── action (T, action_dim)
126
+ │ ├── state (T, state_dim) # or robot_state, etc.
127
+ │ └── image_* (T, H, W, C)
128
+ └── meta/
129
+ └── episode_ends (num_episodes,)
130
+ """
131
+
132
+ def __init__(self, zarr_path: str):
133
+ import zarr
134
+ self.zarr_path = Path(zarr_path)
135
+ self.root = zarr.open(str(zarr_path), mode='r')
136
+ self.data = self.root['data']
137
+ self.meta = self.root['meta']
138
+ self.episode_ends = self.meta['episode_ends'][:]
139
+
140
+ def __len__(self) -> int:
141
+ return len(self.episode_ends)
142
+
143
+ def _get_episode_slice(self, traj_id: int) -> Tuple[int, int]:
144
+ start = 0 if traj_id == 0 else int(self.episode_ends[traj_id - 1])
145
+ end = int(self.episode_ends[traj_id])
146
+ return start, end
147
+
148
+ def get_trajectory_length(self, traj_id: int) -> int:
149
+ start, end = self._get_episode_slice(traj_id)
150
+ return end - start
151
+
152
+ def get_step_data(
153
+ self,
154
+ traj_id: int,
155
+ step_idx: int,
156
+ state_keys: List[str],
157
+ image_keys: List[str],
158
+ ) -> Dict[str, Any]:
159
+ start, end = self._get_episode_slice(traj_id)
160
+ global_idx = start + step_idx
161
+
162
+ obs = {"state": {}, "images": {}}
163
+
164
+ # Load state data
165
+ for key in state_keys:
166
+ if key in self.data:
167
+ obs["state"][key] = self.data[key][global_idx]
168
+
169
+ # Load image data
170
+ for key in image_keys:
171
+ # Try different naming conventions
172
+ for img_key in [key, f"image_{key}", f"img_{key}"]:
173
+ if img_key in self.data:
174
+ img = self.data[img_key][global_idx]
175
+ # Handle (C, H, W) -> (H, W, C)
176
+ if img.ndim == 3 and img.shape[0] in [1, 3]:
177
+ img = np.transpose(img, (1, 2, 0))
178
+ obs["images"][key] = img
179
+ break
180
+
181
+ return obs
182
+
183
+ def get_action(
184
+ self,
185
+ traj_id: int,
186
+ step_idx: int,
187
+ action_keys: List[str],
188
+ ) -> np.ndarray:
189
+ start, _ = self._get_episode_slice(traj_id)
190
+ global_idx = start + step_idx
191
+
192
+ # For Zarr, action is typically stored as single array
193
+ return self.data['action'][global_idx]
194
+
195
+ def get_trajectory_actions(
196
+ self,
197
+ traj_id: int,
198
+ action_keys: List[str],
199
+ max_steps: Optional[int] = None,
200
+ ) -> np.ndarray:
201
+ start, end = self._get_episode_slice(traj_id)
202
+ if max_steps:
203
+ end = min(end, start + max_steps)
204
+ return self.data['action'][start:end]
205
+
206
+
207
+ def evaluate_trajectory(
208
+ policy: EvalPolicy,
209
+ dataset: DatasetLoader,
210
+ traj_id: int,
211
+ config: EvalConfig,
212
+ ) -> EvalResult:
213
+ """
214
+ Evaluate a policy on a single trajectory.
215
+
216
+ Args:
217
+ policy: Policy adapter implementing EvalPolicy
218
+ dataset: Dataset loader
219
+ traj_id: Trajectory ID to evaluate
220
+ config: Evaluation configuration
221
+
222
+ Returns:
223
+ EvalResult with metrics and action arrays
224
+ """
225
+ modality = policy.get_modality_config()
226
+ action_horizon = config.action_horizon or modality.action_horizon
227
+
228
+ # Get trajectory length
229
+ traj_length = dataset.get_trajectory_length(traj_id)
230
+ actual_steps = min(config.max_steps, traj_length)
231
+
232
+ logging.info(f"Evaluating trajectory {traj_id}: {actual_steps} steps")
233
+
234
+ # Collect predicted actions
235
+ pred_actions_list = []
236
+
237
+ for step_idx in range(0, actual_steps, action_horizon):
238
+ logging.debug(f"Inferencing at step {step_idx}")
239
+
240
+ # Get observation
241
+ obs = dataset.get_step_data(
242
+ traj_id,
243
+ step_idx,
244
+ modality.state_keys,
245
+ modality.image_keys,
246
+ )
247
+
248
+ # Get action from policy
249
+ action_chunk = policy.get_action(obs, config.task_description)
250
+
251
+ # Collect actions from chunk
252
+ for j in range(action_horizon):
253
+ if step_idx + j >= actual_steps:
254
+ break
255
+ if j < len(action_chunk):
256
+ pred_actions_list.append(action_chunk[j])
257
+
258
+ # Get ground truth actions
259
+ gt_actions = dataset.get_trajectory_actions(
260
+ traj_id,
261
+ modality.action_keys,
262
+ max_steps=actual_steps,
263
+ )
264
+
265
+ pred_actions = np.array(pred_actions_list)[:actual_steps]
266
+ gt_actions = gt_actions[:actual_steps]
267
+
268
+ # Ensure shapes match
269
+ min_len = min(len(pred_actions), len(gt_actions))
270
+ pred_actions = pred_actions[:min_len]
271
+ gt_actions = gt_actions[:min_len]
272
+
273
+ # Handle dimension mismatch
274
+ if pred_actions.shape[-1] != gt_actions.shape[-1]:
275
+ # Take minimum dimension (common case: pred has extra dims)
276
+ min_dim = min(pred_actions.shape[-1], gt_actions.shape[-1])
277
+ pred_actions = pred_actions[..., :min_dim]
278
+ gt_actions = gt_actions[..., :min_dim]
279
+ logging.warning(
280
+ f"Dimension mismatch: pred {pred_actions.shape[-1]}, gt {gt_actions.shape[-1]}. "
281
+ f"Using first {min_dim} dims."
282
+ )
283
+
284
+ # Calculate metrics
285
+ mse = float(np.mean((gt_actions - pred_actions) ** 2))
286
+ mae = float(np.mean(np.abs(gt_actions - pred_actions)))
287
+
288
+ logging.info(f"Trajectory {traj_id}: MSE={mse:.6f}, MAE={mae:.6f}")
289
+
290
+ return EvalResult(
291
+ trajectory_id=traj_id,
292
+ mse=mse,
293
+ mae=mae,
294
+ num_steps=min_len,
295
+ gt_actions=gt_actions,
296
+ pred_actions=pred_actions,
297
+ )
298
+
299
+
300
+ def plot_trajectory_results(
301
+ result: EvalResult,
302
+ action_keys: Optional[List[str]] = None,
303
+ action_horizon: int = 16,
304
+ save_path: Optional[str] = None,
305
+ show: bool = False,
306
+ ) -> plt.Figure:
307
+ """
308
+ Plot evaluation results comparing GT vs predicted actions.
309
+
310
+ Args:
311
+ result: EvalResult from evaluate_trajectory
312
+ action_keys: Optional labels for action dimensions
313
+ action_horizon: Action horizon for marking inference points
314
+ save_path: Path to save plot (optional)
315
+ show: Whether to display plot
316
+
317
+ Returns:
318
+ Matplotlib figure
319
+ """
320
+ gt_actions = result.gt_actions
321
+ pred_actions = result.pred_actions
322
+
323
+ num_dims = gt_actions.shape[1]
324
+ actual_steps = len(gt_actions)
325
+
326
+ # Create figure
327
+ fig, axes = plt.subplots(
328
+ nrows=num_dims,
329
+ ncols=1,
330
+ figsize=(10, 3 * num_dims),
331
+ squeeze=False,
332
+ )
333
+
334
+ fig.suptitle(
335
+ f"Trajectory {result.trajectory_id} | MSE: {result.mse:.4f} | MAE: {result.mae:.4f}",
336
+ fontsize=14,
337
+ )
338
+
339
+ for i in range(num_dims):
340
+ ax = axes[i, 0]
341
+
342
+ # Plot GT and predicted
343
+ ax.plot(gt_actions[:, i], label="GT Action", alpha=0.8)
344
+ ax.plot(pred_actions[:, i], label="Pred Action", alpha=0.8, linestyle="--")
345
+
346
+ # Mark inference points
347
+ for j in range(0, actual_steps, action_horizon):
348
+ ax.axvline(x=j, color='gray', linestyle=':', alpha=0.3)
349
+ if j == 0:
350
+ ax.plot(j, gt_actions[j, i], "ro", markersize=4, label="Inference Point")
351
+ else:
352
+ ax.plot(j, gt_actions[j, i], "ro", markersize=4)
353
+
354
+ # Labels
355
+ dim_label = action_keys[i] if action_keys and i < len(action_keys) else f"Dim {i}"
356
+ ax.set_title(f"Action {dim_label}")
357
+ ax.set_xlabel("Step")
358
+ ax.set_ylabel("Value")
359
+ ax.legend(loc="upper right")
360
+ ax.grid(True, alpha=0.3)
361
+
362
+ plt.tight_layout()
363
+
364
+ if save_path:
365
+ Path(save_path).parent.mkdir(parents=True, exist_ok=True)
366
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
367
+ logging.info(f"Plot saved to {save_path}")
368
+
369
+ if show:
370
+ plt.show()
371
+
372
+ return fig
373
+
374
+
375
+ class OpenLoopEvaluator:
376
+ """
377
+ High-level evaluator for running open-loop evaluation.
378
+
379
+ Usage:
380
+ from vlalab.eval import OpenLoopEvaluator
381
+ from vlalab.eval.adapters import GR00TAdapter
382
+
383
+ # Create adapter
384
+ adapter = GR00TAdapter(policy)
385
+
386
+ # Create evaluator
387
+ evaluator = OpenLoopEvaluator(
388
+ policy=adapter,
389
+ dataset_path="/path/to/dataset.zarr",
390
+ dataset_format="zarr",
391
+ )
392
+
393
+ # Evaluate
394
+ results = evaluator.evaluate(
395
+ traj_ids=[0, 1, 2],
396
+ max_steps=200,
397
+ save_plots_dir="outputs/",
398
+ )
399
+
400
+ print(f"Average MSE: {results['avg_mse']:.4f}")
401
+ """
402
+
403
+ def __init__(
404
+ self,
405
+ policy: EvalPolicy,
406
+ dataset_path: str,
407
+ dataset_format: str = "zarr",
408
+ task_description: Optional[str] = None,
409
+ ):
410
+ """
411
+ Initialize evaluator.
412
+
413
+ Args:
414
+ policy: Policy adapter implementing EvalPolicy
415
+ dataset_path: Path to dataset
416
+ dataset_format: Dataset format ("zarr", "lerobot", etc.)
417
+ task_description: Default task description for language-conditioned models
418
+ """
419
+ self.policy = policy
420
+ self.dataset_path = dataset_path
421
+ self.task_description = task_description
422
+
423
+ # Load dataset
424
+ if dataset_format == "zarr":
425
+ self.dataset = ZarrDatasetLoader(dataset_path)
426
+ else:
427
+ raise ValueError(f"Unsupported dataset format: {dataset_format}")
428
+
429
+ logging.info(f"Loaded dataset with {len(self.dataset)} trajectories")
430
+
431
+ def evaluate(
432
+ self,
433
+ traj_ids: Optional[List[int]] = None,
434
+ max_steps: int = 300,
435
+ action_horizon: Optional[int] = None,
436
+ save_plots_dir: Optional[str] = None,
437
+ task_description: Optional[str] = None,
438
+ ) -> Dict[str, Any]:
439
+ """
440
+ Run evaluation on specified trajectories.
441
+
442
+ Args:
443
+ traj_ids: List of trajectory IDs to evaluate (default: [0])
444
+ max_steps: Maximum steps per trajectory
445
+ action_horizon: Action horizon (default: from policy config)
446
+ save_plots_dir: Directory to save plots (optional)
447
+ task_description: Override default task description
448
+
449
+ Returns:
450
+ Dict with results:
451
+ - "results": List of EvalResult dicts
452
+ - "avg_mse": Average MSE across trajectories
453
+ - "avg_mae": Average MAE across trajectories
454
+ - "num_trajectories": Number of trajectories evaluated
455
+ """
456
+ if traj_ids is None:
457
+ traj_ids = [0]
458
+
459
+ modality = self.policy.get_modality_config()
460
+ config = EvalConfig(
461
+ max_steps=max_steps,
462
+ action_horizon=action_horizon or modality.action_horizon,
463
+ task_description=task_description or self.task_description,
464
+ save_plot_path=None,
465
+ )
466
+
467
+ results = []
468
+ all_mse = []
469
+ all_mae = []
470
+
471
+ for traj_id in traj_ids:
472
+ if traj_id >= len(self.dataset):
473
+ logging.warning(f"Trajectory {traj_id} out of range, skipping")
474
+ continue
475
+
476
+ # Reset policy state
477
+ self.policy.reset()
478
+
479
+ # Evaluate
480
+ result = evaluate_trajectory(
481
+ self.policy,
482
+ self.dataset,
483
+ traj_id,
484
+ config,
485
+ )
486
+
487
+ results.append(result)
488
+ all_mse.append(result.mse)
489
+ all_mae.append(result.mae)
490
+
491
+ # Save plot
492
+ if save_plots_dir:
493
+ plot_path = Path(save_plots_dir) / f"traj_{traj_id}.png"
494
+ plot_trajectory_results(
495
+ result,
496
+ action_keys=modality.action_keys,
497
+ action_horizon=config.action_horizon,
498
+ save_path=str(plot_path),
499
+ )
500
+ plt.close()
501
+
502
+ # Aggregate results
503
+ output = {
504
+ "results": [r.to_dict() for r in results],
505
+ "num_trajectories": len(results),
506
+ }
507
+
508
+ if all_mse:
509
+ output["avg_mse"] = float(np.mean(all_mse))
510
+ output["avg_mae"] = float(np.mean(all_mae))
511
+ logging.info(f"Average MSE: {output['avg_mse']:.6f}")
512
+ logging.info(f"Average MAE: {output['avg_mae']:.6f}")
513
+
514
+ return output
515
+
516
+ def evaluate_and_save(
517
+ self,
518
+ output_path: str,
519
+ **kwargs,
520
+ ) -> Dict[str, Any]:
521
+ """
522
+ Run evaluation and save results to JSON.
523
+
524
+ Args:
525
+ output_path: Path to save results JSON
526
+ **kwargs: Arguments passed to evaluate()
527
+
528
+ Returns:
529
+ Evaluation results dict
530
+ """
531
+ results = self.evaluate(**kwargs)
532
+
533
+ # Save to JSON
534
+ output_path = Path(output_path)
535
+ output_path.parent.mkdir(parents=True, exist_ok=True)
536
+
537
+ with open(output_path, "w") as f:
538
+ json.dump(results, f, indent=2)
539
+
540
+ logging.info(f"Results saved to {output_path}")
541
+
542
+ return results
@@ -0,0 +1,155 @@
1
+ """
2
+ VLA-Lab Unified Policy Interface
3
+
4
+ Defines a standard interface for VLA policies to enable model-agnostic evaluation.
5
+ Each model (GR00T, Diffusion Policy, OpenVLA, etc.) implements an adapter that
6
+ conforms to this interface.
7
+ """
8
+
9
+ from abc import ABC, abstractmethod
10
+ from dataclasses import dataclass, field
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+ import numpy as np
13
+
14
+
15
+ @dataclass
16
+ class ModalityConfig:
17
+ """
18
+ Configuration describing the modalities a policy expects/produces.
19
+
20
+ This enables the evaluator to understand what data to extract from
21
+ datasets and how to format inputs/outputs.
22
+ """
23
+ # State modality keys (e.g., ["joint_position", "gripper_position"])
24
+ state_keys: List[str] = field(default_factory=list)
25
+
26
+ # Action modality keys (e.g., ["arm_action", "gripper_action"])
27
+ action_keys: List[str] = field(default_factory=list)
28
+
29
+ # Image/video modality keys (e.g., ["ego_view", "front_view"])
30
+ image_keys: List[str] = field(default_factory=list)
31
+
32
+ # Language modality keys (e.g., ["annotation.human.action.task_description"])
33
+ language_keys: List[str] = field(default_factory=list)
34
+
35
+ # Action horizon (number of future actions predicted)
36
+ action_horizon: int = 16
37
+
38
+ # Action dimension (total dim after concatenating all action keys)
39
+ action_dim: Optional[int] = None
40
+
41
+ # State dimension (total dim after concatenating all state keys)
42
+ state_dim: Optional[int] = None
43
+
44
+ def to_dict(self) -> Dict[str, Any]:
45
+ return {
46
+ "state_keys": self.state_keys,
47
+ "action_keys": self.action_keys,
48
+ "image_keys": self.image_keys,
49
+ "language_keys": self.language_keys,
50
+ "action_horizon": self.action_horizon,
51
+ "action_dim": self.action_dim,
52
+ "state_dim": self.state_dim,
53
+ }
54
+
55
+ @classmethod
56
+ def from_dict(cls, d: Dict[str, Any]) -> "ModalityConfig":
57
+ return cls(
58
+ state_keys=d.get("state_keys", []),
59
+ action_keys=d.get("action_keys", []),
60
+ image_keys=d.get("image_keys", []),
61
+ language_keys=d.get("language_keys", []),
62
+ action_horizon=d.get("action_horizon", 16),
63
+ action_dim=d.get("action_dim"),
64
+ state_dim=d.get("state_dim"),
65
+ )
66
+
67
+
68
+ class EvalPolicy(ABC):
69
+ """
70
+ Abstract base class for VLA policies in evaluation mode.
71
+
72
+ Each model implementation (GR00T, DP, OpenVLA) provides an adapter
73
+ that wraps the actual policy and implements this interface.
74
+
75
+ The interface is designed to be minimal and model-agnostic:
76
+ - get_action(): Takes a standardized observation dict, returns action array
77
+ - get_modality_config(): Returns what modalities the policy expects
78
+ - reset(): Resets any internal state (for stateful policies)
79
+ """
80
+
81
+ @abstractmethod
82
+ def get_action(
83
+ self,
84
+ obs: Dict[str, Any],
85
+ task_description: Optional[str] = None,
86
+ ) -> np.ndarray:
87
+ """
88
+ Get action from the policy given an observation.
89
+
90
+ Args:
91
+ obs: Standardized observation dictionary with keys:
92
+ - "state": Dict[str, np.ndarray] - state vectors by key
93
+ - "images": Dict[str, np.ndarray] - images by camera name (H, W, C)
94
+ task_description: Optional language instruction
95
+
96
+ Returns:
97
+ Action array of shape (action_horizon, action_dim)
98
+ The action_dim is the concatenation of all action modality keys.
99
+ """
100
+ pass
101
+
102
+ @abstractmethod
103
+ def get_modality_config(self) -> ModalityConfig:
104
+ """
105
+ Get the modality configuration for this policy.
106
+
107
+ Returns:
108
+ ModalityConfig describing expected inputs and outputs
109
+ """
110
+ pass
111
+
112
+ def reset(self) -> None:
113
+ """
114
+ Reset any internal state.
115
+
116
+ Override this for stateful policies (e.g., those with history buffers).
117
+ """
118
+ pass
119
+
120
+ @property
121
+ def action_horizon(self) -> int:
122
+ """Convenience property for action horizon."""
123
+ return self.get_modality_config().action_horizon
124
+
125
+
126
+ class DummyPolicy(EvalPolicy):
127
+ """
128
+ A dummy policy for testing that returns random actions.
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ action_dim: int = 8,
134
+ action_horizon: int = 16,
135
+ state_keys: Optional[List[str]] = None,
136
+ action_keys: Optional[List[str]] = None,
137
+ image_keys: Optional[List[str]] = None,
138
+ ):
139
+ self._config = ModalityConfig(
140
+ state_keys=state_keys or ["joint_position"],
141
+ action_keys=action_keys or ["action"],
142
+ image_keys=image_keys or ["front"],
143
+ action_horizon=action_horizon,
144
+ action_dim=action_dim,
145
+ )
146
+
147
+ def get_action(
148
+ self,
149
+ obs: Dict[str, Any],
150
+ task_description: Optional[str] = None,
151
+ ) -> np.ndarray:
152
+ return np.random.randn(self._config.action_horizon, self._config.action_dim)
153
+
154
+ def get_modality_config(self) -> ModalityConfig:
155
+ return self._config