opentau 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
opentau/__init__.py CHANGED
@@ -56,6 +56,7 @@ When implementing a new policy class (e.g., `DiffusionPolicy`), follow these ste
56
56
  import itertools
57
57
 
58
58
  from opentau.__version__ import __version__ # noqa: F401
59
+ from opentau.utils import transformers_patch # noqa: F401
59
60
 
60
61
  # TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
61
62
  # refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
@@ -633,7 +633,9 @@ class BaseDataset(torch.utils.data.Dataset):
633
633
  For example, {"image_key": torch.zeros(2, 3, 224, 224), "image_key_is_pad": [False, True] } will become
634
634
  {
635
635
  "image_key": torch.zeros(3, 224, 224),
636
+ "image_key_local": torch.zeros(3, 224, 224),
636
637
  "image_key_is_pad: False,
638
+ "image_key_local_is_pad": True,
637
639
  }.
638
640
  """
639
641
  raise NotImplementedError
@@ -1787,16 +1789,12 @@ class LeRobotDataset(BaseDataset):
1787
1789
  cam_keys = {v for k, v in name_map.items() if k.startswith("camera")}
1788
1790
  for k in cam_keys:
1789
1791
  images = item.pop(k)
1790
- assert len(images) == 2, (
1791
- f"{k} in {self.__class__} is expected to have length 2, got shape={images.shape}"
1792
- )
1793
- item[k + "_local"], item[k] = images
1792
+ if len(images) == 2:
1793
+ item[k + "_local"], item[k] = images
1794
1794
 
1795
- pads = item.pop(k + "_is_pad")
1796
- assert len(pads) == 2, (
1797
- f"{k} in {self.__class__} is expected to have length 2, got shape={pads.shape}"
1798
- )
1799
- item[k + "_local_is_pad"], item[k + "_is_pad"] = pads
1795
+ pads = item.get(k + "_is_pad")
1796
+ if hasattr(pads, "__len__") and len(pads) == 2:
1797
+ item[k + "_local_is_pad"], item[k + "_is_pad"] = pads
1800
1798
 
1801
1799
  @staticmethod
1802
1800
  def compute_delta_params(
@@ -0,0 +1,84 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import subprocess
17
+ import sys
18
+ from pathlib import Path
19
+ from types import ModuleType
20
+
21
+ import opentau.scripts.eval as eval_script
22
+ import opentau.scripts.export_to_onnx as export_script
23
+ import opentau.scripts.train as train_script
24
+ import opentau.scripts.visualize_dataset as visualize_script
25
+
26
+
27
+ def launch(script_module: ModuleType, description: str, use_accelerate: bool = True):
28
+ """Generic launcher for OpenTau scripts using Accelerate or Python."""
29
+ parser = argparse.ArgumentParser(
30
+ description=description,
31
+ usage=f"{Path(sys.argv[0]).name} {'[--accelerate-config CONFIG] ' if use_accelerate else ''}[ARGS]",
32
+ )
33
+ if use_accelerate:
34
+ parser.add_argument(
35
+ "--accelerate-config", type=str, help="Path to accelerate config file (yaml)", default=None
36
+ )
37
+
38
+ # We use parse_known_args so that all other arguments are collected
39
+ # These will be passed to the target script
40
+ args, unknown_args = parser.parse_known_args()
41
+
42
+ # Base command
43
+ if use_accelerate:
44
+ cmd = ["accelerate", "launch"]
45
+ # Add accelerate config if provided
46
+ if args.accelerate_config:
47
+ cmd.extend(["--config_file", args.accelerate_config])
48
+ else:
49
+ cmd = [sys.executable]
50
+
51
+ # Add the path to the target script
52
+ # We resolve the path to ensure it's absolute
53
+ script_path = Path(script_module.__file__).resolve()
54
+ cmd.append(str(script_path))
55
+
56
+ # Add all other arguments (passed to the target script)
57
+ cmd.extend(unknown_args)
58
+
59
+ # Print the command for transparency
60
+ print(f"Executing: {' '.join(cmd)}")
61
+
62
+ # Replace the current process with the accelerate launch command
63
+ try:
64
+ subprocess.run(cmd, check=True)
65
+ except subprocess.CalledProcessError as e:
66
+ sys.exit(e.returncode)
67
+ except KeyboardInterrupt:
68
+ sys.exit(130)
69
+
70
+
71
+ def train():
72
+ launch(train_script, "Launch OpenTau training with Accelerate")
73
+
74
+
75
+ def eval():
76
+ launch(eval_script, "Launch OpenTau evaluation with Accelerate")
77
+
78
+
79
+ def export():
80
+ launch(export_script, "Launch OpenTau ONNX export", use_accelerate=False)
81
+
82
+
83
+ def visualize():
84
+ launch(visualize_script, "Launch OpenTau visualization", use_accelerate=False)
opentau/scripts/train.py CHANGED
@@ -73,16 +73,16 @@ def update_policy(
73
73
  train_config.loss_weighting["MSE"] * losses["MSE"] + train_config.loss_weighting["CE"] * losses["CE"]
74
74
  )
75
75
 
76
- # accelerator.backward(loss)
77
- # accelerator.unscale_gradients(optimizer=optimizer)
76
+ accelerator.backward(loss)
77
+ accelerator.unscale_gradients(optimizer=optimizer)
78
78
 
79
- # if accelerator.sync_gradients:
80
- # grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
81
- # if accelerator.is_main_process:
82
- # train_metrics.grad_norm = grad_norm
79
+ if accelerator.sync_gradients:
80
+ grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
81
+ if accelerator.is_main_process:
82
+ train_metrics.grad_norm = grad_norm
83
83
 
84
- # optimizer.step()
85
- # optimizer.zero_grad()
84
+ optimizer.step()
85
+ optimizer.zero_grad()
86
86
 
87
87
  # Step through pytorch scheduler at every batch instead of epoch
88
88
  if lr_scheduler is not None:
@@ -14,7 +14,7 @@
14
14
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
15
  # See the License for the specific language governing permissions and
16
16
  # limitations under the License.
17
- """ Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
17
+ """Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
18
18
 
19
19
  Note: The last frame of the episode doesn't always correspond to a final state.
20
20
  That's because our datasets are composed of transition from state to state up to
@@ -30,34 +30,21 @@ Examples:
30
30
 
31
31
  - Visualize data stored on a local machine:
32
32
  ```
33
- local$ python src/opentau/scripts/visualize_dataset.py \
34
- --repo-id lerobot/pusht \
35
- --episode-index 0
33
+ local$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0
36
34
  ```
37
35
 
38
36
  - Visualize data stored on a distant machine with a local viewer:
39
37
  ```
40
- distant$ python src/opentau/scripts/visualize_dataset.py \
41
- --repo-id lerobot/pusht \
42
- --episode-index 0 \
43
- --save 1 \
44
- --output-dir path/to/directory
38
+ distant$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0 --save 1 --output-dir path/to/directory
45
39
 
46
40
  local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
47
41
  local$ rerun lerobot_pusht_episode_0.rrd
48
42
  ```
49
43
 
50
44
  - Visualize data stored on a distant machine through streaming:
51
- (You need to forward the websocket port to the distant machine, with
52
- `ssh -L 9087:localhost:9087 username@remote-host`)
53
45
  ```
54
- distant$ python src/opentau/scripts/visualize_dataset.py \
55
- --repo-id lerobot/pusht \
56
- --episode-index 0 \
57
- --mode distant \
58
- --ws-port 9087
59
46
 
60
- local$ rerun ws://localhost:9087
47
+ distant$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0 --mode distant --web-port 9090
61
48
  ```
62
49
 
63
50
  """
