opentau 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -96,6 +96,11 @@ class DatasetConfig:
96
96
  data_features_name_mapping: dict[str, str] | None = None
97
97
  loss_type_mapping: str | None = None
98
98
 
99
+ # Ratio of the dataset to be used for validation. Please specify a value.
100
+ # If `val_freq` is set to 0, a validation dataset will not be created and this value will be ignored.
101
+ # Defaults to 0.05.
102
+ val_split_ratio: float = 0.05
103
+
99
104
  def __post_init__(self):
100
105
  """Validate dataset configuration and register custom mappings if provided."""
101
106
  if (self.repo_id is None) == (self.grounding is None):
@@ -148,6 +153,11 @@ class DatasetMixtureConfig:
148
153
  image_resample_strategy: str = "nearest"
149
154
  # Resample strategy for non-image features, such as action or state
150
155
  vector_resample_strategy: str = "nearest"
156
+ # Ratio of the dataset to be used for validation. Please specify a value.
157
+ # If `val_freq` is set to 0, a validation dataset will not be created and this value will be ignored.
158
+ # This value is applied to all datasets in the mixture.
159
+ # Defaults to 0.05.
160
+ val_split_ratio: float = 0.05
151
161
 
152
162
  def __post_init__(self):
153
163
  """Validate dataset mixture configuration."""
@@ -163,6 +173,12 @@ class DatasetMixtureConfig:
163
173
  raise ValueError(
164
174
  f"`vector_resample_strategy` must be one of ['linear', 'nearest'], got {self.vector_resample_strategy}."
165
175
  )
176
+ if self.val_split_ratio < 0 or self.val_split_ratio > 1:
177
+ raise ValueError(f"`val_split_ratio` must be between 0 and 1, got {self.val_split_ratio}.")
178
+
179
+ # set the val_split_ratio for all datasets in the mixture
180
+ for dataset_cfg in self.datasets:
181
+ dataset_cfg.val_split_ratio = self.val_split_ratio
166
182
 
167
183
 
168
184
  @dataclass
@@ -0,0 +1,85 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Deployment configuration classes for inference servers.
15
+
16
+ This module provides configuration classes for deploying trained models
17
+ as inference servers, including gRPC server settings.
18
+ """
19
+
20
+ from dataclasses import dataclass
21
+
22
+
23
+ @dataclass
24
+ class ServerConfig:
25
+ """Configuration for the gRPC inference server.
26
+
27
+ This class contains all configuration parameters needed to run a gRPC
28
+ inference server for robot policy models.
29
+
30
+ Args:
31
+ port: Port number to serve on. Must be between 1 and 65535.
32
+ Defaults to 50051.
33
+ max_workers: Maximum number of gRPC worker threads for handling
34
+ concurrent requests. Defaults to 4.
35
+ max_send_message_length_mb: Maximum size of outgoing messages in
36
+ megabytes. Defaults to 100.
37
+ max_receive_message_length_mb: Maximum size of incoming messages in
38
+ megabytes. Defaults to 100.
39
+
40
+ Raises:
41
+ ValueError: If port is not in valid range or max_workers is less than 1.
42
+
43
+ Example:
44
+ >>> config = ServerConfig(port=50051, max_workers=8)
45
+ >>> config.port
46
+ 50051
47
+ """
48
+
49
+ port: int = 50051
50
+ max_workers: int = 4
51
+ max_send_message_length_mb: int = 100
52
+ max_receive_message_length_mb: int = 100
53
+
54
+ def __post_init__(self):
55
+ """Validate server configuration parameters."""
56
+ if not 1 <= self.port <= 65535:
57
+ raise ValueError(f"`port` must be between 1 and 65535, got {self.port}.")
58
+ if self.max_workers < 1:
59
+ raise ValueError(f"`max_workers` must be at least 1, got {self.max_workers}.")
60
+ if self.max_send_message_length_mb < 1:
61
+ raise ValueError(
62
+ f"`max_send_message_length_mb` must be at least 1, got {self.max_send_message_length_mb}."
63
+ )
64
+ if self.max_receive_message_length_mb < 1:
65
+ raise ValueError(
66
+ f"`max_receive_message_length_mb` must be at least 1, got {self.max_receive_message_length_mb}."
67
+ )
68
+
69
+ @property
70
+ def max_send_message_length(self) -> int:
71
+ """Get maximum send message length in bytes.
72
+
73
+ Returns:
74
+ Maximum send message length in bytes.
75
+ """
76
+ return self.max_send_message_length_mb * 1024 * 1024
77
+
78
+ @property
79
+ def max_receive_message_length(self) -> int:
80
+ """Get maximum receive message length in bytes.
81
+
82
+ Returns:
83
+ Maximum receive message length in bytes.
84
+ """
85
+ return self.max_receive_message_length_mb * 1024 * 1024
opentau/configs/train.py CHANGED
@@ -32,6 +32,7 @@ from huggingface_hub.errors import HfHubHTTPError
32
32
 
33
33
  from opentau.configs import parser
34
34
  from opentau.configs.default import DatasetMixtureConfig, EvalConfig, WandBConfig
35
+ from opentau.configs.deployment import ServerConfig
35
36
  from opentau.configs.policies import PreTrainedConfig
36
37
  from opentau.envs.configs import EnvConfig
37
38
  from opentau.optim import OptimizerConfig
@@ -116,6 +117,7 @@ class TrainPipelineConfig(HubMixin):
116
117
  is disabled. Defaults to 0.
117
118
  last_checkpoint_only: If True, only evaluate the last checkpoint.
118
119
  Defaults to True.
120
+ server: Configuration for the gRPC inference server. Defaults to ServerConfig().
119
121
  """
