opentau 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/configs/default.py +16 -0
- opentau/configs/deployment.py +85 -0
- opentau/configs/train.py +5 -0
- opentau/datasets/factory.py +43 -10
- opentau/datasets/lerobot_dataset.py +19 -19
- opentau/datasets/video_utils.py +11 -6
- opentau/policies/pi05/configuration_pi05.py +9 -6
- opentau/policies/pi05/modeling_pi05.py +296 -30
- opentau/policies/pi05/paligemma_with_expert.py +20 -20
- opentau/scripts/grpc/__init__.py +19 -0
- opentau/scripts/grpc/client.py +601 -0
- opentau/scripts/grpc/robot_inference_pb2.py +61 -0
- opentau/scripts/grpc/robot_inference_pb2_grpc.py +210 -0
- opentau/scripts/grpc/server.py +313 -0
- opentau/scripts/launch.py +12 -4
- opentau/scripts/train.py +94 -17
- opentau/scripts/visualize_dataset.py +141 -38
- opentau/utils/transformers_patch.py +251 -20
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/METADATA +37 -17
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/RECORD +24 -21
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/WHEEL +1 -1
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/entry_points.txt +1 -0
- opentau/scripts/libero_simulation_parallel.py +0 -356
- opentau/scripts/libero_simulation_sequential.py +0 -122
- opentau/scripts/visualize_dataset_html.py +0 -507
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/top_level.txt +0 -0
opentau/configs/default.py
CHANGED
|
@@ -96,6 +96,11 @@ class DatasetConfig:
|
|
|
96
96
|
data_features_name_mapping: dict[str, str] | None = None
|
|
97
97
|
loss_type_mapping: str | None = None
|
|
98
98
|
|
|
99
|
+
# Ratio of the dataset to be used for validation. Please specify a value.
|
|
100
|
+
# If `val_freq` is set to 0, a validation dataset will not be created and this value will be ignored.
|
|
101
|
+
# Defaults to 0.05.
|
|
102
|
+
val_split_ratio: float = 0.05
|
|
103
|
+
|
|
99
104
|
def __post_init__(self):
|
|
100
105
|
"""Validate dataset configuration and register custom mappings if provided."""
|
|
101
106
|
if (self.repo_id is None) == (self.grounding is None):
|
|
@@ -148,6 +153,11 @@ class DatasetMixtureConfig:
|
|
|
148
153
|
image_resample_strategy: str = "nearest"
|
|
149
154
|
# Resample strategy for non-image features, such as action or state
|
|
150
155
|
vector_resample_strategy: str = "nearest"
|
|
156
|
+
# Ratio of the dataset to be used for validation. Please specify a value.
|
|
157
|
+
# If `val_freq` is set to 0, a validation dataset will not be created and this value will be ignored.
|
|
158
|
+
# This value is applied to all datasets in the mixture.
|
|
159
|
+
# Defaults to 0.05.
|
|
160
|
+
val_split_ratio: float = 0.05
|
|
151
161
|
|
|
152
162
|
def __post_init__(self):
|
|
153
163
|
"""Validate dataset mixture configuration."""
|
|
@@ -163,6 +173,12 @@ class DatasetMixtureConfig:
|
|
|
163
173
|
raise ValueError(
|
|
164
174
|
f"`vector_resample_strategy` must be one of ['linear', 'nearest'], got {self.vector_resample_strategy}."
|
|
165
175
|
)
|
|
176
|
+
if self.val_split_ratio < 0 or self.val_split_ratio > 1:
|
|
177
|
+
raise ValueError(f"`val_split_ratio` must be between 0 and 1, got {self.val_split_ratio}.")
|
|
178
|
+
|
|
179
|
+
# set the val_split_ratio for all datasets in the mixture
|
|
180
|
+
for dataset_cfg in self.datasets:
|
|
181
|
+
dataset_cfg.val_split_ratio = self.val_split_ratio
|
|
166
182
|
|
|
167
183
|
|
|
168
184
|
@dataclass
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Deployment configuration classes for inference servers.
|
|
15
|
+
|
|
16
|
+
This module provides configuration classes for deploying trained models
|
|
17
|
+
as inference servers, including gRPC server settings.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class ServerConfig:
|
|
25
|
+
"""Configuration for the gRPC inference server.
|
|
26
|
+
|
|
27
|
+
This class contains all configuration parameters needed to run a gRPC
|
|
28
|
+
inference server for robot policy models.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
port: Port number to serve on. Must be between 1 and 65535.
|
|
32
|
+
Defaults to 50051.
|
|
33
|
+
max_workers: Maximum number of gRPC worker threads for handling
|
|
34
|
+
concurrent requests. Defaults to 4.
|
|
35
|
+
max_send_message_length_mb: Maximum size of outgoing messages in
|
|
36
|
+
megabytes. Defaults to 100.
|
|
37
|
+
max_receive_message_length_mb: Maximum size of incoming messages in
|
|
38
|
+
megabytes. Defaults to 100.
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
ValueError: If port is not in valid range or max_workers is less than 1.
|
|
42
|
+
|
|
43
|
+
Example:
|
|
44
|
+
>>> config = ServerConfig(port=50051, max_workers=8)
|
|
45
|
+
>>> config.port
|
|
46
|
+
50051
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
port: int = 50051
|
|
50
|
+
max_workers: int = 4
|
|
51
|
+
max_send_message_length_mb: int = 100
|
|
52
|
+
max_receive_message_length_mb: int = 100
|
|
53
|
+
|
|
54
|
+
def __post_init__(self):
|
|
55
|
+
"""Validate server configuration parameters."""
|
|
56
|
+
if not 1 <= self.port <= 65535:
|
|
57
|
+
raise ValueError(f"`port` must be between 1 and 65535, got {self.port}.")
|
|
58
|
+
if self.max_workers < 1:
|
|
59
|
+
raise ValueError(f"`max_workers` must be at least 1, got {self.max_workers}.")
|
|
60
|
+
if self.max_send_message_length_mb < 1:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
f"`max_send_message_length_mb` must be at least 1, got {self.max_send_message_length_mb}."
|
|
63
|
+
)
|
|
64
|
+
if self.max_receive_message_length_mb < 1:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"`max_receive_message_length_mb` must be at least 1, got {self.max_receive_message_length_mb}."
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def max_send_message_length(self) -> int:
|
|
71
|
+
"""Get maximum send message length in bytes.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Maximum send message length in bytes.
|
|
75
|
+
"""
|
|
76
|
+
return self.max_send_message_length_mb * 1024 * 1024
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def max_receive_message_length(self) -> int:
|
|
80
|
+
"""Get maximum receive message length in bytes.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Maximum receive message length in bytes.
|
|
84
|
+
"""
|
|
85
|
+
return self.max_receive_message_length_mb * 1024 * 1024
|
opentau/configs/train.py
CHANGED
|
@@ -32,6 +32,7 @@ from huggingface_hub.errors import HfHubHTTPError
|
|
|
32
32
|
|
|
33
33
|
from opentau.configs import parser
|
|
34
34
|
from opentau.configs.default import DatasetMixtureConfig, EvalConfig, WandBConfig
|
|
35
|
+
from opentau.configs.deployment import ServerConfig
|
|
35
36
|
from opentau.configs.policies import PreTrainedConfig
|
|
36
37
|
from opentau.envs.configs import EnvConfig
|
|
37
38
|
from opentau.optim import OptimizerConfig
|
|
@@ -116,6 +117,7 @@ class TrainPipelineConfig(HubMixin):
|
|
|
116
117
|
is disabled. Defaults to 0.
|
|
117
118
|
last_checkpoint_only: If True, only evaluate the last checkpoint.
|
|
118
119
|
Defaults to True.
|
|
120
|
+
server: Configuration for the gRPC inference server. Defaults to ServerConfig().
|
|
119
121
|
"""
|
|
120
122
|
|
|
121
123
|
dataset_mixture: DatasetMixtureConfig
|
|
@@ -163,7 +165,10 @@ class TrainPipelineConfig(HubMixin):
|
|
|
163
165
|
env: EnvConfig | None = None
|
|
164
166
|
eval: EvalConfig | None = field(default_factory=EvalConfig)
|
|
165
167
|
eval_freq: int = 0 # evaluate every eval_freq steps
|
|
168
|
+
val_freq: int = 0 # validate every val_freq steps, if 0, then a validation split is not created
|
|
166
169
|
last_checkpoint_only: bool = True
|
|
170
|
+
# gRPC inference server configuration
|
|
171
|
+
server: ServerConfig = field(default_factory=ServerConfig)
|
|
167
172
|
|
|
168
173
|
def __post_init__(self):
|
|
169
174
|
"""Initialize post-creation attributes and validate batch size configuration."""
|
opentau/datasets/factory.py
CHANGED
|
@@ -61,7 +61,11 @@ Example:
|
|
|
61
61
|
>>> dataloader = mixture.get_dataloader()
|
|
62
62
|
"""
|
|
63
63
|
|
|
64
|
+
import copy
|
|
65
|
+
from typing import Tuple, Union
|
|
66
|
+
|
|
64
67
|
import numpy as np
|
|
68
|
+
import torch
|
|
65
69
|
|
|
66
70
|
# NOTE: Don't delete; imported for side effects.
|
|
67
71
|
import opentau.datasets.grounding.clevr # noqa: F401
|
|
@@ -151,9 +155,13 @@ def make_dataset(
|
|
|
151
155
|
cfg: DatasetConfig,
|
|
152
156
|
train_cfg: TrainPipelineConfig,
|
|
153
157
|
return_advantage_input: bool = False,
|
|
154
|
-
) -> BaseDataset:
|
|
158
|
+
) -> Union[BaseDataset, Tuple[BaseDataset, BaseDataset]]:
|
|
155
159
|
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
|
|
156
160
|
|
|
161
|
+
A train and validation dataset are returned if `train_cfg.val_freq` is greater than 0.
|
|
162
|
+
The validation dataset is a subset of the train dataset, and is used for evaluation during training.
|
|
163
|
+
The validation dataset is created by splitting the train dataset into train and validation sets based on `cfg.val_split_ratio`.
|
|
164
|
+
|
|
157
165
|
Args:
|
|
158
166
|
cfg (DatasetConfig): A DatasetConfig used to create a LeRobotDataset.
|
|
159
167
|
train_cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
|
|
@@ -161,10 +169,11 @@ def make_dataset(
|
|
|
161
169
|
"episode_end_idx", "current_idx", "last_step", "episode_index", and "timestamp". Defaults to False.
|
|
162
170
|
|
|
163
171
|
Raises:
|
|
164
|
-
|
|
172
|
+
ValueError: If exactly one of `cfg.grounding` and `cfg.repo_id` is not provided.
|
|
173
|
+
ValueError: If `cfg.grounding` is not a supported grounding dataset.
|
|
165
174
|
|
|
166
175
|
Returns:
|
|
167
|
-
BaseDataset
|
|
176
|
+
BaseDataset or Tuple[BaseDataset, BaseDataset]: A single dataset or a tuple of (train_dataset, val_dataset) if val_freq > 0.
|
|
168
177
|
"""
|
|
169
178
|
image_transforms = ImageTransforms(cfg.image_transforms) if cfg.image_transforms.enable else None
|
|
170
179
|
|
|
@@ -209,12 +218,20 @@ def make_dataset(
|
|
|
209
218
|
dataset.meta.stats[key] = {}
|
|
210
219
|
dataset.meta.stats[key][stats_type] = np.array(stats, dtype=np.float32)
|
|
211
220
|
|
|
221
|
+
if train_cfg.val_freq > 0:
|
|
222
|
+
val_size = int(len(dataset) * cfg.val_split_ratio)
|
|
223
|
+
train_size = len(dataset) - val_size
|
|
224
|
+
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
|
225
|
+
train_dataset.meta = copy.deepcopy(dataset.meta)
|
|
226
|
+
val_dataset.meta = copy.deepcopy(dataset.meta)
|
|
227
|
+
return train_dataset, val_dataset
|
|
228
|
+
|
|
212
229
|
return dataset
|
|
213
230
|
|
|
214
231
|
|
|
215
232
|
def make_dataset_mixture(
|
|
216
233
|
cfg: TrainPipelineConfig, return_advantage_input: bool = False
|
|
217
|
-
) -> WeightedDatasetMixture:
|
|
234
|
+
) -> Union[WeightedDatasetMixture, Tuple[WeightedDatasetMixture, WeightedDatasetMixture]]:
|
|
218
235
|
"""Creates a dataset mixture from the provided TrainPipelineConfig.
|
|
219
236
|
|
|
220
237
|
Args:
|
|
@@ -223,10 +240,26 @@ def make_dataset_mixture(
|
|
|
223
240
|
"episode_end_idx", "current_idx", "last_step", "episode_index", and "timestamp". Defaults to False.
|
|
224
241
|
|
|
225
242
|
Returns:
|
|
226
|
-
WeightedDatasetMixture: An instance of WeightedDatasetMixture containing the datasets.
|
|
243
|
+
WeightedDatasetMixture or Tuple[WeightedDatasetMixture, WeightedDatasetMixture]: An instance of WeightedDatasetMixture containing the datasets, or a tuple of (train_mixture, val_mixture) if val_freq > 0.
|
|
227
244
|
"""
|
|
228
|
-
datasets = [
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
245
|
+
datasets = []
|
|
246
|
+
val_datasets = []
|
|
247
|
+
for dataset_cfg in cfg.dataset_mixture.datasets:
|
|
248
|
+
res = make_dataset(dataset_cfg, cfg, return_advantage_input=return_advantage_input)
|
|
249
|
+
if isinstance(res, tuple):
|
|
250
|
+
datasets.append(res[0])
|
|
251
|
+
val_datasets.append(res[1])
|
|
252
|
+
else:
|
|
253
|
+
datasets.append(res)
|
|
254
|
+
|
|
255
|
+
train_mixture = WeightedDatasetMixture(
|
|
256
|
+
cfg, datasets, cfg.dataset_mixture.weights, cfg.dataset_mixture.action_freq
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
if val_datasets:
|
|
260
|
+
val_mixture = WeightedDatasetMixture(
|
|
261
|
+
cfg, val_datasets, cfg.dataset_mixture.weights, cfg.dataset_mixture.action_freq
|
|
262
|
+
)
|
|
263
|
+
return train_mixture, val_mixture
|
|
264
|
+
|
|
265
|
+
return train_mixture
|
|
@@ -150,6 +150,7 @@ from opentau.policies.value.configuration_value import ValueConfig
|
|
|
150
150
|
from opentau.policies.value.reward import (
|
|
151
151
|
calculate_return_bins_with_equal_width,
|
|
152
152
|
)
|
|
153
|
+
from opentau.utils.accelerate_utils import get_proc_accelerator
|
|
153
154
|
from opentau.utils.utils import on_accelerate_main_proc
|
|
154
155
|
|
|
155
156
|
|
|
@@ -324,8 +325,17 @@ class LeRobotDatasetMetadata(DatasetMetadata):
|
|
|
324
325
|
if is_valid_version(self.revision):
|
|
325
326
|
self.revision = get_safe_version(self.repo_id, self.revision)
|
|
326
327
|
|
|
327
|
-
|
|
328
|
-
|
|
328
|
+
# In distributed training, only rank 0 downloads to avoid race conditions
|
|
329
|
+
# where other ranks read metadata before the download has finished.
|
|
330
|
+
acc = get_proc_accelerator()
|
|
331
|
+
if acc is not None and acc.num_processes > 1:
|
|
332
|
+
if acc.is_main_process:
|
|
333
|
+
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
|
334
|
+
self.pull_from_repo(allow_patterns="meta/")
|
|
335
|
+
acc.wait_for_everyone()
|
|
336
|
+
else:
|
|
337
|
+
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
|
338
|
+
self.pull_from_repo(allow_patterns="meta/")
|
|
329
339
|
self.load_metadata()
|
|
330
340
|
|
|
331
341
|
def load_metadata(self) -> None:
|
|
@@ -633,7 +643,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
|
633
643
|
For example, {"image_key": torch.zeros(2, 3, 224, 224), "image_key_is_pad": [False, True] } will become
|
|
634
644
|
{
|
|
635
645
|
"image_key": torch.zeros(3, 224, 224),
|
|
646
|
+
"image_key_local": torch.zeros(3, 224, 224),
|
|
636
647
|
"image_key_is_pad: False,
|
|
648
|
+
"image_key_local_is_pad": True,
|
|
637
649
|
}.
|
|
638
650
|
"""
|
|
639
651
|
raise NotImplementedError
|
|
@@ -723,14 +735,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
|
723
735
|
if isinstance(value, torch.Tensor) and value.dtype.is_floating_point:
|
|
724
736
|
standard_item[key] = value.to(dtype=torch.bfloat16)
|
|
725
737
|
|
|
726
|
-
# ensure that non-empty strings contain exactly one newline character at the end of the string
|
|
727
|
-
for key in ["prompt", "response"]:
|
|
728
|
-
if standard_item[key].endswith(
|
|
729
|
-
"\n"
|
|
730
|
-
): # ensure there isn't going to be an extra space at the end after calling replace
|
|
731
|
-
standard_item[key] = standard_item[key][:-1]
|
|
732
|
-
standard_item[key] = standard_item[key].replace("\n", " ") + "\n"
|
|
733
|
-
|
|
734
738
|
return standard_item
|
|
735
739
|
|
|
736
740
|
def resize_with_pad(self, img, width, height, pad_value=0) -> torch.Tensor:
|
|
@@ -1787,16 +1791,12 @@ class LeRobotDataset(BaseDataset):
|
|
|
1787
1791
|
cam_keys = {v for k, v in name_map.items() if k.startswith("camera")}
|
|
1788
1792
|
for k in cam_keys:
|
|
1789
1793
|
images = item.pop(k)
|
|
1790
|
-
|
|
1791
|
-
|
|
1792
|
-
)
|
|
1793
|
-
item[k + "_local"], item[k] = images
|
|
1794
|
+
if len(images) == 2:
|
|
1795
|
+
item[k + "_local"], item[k] = images
|
|
1794
1796
|
|
|
1795
|
-
pads = item.
|
|
1796
|
-
|
|
1797
|
-
|
|
1798
|
-
)
|
|
1799
|
-
item[k + "_local_is_pad"], item[k + "_is_pad"] = pads
|
|
1797
|
+
pads = item.get(k + "_is_pad")
|
|
1798
|
+
if hasattr(pads, "__len__") and len(pads) == 2:
|
|
1799
|
+
item[k + "_local_is_pad"], item[k + "_is_pad"] = pads
|
|
1800
1800
|
|
|
1801
1801
|
@staticmethod
|
|
1802
1802
|
def compute_delta_params(
|
opentau/datasets/video_utils.py
CHANGED
|
@@ -108,6 +108,7 @@ import pyarrow as pa
|
|
|
108
108
|
import torch
|
|
109
109
|
import torchvision
|
|
110
110
|
from datasets.features.features import register_feature
|
|
111
|
+
from packaging import version
|
|
111
112
|
from PIL import Image
|
|
112
113
|
|
|
113
114
|
|
|
@@ -117,13 +118,17 @@ def get_safe_default_codec() -> str:
|
|
|
117
118
|
Returns:
|
|
118
119
|
Backend name: "torchcodec" if available, otherwise "pyav".
|
|
119
120
|
"""
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
else:
|
|
123
|
-
logging.warning(
|
|
124
|
-
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
|
125
|
-
)
|
|
121
|
+
|
|
122
|
+
if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
|
126
123
|
return "pyav"
|
|
124
|
+
else:
|
|
125
|
+
if importlib.util.find_spec("torchcodec"):
|
|
126
|
+
return "torchcodec"
|
|
127
|
+
else:
|
|
128
|
+
logging.warning(
|
|
129
|
+
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
|
130
|
+
)
|
|
131
|
+
return "pyav"
|
|
127
132
|
|
|
128
133
|
|
|
129
134
|
def decode_video_frames(
|
|
@@ -49,18 +49,18 @@ class PI05Config(PreTrainedConfig):
|
|
|
49
49
|
Defaults to identity for visual features and mean-std for state and action.
|
|
50
50
|
max_state_dim: Maximum dimension for state vectors. Shorter vectors are padded. Defaults to 32.
|
|
51
51
|
max_action_dim: Maximum dimension for action vectors. Shorter vectors are padded. Defaults to 32.
|
|
52
|
+
predict_response: Whether to predict the response. Defaults to False.
|
|
52
53
|
resize_imgs_with_padding: Target size (height, width) for image resizing with padding.
|
|
53
54
|
Defaults to (224, 224).
|
|
54
55
|
empty_cameras: Number of empty camera inputs to add. Used for specific adaptations like
|
|
55
56
|
Aloha simulation. Defaults to 0.
|
|
56
|
-
|
|
57
|
+
prompt_max_length: Maximum length for tokenizer. Defaults to 256.
|
|
57
58
|
discrete_action_max_length: Maximum length for discrete action tokens. Defaults to 32.
|
|
58
59
|
proj_width: Width of the projection layer. Defaults to 1024.
|
|
59
60
|
dropout: Dropout rate. Defaults to 0.1.
|
|
60
61
|
num_steps: Number of flow matching steps for decoding. Defaults to 10.
|
|
61
62
|
init_strategy: Initialization strategy. One of "no_init", "full_he_init", "expert_only_he_init".
|
|
62
63
|
Defaults to "full_he_init".
|
|
63
|
-
use_cache: Whether to use KV cache during inference. Defaults to True.
|
|
64
64
|
attention_implementation: Attention implementation to use ("eager" or "fa2"). Defaults to "eager".
|
|
65
65
|
freeze_vision_encoder: Whether to freeze the vision encoder during fine-tuning. Defaults to True.
|
|
66
66
|
train_expert_only: Whether to train only the expert module. Defaults to False.
|
|
@@ -89,6 +89,7 @@ class PI05Config(PreTrainedConfig):
|
|
|
89
89
|
# Shorter state and action vectors will be padded
|
|
90
90
|
max_state_dim: int = 32
|
|
91
91
|
max_action_dim: int = 32
|
|
92
|
+
predict_response: bool = False
|
|
92
93
|
|
|
93
94
|
# Image preprocessing
|
|
94
95
|
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
|
@@ -97,8 +98,11 @@ class PI05Config(PreTrainedConfig):
|
|
|
97
98
|
# left and right wrist cameras in addition to the top camera.
|
|
98
99
|
empty_cameras: int = 0
|
|
99
100
|
|
|
100
|
-
# Tokenizer
|
|
101
|
-
|
|
101
|
+
# Language Tokenizer
|
|
102
|
+
prompt_max_length: int = 256
|
|
103
|
+
|
|
104
|
+
# Response Tokenizer
|
|
105
|
+
response_max_length: int = 52
|
|
102
106
|
|
|
103
107
|
# Maximum length of the action tokens
|
|
104
108
|
discrete_action_max_length: int = 32
|
|
@@ -116,8 +120,7 @@ class PI05Config(PreTrainedConfig):
|
|
|
116
120
|
init_strategy: Literal["no_init", "full_he_init", "expert_only_he_init"] = "full_he_init"
|
|
117
121
|
|
|
118
122
|
# Attention utils
|
|
119
|
-
|
|
120
|
-
attention_implementation: str = "eager" # or fa2
|
|
123
|
+
attention_implementation: str = "eager"
|
|
121
124
|
|
|
122
125
|
# Finetuning settings
|
|
123
126
|
freeze_vision_encoder: bool = True
|