@@ -75,8 +62,34 @@ import torch
75
62
  import torch.utils.data
76
63
  import tqdm
77
64
 
65
+ from opentau.configs.default import DatasetMixtureConfig, WandBConfig
66
+ from opentau.configs.train import TrainPipelineConfig
78
67
  from opentau.datasets.lerobot_dataset import LeRobotDataset
79
- from opentau.scripts.visualize_dataset_html import create_mock_train_config
68
+
69
+
70
+ def create_mock_train_config() -> TrainPipelineConfig:
71
+ """Create a mock TrainPipelineConfig for dataset visualization.
72
+
73
+ Returns:
74
+ TrainPipelineConfig: A mock config with default values.
75
+ """
76
+ return TrainPipelineConfig(
77
+ dataset_mixture=DatasetMixtureConfig(), # Will be set by the dataset
78
+ resolution=(224, 224),
79
+ num_cams=2,
80
+ max_state_dim=32,
81
+ max_action_dim=32,
82
+ action_chunk=50,
83
+ loss_weighting={"MSE": 1, "CE": 1},
84
+ num_workers=4,
85
+ batch_size=8,
86
+ steps=100_000,
87
+ log_freq=200,
88
+ save_checkpoint=True,
89
+ save_freq=20_000,
90
+ use_policy_training_preset=True,
91
+ wandb=WandBConfig(),
92
+ )
80
93
 
81
94
 
82
95
  class EpisodeSampler(torch.utils.data.Sampler):
@@ -108,7 +121,6 @@ def visualize_dataset(
108
121
  num_workers: int = 0,
109
122
  mode: str = "local",
110
123
  web_port: int = 9090,
111
- ws_port: int = 9087,
112
124
  save: bool = False,
113
125
  output_dir: Path | None = None,
114
126
  ) -> Path | None:
@@ -142,7 +154,7 @@ def visualize_dataset(
142
154
  gc.collect()
143
155
 
144
156
  if mode == "distant":
145
- rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
157
+ rr.serve_web_viewer(open_browser=False, web_port=web_port)
146
158
 
147
159
  logging.info("Logging to Rerun")
148
160
 
@@ -194,7 +206,7 @@ def visualize_dataset(
194
206
  print("Ctrl-C received. Exiting.")
195
207
 
196
208
 
197
- def main():
209
+ def parse_args() -> dict:
198
210
  parser = argparse.ArgumentParser()
199
211
 
200
212
  parser.add_argument(
@@ -250,12 +262,6 @@ def main():
250
262
  default=9090,
251
263
  help="Web port for rerun.io when `--mode distant` is set.",
252
264
  )
253
- parser.add_argument(
254
- "--ws-port",
255
- type=int,
256
- default=9087,
257
- help="Web socket port for rerun.io when `--mode distant` is set.",
258
- )
259
265
  parser.add_argument(
260
266
  "--save",
261
267
  type=int,
@@ -279,15 +285,25 @@ def main():
279
285
  )
280
286
 
281
287
  args = parser.parse_args()
282
- kwargs = vars(args)
288
+ return vars(args)
289
+
290
+
291
+ def main():
292
+ kwargs = parse_args()
283
293
  repo_id = kwargs.pop("repo_id")
284
294
  root = kwargs.pop("root")
285
295
  tolerance_s = kwargs.pop("tolerance_s")
286
296
 
287
297
  logging.info("Loading dataset")
288
- dataset = LeRobotDataset(create_mock_train_config(), repo_id, root=root, tolerance_s=tolerance_s)
298
+ dataset = LeRobotDataset(
299
+ create_mock_train_config(),
300
+ repo_id,
301
+ root=root,
302
+ tolerance_s=tolerance_s,
303
+ standardize=False,
304
+ )
289
305
 
290
- visualize_dataset(dataset, **vars(args))
306
+ visualize_dataset(dataset, **kwargs)
291
307
 
292
308
 
293
309
  if __name__ == "__main__":
@@ -0,0 +1,276 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Module for patching transformers
16
+
17
+ Most patches come from the branch fix/lerobot-openpi
18
+ """
19
+
20
+ from typing import Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers.models.gemma import modeling_gemma
25
+ from transformers.models.gemma.configuration_gemma import GemmaConfig
26
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaModel
27
+
28
+ # Monkey patch __init__ of GemmaConfig to fix or modify its behavior as needed.
29
+
30
+ _original_gemma_config_init = GemmaConfig.__init__
31
+
32
+
33
+ def patched_gemma_config_init(
34
+ self, *args, use_adarms: bool = False, adarms_cond_dim: Optional[int] = None, **kwargs
35
+ ):
36
+ """Initializes the GemmaConfig with added ADARMS support.
37
+
38
+ Args:
39
+ self: The GemmaConfig instance.
40
+ *args: Variable length argument list.
41
+ use_adarms: Whether to use Adaptive RMS normalization.
42
+ adarms_cond_dim: The dimension of the conditioning vector for ADARMS.
43
+ **kwargs: Arbitrary keyword arguments.
44
+ """
45
+ # Call the original init with all other arguments
46
+ _original_gemma_config_init(self, *args, **kwargs)
47
+
48
+ # Initialize custom attributes
49
+ self.use_adarms = use_adarms
50
+ self.adarms_cond_dim = adarms_cond_dim
51
+
52
+ # Set default for adarms_cond_dim if use_adarms is True
53
+ if self.use_adarms and self.adarms_cond_dim is None:
54
+ # hidden_size is set by _original_gemma_config_init
55
+ self.adarms_cond_dim = self.hidden_size
56
+
57
+
58
+ GemmaConfig.__init__ = patched_gemma_config_init
59
+
60
+
61
+ # --- Modeling Patches ---
62
+
63
+
64
+ def _gated_residual(x, y, gate):
65
+ """
66
+ Applies gated residual connection with optional gate parameter.
67
+
68
+ Args:
69
+ x: Input tensor (residual)
70
+ y: Output tensor to be added
71
+ gate: Optional gate tensor to modulate the addition
72
+
73
+ Returns:
74
+ x + y if gate is None, otherwise x + y * gate
75
+ """
76
+ if x is None and y is None:
77
+ return None
78
+ if x is None or y is None:
79
+ return x if x is not None else y
80
+ if gate is None:
81
+ return x + y
82
+ return x + y * gate
83
+
84
+
85
+ modeling_gemma._gated_residual = _gated_residual
86
+
87
+
88
+ class PatchedGemmaRMSNorm(nn.Module):
89
+ """RMS normalization with optional adaptive support (ADARMS)."""
90
+
91
+ def __init__(self, dim: int, eps: float = 1e-6, cond_dim: Optional[int] = None):
92
+ """Initializes the PatchedGemmaRMSNorm.
93
+
94
+ Args:
95
+ dim: The dimension of the input tensor.
96
+ eps: The epsilon value for numerical stability.
97
+ cond_dim: The dimension of the conditioning vector (if using ADARMS).
98
+ """
99
+ super().__init__()
100
+ self.eps = eps
101
+ self.dim = dim
102
+ self.cond_dim = cond_dim
103
+
104
+ # Dense layer for adaptive normalization (if cond_dim is provided)
105
+ if cond_dim is not None:
106
+ self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
107
+ # Initialize with zeros (matches source implementation)
108
+ nn.init.zeros_(self.dense.weight)
109
+ else:
110
+ self.weight = nn.Parameter(torch.zeros(dim))
111
+ self.dense = None
112
+
113
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
114
+ """Applies RMS normalization.
115
+
116
+ Args:
117
+ x: The input tensor.
118
+
119
+ Returns:
120
+ The normalized tensor.
121
+ """
122
+ # Compute variance in float32 (like the source implementation)
123
+ var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
124
+ # Compute normalization in float32
125
+ normed_inputs = x * torch.rsqrt(var + self.eps)
126
+ return normed_inputs
127
+
128
+ def forward(
129
+ self, x: torch.Tensor, cond: Optional[torch.Tensor] = None
130
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
131
+ """Forward pass of the normalization layer.
132
+
133
+ Args:
134
+ x: The input tensor.
135
+ cond: The conditioning tensor for adaptive normalization.
136
+
137
+ Returns:
138
+ A tuple containing the normalized tensor and the gate tensor (if applicable).
139
+ If cond is None, the gate tensor will be None.
140
+
141
+ Raises:
142
+ ValueError: If cond dimension does not match the configured cond_dim.
143
+ """
144
+ dtype = x.dtype # original dtype, could be half-precision
145
+ normed_inputs = self._norm(x)
146
+
147
+ if cond is None or self.dense is None:
148
+ # regular RMSNorm
149
+ # scale by learned parameter in float32 (matches source implementation)
150
+ normed_inputs = normed_inputs * (1.0 + self.weight.float())
151
+ return normed_inputs.to(dtype), None # return in original dtype with None gate
152
+
153
+ # adaptive RMSNorm (if cond is provided and dense layer exists)
154
+ if cond.shape[-1] != self.cond_dim:
155
+ raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}")
156
+
157
+ modulation = self.dense(cond)
158
+ # Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features]
159
+ if len(x.shape) == 3: # [batch, seq, features]
160
+ modulation = modulation.unsqueeze(1)
161
+
162
+ scale, shift, gate = torch.chunk(modulation, 3, dim=-1)
163
+
164
+ normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32)
165
+
166
+ return normed_inputs.to(dtype), gate.to(dtype)
167
+
168
+ def extra_repr(self) -> str:
169
+ """Returns the extra representation of the module."""
170
+ repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}"
171
+ if self.dense is not None:
172
+ repr_str += f", adaptive=True, cond_dim={self.cond_dim}"
173
+ return repr_str
174
+
175
+
176
+ # Apply patches
177
+ modeling_gemma.GemmaRMSNorm = PatchedGemmaRMSNorm
178
+
179
+
180
+ def patched_gemma_decoder_layer_init(self, config: GemmaConfig, layer_idx: int):
181
+ """Initializes a GemmaDecoderLayer with potential ADARMS support.
182
+
183
+ Args:
184
+ self: The GemmaDecoderLayer instance.
185
+ config: The configuration object.
186
+ layer_idx: The index of the layer.
187
+ """
188
+ modeling_gemma.GradientCheckpointingLayer.__init__(self)
189
+ self.hidden_size = config.hidden_size
190
+
191
+ self.self_attn = modeling_gemma.GemmaAttention(config=config, layer_idx=layer_idx)
192
+
193
+ self.mlp = modeling_gemma.GemmaMLP(config)
194
+
195
+ cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
196
+ self.input_layernorm = modeling_gemma.GemmaRMSNorm(
197
+ config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
198
+ )
199
+ self.post_attention_layernorm = modeling_gemma.GemmaRMSNorm(
200
+ config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
201
+ )
202
+
203
+
204
+ modeling_gemma.GemmaDecoderLayer.__init__ = patched_gemma_decoder_layer_init
205
+
206
+
207
+ def patched_gemma_model_init(self, config: GemmaConfig):
208
+ """Initializes the GemmaModel with potential ADARMS support.
209
+
210
+ Args:
211
+ self: The GemmaModel instance.
212
+ config: The configuration object.
213
+ """
214
+ modeling_gemma.GemmaPreTrainedModel.__init__(self, config)
215
+ self.padding_idx = config.pad_token_id
216
+ self.vocab_size = config.vocab_size
217
+
218
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
219
+ self.layers = nn.ModuleList(
220
+ [modeling_gemma.GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
221
+ )
222
+
223
+ cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
224
+ self.norm = modeling_gemma.GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
225
+ self.rotary_emb = modeling_gemma.GemmaRotaryEmbedding(config=config)
226
+ self.gradient_checkpointing = False
227
+
228
+ # Initialize weights and apply final processing
229
+ self.post_init()
230
+
231
+
232
+ modeling_gemma.GemmaModel.__init__ = patched_gemma_model_init
233
+
234
+
235
+ def patched_gemma_pretrained_model_init_weights(self, module: nn.Module):
236
+ """Initializes the weights of the GemmaPreTrainedModel.
237
+
238
+ Args:
239
+ self: The GemmaPreTrainedModel instance.
240
+ module: The module to initialize.
241
+ """
242
+ std = self.config.initializer_range
243
+ if isinstance(module, nn.Linear):
244
+ module.weight.data.normal_(mean=0.0, std=std)
245
+ if module.bias is not None:
246
+ module.bias.data.zero_()
247
+ elif isinstance(module, nn.Embedding):
248
+ module.weight.data.normal_(mean=0.0, std=std)
249
+ if module.padding_idx is not None:
250
+ module.weight.data[module.padding_idx].zero_()
251
+ elif isinstance(module, modeling_gemma.GemmaRMSNorm):
252
+ if hasattr(module, "weight"):
253
+ module.weight.data.fill_(1.0)
254
+
255
+
256
+ modeling_gemma.GemmaPreTrainedModel._init_weights = patched_gemma_pretrained_model_init_weights
257
+
258
+
259
+ def patched_paligemma_model_get_image_features(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
260
+ """Obtains image last hidden states from the vision tower and apply multimodal projection.
261
+
262
+ Args:
263
+ self: The PaliGemmaModel instance.
264
+ pixel_values: The tensors corresponding to the input images.
265
+ Shape: (batch_size, channels, height, width).
266
+
267
+ Returns:
268
+ Image feature tensor of shape (num_images, image_length, embed_dim).
269
+ """
270
+ image_outputs = self.vision_tower(pixel_values)
271
+ selected_image_feature = image_outputs.last_hidden_state
272
+ image_features = self.multi_modal_projector(selected_image_feature)
273
+ return image_features
274
+
275
+
276
+ PaliGemmaModel.get_image_features = patched_paligemma_model_get_image_features
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opentau
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: OpenTau: Tensor's VLA Training Infrastructure for Real-World Robotics in Pytorch
5
5
  Author-email: Shuheng Liu <wish1104@icloud.com>, William Yue <williamyue37@gmail.com>, Akshay Shah <akshayhitendrashah@gmail.com>, Xingrui Gu <xingrui_gu@berkeley.edu>
6
6
  License: Apache-2.0
@@ -52,7 +52,7 @@ Requires-Dist: onnxruntime>=1.22.1; sys_platform == "darwin" or platform_machine
52
52
  Requires-Dist: onnxruntime-gpu>=1.22.0; (sys_platform == "linux" and platform_machine == "x86_64") or (sys_platform == "win32" and (platform_machine == "AMD64" or platform_machine == "x86_64"))
53
53
  Requires-Dist: onnxscript>=0.3.1
54
54
  Requires-Dist: onnx-ir>=0.1.4
55
- Requires-Dist: opentau-transformers==4.53.3
55
+ Requires-Dist: transformers==4.53.3
56
56
  Requires-Dist: scipy>=1.15.2
57
57
  Requires-Dist: pytest>=8.1.0
58
58
  Requires-Dist: pytest-cov>=5.0.0
@@ -94,6 +94,9 @@ Requires-Dist: numpy<2; extra == "libero"
94
94
  Requires-Dist: gym<0.27,>=0.25; extra == "libero"
95
95
  Requires-Dist: pyopengl-accelerate==3.1.7; sys_platform == "linux" and extra == "libero"
96
96
  Requires-Dist: gymnasium[other]>=0.29; extra == "libero"
97
+ Requires-Dist: mujoco>=3.1.6; sys_platform == "linux" and extra == "libero"
98
+ Requires-Dist: pyopengl==3.1.7; sys_platform == "linux" and extra == "libero"
99
+ Requires-Dist: numpy==1.26.4; sys_platform == "linux" and extra == "libero"
97
100
  Dynamic: license-file
98
101
 
99
102
  <p align="center">
@@ -134,10 +137,10 @@ OpenTau ($\tau$) is a tool developed by *[Tensor][1]* to bridge this gap, and we
134
137
  ## Quick Start
135
138
  If you are familiar with LeRobot, getting started with OpenTau is very easy.
136
139
  Because OpenTau is a fork of the popular LeRobot repository, any LeRobot-compliant policy and dataset can be used directly with OpenTau.
137
- Check out our documentation to get started quickly.
138
- We provide a quick start guide to help you get started with OpenTau.
140
+ Check out our [documentation](https://opentau.readthedocs.io/) to get started quickly.
141
+ We provide a [quick start guide](https://opentau.readthedocs.io/en/latest/getting_started.html) to help you get started with OpenTau.
139
142
 
140
- For using local notebooks to train and evaluate models, find the notebooks at `notebooks/pi05_training.ipynb` and `notebooks/pi05_evaluation_only.ipynb`.
143
+ For using local notebooks to train and evaluate models, find the notebooks at [notebooks/pi05_training.ipynb](https://github.com/TensorAuto/OpenTau/blob/main/notebooks/pi05_training.ipynb) and [notebooks/pi05_evaluation_only.ipynb](https://github.com/TensorAuto/OpenTau/blob/main/notebooks/pi05_evaluation_only.ipynb).
141
144
 
142
145
  For using the Google Colab notebooks to train and evaluate models, find the colab notebooks here: [pi05_training](https://colab.research.google.com/drive/1DeU0lNnEzs1KHo0Nkgh4YKBr-xu9moBM?usp=sharing) and [pi05_evaluation_only](https://colab.research.google.com/drive/1U_AyuH9WYMT4anEWvsOtIT7g01jA0WGm?usp=sharing) respectively.
143
146
 
@@ -1,4 +1,4 @@
1
- opentau/__init__.py,sha256=KktmFuQQqDhhzcJKczJ99crma9qjxiJZUPv-_d0ZeLU,6697
1
+ opentau/__init__.py,sha256=fIqOYsZsF-bitUI-4taSNke_D1YJYCehGseZNe29GG0,6756
2
2
  opentau/__version__.py,sha256=junxoss59Jz_hmg3YnzhpVk_Q5Fo6uha23P1ET81N1c,889
3
3
  opentau/constants.py,sha256=-_CbJujCp6hbBjJHgYMguCTcSAkVkmdpM4wHqZp7vRQ,2020
4
4
  opentau/configs/__init__.py,sha256=hC-KkeCfq1mtMw9WjPCZfOTxrzQWW7hAa1w8BRC_Bqw,784
@@ -15,7 +15,7 @@ opentau/datasets/compute_stats.py,sha256=N359TDuJicLKMtxxy0JVEcUtnTOB57gL5G8e9Dq
15
15
  opentau/datasets/dataset_mixture.py,sha256=8UWjY9oKn9jEMe-e9Dy6no1p_21H0kXKv8A10Ku_8_o,19850
16
16
  opentau/datasets/factory.py,sha256=NKWpbuNBve0PsmK1midj8g1IpQapeHn-VrxCOC3X4eI,10480
17
17
  opentau/datasets/image_writer.py,sha256=JYCkImHFYpLuE88t16cYqXqQS7EHS7g6kLWXPCJmWgw,11072
18
- opentau/datasets/lerobot_dataset.py,sha256=rz_3BcXqpcIzYr0NEVmkfLf1dY7vcTdo6zuV1CZkIuI,84747
18
+ opentau/datasets/lerobot_dataset.py,sha256=c6bGOz75yEJfYkYqlcfszGkap0VBAMBFXrH8fz1P1WQ,84651
19
19
  opentau/datasets/online_buffer.py,sha256=x14P8tBz25s-hRlE8loFJs5CAvh65RGWeogF271hiF0,19671
20
20
  opentau/datasets/sampler.py,sha256=5g-6prsWItVjqkt1J7mA9JPNQPhDSFx3r6rA4JphP9U,4012
21
21
  opentau/datasets/standard_data_format_mapping.py,sha256=wEKilksMUjJGeIhvyLuR9qhyhtiJMK1e1AzCkbyx-l4,9667
@@ -78,13 +78,12 @@ opentau/scripts/fake_tensor_training.py,sha256=y4F3CFs2jjpIJcT1wKvsrgFEebU9QFzba
78
78
  opentau/scripts/get_advantage_and_percentiles.py,sha256=JdjlADYzdS1Jc_19H6lLYMRnPlWxeckRSUQqwqb0rC4,8993
79
79
  opentau/scripts/high_level_planner_inference.py,sha256=nbXr8Hp64YGeprMTpT8kvT_NgpBlI02CUlO6Mm2Js_E,3846
80
80
  opentau/scripts/inference.py,sha256=_lp9YjPzarAnjiA8k2jBlIKZxza6PEHw--UyaqLPdNo,2110
81
- opentau/scripts/launch_train.py,sha256=ThyZ0IqRfarvD3qEqa8sazSgYYL3Bh22zz6-z1JtZBs,2066
81
+ opentau/scripts/launch.py,sha256=kcJtdO1WHYxiHSJpJ_y618tbIvBuGXy8FmH5BEEdVdI,2826
82
82
  opentau/scripts/libero_simulation_parallel.py,sha256=qMee6T0EwMoAT1J2u8X4w8rsbOJYwyqD3LRAPe2Ta1g,13105
83
83
  opentau/scripts/libero_simulation_sequential.py,sha256=xFSUQEuyai20QD-pYitp-UJPGE9zlaaIu4YSO0bhYKg,4775
84
84
  opentau/scripts/nav_high_level_planner_inference.py,sha256=z2WHw68NWi-fJUd5TV4CrJHzxo-L7e2UliGjfOlqifM,1878
85
- opentau/scripts/train.py,sha256=UftedMDTlTDLidV8fMPtiCu8xK9G16ivtIqaJp38su0,16866
86
- opentau/scripts/visualize_dataset.py,sha256=_xGfAXQqhjGYMi__6L7qRH2xS5XQ2-GQRXjNw3KXMlY,10109
87
- opentau/scripts/visualize_dataset_html.py,sha256=gEX-E5fFqBhINthf7xLMICHySvw9e3Kcf1HPRnJIyug,17979
85
+ opentau/scripts/train.py,sha256=nkvsvna5yliphp7pwVyFY6yBwCA_kmffyohRO2wjiHU,16850
86
+ opentau/scripts/visualize_dataset.py,sha256=RsON_13oqTm7HN14tGnDBIVAJPCW_-EJzpMHeiXxp24,10492
88
87
  opentau/scripts/zero_to_fp32.py,sha256=Rkl1ZczytKix9vGMg0EELzdJYFqUM1yB9p3xvSaK9k8,33272
89
88
  opentau/utils/__init__.py,sha256=hIUeGPpZHf2AVf0-5C2p0BOcY0cFHCTT5yHn-SpEPwY,856
90
89
  opentau/utils/accelerate_utils.py,sha256=vXnSGo1hXCUNof-oNKLMJ_SOMjpKhpZ1gx21ObSsopI,2630
@@ -99,10 +98,11 @@ opentau/utils/logging_utils.py,sha256=zd7ypmk7aqVposPhA7Kg-PYrstapY4MsuTklsTD4r4
99
98
  opentau/utils/monkey_patch.py,sha256=cVgZ1N-NNVnlRKPA1dwO9FM4IbxR0V_Hbil6p-6knhA,9558
100
99
  opentau/utils/random_utils.py,sha256=k3Ab3Y98LozGdsBzKoP8xSsFTcnaRqUzY34BsETCrrA,9102
101
100
  opentau/utils/train_utils.py,sha256=0d7yvk8wlP-75pwB55gr095b_b1sWG5nlqdVxyH6_o0,6796
101
+ opentau/utils/transformers_patch.py,sha256=-3Fvf-_owtT_QDUkoGfMWO-pxN5xeQikPljtLMn4MRs,9906
102
102
  opentau/utils/utils.py,sha256=DrMStfjBEkw_8WVhYMnCQJNBxMeozIJ8LBSpOtMQhFM,15760
103
- opentau-0.1.0.dist-info/licenses/LICENSE,sha256=tl3_NkxplsgU86xSvEWnDlE1UR_JsIvGo7t4hPtsIbE,27680
104
- opentau-0.1.0.dist-info/METADATA,sha256=D89t5Nd6jD5bsHXc532C0yypVj2fOWYeThanhqj8Q_8,10456
105
- opentau-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
106
- opentau-0.1.0.dist-info/entry_points.txt,sha256=Q2Jf-g98RrhE7trmkvW3KrOgSBQJfdRRgUxaraF-A3c,68
107
- opentau-0.1.0.dist-info/top_level.txt,sha256=7_yrS4x5KSeTRr2LICTCNOZmF-1_kSOFPKHvtJPL1Dw,8
108
- opentau-0.1.0.dist-info/RECORD,,
103
+ opentau-0.1.2.dist-info/licenses/LICENSE,sha256=tl3_NkxplsgU86xSvEWnDlE1UR_JsIvGo7t4hPtsIbE,27680
104
+ opentau-0.1.2.dist-info/METADATA,sha256=Up5VRGhf8RVjBA0mBy6xKA21-6R_t51xvGmG-YgC1EQ,10943
105
+ opentau-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
106
+ opentau-0.1.2.dist-info/entry_points.txt,sha256=NGF_MWpSKri0lvjR9WGN4pBUap8B-z21f7XMluxc1M4,208
107
+ opentau-0.1.2.dist-info/top_level.txt,sha256=7_yrS4x5KSeTRr2LICTCNOZmF-1_kSOFPKHvtJPL1Dw,8
108
+ opentau-0.1.2.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ [console_scripts]
2
+ opentau-dataset-viz = opentau.scripts.launch:visualize
3
+ opentau-eval = opentau.scripts.launch:eval
4
+ opentau-export = opentau.scripts.launch:export
5
+ opentau-train = opentau.scripts.launch:train
@@ -1,63 +0,0 @@
1
- # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import argparse
16
- import subprocess
17
- import sys
18
- from pathlib import Path
19
-
20
- import opentau.scripts.train as train_script
21
-
22
-
23
- def main():
24
- parser = argparse.ArgumentParser(
25
- description="Launch OpenTau training with Accelerate",
26
- usage="opentau-train [--accelerate-config CONFIG] [TRAINING_ARGS]",
27
- )
28
- parser.add_argument(
29
- "--accelerate-config", type=str, help="Path to accelerate config file (yaml)", default=None
30
- )
31
- # We use parse_known_args so that all other arguments are collected
32
- # These will be passed to the training script
33
- args, unknown_args = parser.parse_known_args()
34
-
35
- # Base command
36
- cmd = ["accelerate", "launch"]
37
-
38
- # Add accelerate config if provided
39
- if args.accelerate_config:
40
- cmd.extend(["--config_file", args.accelerate_config])
41
-
42
- # Add the path to the training script
43
- # We resolve the path to ensure it's absolute
44
- train_script_path = Path(train_script.__file__).resolve()
45
- cmd.append(str(train_script_path))
46
-
47
- # Add all other arguments (passed to the training script)
48
- cmd.extend(unknown_args)
49
-
50
- # Print the command for transparency
51
- print(f"Executing: {' '.join(cmd)}")
52
-
53
- # Replace the current process with the accelerate launch command
54
- try:
55
- subprocess.run(cmd, check=True)
56
- except subprocess.CalledProcessError as e:
57
- sys.exit(e.returncode)
58
- except KeyboardInterrupt:
59
- sys.exit(130)
60
-
61
-
62
- if __name__ == "__main__":
63
- main()
@@ -1,507 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
- # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
- """ Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
18
-
19
- Note: The last frame of the episode doesnt always correspond to a final state.
20
- That's because our datasets are composed of transition from state to state up to
21
- the antepenultimate state associated to the ultimate action to arrive in the final state.
22
- However, there might not be a transition from a final state to another state.
23
-
24
- Note: This script aims to visualize the data used to train the neural networks.
25
- ~What you see is what you get~. When visualizing image modality, it is often expected to observe
26
- lossly compression artifacts since these images have been decoded from compressed mp4 videos to
27
- save disk space. The compression factor applied has been tuned to not affect success rate.
28
-
29
- Example of usage:
30
-
31
- - Visualize data stored on a local machine:
32
- ```bash
33
- local$ python src/opentau/scripts/visualize_dataset_html.py \
34
- --repo-id lerobot/pusht
35
-
36
- local$ open http://localhost:9090
37
- ```
38
-
39
- - Visualize data stored on a distant machine with a local viewer:
40
- ```bash
41
- distant$ python src/opentau/scripts/visualize_dataset_html.py \
42
- --repo-id lerobot/pusht
43
-
44
- local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
45
- local$ open http://localhost:9090
46
- ```
47
-
48
- - Select episodes to visualize:
49
- ```bash
50
- python src/opentau/scripts/visualize_dataset_html.py \
51
- --repo-id lerobot/pusht \
52
- --episodes 7 3 5 1 4
53
- ```
54
- """
55
-
56
- import argparse
57
- import csv
58
- import json
59
- import logging
60
- import re
61
- import shutil
62
- import tempfile
63
- from io import StringIO
64
- from pathlib import Path
65
-
66
- import numpy as np
67
- import pandas as pd
68
- import requests
69
- from flask import Flask, redirect, render_template, request, url_for
70
-
71
- from opentau import available_datasets
72
- from opentau.configs.default import DatasetMixtureConfig, WandBConfig
73
- from opentau.configs.train import TrainPipelineConfig
74
- from opentau.datasets.lerobot_dataset import LeRobotDataset
75
- from opentau.datasets.utils import IterableNamespace
76
- from opentau.utils.utils import init_logging
77
-
78
-
79
- def run_server(
80
- dataset: LeRobotDataset | IterableNamespace | None,
81
- episodes: list[int] | None,
82
- host: str,
83
- port: str,
84
- static_folder: Path,
85
- template_folder: Path,
86
- ):
87
- app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
88
- app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
89
-
90
- @app.route("/")
91
- def hommepage(dataset=dataset):
92
- if dataset:
93
- dataset_namespace, dataset_name = dataset.repo_id.split("/")
94
- return redirect(
95
- url_for(
96
- "show_episode",
97
- dataset_namespace=dataset_namespace,
98
- dataset_name=dataset_name,
99
- episode_id=0,
100
- )
101
- )
102
-
103
- dataset_param, episode_param = None, None
104
- all_params = request.args
105
- if "dataset" in all_params:
106
- dataset_param = all_params["dataset"]
107
- if "episode" in all_params:
108
- episode_param = int(all_params["episode"])
109
-
110
- if dataset_param:
111
- dataset_namespace, dataset_name = dataset_param.split("/")
112
- return redirect(
113
- url_for(
114
- "show_episode",
115
- dataset_namespace=dataset_namespace,
116
- dataset_name=dataset_name,
117
- episode_id=episode_param if episode_param is not None else 0,
118
- )
119
- )
120
-
121
- featured_datasets = [
122
- "lerobot/aloha_static_cups_open",
123
- "lerobot/columbia_cairlab_pusht_real",
124
- "lerobot/taco_play",
125
- ]
126
- return render_template(
127
- "visualize_dataset_homepage.html",
128
- featured_datasets=featured_datasets,
129
- lerobot_datasets=available_datasets,
130
- )
131
-
132
- @app.route("/<string:dataset_namespace>/<string:dataset_name>")
133
- def show_first_episode(dataset_namespace, dataset_name):
134
- first_episode_id = 0
135
- return redirect(
136
- url_for(
137
- "show_episode",
138
- dataset_namespace=dataset_namespace,
139
- dataset_name=dataset_name,
140
- episode_id=first_episode_id,
141
- )
142
- )
143
-
144
- @app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
145
- def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
146
- repo_id = f"{dataset_namespace}/{dataset_name}"
147
- try:
148
- if dataset is None:
149
- dataset = get_dataset_info(repo_id)
150
- except FileNotFoundError:
151
- return (
152
- "Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461",
153
- 400,
154
- )
155
- dataset_version = (
156
- str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
157
- )
158
- match = re.search(r"v(\d+)\.", dataset_version)
159
- if match:
160
- major_version = int(match.group(1))
161
- if major_version < 2:
162
- return "Make sure to convert your LeRobotDataset to v2 & above."
163
-
164
- episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
165
- dataset_info = {
166
- "repo_id": f"{dataset_namespace}/{dataset_name}",
167
- "num_samples": dataset.num_frames
168
- if isinstance(dataset, LeRobotDataset)
169
- else dataset.total_frames,
170
- "num_episodes": dataset.num_episodes
171
- if isinstance(dataset, LeRobotDataset)
172
- else dataset.total_episodes,
173
- "fps": dataset.fps,
174
- }
175
- if isinstance(dataset, LeRobotDataset):
176
- video_paths = [
177
- dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
178
- ]
179
- videos_info = [
180
- {"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
181
- for video_path in video_paths
182
- ]
183
- tasks = dataset.meta.episodes[episode_id]["tasks"]
184
- else:
185
- video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
186
- videos_info = [
187
- {
188
- "url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
189
- + dataset.video_path.format(
190
- episode_chunk=int(episode_id) // dataset.chunks_size,
191
- video_key=video_key,
192
- episode_index=episode_id,
193
- ),
194
- "filename": video_key,
195
- }
196
- for video_key in video_keys
197
- ]
198
-
199
- response = requests.get(
200
- f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
201
- )
202
- response.raise_for_status()
203
- # Split into lines and parse each line as JSON
204
- tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
205
-
206
- filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
207
- tasks = filtered_tasks_jsonl[0]["tasks"]
208
-
209
- videos_info[0]["language_instruction"] = tasks
210
-
211
- if episodes is None:
212
- episodes = list(
213
- range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
214
- )
215
-
216
- return render_template(
217
- "visualize_dataset_template.html",
218
- episode_id=episode_id,
219
- episodes=episodes,
220
- dataset_info=dataset_info,
221
- videos_info=videos_info,
222
- episode_data_csv_str=episode_data_csv_str,
223
- columns=columns,
224
- ignored_columns=ignored_columns,
225
- )
226
-
227
- app.run(host=host, port=port)
228
-
229
-
230
- def get_ep_csv_fname(episode_id: int):
231
- ep_csv_fname = f"episode_{episode_id}.csv"
232
- return ep_csv_fname
233
-
234
-
235
- def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
236
- """Get a csv str containing timeseries data of an episode (e.g. state and action).
237
- This file will be loaded by Dygraph javascript to plot data in real time."""
238
- columns = []
239
-
240
- selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
241
- selected_columns.remove("timestamp")
242
-
243
- ignored_columns = []
244
- for column_name in selected_columns:
245
- shape = dataset.features[column_name]["shape"]
246
- shape_dim = len(shape)
247
- if shape_dim > 1:
248
- selected_columns.remove(column_name)
249
- ignored_columns.append(column_name)
250
-
251
- # init header of csv with state and action names
252
- header = ["timestamp"]
253
-
254
- for column_name in selected_columns:
255
- dim_state = (
256
- dataset.meta.shapes[column_name][0]
257
- if isinstance(dataset, LeRobotDataset)
258
- else dataset.features[column_name].shape[0]
259
- )
260
-
261
- if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
262
- column_names = dataset.features[column_name]["names"]
263
- while not isinstance(column_names, list):
264
- column_names = list(column_names.values())[0]
265
- else:
266
- column_names = [f"{column_name}_{i}" for i in range(dim_state)]
267
- columns.append({"key": column_name, "value": column_names})
268
-
269
- header += column_names
270
-
271
- selected_columns.insert(0, "timestamp")
272
-
273
- if isinstance(dataset, LeRobotDataset):
274
- from_idx = dataset.episode_data_index["from"][episode_index]
275
- to_idx = dataset.episode_data_index["to"][episode_index]
276
- data = (
277
- dataset.hf_dataset.select(range(from_idx, to_idx))
278
- .select_columns(selected_columns)
279
- .with_format("pandas")
280
- )
281
- else:
282
- repo_id = dataset.repo_id
283
-
284
- url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
285
- episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
286
- )
287
- df = pd.read_parquet(url)
288
- data = df[selected_columns] # Select specific columns
289
-
290
- rows = np.hstack(
291
- (
292
- np.expand_dims(data["timestamp"], axis=1),
293
- *[np.vstack(data[col]) for col in selected_columns[1:]],
294
- )
295
- ).tolist()
296
-
297
- # Convert data to CSV string
298
- csv_buffer = StringIO()
299
- csv_writer = csv.writer(csv_buffer)
300
- # Write header
301
- csv_writer.writerow(header)
302
- # Write data rows
303
- csv_writer.writerows(rows)
304
- csv_string = csv_buffer.getvalue()
305
-
306
- return csv_string, columns, ignored_columns
307
-
308
-
309
- def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
310
- # get first frame of episode (hack to get video_path of the episode)
311
- first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
312
- return [
313
- dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
314
- for key in dataset.meta.video_keys
315
- ]
316
-
317
-
318
- def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
319
- # check if the dataset has language instructions
320
- if "language_instruction" not in dataset.features:
321
- return None
322
-
323
- # get first frame index
324
- first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
325
-
326
- language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
327
- # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
328
- # with the tf.tensor appearing in the string
329
- return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
330
-
331
-
332
- def get_dataset_info(repo_id: str) -> IterableNamespace:
333
- response = requests.get(
334
- f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
335
- )
336
- response.raise_for_status() # Raises an HTTPError for bad responses
337
- dataset_info = response.json()
338
- dataset_info["repo_id"] = repo_id
339
- return IterableNamespace(dataset_info)
340
-
341
-
342
- def visualize_dataset_html(
343
- dataset: LeRobotDataset | None,
344
- episodes: list[int] | None = None,
345
- output_dir: Path | None = None,
346
- serve: bool = True,
347
- host: str = "127.0.0.1",
348
- port: int = 9090,
349
- force_override: bool = False,
350
- ) -> Path | None:
351
- init_logging()
352
-
353
- template_dir = Path(__file__).resolve().parent.parent / "templates"
354
-
355
- if output_dir is None:
356
- # Create a temporary directory that will be automatically cleaned up
357
- output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
358
-
359
- output_dir = Path(output_dir)
360
- if output_dir.exists():
361
- if force_override:
362
- shutil.rmtree(output_dir)
363
- else:
364
- logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
365
-
366
- output_dir.mkdir(parents=True, exist_ok=True)
367
-
368
- static_dir = output_dir / "static"
369
- static_dir.mkdir(parents=True, exist_ok=True)
370
-
371
- if dataset is None:
372
- if serve:
373
- run_server(
374
- dataset=None,
375
- episodes=None,
376
- host=host,
377
- port=port,
378
- static_folder=static_dir,
379
- template_folder=template_dir,
380
- )
381
- else:
382
- # Create a simlink from the dataset video folder containing mp4 files to the output directory
383
- # so that the http server can get access to the mp4 files.
384
- if isinstance(dataset, LeRobotDataset):
385
- ln_videos_dir = static_dir / "videos"
386
- if not ln_videos_dir.exists():
387
- ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
388
-
389
- if serve:
390
- run_server(dataset, episodes, host, port, static_dir, template_dir)
391
-
392
-
393
- def create_mock_train_config() -> TrainPipelineConfig:
394
- """Create a mock TrainPipelineConfig for dataset visualization.
395
-
396
- Returns:
397
- TrainPipelineConfig: A mock config with default values.
398
- """
399
- return TrainPipelineConfig(
400
- dataset_mixture=DatasetMixtureConfig(), # Will be set by the dataset
401
- resolution=(224, 224),
402
- num_cams=2,
403
- max_state_dim=32,
404
- max_action_dim=32,
405
- action_chunk=50,
406
- loss_weighting={"MSE": 1, "CE": 1},
407
- num_workers=4,
408
- batch_size=8,
409
- steps=100_000,
410
- log_freq=200,
411
- save_checkpoint=True,
412
- save_freq=20_000,
413
- use_policy_training_preset=True,
414
- wandb=WandBConfig(),
415
- )
416
-
417
-
418
- def main():
419
- parser = argparse.ArgumentParser()
420
-
421
- parser.add_argument(
422
- "--repo-id",
423
- type=str,
424
- default=None,
425
- help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
426
- )
427
- parser.add_argument(
428
- "--root",
429
- type=Path,
430
- default=None,
431
- help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
432
- )
433
- parser.add_argument(
434
- "--load-from-hf-hub",
435
- type=int,
436
- default=0,
437
- help="Load videos and parquet files from HF Hub rather than local system.",
438
- )
439
- parser.add_argument(
440
- "--episodes",
441
- type=int,
442
- nargs="*",
443
- default=None,
444
- help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
445
- )
446
- parser.add_argument(
447
- "--output-dir",
448
- type=Path,
449
- default=None,
450
- help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
451
- )
452
- parser.add_argument(
453
- "--serve",
454
- type=int,
455
- default=1,
456
- help="Launch web server.",
457
- )
458
- parser.add_argument(
459
- "--host",
460
- type=str,
461
- default="127.0.0.1",
462
- help="Web host used by the http server.",
463
- )
464
- parser.add_argument(
465
- "--port",
466
- type=int,
467
- default=9090,
468
- help="Web port used by the http server.",
469
- )
470
- parser.add_argument(
471
- "--force-override",
472
- type=int,
473
- default=0,
474
- help="Delete the output directory if it exists already.",
475
- )
476
-
477
- parser.add_argument(
478
- "--tolerance-s",
479
- type=float,
480
- default=1e-4,
481
- help=(
482
- "Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
483
- "This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
484
- "If not given, defaults to 1e-4."
485
- ),
486
- )
487
-
488
- args = parser.parse_args()
489
- kwargs = vars(args)
490
- repo_id = kwargs.pop("repo_id")
491
- load_from_hf_hub = kwargs.pop("load_from_hf_hub")
492
- root = kwargs.pop("root")
493
- tolerance_s = kwargs.pop("tolerance_s")
494
-
495
- dataset = None
496
- if repo_id:
497
- dataset = (
498
- LeRobotDataset(create_mock_train_config(), repo_id, root=root, tolerance_s=tolerance_s)
499
- if not load_from_hf_hub
500
- else get_dataset_info(repo_id)
501
- )
502
-
503
- visualize_dataset_html(dataset, **vars(args))
504
-
505
-
506
- if __name__ == "__main__":
507
- main()
@@ -1,2 +0,0 @@
1
- [console_scripts]
2
- opentau-train = opentau.scripts.launch_train:main