120
122
 
121
123
  dataset_mixture: DatasetMixtureConfig
@@ -163,7 +165,10 @@ class TrainPipelineConfig(HubMixin):
163
165
  env: EnvConfig | None = None
164
166
  eval: EvalConfig | None = field(default_factory=EvalConfig)
165
167
  eval_freq: int = 0 # evaluate every eval_freq steps
168
+ val_freq: int = 0 # validate every val_freq steps, if 0, then a validation split is not created
166
169
  last_checkpoint_only: bool = True
170
+ # gRPC inference server configuration
171
+ server: ServerConfig = field(default_factory=ServerConfig)
167
172
 
168
173
  def __post_init__(self):
169
174
  """Initialize post-creation attributes and validate batch size configuration."""
@@ -61,7 +61,11 @@ Example:
61
61
  >>> dataloader = mixture.get_dataloader()
62
62
  """
63
63
 
64
+ import copy
65
+ from typing import Tuple, Union
66
+
64
67
  import numpy as np
68
+ import torch
65
69
 
66
70
  # NOTE: Don't delete; imported for side effects.
67
71
  import opentau.datasets.grounding.clevr # noqa: F401
@@ -151,9 +155,13 @@ def make_dataset(
151
155
  cfg: DatasetConfig,
152
156
  train_cfg: TrainPipelineConfig,
153
157
  return_advantage_input: bool = False,
154
- ) -> BaseDataset:
158
+ ) -> Union[BaseDataset, Tuple[BaseDataset, BaseDataset]]:
155
159
  """Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
156
160
 
161
+ A train and validation dataset are returned if `train_cfg.val_freq` is greater than 0.
162
+ The validation dataset is a subset of the train dataset, and is used for evaluation during training.
163
+ The validation dataset is created by splitting the train dataset into train and validation sets based on `cfg.val_split_ratio`.
164
+
157
165
  Args:
158
166
  cfg (DatasetConfig): A DatasetConfig used to create a LeRobotDataset.
159
167
  train_cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
