vlalab 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlalab/__init__.py +8 -1
- vlalab/apps/streamlit/app.py +310 -37
- vlalab/apps/streamlit/pages/eval_viewer.py +374 -0
- vlalab/cli.py +1 -1
- vlalab/eval/__init__.py +15 -0
- vlalab/eval/adapters/__init__.py +14 -0
- vlalab/eval/adapters/dp_adapter.py +279 -0
- vlalab/eval/adapters/groot_adapter.py +253 -0
- vlalab/eval/open_loop_eval.py +542 -0
- vlalab/eval/policy_interface.py +155 -0
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/METADATA +12 -70
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/RECORD +16 -9
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/WHEEL +0 -0
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/entry_points.txt +0 -0
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {vlalab-0.1.0.dist-info → vlalab-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GR00T Policy Adapter
|
|
3
|
+
|
|
4
|
+
Wraps NVIDIA GR00T policy to conform to the unified EvalPolicy interface.
|
|
5
|
+
This adapter handles the conversion between VLA-Lab's standardized observation
|
|
6
|
+
format and GR00T's specific input/output formats.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any, Dict, List, Optional
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from vlalab.eval.policy_interface import EvalPolicy, ModalityConfig
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def parse_observation_gr00t(
|
|
16
|
+
obs: Dict[str, Any],
|
|
17
|
+
modality_configs: Dict[str, Any],
|
|
18
|
+
) -> Dict[str, Any]:
|
|
19
|
+
"""
|
|
20
|
+
Convert standardized observation to GR00T's expected format.
|
|
21
|
+
|
|
22
|
+
GR00T expects observations in the format:
|
|
23
|
+
{
|
|
24
|
+
"video": {camera_name: array (1, T, H, W, C)},
|
|
25
|
+
"state": {state_key: array (1, T, D)},
|
|
26
|
+
"language": {lang_key: [[text]]}
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
obs: Standardized observation dict with:
|
|
31
|
+
- "state": Dict[str, np.ndarray] - state vectors
|
|
32
|
+
- "images": Dict[str, np.ndarray] - images (H, W, C)
|
|
33
|
+
- "task_description": Optional[str] - language instruction
|
|
34
|
+
modality_configs: GR00T modality configuration
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
GR00T-formatted observation dict
|
|
38
|
+
"""
|
|
39
|
+
new_obs = {
|
|
40
|
+
"video": {},
|
|
41
|
+
"state": {},
|
|
42
|
+
"language": {},
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
# Process state modalities
|
|
46
|
+
state_config = modality_configs.get("state", {})
|
|
47
|
+
state_keys = getattr(state_config, "modality_keys", []) if hasattr(state_config, "modality_keys") else state_config.get("modality_keys", [])
|
|
48
|
+
|
|
49
|
+
for key in state_keys:
|
|
50
|
+
if "state" in obs and key in obs["state"]:
|
|
51
|
+
arr = obs["state"][key]
|
|
52
|
+
# Add batch dimension: (D,) -> (1, 1, D) or (T, D) -> (1, T, D)
|
|
53
|
+
if arr.ndim == 1:
|
|
54
|
+
arr = arr[None, None, :]
|
|
55
|
+
elif arr.ndim == 2:
|
|
56
|
+
arr = arr[None, :]
|
|
57
|
+
new_obs["state"][key] = arr
|
|
58
|
+
|
|
59
|
+
# Process video/image modalities
|
|
60
|
+
video_config = modality_configs.get("video", {})
|
|
61
|
+
video_keys = getattr(video_config, "modality_keys", []) if hasattr(video_config, "modality_keys") else video_config.get("modality_keys", [])
|
|
62
|
+
|
|
63
|
+
for key in video_keys:
|
|
64
|
+
if "images" in obs and key in obs["images"]:
|
|
65
|
+
img = obs["images"][key]
|
|
66
|
+
# Add batch and time dimensions: (H, W, C) -> (1, 1, H, W, C)
|
|
67
|
+
if img.ndim == 3:
|
|
68
|
+
img = img[None, None, :]
|
|
69
|
+
elif img.ndim == 4:
|
|
70
|
+
img = img[None, :]
|
|
71
|
+
new_obs["video"][key] = img
|
|
72
|
+
|
|
73
|
+
# Process language modalities
|
|
74
|
+
lang_config = modality_configs.get("language", {})
|
|
75
|
+
lang_keys = getattr(lang_config, "modality_keys", []) if hasattr(lang_config, "modality_keys") else lang_config.get("modality_keys", [])
|
|
76
|
+
|
|
77
|
+
task_desc = obs.get("task_description", "")
|
|
78
|
+
for key in lang_keys:
|
|
79
|
+
new_obs["language"][key] = [[task_desc]]
|
|
80
|
+
|
|
81
|
+
return new_obs
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def parse_action_gr00t(
|
|
85
|
+
action: Dict[str, Any],
|
|
86
|
+
action_keys: List[str],
|
|
87
|
+
) -> np.ndarray:
|
|
88
|
+
"""
|
|
89
|
+
Convert GR00T action output to standardized array format.
|
|
90
|
+
|
|
91
|
+
GR00T outputs actions in the format:
|
|
92
|
+
{"arm_action": array (T, D1), "gripper_action": array (T, D2), ...}
|
|
93
|
+
|
|
94
|
+
This function concatenates all action keys into a single array.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
action: GR00T action dict
|
|
98
|
+
action_keys: List of action keys to concatenate
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Action array of shape (action_horizon, total_action_dim)
|
|
102
|
+
"""
|
|
103
|
+
action_parts = []
|
|
104
|
+
for key in action_keys:
|
|
105
|
+
full_key = f"action.{key}" if not key.startswith("action.") else key
|
|
106
|
+
if full_key in action:
|
|
107
|
+
arr = action[full_key]
|
|
108
|
+
# Ensure 2D: (T, D)
|
|
109
|
+
arr = np.atleast_1d(arr)
|
|
110
|
+
if arr.ndim == 1:
|
|
111
|
+
arr = arr[:, None]
|
|
112
|
+
action_parts.append(arr)
|
|
113
|
+
|
|
114
|
+
if not action_parts:
|
|
115
|
+
raise ValueError(f"No action keys found. Available: {list(action.keys())}")
|
|
116
|
+
|
|
117
|
+
return np.concatenate(action_parts, axis=-1)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class GR00TAdapter(EvalPolicy):
|
|
121
|
+
"""
|
|
122
|
+
Adapter for NVIDIA GR00T policy.
|
|
123
|
+
|
|
124
|
+
Usage:
|
|
125
|
+
from gr00t.policy.gr00t_policy import Gr00tPolicy
|
|
126
|
+
from gr00t.policy.server_client import PolicyClient
|
|
127
|
+
|
|
128
|
+
# Option 1: Wrap local policy
|
|
129
|
+
gr00t_policy = Gr00tPolicy(embodiment_tag=..., model_path=...)
|
|
130
|
+
adapter = GR00TAdapter(gr00t_policy)
|
|
131
|
+
|
|
132
|
+
# Option 2: Wrap remote policy client
|
|
133
|
+
client = PolicyClient(host="localhost", port=5555)
|
|
134
|
+
adapter = GR00TAdapter(client)
|
|
135
|
+
|
|
136
|
+
# Use with evaluator
|
|
137
|
+
action = adapter.get_action(obs, task_description="pick up the cube")
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
policy: Any,
|
|
143
|
+
embodiment_tag: Optional[str] = None,
|
|
144
|
+
):
|
|
145
|
+
"""
|
|
146
|
+
Initialize the GR00T adapter.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
policy: GR00T policy instance (Gr00tPolicy or PolicyClient)
|
|
150
|
+
embodiment_tag: Optional embodiment tag (auto-detected if not provided)
|
|
151
|
+
"""
|
|
152
|
+
self.policy = policy
|
|
153
|
+
self.embodiment_tag = embodiment_tag
|
|
154
|
+
|
|
155
|
+
# Get modality config from policy
|
|
156
|
+
self._raw_modality_config = self._get_raw_modality_config()
|
|
157
|
+
self._modality_config = self._build_modality_config()
|
|
158
|
+
|
|
159
|
+
def _get_raw_modality_config(self) -> Dict[str, Any]:
|
|
160
|
+
"""Get raw modality config from the underlying policy."""
|
|
161
|
+
if hasattr(self.policy, "get_modality_config"):
|
|
162
|
+
config = self.policy.get_modality_config()
|
|
163
|
+
# If it's already a dict of ModalityConfig objects, convert
|
|
164
|
+
if isinstance(config, dict):
|
|
165
|
+
return config
|
|
166
|
+
|
|
167
|
+
if hasattr(self.policy, "modality_configs"):
|
|
168
|
+
return self.policy.modality_configs
|
|
169
|
+
|
|
170
|
+
# Fallback: return empty config
|
|
171
|
+
return {}
|
|
172
|
+
|
|
173
|
+
def _build_modality_config(self) -> ModalityConfig:
|
|
174
|
+
"""Build VLA-Lab ModalityConfig from GR00T config."""
|
|
175
|
+
raw = self._raw_modality_config
|
|
176
|
+
|
|
177
|
+
# Extract keys from each modality
|
|
178
|
+
state_keys = []
|
|
179
|
+
action_keys = []
|
|
180
|
+
image_keys = []
|
|
181
|
+
language_keys = []
|
|
182
|
+
action_horizon = 16
|
|
183
|
+
|
|
184
|
+
if "state" in raw:
|
|
185
|
+
state_cfg = raw["state"]
|
|
186
|
+
state_keys = getattr(state_cfg, "modality_keys", []) if hasattr(state_cfg, "modality_keys") else state_cfg.get("modality_keys", [])
|
|
187
|
+
|
|
188
|
+
if "action" in raw:
|
|
189
|
+
action_cfg = raw["action"]
|
|
190
|
+
action_keys = getattr(action_cfg, "modality_keys", []) if hasattr(action_cfg, "modality_keys") else action_cfg.get("modality_keys", [])
|
|
191
|
+
# Get action horizon from delta_indices
|
|
192
|
+
delta_indices = getattr(action_cfg, "delta_indices", None) if hasattr(action_cfg, "delta_indices") else action_cfg.get("delta_indices")
|
|
193
|
+
if delta_indices is not None:
|
|
194
|
+
action_horizon = len(delta_indices)
|
|
195
|
+
|
|
196
|
+
if "video" in raw:
|
|
197
|
+
video_cfg = raw["video"]
|
|
198
|
+
image_keys = getattr(video_cfg, "modality_keys", []) if hasattr(video_cfg, "modality_keys") else video_cfg.get("modality_keys", [])
|
|
199
|
+
|
|
200
|
+
if "language" in raw:
|
|
201
|
+
lang_cfg = raw["language"]
|
|
202
|
+
language_keys = getattr(lang_cfg, "modality_keys", []) if hasattr(lang_cfg, "modality_keys") else lang_cfg.get("modality_keys", [])
|
|
203
|
+
|
|
204
|
+
return ModalityConfig(
|
|
205
|
+
state_keys=state_keys,
|
|
206
|
+
action_keys=action_keys,
|
|
207
|
+
image_keys=image_keys,
|
|
208
|
+
language_keys=language_keys,
|
|
209
|
+
action_horizon=action_horizon,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def get_action(
|
|
213
|
+
self,
|
|
214
|
+
obs: Dict[str, Any],
|
|
215
|
+
task_description: Optional[str] = None,
|
|
216
|
+
) -> np.ndarray:
|
|
217
|
+
"""
|
|
218
|
+
Get action from GR00T policy.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
obs: Standardized observation dict
|
|
222
|
+
task_description: Language instruction
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
Action array of shape (action_horizon, action_dim)
|
|
226
|
+
"""
|
|
227
|
+
# Add task description to obs
|
|
228
|
+
obs_with_lang = dict(obs)
|
|
229
|
+
if task_description:
|
|
230
|
+
obs_with_lang["task_description"] = task_description
|
|
231
|
+
|
|
232
|
+
# Convert to GR00T format
|
|
233
|
+
gr00t_obs = parse_observation_gr00t(obs_with_lang, self._raw_modality_config)
|
|
234
|
+
|
|
235
|
+
# Get action from policy
|
|
236
|
+
action_dict, info = self.policy.get_action(gr00t_obs)
|
|
237
|
+
|
|
238
|
+
# Convert action to array
|
|
239
|
+
action_array = parse_action_gr00t(
|
|
240
|
+
action_dict,
|
|
241
|
+
self._modality_config.action_keys,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
return action_array
|
|
245
|
+
|
|
246
|
+
def get_modality_config(self) -> ModalityConfig:
|
|
247
|
+
"""Get modality configuration."""
|
|
248
|
+
return self._modality_config
|
|
249
|
+
|
|
250
|
+
def reset(self) -> None:
|
|
251
|
+
"""Reset the policy."""
|
|
252
|
+
if hasattr(self.policy, "reset"):
|
|
253
|
+
self.policy.reset()
|