opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,42 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Reward configuration module.
15
+
16
+ This module provides the RewardConfig class which contains configuration
17
+ parameters for reward computation in reinforcement learning scenarios.
18
+ """
19
+
20
+ from dataclasses import dataclass
21
+
22
+
23
+ @dataclass
24
+ class RewardConfig:
25
+ """Configuration for reward computation settings.
26
+
27
+ This configuration is used for reward modeling and computation in reinforcement
28
+ learning scenarios.
29
+
30
+ Args:
31
+ number_of_bins: Number of bins used for reward discretization or binning.
32
+ Defaults to 201.
33
+ C_neg: Negative constant used in reward computation. Defaults to -1000.0.
34
+ reward_normalizer: Normalization factor for rewards. Defaults to 400.
35
+ N_steps_look_ahead: Number of steps to look ahead when computing rewards.
36
+ Defaults to 50.
37
+ """
38
+
39
+ number_of_bins: int = 201
40
+ C_neg: float = -1000.0
41
+ reward_normalizer: int = 400
42
+ N_steps_look_ahead: int = 50
@@ -0,0 +1,370 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Training pipeline configuration module.
16
+
17
+ This module provides the TrainPipelineConfig class which contains all configuration
18
+ parameters needed to run a training pipeline, including dataset settings, policy
19
+ configuration, training hyperparameters, and evaluation settings.
20
+ """
21
+
22
+ import datetime as dt
23
+ import os
24
+ import sys
25
+ from dataclasses import dataclass, field
26
+ from pathlib import Path
27
+ from typing import Type
28
+
29
+ import draccus
30
+ from huggingface_hub import hf_hub_download
31
+ from huggingface_hub.errors import HfHubHTTPError
32
+
33
+ from opentau.configs import parser
34
+ from opentau.configs.default import DatasetMixtureConfig, EvalConfig, WandBConfig
35
+ from opentau.configs.policies import PreTrainedConfig
36
+ from opentau.envs.configs import EnvConfig
37
+ from opentau.optim import OptimizerConfig
38
+ from opentau.optim.schedulers import LRSchedulerConfig
39
+ from opentau.utils.hub import HubMixin
40
+
41
+ TRAIN_CONFIG_NAME = "train_config.json"
42
+
43
+
44
+ # Somehow, calling `logging.warning()` sets the logger level to WARNING.
45
+ # We print directly to stderr instead.
46
+ def warn(*args, **kwargs):
47
+ """Print a warning message to stderr.
48
+
49
+ This function is used instead of logging.warning() to avoid setting the logger
50
+ level to WARNING.
51
+
52
+ Args:
53
+ *args: Variable length argument list to print.
54
+ **kwargs: Arbitrary keyword arguments passed to print().
55
+ """
56
+ print("WARNING:", *args, **kwargs, file=sys.stderr)
57
+
58
+
59
+ @dataclass
60
+ class TrainPipelineConfig(HubMixin):
61
+ """Configuration for the training pipeline.
62
+
63
+ This class contains all configuration parameters needed to run a training
64
+ pipeline, including dataset settings, policy configuration, training hyperparameters,
65
+ and evaluation settings.
66
+
67
+ Args:
68
+ dataset_mixture: Configuration for the dataset mixture to use during training.
69
+ policy: Configuration for the policy model. If None, must be set via CLI or
70
+ from a pretrained checkpoint.
71
+ output_dir: Directory where all run outputs will be saved. If another training
72
+ session uses the same directory, its contents will be overwritten unless
73
+ `resume` is set to True.
74
+ job_name: Name identifier for the training job. If not provided, defaults to
75
+ the policy type.
76
+ resume: If True, resume a previous run. Requires `output_dir` to point to
77
+ an existing run directory with at least one checkpoint. When resuming,
78
+ the configuration from the checkpoint is used by default, regardless of
79
+ command-line arguments.
80
+ seed: Random seed used for training (model initialization, dataset shuffling)
81
+ and for evaluation environments. Defaults to 1000.
82
+ resolution: Resolution of images (height, width) in data samples. Defaults to (224, 224).
83
+ num_cams: Number of cameras for the cloud VLM in each data sample. Defaults to 2.
84
+ max_state_dim: Maximum dimension of the state vector. Defaults to 32.
85
+ max_action_dim: Maximum dimension of the action vector. Defaults to 32.
86
+ action_chunk: Size of action chunk. Defaults to 50.
87
+ loss_weighting: Dictionary mapping loss type names to their weights.
88
+ Defaults to {"MSE": 1, "CE": 1}.
89
+ num_workers: Number of workers for the dataloader. Defaults to 4.
90
+ batch_size: Total batch size for training. If None, calculated from
91
+ `dataloader_batch_size * gradient_accumulation_steps`.
92
+ gradient_accumulation_steps: Number of gradient accumulation steps.
93
+ Defaults to 1.
94
+ dataloader_batch_size: Batch size used by the dataloader. If None, calculated
95
+ from `batch_size // gradient_accumulation_steps`.
96
+ prefetch_factor: Prefetch factor for the dataloader. If None, uses default.
97
+ steps: Total number of training steps. Defaults to 100,000.
98
+ log_freq: Frequency of logging in training iterations. Defaults to 200.
99
+ save_checkpoint: Whether to save checkpoints during training. Defaults to True.
100
+ save_freq: Frequency of checkpoint saving in training iterations. Checkpoints
101
+ are saved every `save_freq` steps and after the last training step.
102
+ Defaults to 20,000.
103
+ use_policy_training_preset: If True, use optimizer and scheduler presets from
104
+ the policy configuration. Defaults to False.
105
+ optimizer: Configuration for the optimizer. Required if
106
+ `use_policy_training_preset` is False.
107
+ scheduler: Configuration for the learning rate scheduler. Required if
108
+ `use_policy_training_preset` is False.
109
+ wandb: Configuration for Weights & Biases logging. Defaults to WandBConfig().
110
+ debug: If True, set logging level to DEBUG. Defaults to False.
111
+ trace_nans: Enable anomaly detection for debugging NaN/Inf values.
112
+ Warning: causes large computational overhead. Defaults to False.
113
+ env: Optional environment configuration for evaluation. Defaults to None.
114
+ eval: Configuration for evaluation settings. Defaults to EvalConfig().
115
+ eval_freq: Frequency of evaluation in training steps. If 0, evaluation
116
+ is disabled. Defaults to 0.
117
+ last_checkpoint_only: If True, only evaluate the last checkpoint.
118
+ Defaults to True.
119
+ """
120
+
121
+ dataset_mixture: DatasetMixtureConfig
122
+ policy: PreTrainedConfig | None = None
123
+ # Set `dir` to where you would like to save all of the run outputs. If you run another training session
124
+ # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
125
+ output_dir: Path | None = None
126
+ job_name: str | None = None
127
+ # Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
128
+ # `dir` is the directory of an existing run with at least one checkpoint in it.
129
+ # Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
130
+ # regardless of what's provided with the training command at the time of resumption.
131
+ resume: bool = False
132
+ # `seed` is used for training (eg: model initialization, dataset shuffling)
133
+ # AND for the evaluation environments.
134
+ seed: int | None = 1000
135
+ # parameters for the Standard Data Format
136
+ resolution: tuple[int, int] = (224, 224) # resolution of images (H, W) in data sample
137
+ num_cams: int = 2 # number of cameras for the cloud VLM in each data sample
138
+ max_state_dim: int = 32 # maximum dimension of the state vector
139
+ max_action_dim: int = 32 # maximum dimension of the action vector
140
+ action_chunk: int = 50 # size of action chunk
141
+ loss_weighting: dict[str, float] = field(default_factory=lambda: {"MSE": 1, "CE": 1})
142
+ # Number of workers for the dataloader.
143
+ num_workers: int = 4
144
+ batch_size: int | None = None
145
+ gradient_accumulation_steps: int = 1
146
+ dataloader_batch_size: int | None = None
147
+ # Prefetch factor for the dataloader.
148
+ prefetch_factor: int | None = None
149
+ steps: int = 100_000
150
+ log_freq: int = 200
151
+ save_checkpoint: bool = True
152
+ # Checkpoint is saved every `save_freq` training iterations and after the last training step.
153
+ save_freq: int = 20_000
154
+ use_policy_training_preset: bool = False
155
+ optimizer: OptimizerConfig | None = None
156
+ scheduler: LRSchedulerConfig | None = None
157
+ wandb: WandBConfig = field(default_factory=WandBConfig)
158
+ # Whether to set the logging level to DEBUG. By default, the logging level will be INFO.
159
+ debug: bool = False
160
+ # Enable anomaly detection for debugging NaN/Inf values (warning: large computational overhead)
161
+ trace_nans: bool = False
162
+ # optional environment and evaluation config for evaluation
163
+ env: EnvConfig | None = None
164
+ eval: EvalConfig | None = field(default_factory=EvalConfig)
165
+ eval_freq: int = 0 # evaluate every eval_freq steps
166
+ last_checkpoint_only: bool = True
167
+
168
+ def __post_init__(self):
169
+ """Initialize post-creation attributes and validate batch size configuration."""
170
+ self.checkpoint_path = None
171
+
172
+ if self.dataloader_batch_size is None and self.batch_size is None:
173
+ raise ValueError("At least one of `batch_size` and `dataloader_batch_size` should be set.")
174
+ if self.batch_size is None:
175
+ self.batch_size = self.dataloader_batch_size * self.gradient_accumulation_steps
176
+ if self.dataloader_batch_size is None:
177
+ if self.batch_size % self.gradient_accumulation_steps != 0:
178
+ raise ValueError(
179
+ "`batch_size` must be divisible by `gradient_accumulation_steps` "
180
+ "when `dataloader_batch_size` is not set. "
181
+ f"Got {self.batch_size=}, {self.gradient_accumulation_steps=}."
182
+ )
183
+ self.dataloader_batch_size = self.batch_size // self.gradient_accumulation_steps
184
+ if self.dataloader_batch_size * self.gradient_accumulation_steps != self.batch_size:
185
+ raise ValueError(
186
+ "`batch_size` must be equal to `dataloader_batch_size * gradient_accumulation_steps`. "
187
+ f"Got {self.batch_size=}, {self.dataloader_batch_size=}, {self.gradient_accumulation_steps=}."
188
+ )
189
+ assert (
190
+ self.batch_size >= 1 and self.gradient_accumulation_steps >= 1 and self.dataloader_batch_size >= 1
191
+ )
192
+
193
+ if self.policy:
194
+ self.policy.max_state_dim = self.max_state_dim
195
+ self.policy.max_action_state = self.max_action_dim
196
+ self.policy.chunk_size = self.action_chunk
197
+ if self.job_name:
198
+ warn(
199
+ "cfg.job_name is deprecated and ignored. Set cfg.wandb.project and/or cfg.wandb.name instead."
200
+ )
201
+
202
+ def validate(self):
203
+ """Validate and finalize the training configuration.
204
+
205
+ This method performs several validation and setup tasks:
206
+ - Loads policy configuration from CLI arguments or pretrained path if specified
207
+ - Sets up checkpoint paths for resuming training
208
+ - Validates output directory and creates default if needed
209
+ - Sets up optimizer and scheduler from presets if enabled
210
+ - Updates policy configuration with training parameters
211
+
212
+ Raises:
213
+ ValueError: If required configurations are missing or invalid.
214
+ FileExistsError: If output directory exists and resume is False.
215
+ NotADirectoryError: If config_path for resuming doesn't exist locally.
216
+ """
217
+ # HACK: We parse again the cli args here to get the pretrained paths if there was some.
218
+ policy_path = parser.get_path_arg("policy")
219
+ if policy_path:
220
+ # Only load the policy config
221
+ cli_overrides = parser.get_cli_overrides("policy")
222
+ self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
223
+ self.policy.pretrained_path = policy_path
224
+ elif self.resume:
225
+ # The entire train config is already loaded, we just need to get the checkpoint dir
226
+ config_path = parser.parse_arg("config_path")
227
+ if not config_path:
228
+ raise ValueError(
229
+ f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
230
+ )
231
+ if not Path(config_path).resolve().exists():
232
+ raise NotADirectoryError(
233
+ f"{config_path=} is expected to be a local path. "
234
+ "Resuming from the hub is not supported for now."
235
+ )
236
+ policy_path = Path(config_path).parent
237
+ self.policy.pretrained_path = policy_path
238
+ self.checkpoint_path = policy_path
239
+
240
+ if not self.job_name:
241
+ self.job_name = f"{self.policy.type}"
242
+
243
+ if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
244
+ raise FileExistsError(
245
+ f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
246
+ f"Please change your output directory so that {self.output_dir} is not overwritten."
247
+ )
248
+ elif not self.output_dir:
249
+ now = dt.datetime.now()
250
+ train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
251
+ self.output_dir = Path("outputs/train") / train_dir
252
+
253
+ if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
254
+ raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
255
+ elif self.use_policy_training_preset and not self.resume:
256
+ self.optimizer = self.policy.get_optimizer_preset()
257
+ self.scheduler = self.policy.get_scheduler_preset()
258
+
259
+ if self.policy:
260
+ self.policy.max_state_dim = self.max_state_dim
261
+ self.policy.max_action_state = self.max_action_dim
262
+ self.policy.chunk_size = self.action_chunk
263
+
264
+ @classmethod
265
+ def __get_path_fields__(cls) -> list[str]:
266
+ """Get list of field names that support path-based loading.
267
+
268
+ This enables the parser to load config from the policy using
269
+ `--policy.path=local/dir`.
270
+
271
+ Returns:
272
+ List of field names that support path-based configuration loading.
273
+ """
274
+ return ["policy"]
275
+
276
+ def to_dict(self) -> dict:
277
+ """Convert the configuration to a dictionary.
278
+
279
+ Returns:
280
+ Dictionary representation of the configuration.
281
+ """
282
+ return draccus.encode(self)
283
+
284
+ def _save_pretrained(self, save_directory: Path) -> None:
285
+ """Save the configuration to a directory.
286
+
287
+ Args:
288
+ save_directory: Directory path where the configuration will be saved.
289
+ """
290
+ with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
291
+ draccus.dump(self, f, indent=4)
292
+
293
+ @classmethod
294
+ def from_pretrained(
295
+ cls: Type["TrainPipelineConfig"],
296
+ pretrained_name_or_path: str | Path,
297
+ *,
298
+ force_download: bool = False,
299
+ resume_download: bool = None,
300
+ proxies: dict | None = None,
301
+ token: str | bool | None = None,
302
+ cache_dir: str | Path | None = None,
303
+ local_files_only: bool = False,
304
+ revision: str | None = None,
305
+ **kwargs,
306
+ ) -> "TrainPipelineConfig":
307
+ """Load a training configuration from a pretrained model or local path.
308
+
309
+ Args:
310
+ cls: The class to instantiate.
311
+ pretrained_name_or_path: Can be either:
312
+
313
+ - A string, the model id of a pretrained config hosted inside a model
314
+ repo on huggingface.co.
315
+ - A path to a directory containing a configuration file saved using
316
+ the `_save_pretrained` method.
317
+ - A path to a saved configuration JSON file.
318
+ force_download: Whether to force (re-)downloading the config files and
319
+ configuration from the HuggingFace Hub. Defaults to False.
320
+ resume_download: Whether to resume downloading the config files.
321
+ Defaults to None.
322
+ proxies: Dictionary of proxies to use for requests. Defaults to None.
323
+ token: The token to use as HTTP bearer authorization. If True, will use
324
+ the token generated when running `huggingface-cli login`. Defaults to None.
325
+ cache_dir: Path to a directory in which a downloaded pretrained model
326
+ configuration should be cached. Defaults to None.
327
+ local_files_only: Whether to only look at local files (i.e., do not try
328
+ to download the config). Defaults to False.
329
+ revision: The specific model version to use. It can be a branch name, a
330
+ tag name, or a commit id. Defaults to None.
331
+ **kwargs: Additional keyword arguments passed to the parser.
332
+
333
+ Returns:
334
+ An instance of TrainPipelineConfig loaded from the specified path.
335
+
336
+ Raises:
337
+ FileNotFoundError: If the configuration file is not found on the
338
+ HuggingFace Hub or in the local path.
339
+ """
340
+ model_id = str(pretrained_name_or_path)
341
+ config_file: str | None = None
342
+ if Path(model_id).is_dir():
343
+ if TRAIN_CONFIG_NAME in os.listdir(model_id):
344
+ config_file = os.path.join(model_id, TRAIN_CONFIG_NAME)
345
+ else:
346
+ print(f"{TRAIN_CONFIG_NAME} not found in {Path(model_id).resolve()}")
347
+ elif Path(model_id).is_file():
348
+ config_file = model_id
349
+ else:
350
+ try:
351
+ config_file = hf_hub_download(
352
+ repo_id=model_id,
353
+ filename=TRAIN_CONFIG_NAME,
354
+ revision=revision,
355
+ cache_dir=cache_dir,
356
+ force_download=force_download,
357
+ proxies=proxies,
358
+ resume_download=resume_download,
359
+ token=token,
360
+ local_files_only=local_files_only,
361
+ )
362
+ except HfHubHTTPError as e:
363
+ raise FileNotFoundError(
364
+ f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
365
+ ) from e
366
+
367
+ cli_args = kwargs.pop("cli_args", [])
368
+ cfg = draccus.parse(cls, config_file, args=cli_args)
369
+
370
+ return cfg
@@ -0,0 +1,76 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Type definitions for configuration classes.
16
+
17
+ This module provides type definitions used across configuration classes, including
18
+ enumerations for feature types and normalization modes, as well as protocol
19
+ definitions and dataclasses for policy features.
20
+ """
21
+
22
+ # Note: We subclass str so that serialization is straightforward
23
+ # https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
24
+ from dataclasses import dataclass
25
+ from enum import Enum
26
+ from typing import Any, Protocol
27
+
28
+
29
+ class FeatureType(str, Enum):
30
+ """Enumeration of feature types used in policy configurations."""
31
+
32
+ STATE = "STATE"
33
+ """Robot state features."""
34
+ VISUAL = "VISUAL"
35
+ """Visual/image features."""
36
+ ENV = "ENV"
37
+ """Environment state features."""
38
+ ACTION = "ACTION"
39
+ """Action features."""
40
+
41
+
42
+ class NormalizationMode(str, Enum):
43
+ """Enumeration of normalization modes for features."""
44
+
45
+ MIN_MAX = "MIN_MAX"
46
+ """Normalize using min-max scaling."""
47
+ MEAN_STD = "MEAN_STD"
48
+ """Normalize using mean and standard deviation."""
49
+ IDENTITY = "IDENTITY"
50
+ """No normalization (identity transformation)."""
51
+
52
+
53
+ class DictLike(Protocol):
54
+ """Protocol for dictionary-like objects that support item access.
55
+
56
+ This protocol defines the interface for objects that can be accessed
57
+ using dictionary-style indexing with the `[]` operator.
58
+ """
59
+
60
+ def __getitem__(self, key: Any) -> Any: ...
61
+
62
+
63
+ @dataclass
64
+ class PolicyFeature:
65
+ """Configuration for a policy feature.
66
+
67
+ This class describes a single feature used by a policy, including its
68
+ type and shape information.
69
+
70
+ Args:
71
+ type: The type of feature (STATE, VISUAL, ENV, or ACTION).
72
+ shape: The shape of the feature as a tuple.
73
+ """
74
+
75
+ type: FeatureType
76
+ shape: tuple
opentau/constants.py ADDED
@@ -0,0 +1,52 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Constants used throughout the OpenTau library.
16
+
17
+ This module defines key constants for:
18
+ - Observation and action keys used in datasets and environments
19
+ - File and directory names for checkpoints, training state, and model storage
20
+ - Cache directory configuration for HuggingFace Hub integration
21
+
22
+ These constants ensure consistent naming conventions across the codebase and
23
+ provide a centralized location for configuration values.
24
+ """
25
+
26
+ import os
27
+ from pathlib import Path
28
+
29
+ from huggingface_hub.constants import HF_HOME
30
+
31
+ OBS_STATE = "observation.state"
32
+ OBS_ENV = "observation.environment_state" # TODO: remove
33
+ OBS_ROBOT = "state" # TODO: remove
34
+ OBS_IMAGE = "observation.image" # TODO: remove
35
+ OBS_IMAGES = "observation.images" # TODO: remove
36
+ ACTION = "actions" # TODO: remove
37
+ OBS_ENV_STATE = "observation.environment_state"
38
+
39
+ # files & directories
40
+ CHECKPOINTS_DIR = "checkpoints"
41
+ LAST_CHECKPOINT_LINK = "last"
42
+ PRETRAINED_MODEL_DIR = "pretrained_model"
43
+ TRAINING_STATE_DIR = "training_state"
44
+ RNG_STATE = "rng_state.safetensors"
45
+ TRAINING_STEP = "training_step.json"
46
+ OPTIMIZER_STATE = "optimizer_state.safetensors"
47
+ OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
48
+ SCHEDULER_STATE = "scheduler_state.json"
49
+
50
+ # cache dir
51
+ default_cache_path = Path(HF_HOME) / "opentau"
52
+ HF_OPENTAU_HOME = Path(os.getenv("HF_OPENTAU_HOME", default_cache_path)).expanduser()
@@ -0,0 +1,84 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Dataset management and processing utilities for robot learning and vision-language tasks.
15
+
16
+ This module provides a comprehensive toolkit for loading, creating, managing, and
17
+ processing datasets for training vision-language-action (VLA) models. It supports
18
+ both robot learning datasets (with actions and states) and vision-language
19
+ grounding datasets (for multimodal understanding tasks).
20
+
21
+ The module is organized into several key components:
22
+
23
+ - **Core Datasets**: LeRobotDataset for robot learning data with support for
24
+ temporal alignment, multi-modal data, and version compatibility.
25
+ - **Grounding Datasets**: Vision-language datasets (CLEVR, COCO-QA, PIXMO, VSR)
26
+ for training visual understanding without robot actions.
27
+ - **Dataset Mixtures**: WeightedDatasetMixture for combining multiple datasets
28
+ with controlled sampling proportions.
29
+ - **Data Processing**: Utilities for statistics computation, image/video
30
+ handling, transforms, and format standardization.
31
+ - **Factory Functions**: High-level functions for creating datasets and mixtures
32
+ from configuration objects.
33
+
34
+ Key Features:
35
+
36
+ - **HuggingFace Integration**: Seamless loading from HuggingFace Hub with
37
+ automatic version checking and backward compatibility.
38
+ - **Temporal Alignment**: Delta timestamps enable sampling features at
39
+ different time offsets with optional Gaussian noise for data augmentation.
40
+ - **Multi-modal Support**: Handles images, videos, state vectors, actions,
41
+ and text prompts with automatic format conversion.
42
+ - **Weighted Sampling**: Combine heterogeneous datasets with configurable
43
+ sampling weights for balanced training.
44
+ - **Standard Data Format**: Unified data format across all datasets for
45
+ consistent model input/output interfaces.
46
+ - **Statistics Management**: Automatic computation and aggregation of dataset
47
+ statistics for normalization.
48
+ - **Video Handling**: Multiple video backends (torchcodec, pyav, video_reader)
49
+ for efficient frame extraction and encoding.
50
+ - **Asynchronous I/O**: High-performance image writing for real-time data
51
+ recording without blocking.
52
+
53
+ Main Modules:
54
+
55
+ - **lerobot_dataset**: Core dataset implementation for robot learning data.
56
+ - **grounding**: Vision-language grounding datasets (CLEVR, COCO-QA, PIXMO, VSR).
57
+ - **dataset_mixture**: Weighted combination of multiple datasets.
58
+ - **factory**: Factory functions for creating datasets from configurations.
59
+ - **utils**: Utility functions for I/O, metadata management, and validation.
60
+ - **compute_stats**: Statistics computation and aggregation utilities.
61
+ - **transforms**: Image transformation pipelines for data augmentation.
62
+ - **video_utils**: Video encoding, decoding, and metadata extraction.
63
+ - **image_writer**: Asynchronous image writing for high-frequency recording.
64
+ - **sampler**: Episode-aware sampling with boundary frame filtering.
65
+ - **standard_data_format_mapping**: Feature name and loss type mappings.
66
+
67
+ Example:
68
+ Create a dataset mixture from configuration:
69
+
70
+ >>> from opentau.datasets.factory import make_dataset_mixture
71
+ >>> mixture = make_dataset_mixture(train_cfg)
72
+ >>> dataloader = mixture.get_dataloader()
73
+
74
+ Load a single dataset:
75
+
76
+ >>> from opentau.datasets.factory import make_dataset
77
+ >>> dataset = make_dataset(dataset_cfg, train_cfg)
78
+
79
+ Access grounding datasets:
80
+
81
+ >>> from opentau import available_grounding_datasets
82
+ >>> print(list(available_grounding_datasets.keys()))
83
+ ['clevr', 'cocoqa', 'dummy', 'pixmo', 'vsr']
84
+ """