@@ -161,10 +169,11 @@ def make_dataset(
161
169
  "episode_end_idx", "current_idx", "last_step", "episode_index", and "timestamp". Defaults to False.
162
170
 
163
171
  Raises:
164
- NotImplementedError: The MultiLeRobotDataset is currently deactivated.
172
+ ValueError: If exactly one of `cfg.grounding` and `cfg.repo_id` is not provided.
173
+ ValueError: If `cfg.grounding` is not a supported grounding dataset.
165
174
 
166
175
  Returns:
167
- BaseDataset
176
+ BaseDataset or Tuple[BaseDataset, BaseDataset]: A single dataset or a tuple of (train_dataset, val_dataset) if val_freq > 0.
168
177
  """
169
178
  image_transforms = ImageTransforms(cfg.image_transforms) if cfg.image_transforms.enable else None
170
179
 
@@ -209,12 +218,20 @@ def make_dataset(
209
218
  dataset.meta.stats[key] = {}
210
219
  dataset.meta.stats[key][stats_type] = np.array(stats, dtype=np.float32)
211
220
 
221
+ if train_cfg.val_freq > 0:
222
+ val_size = int(len(dataset) * cfg.val_split_ratio)
223
+ train_size = len(dataset) - val_size
224
+ train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
225
+ train_dataset.meta = copy.deepcopy(dataset.meta)
226
+ val_dataset.meta = copy.deepcopy(dataset.meta)
227
+ return train_dataset, val_dataset
228
+
212
229
  return dataset
213
230
 
214
231
 
215
232
  def make_dataset_mixture(
216
233
  cfg: TrainPipelineConfig, return_advantage_input: bool = False
217
- ) -> WeightedDatasetMixture:
234
+ ) -> Union[WeightedDatasetMixture, Tuple[WeightedDatasetMixture, WeightedDatasetMixture]]:
218
235
  """Creates a dataset mixture from the provided TrainPipelineConfig.
219
236
 
220
237
  Args:
@@ -223,10 +240,26 @@ def make_dataset_mixture(
223
240
  "episode_end_idx", "current_idx", "last_step", "episode_index", and "timestamp". Defaults to False.
224
241
 
225
242
  Returns:
226
- WeightedDatasetMixture: An instance of WeightedDatasetMixture containing the datasets.
243
+ WeightedDatasetMixture or Tuple[WeightedDatasetMixture, WeightedDatasetMixture]: An instance of WeightedDatasetMixture containing the datasets, or a tuple of (train_mixture, val_mixture) if val_freq > 0.
227
244
  """
228
- datasets = [
229
- make_dataset(dataset_cfg, cfg, return_advantage_input=return_advantage_input)
230
- for dataset_cfg in cfg.dataset_mixture.datasets
231
- ]
232
- return WeightedDatasetMixture(cfg, datasets, cfg.dataset_mixture.weights, cfg.dataset_mixture.action_freq)
245
+ datasets = []
246
+ val_datasets = []
247
+ for dataset_cfg in cfg.dataset_mixture.datasets:
248
+ res = make_dataset(dataset_cfg, cfg, return_advantage_input=return_advantage_input)
249
+ if isinstance(res, tuple):
250
+ datasets.append(res[0])
251
+ val_datasets.append(res[1])
252
+ else:
253
+ datasets.append(res)
254
+
255
+ train_mixture = WeightedDatasetMixture(
256
+ cfg, datasets, cfg.dataset_mixture.weights, cfg.dataset_mixture.action_freq
257
+ )
258
+
259
+ if val_datasets:
260
+ val_mixture = WeightedDatasetMixture(
261
+ cfg, val_datasets, cfg.dataset_mixture.weights, cfg.dataset_mixture.action_freq
262
+ )
263
+ return train_mixture, val_mixture
264
+
265
+ return train_mixture
@@ -150,6 +150,7 @@ from opentau.policies.value.configuration_value import ValueConfig
150
150
  from opentau.policies.value.reward import (
151
151
  calculate_return_bins_with_equal_width,
152
152
  )
153
+ from opentau.utils.accelerate_utils import get_proc_accelerator
153
154
  from opentau.utils.utils import on_accelerate_main_proc
154
155
 
155
156
 
@@ -324,8 +325,17 @@ class LeRobotDatasetMetadata(DatasetMetadata):
324
325
  if is_valid_version(self.revision):
325
326
  self.revision = get_safe_version(self.repo_id, self.revision)
326
327
 
327
- (self.root / "meta").mkdir(exist_ok=True, parents=True)
328
- self.pull_from_repo(allow_patterns="meta/")
328
+ # In distributed training, only rank 0 downloads to avoid race conditions
329
+ # where other ranks read metadata before the download has finished.
330
+ acc = get_proc_accelerator()
331
+ if acc is not None and acc.num_processes > 1:
332
+ if acc.is_main_process:
333
+ (self.root / "meta").mkdir(exist_ok=True, parents=True)
334
+ self.pull_from_repo(allow_patterns="meta/")
335
+ acc.wait_for_everyone()
336
+ else:
337
+ (self.root / "meta").mkdir(exist_ok=True, parents=True)
338
+ self.pull_from_repo(allow_patterns="meta/")
329
339
  self.load_metadata()
330
340
 
331
341
  def load_metadata(self) -> None:
@@ -725,14 +735,6 @@ class BaseDataset(torch.utils.data.Dataset):
725
735
  if isinstance(value, torch.Tensor) and value.dtype.is_floating_point:
726
736
  standard_item[key] = value.to(dtype=torch.bfloat16)
727
737
 
728
- # ensure that non-empty strings contain exactly one newline character at the end of the string
729
- for key in ["prompt", "response"]:
730
- if standard_item[key].endswith(
731
- "\n"
732
- ): # ensure there isn't going to be an extra space at the end after calling replace
733
- standard_item[key] = standard_item[key][:-1]
734
- standard_item[key] = standard_item[key].replace("\n", " ") + "\n"
735
-
736
738
  return standard_item
737
739
 
738
740
  def resize_with_pad(self, img, width, height, pad_value=0) -> torch.Tensor:
@@ -108,6 +108,7 @@ import pyarrow as pa
108
108
  import torch
109
109
  import torchvision
110
110
  from datasets.features.features import register_feature
111
+ from packaging import version
111
112
  from PIL import Image
112
113
 
113
114
 
@@ -117,13 +118,17 @@ def get_safe_default_codec() -> str:
117
118
  Returns:
118
119
  Backend name: "torchcodec" if available, otherwise "pyav".
119
120
  """
120
- if importlib.util.find_spec("torchcodec"):
121
- return "torchcodec"
122
- else:
123
- logging.warning(
124
- "'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
125
- )
121
+
122
+ if version.parse(torch.__version__) >= version.parse("2.8.0"):
126
123
  return "pyav"
124
+ else:
125
+ if importlib.util.find_spec("torchcodec"):
126
+ return "torchcodec"
127
+ else:
128
+ logging.warning(
129
+ "'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
130
+ )
131
+ return "pyav"
127
132
 
128
133
 
129
134
  def decode_video_frames(
@@ -49,18 +49,18 @@ class PI05Config(PreTrainedConfig):
49
49
  Defaults to identity for visual features and mean-std for state and action.
50
50
  max_state_dim: Maximum dimension for state vectors. Shorter vectors are padded. Defaults to 32.
51
51
  max_action_dim: Maximum dimension for action vectors. Shorter vectors are padded. Defaults to 32.
52
+ predict_response: Whether to predict the response. Defaults to False.
52
53
  resize_imgs_with_padding: Target size (height, width) for image resizing with padding.
53
54
  Defaults to (224, 224).
54
55
  empty_cameras: Number of empty camera inputs to add. Used for specific adaptations like
55
56
  Aloha simulation. Defaults to 0.
56
- tokenizer_max_length: Maximum length for tokenizer. Defaults to 256.
57
+ prompt_max_length: Maximum length for tokenizer. Defaults to 256.
57
58
  discrete_action_max_length: Maximum length for discrete action tokens. Defaults to 32.
58
59
  proj_width: Width of the projection layer. Defaults to 1024.
59
60
  dropout: Dropout rate. Defaults to 0.1.
60
61
  num_steps: Number of flow matching steps for decoding. Defaults to 10.
61
62
  init_strategy: Initialization strategy. One of "no_init", "full_he_init", "expert_only_he_init".
62
63
  Defaults to "full_he_init".
63
- use_cache: Whether to use KV cache during inference. Defaults to True.
64
64
  attention_implementation: Attention implementation to use ("eager" or "fa2"). Defaults to "eager".
65
65
  freeze_vision_encoder: Whether to freeze the vision encoder during fine-tuning. Defaults to True.
66
66
  train_expert_only: Whether to train only the expert module. Defaults to False.
@@ -89,6 +89,7 @@ class PI05Config(PreTrainedConfig):
89
89
  # Shorter state and action vectors will be padded
90
90
  max_state_dim: int = 32
91
91
  max_action_dim: int = 32
92
+ predict_response: bool = False
92
93
 
93
94
  # Image preprocessing
94
95
  resize_imgs_with_padding: tuple[int, int] = (224, 224)
@@ -97,8 +98,11 @@ class PI05Config(PreTrainedConfig):
97
98
  # left and right wrist cameras in addition to the top camera.
98
99
  empty_cameras: int = 0
99
100
 
100
- # Tokenizer
101
- tokenizer_max_length: int = 256
101
+ # Language Tokenizer
102
+ prompt_max_length: int = 256
103
+
104
+ # Response Tokenizer
105
+ response_max_length: int = 52
102
106
 
103
107
  # Maximum length of the action tokens
104
108
  discrete_action_max_length: int = 32
@@ -116,8 +120,7 @@ class PI05Config(PreTrainedConfig):
116
120
  init_strategy: Literal["no_init", "full_he_init", "expert_only_he_init"] = "full_he_init"
117
121
 
118
122
  # Attention utils
119
- use_cache: bool = True
120
- attention_implementation: str = "eager" # or fa2
123
+ attention_implementation: str = "eager"
121
124
 
122
125
  # Finetuning settings
123
126
  freeze_vision_encoder: bool = True