vlalab 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,253 @@
1
+ """
2
+ GR00T Policy Adapter
3
+
4
+ Wraps NVIDIA GR00T policy to conform to the unified EvalPolicy interface.
5
+ This adapter handles the conversion between VLA-Lab's standardized observation
6
+ format and GR00T's specific input/output formats.
7
+ """
8
+
9
+ from typing import Any, Dict, List, Optional
10
+ import numpy as np
11
+
12
+ from vlalab.eval.policy_interface import EvalPolicy, ModalityConfig
13
+
14
+
15
+ def parse_observation_gr00t(
16
+ obs: Dict[str, Any],
17
+ modality_configs: Dict[str, Any],
18
+ ) -> Dict[str, Any]:
19
+ """
20
+ Convert standardized observation to GR00T's expected format.
21
+
22
+ GR00T expects observations in the format:
23
+ {
24
+ "video": {camera_name: array (1, T, H, W, C)},
25
+ "state": {state_key: array (1, T, D)},
26
+ "language": {lang_key: [[text]]}
27
+ }
28
+
29
+ Args:
30
+ obs: Standardized observation dict with:
31
+ - "state": Dict[str, np.ndarray] - state vectors
32
+ - "images": Dict[str, np.ndarray] - images (H, W, C)
33
+ - "task_description": Optional[str] - language instruction
34
+ modality_configs: GR00T modality configuration
35
+
36
+ Returns:
37
+ GR00T-formatted observation dict
38
+ """
39
+ new_obs = {
40
+ "video": {},
41
+ "state": {},
42
+ "language": {},
43
+ }
44
+
45
+ # Process state modalities
46
+ state_config = modality_configs.get("state", {})
47
+ state_keys = getattr(state_config, "modality_keys", []) if hasattr(state_config, "modality_keys") else state_config.get("modality_keys", [])
48
+
49
+ for key in state_keys:
50
+ if "state" in obs and key in obs["state"]:
51
+ arr = obs["state"][key]
52
+ # Add batch dimension: (D,) -> (1, 1, D) or (T, D) -> (1, T, D)
53
+ if arr.ndim == 1:
54
+ arr = arr[None, None, :]
55
+ elif arr.ndim == 2:
56
+ arr = arr[None, :]
57
+ new_obs["state"][key] = arr
58
+
59
+ # Process video/image modalities
60
+ video_config = modality_configs.get("video", {})
61
+ video_keys = getattr(video_config, "modality_keys", []) if hasattr(video_config, "modality_keys") else video_config.get("modality_keys", [])
62
+
63
+ for key in video_keys:
64
+ if "images" in obs and key in obs["images"]:
65
+ img = obs["images"][key]
66
+ # Add batch and time dimensions: (H, W, C) -> (1, 1, H, W, C)
67
+ if img.ndim == 3:
68
+ img = img[None, None, :]
69
+ elif img.ndim == 4:
70
+ img = img[None, :]
71
+ new_obs["video"][key] = img
72
+
73
+ # Process language modalities
74
+ lang_config = modality_configs.get("language", {})
75
+ lang_keys = getattr(lang_config, "modality_keys", []) if hasattr(lang_config, "modality_keys") else lang_config.get("modality_keys", [])
76
+
77
+ task_desc = obs.get("task_description", "")
78
+ for key in lang_keys:
79
+ new_obs["language"][key] = [[task_desc]]
80
+
81
+ return new_obs
82
+
83
+
84
+ def parse_action_gr00t(
85
+ action: Dict[str, Any],
86
+ action_keys: List[str],
87
+ ) -> np.ndarray:
88
+ """
89
+ Convert GR00T action output to standardized array format.
90
+
91
+ GR00T outputs actions in the format:
92
+ {"arm_action": array (T, D1), "gripper_action": array (T, D2), ...}
93
+
94
+ This function concatenates all action keys into a single array.
95
+
96
+ Args:
97
+ action: GR00T action dict
98
+ action_keys: List of action keys to concatenate
99
+
100
+ Returns:
101
+ Action array of shape (action_horizon, total_action_dim)
102
+ """
103
+ action_parts = []
104
+ for key in action_keys:
105
+ full_key = f"action.{key}" if not key.startswith("action.") else key
106
+ if full_key in action:
107
+ arr = action[full_key]
108
+ # Ensure 2D: (T, D)
109
+ arr = np.atleast_1d(arr)
110
+ if arr.ndim == 1:
111
+ arr = arr[:, None]
112
+ action_parts.append(arr)
113
+
114
+ if not action_parts:
115
+ raise ValueError(f"No action keys found. Available: {list(action.keys())}")
116
+
117
+ return np.concatenate(action_parts, axis=-1)
118
+
119
+
120
+ class GR00TAdapter(EvalPolicy):
121
+ """
122
+ Adapter for NVIDIA GR00T policy.
123
+
124
+ Usage:
125
+ from gr00t.policy.gr00t_policy import Gr00tPolicy
126
+ from gr00t.policy.server_client import PolicyClient
127
+
128
+ # Option 1: Wrap local policy
129
+ gr00t_policy = Gr00tPolicy(embodiment_tag=..., model_path=...)
130
+ adapter = GR00TAdapter(gr00t_policy)
131
+
132
+ # Option 2: Wrap remote policy client
133
+ client = PolicyClient(host="localhost", port=5555)
134
+ adapter = GR00TAdapter(client)
135
+
136
+ # Use with evaluator
137
+ action = adapter.get_action(obs, task_description="pick up the cube")
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ policy: Any,
143
+ embodiment_tag: Optional[str] = None,
144
+ ):
145
+ """
146
+ Initialize the GR00T adapter.
147
+
148
+ Args:
149
+ policy: GR00T policy instance (Gr00tPolicy or PolicyClient)
150
+ embodiment_tag: Optional embodiment tag (auto-detected if not provided)
151
+ """
152
+ self.policy = policy
153
+ self.embodiment_tag = embodiment_tag
154
+
155
+ # Get modality config from policy
156
+ self._raw_modality_config = self._get_raw_modality_config()
157
+ self._modality_config = self._build_modality_config()
158
+
159
+ def _get_raw_modality_config(self) -> Dict[str, Any]:
160
+ """Get raw modality config from the underlying policy."""
161
+ if hasattr(self.policy, "get_modality_config"):
162
+ config = self.policy.get_modality_config()
163
+ # If it's already a dict of ModalityConfig objects, convert
164
+ if isinstance(config, dict):
165
+ return config
166
+
167
+ if hasattr(self.policy, "modality_configs"):
168
+ return self.policy.modality_configs
169
+
170
+ # Fallback: return empty config
171
+ return {}
172
+
173
+ def _build_modality_config(self) -> ModalityConfig:
174
+ """Build VLA-Lab ModalityConfig from GR00T config."""
175
+ raw = self._raw_modality_config
176
+
177
+ # Extract keys from each modality
178
+ state_keys = []
179
+ action_keys = []
180
+ image_keys = []
181
+ language_keys = []
182
+ action_horizon = 16
183
+
184
+ if "state" in raw:
185
+ state_cfg = raw["state"]
186
+ state_keys = getattr(state_cfg, "modality_keys", []) if hasattr(state_cfg, "modality_keys") else state_cfg.get("modality_keys", [])
187
+
188
+ if "action" in raw:
189
+ action_cfg = raw["action"]
190
+ action_keys = getattr(action_cfg, "modality_keys", []) if hasattr(action_cfg, "modality_keys") else action_cfg.get("modality_keys", [])
191
+ # Get action horizon from delta_indices
192
+ delta_indices = getattr(action_cfg, "delta_indices", None) if hasattr(action_cfg, "delta_indices") else action_cfg.get("delta_indices")
193
+ if delta_indices is not None:
194
+ action_horizon = len(delta_indices)
195
+
196
+ if "video" in raw:
197
+ video_cfg = raw["video"]
198
+ image_keys = getattr(video_cfg, "modality_keys", []) if hasattr(video_cfg, "modality_keys") else video_cfg.get("modality_keys", [])
199
+
200
+ if "language" in raw:
201
+ lang_cfg = raw["language"]
202
+ language_keys = getattr(lang_cfg, "modality_keys", []) if hasattr(lang_cfg, "modality_keys") else lang_cfg.get("modality_keys", [])
203
+
204
+ return ModalityConfig(
205
+ state_keys=state_keys,
206
+ action_keys=action_keys,
207
+ image_keys=image_keys,
208
+ language_keys=language_keys,
209
+ action_horizon=action_horizon,
210
+ )
211
+
212
+ def get_action(
213
+ self,
214
+ obs: Dict[str, Any],
215
+ task_description: Optional[str] = None,
216
+ ) -> np.ndarray:
217
+ """
218
+ Get action from GR00T policy.
219
+
220
+ Args:
221
+ obs: Standardized observation dict
222
+ task_description: Language instruction
223
+
224
+ Returns:
225
+ Action array of shape (action_horizon, action_dim)
226
+ """
227
+ # Add task description to obs
228
+ obs_with_lang = dict(obs)
229
+ if task_description:
230
+ obs_with_lang["task_description"] = task_description
231
+
232
+ # Convert to GR00T format
233
+ gr00t_obs = parse_observation_gr00t(obs_with_lang, self._raw_modality_config)
234
+
235
+ # Get action from policy
236
+ action_dict, info = self.policy.get_action(gr00t_obs)
237
+
238
+ # Convert action to array
239
+ action_array = parse_action_gr00t(
240
+ action_dict,
241
+ self._modality_config.action_keys,
242
+ )
243
+
244
+ return action_array
245
+
246
+ def get_modality_config(self) -> ModalityConfig:
247
+ """Get modality configuration."""
248
+ return self._modality_config
249
+
250
+ def reset(self) -> None:
251
+ """Reset the policy."""
252
+ if hasattr(self.policy, "reset"):
253
+ self.policy.reset()