opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,507 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """ Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
18
+
19
+ Note: The last frame of the episode doesnt always correspond to a final state.
20
+ That's because our datasets are composed of transition from state to state up to
21
+ the antepenultimate state associated to the ultimate action to arrive in the final state.
22
+ However, there might not be a transition from a final state to another state.
23
+
24
+ Note: This script aims to visualize the data used to train the neural networks.
25
+ ~What you see is what you get~. When visualizing image modality, it is often expected to observe
26
+ lossly compression artifacts since these images have been decoded from compressed mp4 videos to
27
+ save disk space. The compression factor applied has been tuned to not affect success rate.
28
+
29
+ Example of usage:
30
+
31
+ - Visualize data stored on a local machine:
32
+ ```bash
33
+ local$ python src/opentau/scripts/visualize_dataset_html.py \
34
+ --repo-id lerobot/pusht
35
+
36
+ local$ open http://localhost:9090
37
+ ```
38
+
39
+ - Visualize data stored on a distant machine with a local viewer:
40
+ ```bash
41
+ distant$ python src/opentau/scripts/visualize_dataset_html.py \
42
+ --repo-id lerobot/pusht
43
+
44
+ local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
45
+ local$ open http://localhost:9090
46
+ ```
47
+
48
+ - Select episodes to visualize:
49
+ ```bash
50
+ python src/opentau/scripts/visualize_dataset_html.py \
51
+ --repo-id lerobot/pusht \
52
+ --episodes 7 3 5 1 4
53
+ ```
54
+ """
55
+
56
+ import argparse
57
+ import csv
58
+ import json
59
+ import logging
60
+ import re
61
+ import shutil
62
+ import tempfile
63
+ from io import StringIO
64
+ from pathlib import Path
65
+
66
+ import numpy as np
67
+ import pandas as pd
68
+ import requests
69
+ from flask import Flask, redirect, render_template, request, url_for
70
+
71
+ from opentau import available_datasets
72
+ from opentau.configs.default import DatasetMixtureConfig, WandBConfig
73
+ from opentau.configs.train import TrainPipelineConfig
74
+ from opentau.datasets.lerobot_dataset import LeRobotDataset
75
+ from opentau.datasets.utils import IterableNamespace
76
+ from opentau.utils.utils import init_logging
77
+
78
+
79
+ def run_server(
80
+ dataset: LeRobotDataset | IterableNamespace | None,
81
+ episodes: list[int] | None,
82
+ host: str,
83
+ port: str,
84
+ static_folder: Path,
85
+ template_folder: Path,
86
+ ):
87
+ app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
88
+ app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
89
+
90
+ @app.route("/")
91
+ def hommepage(dataset=dataset):
92
+ if dataset:
93
+ dataset_namespace, dataset_name = dataset.repo_id.split("/")
94
+ return redirect(
95
+ url_for(
96
+ "show_episode",
97
+ dataset_namespace=dataset_namespace,
98
+ dataset_name=dataset_name,
99
+ episode_id=0,
100
+ )
101
+ )
102
+
103
+ dataset_param, episode_param = None, None
104
+ all_params = request.args
105
+ if "dataset" in all_params:
106
+ dataset_param = all_params["dataset"]
107
+ if "episode" in all_params:
108
+ episode_param = int(all_params["episode"])
109
+
110
+ if dataset_param:
111
+ dataset_namespace, dataset_name = dataset_param.split("/")
112
+ return redirect(
113
+ url_for(
114
+ "show_episode",
115
+ dataset_namespace=dataset_namespace,
116
+ dataset_name=dataset_name,
117
+ episode_id=episode_param if episode_param is not None else 0,
118
+ )
119
+ )
120
+
121
+ featured_datasets = [
122
+ "lerobot/aloha_static_cups_open",
123
+ "lerobot/columbia_cairlab_pusht_real",
124
+ "lerobot/taco_play",
125
+ ]
126
+ return render_template(
127
+ "visualize_dataset_homepage.html",
128
+ featured_datasets=featured_datasets,
129
+ lerobot_datasets=available_datasets,
130
+ )
131
+
132
+ @app.route("/<string:dataset_namespace>/<string:dataset_name>")
133
+ def show_first_episode(dataset_namespace, dataset_name):
134
+ first_episode_id = 0
135
+ return redirect(
136
+ url_for(
137
+ "show_episode",
138
+ dataset_namespace=dataset_namespace,
139
+ dataset_name=dataset_name,
140
+ episode_id=first_episode_id,
141
+ )
142
+ )
143
+
144
+ @app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
145
+ def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
146
+ repo_id = f"{dataset_namespace}/{dataset_name}"
147
+ try:
148
+ if dataset is None:
149
+ dataset = get_dataset_info(repo_id)
150
+ except FileNotFoundError:
151
+ return (
152
+ "Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461",
153
+ 400,
154
+ )
155
+ dataset_version = (
156
+ str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
157
+ )
158
+ match = re.search(r"v(\d+)\.", dataset_version)
159
+ if match:
160
+ major_version = int(match.group(1))
161
+ if major_version < 2:
162
+ return "Make sure to convert your LeRobotDataset to v2 & above."
163
+
164
+ episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
165
+ dataset_info = {
166
+ "repo_id": f"{dataset_namespace}/{dataset_name}",
167
+ "num_samples": dataset.num_frames
168
+ if isinstance(dataset, LeRobotDataset)
169
+ else dataset.total_frames,
170
+ "num_episodes": dataset.num_episodes
171
+ if isinstance(dataset, LeRobotDataset)
172
+ else dataset.total_episodes,
173
+ "fps": dataset.fps,
174
+ }
175
+ if isinstance(dataset, LeRobotDataset):
176
+ video_paths = [
177
+ dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
178
+ ]
179
+ videos_info = [
180
+ {"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
181
+ for video_path in video_paths
182
+ ]
183
+ tasks = dataset.meta.episodes[episode_id]["tasks"]
184
+ else:
185
+ video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
186
+ videos_info = [
187
+ {
188
+ "url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
189
+ + dataset.video_path.format(
190
+ episode_chunk=int(episode_id) // dataset.chunks_size,
191
+ video_key=video_key,
192
+ episode_index=episode_id,
193
+ ),
194
+ "filename": video_key,
195
+ }
196
+ for video_key in video_keys
197
+ ]
198
+
199
+ response = requests.get(
200
+ f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
201
+ )
202
+ response.raise_for_status()
203
+ # Split into lines and parse each line as JSON
204
+ tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
205
+
206
+ filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
207
+ tasks = filtered_tasks_jsonl[0]["tasks"]
208
+
209
+ videos_info[0]["language_instruction"] = tasks
210
+
211
+ if episodes is None:
212
+ episodes = list(
213
+ range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
214
+ )
215
+
216
+ return render_template(
217
+ "visualize_dataset_template.html",
218
+ episode_id=episode_id,
219
+ episodes=episodes,
220
+ dataset_info=dataset_info,
221
+ videos_info=videos_info,
222
+ episode_data_csv_str=episode_data_csv_str,
223
+ columns=columns,
224
+ ignored_columns=ignored_columns,
225
+ )
226
+
227
+ app.run(host=host, port=port)
228
+
229
+
230
+ def get_ep_csv_fname(episode_id: int):
231
+ ep_csv_fname = f"episode_{episode_id}.csv"
232
+ return ep_csv_fname
233
+
234
+
235
+ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
236
+ """Get a csv str containing timeseries data of an episode (e.g. state and action).
237
+ This file will be loaded by Dygraph javascript to plot data in real time."""
238
+ columns = []
239
+
240
+ selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
241
+ selected_columns.remove("timestamp")
242
+
243
+ ignored_columns = []
244
+ for column_name in selected_columns:
245
+ shape = dataset.features[column_name]["shape"]
246
+ shape_dim = len(shape)
247
+ if shape_dim > 1:
248
+ selected_columns.remove(column_name)
249
+ ignored_columns.append(column_name)
250
+
251
+ # init header of csv with state and action names
252
+ header = ["timestamp"]
253
+
254
+ for column_name in selected_columns:
255
+ dim_state = (
256
+ dataset.meta.shapes[column_name][0]
257
+ if isinstance(dataset, LeRobotDataset)
258
+ else dataset.features[column_name].shape[0]
259
+ )
260
+
261
+ if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
262
+ column_names = dataset.features[column_name]["names"]
263
+ while not isinstance(column_names, list):
264
+ column_names = list(column_names.values())[0]
265
+ else:
266
+ column_names = [f"{column_name}_{i}" for i in range(dim_state)]
267
+ columns.append({"key": column_name, "value": column_names})
268
+
269
+ header += column_names
270
+
271
+ selected_columns.insert(0, "timestamp")
272
+
273
+ if isinstance(dataset, LeRobotDataset):
274
+ from_idx = dataset.episode_data_index["from"][episode_index]
275
+ to_idx = dataset.episode_data_index["to"][episode_index]
276
+ data = (
277
+ dataset.hf_dataset.select(range(from_idx, to_idx))
278
+ .select_columns(selected_columns)
279
+ .with_format("pandas")
280
+ )
281
+ else:
282
+ repo_id = dataset.repo_id
283
+
284
+ url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
285
+ episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
286
+ )
287
+ df = pd.read_parquet(url)
288
+ data = df[selected_columns] # Select specific columns
289
+
290
+ rows = np.hstack(
291
+ (
292
+ np.expand_dims(data["timestamp"], axis=1),
293
+ *[np.vstack(data[col]) for col in selected_columns[1:]],
294
+ )
295
+ ).tolist()
296
+
297
+ # Convert data to CSV string
298
+ csv_buffer = StringIO()
299
+ csv_writer = csv.writer(csv_buffer)
300
+ # Write header
301
+ csv_writer.writerow(header)
302
+ # Write data rows
303
+ csv_writer.writerows(rows)
304
+ csv_string = csv_buffer.getvalue()
305
+
306
+ return csv_string, columns, ignored_columns
307
+
308
+
309
+ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
310
+ # get first frame of episode (hack to get video_path of the episode)
311
+ first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
312
+ return [
313
+ dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
314
+ for key in dataset.meta.video_keys
315
+ ]
316
+
317
+
318
+ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
319
+ # check if the dataset has language instructions
320
+ if "language_instruction" not in dataset.features:
321
+ return None
322
+
323
+ # get first frame index
324
+ first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
325
+
326
+ language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
327
+ # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
328
+ # with the tf.tensor appearing in the string
329
+ return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
330
+
331
+
332
+ def get_dataset_info(repo_id: str) -> IterableNamespace:
333
+ response = requests.get(
334
+ f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
335
+ )
336
+ response.raise_for_status() # Raises an HTTPError for bad responses
337
+ dataset_info = response.json()
338
+ dataset_info["repo_id"] = repo_id
339
+ return IterableNamespace(dataset_info)
340
+
341
+
342
+ def visualize_dataset_html(
343
+ dataset: LeRobotDataset | None,
344
+ episodes: list[int] | None = None,
345
+ output_dir: Path | None = None,
346
+ serve: bool = True,
347
+ host: str = "127.0.0.1",
348
+ port: int = 9090,
349
+ force_override: bool = False,
350
+ ) -> Path | None:
351
+ init_logging()
352
+
353
+ template_dir = Path(__file__).resolve().parent.parent / "templates"
354
+
355
+ if output_dir is None:
356
+ # Create a temporary directory that will be automatically cleaned up
357
+ output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
358
+
359
+ output_dir = Path(output_dir)
360
+ if output_dir.exists():
361
+ if force_override:
362
+ shutil.rmtree(output_dir)
363
+ else:
364
+ logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
365
+
366
+ output_dir.mkdir(parents=True, exist_ok=True)
367
+
368
+ static_dir = output_dir / "static"
369
+ static_dir.mkdir(parents=True, exist_ok=True)
370
+
371
+ if dataset is None:
372
+ if serve:
373
+ run_server(
374
+ dataset=None,
375
+ episodes=None,
376
+ host=host,
377
+ port=port,
378
+ static_folder=static_dir,
379
+ template_folder=template_dir,
380
+ )
381
+ else:
382
+ # Create a simlink from the dataset video folder containing mp4 files to the output directory
383
+ # so that the http server can get access to the mp4 files.
384
+ if isinstance(dataset, LeRobotDataset):
385
+ ln_videos_dir = static_dir / "videos"
386
+ if not ln_videos_dir.exists():
387
+ ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
388
+
389
+ if serve:
390
+ run_server(dataset, episodes, host, port, static_dir, template_dir)
391
+
392
+
393
+ def create_mock_train_config() -> TrainPipelineConfig:
394
+ """Create a mock TrainPipelineConfig for dataset visualization.
395
+
396
+ Returns:
397
+ TrainPipelineConfig: A mock config with default values.
398
+ """
399
+ return TrainPipelineConfig(
400
+ dataset_mixture=DatasetMixtureConfig(), # Will be set by the dataset
401
+ resolution=(224, 224),
402
+ num_cams=2,
403
+ max_state_dim=32,
404
+ max_action_dim=32,
405
+ action_chunk=50,
406
+ loss_weighting={"MSE": 1, "CE": 1},
407
+ num_workers=4,
408
+ batch_size=8,
409
+ steps=100_000,
410
+ log_freq=200,
411
+ save_checkpoint=True,
412
+ save_freq=20_000,
413
+ use_policy_training_preset=True,
414
+ wandb=WandBConfig(),
415
+ )
416
+
417
+
418
+ def main():
419
+ parser = argparse.ArgumentParser()
420
+
421
+ parser.add_argument(
422
+ "--repo-id",
423
+ type=str,
424
+ default=None,
425
+ help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
426
+ )
427
+ parser.add_argument(
428
+ "--root",
429
+ type=Path,
430
+ default=None,
431
+ help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
432
+ )
433
+ parser.add_argument(
434
+ "--load-from-hf-hub",
435
+ type=int,
436
+ default=0,
437
+ help="Load videos and parquet files from HF Hub rather than local system.",
438
+ )
439
+ parser.add_argument(
440
+ "--episodes",
441
+ type=int,
442
+ nargs="*",
443
+ default=None,
444
+ help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
445
+ )
446
+ parser.add_argument(
447
+ "--output-dir",
448
+ type=Path,
449
+ default=None,
450
+ help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
451
+ )
452
+ parser.add_argument(
453
+ "--serve",
454
+ type=int,
455
+ default=1,
456
+ help="Launch web server.",
457
+ )
458
+ parser.add_argument(
459
+ "--host",
460
+ type=str,
461
+ default="127.0.0.1",
462
+ help="Web host used by the http server.",
463
+ )
464
+ parser.add_argument(
465
+ "--port",
466
+ type=int,
467
+ default=9090,
468
+ help="Web port used by the http server.",
469
+ )
470
+ parser.add_argument(
471
+ "--force-override",
472
+ type=int,
473
+ default=0,
474
+ help="Delete the output directory if it exists already.",
475
+ )
476
+
477
+ parser.add_argument(
478
+ "--tolerance-s",
479
+ type=float,
480
+ default=1e-4,
481
+ help=(
482
+ "Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
483
+ "This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
484
+ "If not given, defaults to 1e-4."
485
+ ),
486
+ )
487
+
488
+ args = parser.parse_args()
489
+ kwargs = vars(args)
490
+ repo_id = kwargs.pop("repo_id")
491
+ load_from_hf_hub = kwargs.pop("load_from_hf_hub")
492
+ root = kwargs.pop("root")
493
+ tolerance_s = kwargs.pop("tolerance_s")
494
+
495
+ dataset = None
496
+ if repo_id:
497
+ dataset = (
498
+ LeRobotDataset(create_mock_train_config(), repo_id, root=root, tolerance_s=tolerance_s)
499
+ if not load_from_hf_hub
500
+ else get_dataset_info(repo_id)
501
+ )
502
+
503
+ visualize_dataset_html(dataset, **vars(args))
504
+
505
+
506
+ if __name__ == "__main__":
507
+ main()