opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,315 @@
1
+ #!/usr/bin/env python
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """Normalization and Unnormalization utilities for policies.
18
+
19
+ This module provides classes and functions to normalize and unnormalize data
20
+ (e.g., observations and actions) based on statistical properties (mean, std, min, max).
21
+ It handles different normalization modes and supports creating buffers for statistics.
22
+ """
23
+
24
+ import sys
25
+
26
+ import numpy as np
27
+ import torch
28
+ from torch import Tensor, nn
29
+
30
+ from opentau.configs.types import FeatureType, NormalizationMode, PolicyFeature
31
+
32
+ EPS = 1e-8 # Small epsilon value for numerical stability in normalization
33
+
34
+
35
+ def warn_missing_keys(features: dict[str, PolicyFeature], batch: dict[str, Tensor], mode: str) -> None:
36
+ """Warns if expected features are missing from the batch.
37
+
38
+ Args:
39
+ features: Dictionary of expected policy features.
40
+ batch: Dictionary containing the data batch.
41
+ mode: The operation mode (e.g., "normalization" or "unnormalization") for the warning message.
42
+ """
43
+ for missing_key in set(features) - set(batch):
44
+ red_seq = "\033[91m"
45
+ reset_seq = "\033[0m"
46
+ print(
47
+ f"{red_seq}Warning: {missing_key} was missing from the batch during {mode}.{reset_seq}",
48
+ file=sys.stderr,
49
+ )
50
+
51
+
52
+ def create_stats_buffers(
53
+ features: dict[str, PolicyFeature],
54
+ norm_map: dict[str, NormalizationMode],
55
+ stats: dict[str, dict[str, Tensor]] | None = None,
56
+ ) -> dict[str, dict[str, nn.ParameterDict]]:
57
+ """
58
+ Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
59
+ statistics.
60
+
61
+ Args:
62
+ features: Dictionary mapping feature names to PolicyFeature objects.
63
+ norm_map: Dictionary mapping feature types to NormalizationMode.
64
+ stats: Optional dictionary containing pre-computed statistics (mean, std, min, max)
65
+ for each feature. If None, buffers are initialized with infinity.
66
+
67
+ Returns:
68
+ dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
69
+ `nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
70
+
71
+ Raises:
72
+ ValueError: If stats contain types other than np.ndarray or torch.Tensor.
73
+ """
74
+ stats_buffers = {}
75
+
76
+ for key, ft in features.items():
77
+ norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
78
+ if norm_mode is NormalizationMode.IDENTITY:
79
+ continue
80
+
81
+ assert isinstance(norm_mode, NormalizationMode)
82
+
83
+ shape = tuple(ft.shape)
84
+
85
+ if ft.type is FeatureType.VISUAL:
86
+ # sanity checks
87
+ assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
88
+ c, h, w = shape
89
+ assert c < h and c < w, f"{key} is not channel first ({shape=})"
90
+ # override image shape to be invariant to height and width
91
+ shape = (c, 1, 1)
92
+
93
+ # Note: we initialize mean, std, min, max to infinity. They should be overwritten
94
+ # downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
95
+ # we assert they are not infinity anymore.
96
+
97
+ buffer = {}
98
+ if norm_mode is NormalizationMode.MEAN_STD:
99
+ mean = torch.ones(shape, dtype=torch.float32) * torch.inf
100
+ std = torch.ones(shape, dtype=torch.float32) * torch.inf
101
+ buffer = nn.ParameterDict(
102
+ {
103
+ "mean": nn.Parameter(mean, requires_grad=False),
104
+ "std": nn.Parameter(std, requires_grad=False),
105
+ }
106
+ )
107
+ elif norm_mode is NormalizationMode.MIN_MAX:
108
+ min = torch.ones(shape, dtype=torch.float32) * torch.inf
109
+ max = torch.ones(shape, dtype=torch.float32) * torch.inf
110
+ buffer = nn.ParameterDict(
111
+ {
112
+ "min": nn.Parameter(min, requires_grad=False),
113
+ "max": nn.Parameter(max, requires_grad=False),
114
+ }
115
+ )
116
+
117
+ # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
118
+ if stats:
119
+ if isinstance(stats[key]["mean"], np.ndarray):
120
+ if norm_mode is NormalizationMode.MEAN_STD:
121
+ buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
122
+ buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
123
+ elif norm_mode is NormalizationMode.MIN_MAX:
124
+ buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
125
+ buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
126
+ elif isinstance(stats[key]["mean"], torch.Tensor):
127
+ # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
128
+ # tensors anywhere (for example, when we use the same stats for normalization and
129
+ # unnormalization). See the logic here
130
+ # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
131
+ if norm_mode is NormalizationMode.MEAN_STD:
132
+ buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
133
+ buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
134
+ elif norm_mode is NormalizationMode.MIN_MAX:
135
+ buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
136
+ buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
137
+ else:
138
+ type_ = type(stats[key]["mean"])
139
+ raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
140
+
141
+ stats_buffers[key] = buffer
142
+ return stats_buffers
143
+
144
+
145
+ def _no_stats_error_str(name: str) -> str:
146
+ """Returns an error message string for missing statistics.
147
+
148
+ Args:
149
+ name: Name of the statistic (e.g., "mean", "std").
150
+
151
+ Returns:
152
+ str: The error message string.
153
+ """
154
+ return (
155
+ f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
156
+ "pretrained model."
157
+ )
158
+
159
+
160
+ class Normalize(nn.Module):
161
+ """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
162
+
163
+ def __init__(
164
+ self,
165
+ features: dict[str, PolicyFeature],
166
+ norm_map: dict[str, NormalizationMode],
167
+ stats: dict[str, dict[str, Tensor]] | None = None,
168
+ ):
169
+ """Initializes the Normalize module.
170
+
171
+ Args:
172
+ features: A dictionary where keys are input modalities (e.g. "observation.image") and values
173
+ are their PolicyFeature definitions.
174
+ norm_map: A dictionary where keys are feature types and values are their normalization modes.
175
+ stats: A dictionary where keys are output modalities (e.g. "observation.image")
176
+ and values are dictionaries of statistic types and their values (e.g.
177
+ `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
178
+ training the model for the first time, these statistics will overwrite the default buffers. If
179
+ not provided, as expected for finetuning or evaluation, the default buffers should to be
180
+ overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
181
+ dataset is not needed to get the stats, since they are already in the policy state_dict.
182
+ """
183
+ super().__init__()
184
+ self.features = features
185
+ self.norm_map = norm_map
186
+ self.stats = stats
187
+ stats_buffers = create_stats_buffers(features, norm_map, stats)
188
+ for key, buffer in stats_buffers.items():
189
+ setattr(self, "buffer_" + key.replace(".", "_"), buffer)
190
+
191
+ # TODO(rcadene): should we remove torch.no_grad?
192
+ @torch.no_grad
193
+ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
194
+ """Normalizes the batch data.
195
+
196
+ Args:
197
+ batch: Dictionary containing the data to normalize.
198
+
199
+ Returns:
200
+ dict[str, Tensor]: The normalized batch data.
201
+
202
+ Raises:
203
+ ValueError: If an unknown normalization mode is encountered.
204
+ """
205
+ warn_missing_keys(self.features, batch, "normalization")
206
+ batch = dict(batch) # shallow copy avoids mutating the input batch
207
+ for key, ft in self.features.items():
208
+ if key not in batch:
209
+ # FIXME(aliberts, rcadene): This might lead to silent fail!
210
+ continue
211
+
212
+ norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
213
+ if norm_mode is NormalizationMode.IDENTITY:
214
+ continue
215
+
216
+ if batch[key].numel() == 0: # skip empty tensors, which won't broadcast well
217
+ continue
218
+
219
+ buffer = getattr(self, "buffer_" + key.replace(".", "_"))
220
+
221
+ if norm_mode is NormalizationMode.MEAN_STD:
222
+ mean = buffer["mean"]
223
+ std = buffer["std"]
224
+ assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
225
+ assert not torch.isinf(std).any(), _no_stats_error_str("std")
226
+ batch[key] = (batch[key] - mean) / (std + EPS)
227
+ elif norm_mode is NormalizationMode.MIN_MAX:
228
+ min = buffer["min"]
229
+ max = buffer["max"]
230
+ assert not torch.isinf(min).any(), _no_stats_error_str("min")
231
+ assert not torch.isinf(max).any(), _no_stats_error_str("max")
232
+ batch[key] = (batch[key] - min) / (max - min + EPS)
233
+ # normalize to [-1, 1]
234
+ batch[key] = batch[key] * 2 - 1
235
+ else:
236
+ raise ValueError(norm_mode)
237
+ return batch
238
+
239
+
240
+ class Unnormalize(nn.Module):
241
+ """
242
+ Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
243
+ original range used by the environment.
244
+ """
245
+
246
+ def __init__(
247
+ self,
248
+ features: dict[str, PolicyFeature],
249
+ norm_map: dict[str, NormalizationMode],
250
+ stats: dict[str, dict[str, Tensor]] | None = None,
251
+ ):
252
+ """Initializes the Unnormalize module.
253
+
254
+ Args:
255
+ features: A dictionary where keys are input modalities (e.g. "observation.image") and values
256
+ are their PolicyFeature definitions.
257
+ norm_map: A dictionary where keys are feature types and values are their normalization modes.
258
+ stats: A dictionary where keys are output modalities (e.g. "observation.image")
259
+ and values are dictionaries of statistic types and their values. If provided,
260
+ these statistics will overwrite the default buffers.
261
+ """
262
+ super().__init__()
263
+ self.features = features
264
+ self.norm_map = norm_map
265
+ self.stats = stats
266
+ # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
267
+ stats_buffers = create_stats_buffers(features, norm_map, stats)
268
+ for key, buffer in stats_buffers.items():
269
+ setattr(self, "buffer_" + key.replace(".", "_"), buffer)
270
+
271
+ # TODO(rcadene): should we remove torch.no_grad?
272
+ @torch.no_grad
273
+ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
274
+ """Unnormalizes the batch data.
275
+
276
+ Args:
277
+ batch: Dictionary containing the data to unnormalize.
278
+
279
+ Returns:
280
+ dict[str, Tensor]: The unnormalized batch data.
281
+
282
+ Raises:
283
+ ValueError: If an unknown normalization mode is encountered.
284
+ """
285
+ warn_missing_keys(self.features, batch, "unnormalization")
286
+ batch = dict(batch) # shallow copy avoids mutating the input batch
287
+ for key, ft in self.features.items():
288
+ if key not in batch:
289
+ continue
290
+
291
+ norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
292
+ if norm_mode is NormalizationMode.IDENTITY:
293
+ continue
294
+
295
+ if batch[key].numel() == 0: # skip empty tensors, which won't broadcast well
296
+ continue
297
+
298
+ buffer = getattr(self, "buffer_" + key.replace(".", "_"))
299
+
300
+ if norm_mode is NormalizationMode.MEAN_STD:
301
+ mean = buffer["mean"]
302
+ std = buffer["std"]
303
+ assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
304
+ assert not torch.isinf(std).any(), _no_stats_error_str("std")
305
+ batch[key] = batch[key] * (std + EPS) + mean
306
+ elif norm_mode is NormalizationMode.MIN_MAX:
307
+ min = buffer["min"]
308
+ max = buffer["max"]
309
+ assert not torch.isinf(min).any(), _no_stats_error_str("min")
310
+ assert not torch.isinf(max).any(), _no_stats_error_str("max")
311
+ batch[key] = (batch[key] + 1) / 2
312
+ batch[key] = batch[key] * (max - min + EPS) + min
313
+ else:
314
+ raise ValueError(norm_mode)
315
+ return batch
@@ -0,0 +1,19 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """PI0 Policy Module.
15
+
16
+ This module implements the π0 (Pi0) Vision-Language-Action Flow Model policy,
17
+ designed for general robot control. It includes the policy definition,
18
+ configuration, and model architecture.
19
+ """
@@ -0,0 +1,250 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Configuration module for the PI0 Policy.
18
+
19
+ This module defines the `PI0Config` class, which handles the configuration parameters
20
+ for the PI0 Vision-Language-Action Flow Model. It includes settings for the model architecture,
21
+ optimization, scheduling, and data processing.
22
+ """
23
+
24
+ import logging
25
+ from dataclasses import dataclass, field
26
+ from typing import Literal
27
+
28
+ from opentau.configs.policies import PreTrainedConfig
29
+ from opentau.configs.types import FeatureType, NormalizationMode, PolicyFeature
30
+ from opentau.optim.optimizers import AdamWConfig
31
+ from opentau.optim.schedulers import (
32
+ CosineDecayWithWarmupSchedulerConfig,
33
+ LRSchedulerConfig,
34
+ )
35
+
36
+
37
+ @PreTrainedConfig.register_subclass("pi0")
38
+ @dataclass
39
+ class PI0Config(PreTrainedConfig):
40
+ """Configuration class for the PI0 Policy.
41
+
42
+ This class defines the configuration parameters for the PI0 model, including
43
+ input/output structure, model architecture, training settings, and preprocessing.
44
+
45
+ Args:
46
+ n_obs_steps: Number of observation steps to use. Defaults to 1.
47
+ chunk_size: Size of the action chunk. The upper bound for n_action_steps. Defaults to 50.
48
+ n_action_steps: Number of action steps to predict. Defaults to 50.
49
+ safety_buffer: Safety buffer size. Defaults to 0.
50
+ normalization_mapping: Mapping of feature names to normalization modes.
51
+ Defaults to identity for visual features and mean-std for state and action.
52
+ max_state_dim: Maximum dimension for state vectors. Shorter vectors are padded. Defaults to 32.
53
+ max_action_dim: Maximum dimension for action vectors. Shorter vectors are padded. Defaults to 32.
54
+ resize_imgs_with_padding: Target size (height, width) for image resizing with padding.
55
+ Defaults to (224, 224).
56
+ empty_cameras: Number of empty camera inputs to add. Used for specific adaptations like
57
+ Aloha simulation. Defaults to 0.
58
+ tokenizer_max_length: Maximum length for tokenizer. Defaults to 48.
59
+ proj_width: Width of the projection layer. Defaults to 1024.
60
+ dropout: Dropout rate. Defaults to 0.1.
61
+ num_steps: Number of flow matching steps for decoding. Defaults to 10.
62
+ advantage_threshold: Advantage binarization threshold for AWR. Defaults to 0.0.
63
+ advantage: Advantage conditioning mode. One of "ignore", "on", "use".
64
+ "use" uses values from dataset, "ignore" disables conditioning,
65
+ "on" sets advantage to True (for expert demos). Defaults to "use".
66
+ init_strategy: Initialization strategy. One of "no_init", "full_he_init", "expert_only_he_init".
67
+ Defaults to "full_he_init".
68
+ use_cache: Whether to use KV cache during inference. Defaults to True.
69
+ attention_implementation: Attention implementation to use ("eager" or "fa2"). Defaults to "eager".
70
+ freeze_vision_encoder: Whether to freeze the vision encoder during fine-tuning. Defaults to True.
71
+ train_expert_only: Whether to train only the expert module. Defaults to False.
72
+ train_state_proj: Whether to train the state projection layer. Defaults to True.
73
+ optimizer_lr: Learning rate for the optimizer. Defaults to 2.5e-5.
74
+ optimizer_betas: Beta parameters for AdamW optimizer. Defaults to (0.9, 0.95).
75
+ optimizer_eps: Epsilon parameter for AdamW optimizer. Defaults to 1e-8.
76
+ optimizer_weight_decay: Weight decay for AdamW optimizer. Defaults to 1e-10.
77
+ scheduler_warmup_steps: Number of warmup steps for the scheduler. Defaults to 1_000.
78
+ scheduler_decay_steps: Number of decay steps for the scheduler. Defaults to 30_000.
79
+ scheduler_decay_lr: Target learning rate after decay. Defaults to 2.5e-6.
80
+ """
81
+
82
+ # Input / output structure.
83
+ n_obs_steps: int = 1
84
+ chunk_size: int = 50
85
+ n_action_steps: int = 50
86
+ safety_buffer: int = 0
87
+
88
+ normalization_mapping: dict[str, NormalizationMode] = field(
89
+ default_factory=lambda: {
90
+ "VISUAL": NormalizationMode.IDENTITY,
91
+ "STATE": NormalizationMode.MEAN_STD,
92
+ "ACTION": NormalizationMode.MEAN_STD,
93
+ }
94
+ )
95
+
96
+ # Shorter state and action vectors will be padded
97
+ max_state_dim: int = 32
98
+ max_action_dim: int = 32
99
+
100
+ # Image preprocessing
101
+ resize_imgs_with_padding: tuple[int, int] = (224, 224)
102
+
103
+ # Add empty images. Used by pi0_aloha_sim which adds the empty
104
+ # left and right wrist cameras in addition to the top camera.
105
+ empty_cameras: int = 0
106
+
107
+ # Tokenizer
108
+ tokenizer_max_length: int = 48
109
+
110
+ # Projector
111
+ proj_width: int = 1024
112
+
113
+ # Dropout
114
+ dropout: float = 0.1
115
+
116
+ # Decoding
117
+ num_steps: int = 10
118
+
119
+ # advantage binarization threshold for AWR
120
+ advantage_threshold: float = 0.0
121
+
122
+ # When set to "use", the advantage values provided in the dataset will be used.
123
+ # When set to "ignore", no advantage conditioning will be applied.
124
+ # When set to "on", the advantage will always be True.
125
+ # This should only be "on" when training on expert demonstrations or interventions.
126
+ advantage: Literal["ignore", "on", "use"] = "use"
127
+
128
+ # Initialization strategy
129
+ init_strategy: Literal["no_init", "full_he_init", "expert_only_he_init"] = "full_he_init"
130
+
131
+ # Attention utils
132
+ use_cache: bool = True
133
+ attention_implementation: str = "eager" # or fa2
134
+
135
+ # Finetuning settings
136
+ freeze_vision_encoder: bool = True
137
+ train_expert_only: bool = False
138
+ train_state_proj: bool = True
139
+
140
+ # Training presets
141
+ optimizer_lr: float = 2.5e-5
142
+ optimizer_betas: tuple[float, float] = (0.9, 0.95)
143
+ optimizer_eps: float = 1e-8
144
+ optimizer_weight_decay: float = 1e-10
145
+
146
+ scheduler_warmup_steps: int = 1_000
147
+ scheduler_decay_steps: int = 30_000
148
+ scheduler_decay_lr: float = 2.5e-6
149
+
150
+ # TODO: Add EMA
151
+
152
+ def __post_init__(self):
153
+ """Post-initialization validation."""
154
+ super().__post_init__()
155
+
156
+ # TODO(Steven): Validate device and amp? in all policy configs?
157
+ """Input validation (not exhaustive)."""
158
+ if self.n_action_steps > self.chunk_size:
159
+ raise ValueError(
160
+ f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
161
+ f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
162
+ )
163
+ if self.n_obs_steps != 1:
164
+ raise ValueError(
165
+ f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
166
+ )
167
+
168
+ assert self.init_strategy in ["no_init", "full_he_init", "expert_only_he_init"], (
169
+ f"Invalid init strategy: {self.init_strategy} must be one of ['no_init', 'full_he_init', 'expert_only_he_init']"
170
+ )
171
+
172
+ if self.init_strategy == "expert_only_he_init" and self.pretrained_path == "lerobot/pi0":
173
+ raise ValueError(
174
+ "You cannot load pretrained PI0 model when init_strategy is 'expert_only_he_init' due to differences in PaliGemma tokenizer vocab sizes."
175
+ )
176
+
177
+ if self.pretrained_path is not None and self.pretrained_path != "lerobot/pi0":
178
+ logging.info("Setting init_strategy to 'no_init' because we are resuming from a checkpoint.")
179
+ self.init_strategy = "no_init"
180
+
181
+ def validate_features(self) -> None:
182
+ """Validates the features and adds empty cameras if configured.
183
+
184
+ This method checks feature configurations and dynamically adds empty camera inputs
185
+ to `self.input_features` based on the `empty_cameras` parameter.
186
+ """
187
+ # TODO: implement value error
188
+ # if not self.image_features and not self.env_state_feature:
189
+ # raise ValueError("You must provide at least one image or the environment state among the inputs.")
190
+
191
+ for i in range(self.empty_cameras):
192
+ key = f"observation.images.empty_camera_{i}"
193
+ empty_camera = PolicyFeature(
194
+ type=FeatureType.VISUAL,
195
+ shape=(3, 480, 640),
196
+ )
197
+ self.input_features[key] = empty_camera
198
+
199
+ def get_optimizer_preset(self) -> AdamWConfig:
200
+ """Returns the default optimizer configuration.
201
+
202
+ Returns:
203
+ AdamWConfig: The optimizer configuration with default parameters.
204
+ """
205
+ return AdamWConfig(
206
+ lr=self.optimizer_lr,
207
+ betas=self.optimizer_betas,
208
+ eps=self.optimizer_eps,
209
+ weight_decay=self.optimizer_weight_decay,
210
+ )
211
+
212
+ def get_scheduler_preset(self) -> LRSchedulerConfig:
213
+ """Returns the default scheduler configuration.
214
+
215
+ Returns:
216
+ CosineDecayWithWarmupSchedulerConfig: The scheduler configuration with default parameters.
217
+ """
218
+ return CosineDecayWithWarmupSchedulerConfig(
219
+ peak_lr=self.optimizer_lr,
220
+ decay_lr=self.scheduler_decay_lr,
221
+ num_warmup_steps=self.scheduler_warmup_steps,
222
+ num_decay_steps=self.scheduler_decay_steps,
223
+ )
224
+
225
+ @property
226
+ def observation_delta_indices(self) -> None:
227
+ """Indices for observation deltas.
228
+
229
+ Returns:
230
+ None: As observation deltas are not used.
231
+ """
232
+ return None
233
+
234
+ @property
235
+ def action_delta_indices(self) -> list[int]:
236
+ """Indices for action deltas.
237
+
238
+ Returns:
239
+ list[int]: A list of indices corresponding to the chunk size.
240
+ """
241
+ return list(range(self.chunk_size))
242
+
243
+ @property
244
+ def reward_delta_indices(self) -> None:
245
+ """Indices for reward deltas.
246
+
247
+ Returns:
248
+ None: As reward deltas are not used.
249
+ """
250
+ return None