opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,512 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
"""Value Function Model using SIGLIP and Gemma 3 270M
|
|
19
|
+
|
|
20
|
+
A value function model that estimates state values for reinforcement learning.
|
|
21
|
+
Uses SIGLIP for vision encoding and Gemma 3 270M for language processing.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
import torch.nn.functional as F # noqa: N812
|
|
26
|
+
from einops import rearrange
|
|
27
|
+
from torch import Tensor, nn
|
|
28
|
+
from transformers import AutoTokenizer
|
|
29
|
+
|
|
30
|
+
from opentau.policies.normalize import Normalize
|
|
31
|
+
from opentau.policies.pretrained import PreTrainedPolicy
|
|
32
|
+
from opentau.policies.value.configuration_value import ValueConfig
|
|
33
|
+
from opentau.policies.value.siglip_gemma import (
|
|
34
|
+
SiglipGemmaValueConfig,
|
|
35
|
+
SiglipGemmaValueModel,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def make_att_2d_masks(pad_masks, att_masks):
|
|
40
|
+
"""Creates a 2-D attention mask given padding and 1-D attention masks.
|
|
41
|
+
|
|
42
|
+
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
|
43
|
+
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
|
44
|
+
setup several types of attention, for example:
|
|
45
|
+
|
|
46
|
+
[[1 1 1 1 1 1]]: pure causal attention.
|
|
47
|
+
|
|
48
|
+
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
|
49
|
+
themselves and the last 3 tokens have a causal attention. The first
|
|
50
|
+
entry could also be a 1 without changing behaviour.
|
|
51
|
+
|
|
52
|
+
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
|
53
|
+
block can attend all previous blocks and all tokens on the same block.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
pad_masks: bool[B, N] true if its part of the input, false if padding.
|
|
57
|
+
att_masks: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
|
58
|
+
it and 0 where it shares the same attention mask as the previous token.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
torch.Tensor: The 2D attention masks.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If input masks are not 2D.
|
|
65
|
+
"""
|
|
66
|
+
if att_masks.ndim != 2:
|
|
67
|
+
raise ValueError(att_masks.ndim)
|
|
68
|
+
if pad_masks.ndim != 2:
|
|
69
|
+
raise ValueError(pad_masks.ndim)
|
|
70
|
+
|
|
71
|
+
cumsum = torch.cumsum(att_masks, dim=1)
|
|
72
|
+
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
|
73
|
+
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
|
74
|
+
att_2d_masks = att_2d_masks & pad_2d_masks
|
|
75
|
+
return att_2d_masks
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def resize_with_pad(img, width, height, pad_value=-1):
|
|
79
|
+
"""Resizes an image while preserving aspect ratio and padding to target dimensions.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
img: Input image tensor of shape (B, C, H, W).
|
|
83
|
+
width: Target width.
|
|
84
|
+
height: Target height.
|
|
85
|
+
pad_value: Value to use for padding. Defaults to -1.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
torch.Tensor: Resized and padded image tensor.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
ValueError: If image dimensions are not 4D (B, C, H, W).
|
|
92
|
+
"""
|
|
93
|
+
# assume no-op when width height fits already
|
|
94
|
+
if img.ndim != 4:
|
|
95
|
+
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
|
96
|
+
|
|
97
|
+
cur_height, cur_width = img.shape[2:]
|
|
98
|
+
|
|
99
|
+
ratio = max(cur_width / width, cur_height / height)
|
|
100
|
+
resized_height = int(cur_height / ratio)
|
|
101
|
+
resized_width = int(cur_width / ratio)
|
|
102
|
+
resized_img = F.interpolate(
|
|
103
|
+
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
pad_height = max(0, int(height - resized_height))
|
|
107
|
+
pad_width = max(0, int(width - resized_width))
|
|
108
|
+
|
|
109
|
+
# pad on left and top of image
|
|
110
|
+
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
|
111
|
+
return padded_img
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class ValueFunction(PreTrainedPolicy):
|
|
115
|
+
"""Wrapper class around Value Function model to train and run inference within OpenTau."""
|
|
116
|
+
|
|
117
|
+
config_class = ValueConfig
|
|
118
|
+
name = "value"
|
|
119
|
+
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
config: ValueConfig,
|
|
123
|
+
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
|
124
|
+
):
|
|
125
|
+
"""Initializes the ValueFunction policy.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
config: Value Function configuration class instance or None, in which case the default instantiation of
|
|
129
|
+
the configuration class is used.
|
|
130
|
+
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
|
131
|
+
that they will be passed with a call to `load_state_dict` before the policy is used.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
super().__init__(config)
|
|
135
|
+
config.validate_features()
|
|
136
|
+
self.config = config
|
|
137
|
+
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
|
138
|
+
|
|
139
|
+
self.language_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-270m")
|
|
140
|
+
self.model = ValueModel(config)
|
|
141
|
+
|
|
142
|
+
def reset(self):
|
|
143
|
+
"""Resets the internal state of the policy.
|
|
144
|
+
|
|
145
|
+
This method is called at the beginning of each episode. For value functions,
|
|
146
|
+
there is no internal state to reset.
|
|
147
|
+
"""
|
|
148
|
+
pass # Value functions don't need state reset
|
|
149
|
+
|
|
150
|
+
def get_optim_params(self) -> dict:
|
|
151
|
+
"""Returns the parameters to be optimized.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
dict: Dictionary of parameters to optimize.
|
|
155
|
+
"""
|
|
156
|
+
return self.parameters()
|
|
157
|
+
|
|
158
|
+
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
|
159
|
+
"""Selects an action based on the current policy.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
batch: Dictionary containing observation data.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Tensor: The selected action.
|
|
166
|
+
|
|
167
|
+
Raises:
|
|
168
|
+
NotImplementedError: Value functions do not select actions.
|
|
169
|
+
"""
|
|
170
|
+
raise NotImplementedError("Value functions do not select actions. Use predict_value() instead.")
|
|
171
|
+
|
|
172
|
+
def sample_actions(self, batch: dict[str, Tensor], noise: Tensor = None):
|
|
173
|
+
"""Samples actions from the policy.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
batch: Dictionary containing observation data.
|
|
177
|
+
noise: Optional noise tensor.
|
|
178
|
+
|
|
179
|
+
Raises:
|
|
180
|
+
NotImplementedError: Value functions do not sample actions.
|
|
181
|
+
"""
|
|
182
|
+
raise NotImplementedError("Value functions do not sample actions. Use predict_value() instead.")
|
|
183
|
+
|
|
184
|
+
def calculate_value(self, logits: Tensor) -> Tensor:
|
|
185
|
+
"""Calculates the expected value from the logits distribution.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
logits: Tensor containing the logits for value bins.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
Tensor: The expected value.
|
|
192
|
+
"""
|
|
193
|
+
start_idx = torch.linspace(
|
|
194
|
+
-1,
|
|
195
|
+
-1 / self.config.reward_config.number_of_bins,
|
|
196
|
+
self.config.reward_config.number_of_bins,
|
|
197
|
+
device=logits.device,
|
|
198
|
+
)
|
|
199
|
+
end_idx = torch.linspace(
|
|
200
|
+
-1 + 1 / self.config.reward_config.number_of_bins,
|
|
201
|
+
0,
|
|
202
|
+
self.config.reward_config.number_of_bins,
|
|
203
|
+
device=logits.device,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
mid_idx = rearrange(
|
|
207
|
+
(start_idx + end_idx) / 2,
|
|
208
|
+
"n -> 1 n",
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
value = torch.softmax(logits, dim=-1).to(dtype=torch.float32) @ mid_idx.T
|
|
212
|
+
|
|
213
|
+
return rearrange(value, "b 1 -> b")
|
|
214
|
+
|
|
215
|
+
@torch.no_grad()
|
|
216
|
+
def predict_value(self, batch: dict[str, Tensor]) -> Tensor:
|
|
217
|
+
"""Predict value estimates given environment observations.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
batch: Dictionary containing observations (images, state, prompt)
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Tensor of shape [batch_size, 1] containing value estimates
|
|
224
|
+
"""
|
|
225
|
+
self.eval()
|
|
226
|
+
|
|
227
|
+
batch = self.normalize_inputs(batch)
|
|
228
|
+
|
|
229
|
+
images, img_masks = self.prepare_images(batch)
|
|
230
|
+
lang_tokens, lang_masks = self.prepare_language(batch)
|
|
231
|
+
state = batch.get("state")
|
|
232
|
+
|
|
233
|
+
logits = self.model.forward(images, img_masks, lang_tokens, lang_masks, state)
|
|
234
|
+
return self.calculate_value(logits)
|
|
235
|
+
|
|
236
|
+
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor] | None]:
|
|
237
|
+
"""Do a full training forward pass to compute the value loss.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
batch: Dictionary containing observations and target values
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
Tuple of (loss_dict, None) where loss_dict contains the MSE loss
|
|
244
|
+
"""
|
|
245
|
+
batch = self.normalize_inputs(batch)
|
|
246
|
+
|
|
247
|
+
images, img_masks = self.prepare_images(batch)
|
|
248
|
+
lang_tokens, lang_masks = self.prepare_language(batch)
|
|
249
|
+
state = batch.get("state")
|
|
250
|
+
|
|
251
|
+
logits = self.model.forward(images, img_masks, lang_tokens, lang_masks, state)
|
|
252
|
+
values = self.calculate_value(logits)
|
|
253
|
+
# Compute Cross-Entropy loss
|
|
254
|
+
logits = logits.to(dtype=torch.float32) # upcast to float32 for loss calculation
|
|
255
|
+
batch["return_bin_idx"] = batch["return_bin_idx"].to(dtype=torch.long)
|
|
256
|
+
loss = F.cross_entropy(logits, batch["return_bin_idx"])
|
|
257
|
+
|
|
258
|
+
l1_loss = F.l1_loss(values, batch["return_continuous"])
|
|
259
|
+
|
|
260
|
+
accuracy = (logits.argmax(dim=-1) == batch["return_bin_idx"]).float().mean()
|
|
261
|
+
|
|
262
|
+
return {
|
|
263
|
+
"MSE": torch.zeros_like(loss, requires_grad=False),
|
|
264
|
+
"CE": loss,
|
|
265
|
+
"L1": l1_loss,
|
|
266
|
+
"Accuracy": accuracy,
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
def prepare_images(self, batch):
|
|
270
|
+
"""Preprocesses images for the model.
|
|
271
|
+
|
|
272
|
+
Resizes images to 224x224, pads to keep aspect ratio, and converts pixel range
|
|
273
|
+
from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP. Also handles missing images
|
|
274
|
+
by creating empty placeholders.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
batch: Dictionary containing batch data.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
tuple: A tuple containing a list of image tensors and a list of mask tensors.
|
|
281
|
+
|
|
282
|
+
Raises:
|
|
283
|
+
ValueError: If all image features are missing from the batch.
|
|
284
|
+
"""
|
|
285
|
+
images = []
|
|
286
|
+
img_masks = []
|
|
287
|
+
|
|
288
|
+
present_img_keys = [key for key in self.config.image_features if key in batch]
|
|
289
|
+
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
|
290
|
+
|
|
291
|
+
if len(present_img_keys) == 0:
|
|
292
|
+
raise ValueError(
|
|
293
|
+
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# Preprocess image features present in the batch
|
|
297
|
+
for key in present_img_keys:
|
|
298
|
+
img = batch[key]
|
|
299
|
+
|
|
300
|
+
if self.config.resize_imgs_with_padding is not None:
|
|
301
|
+
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
|
302
|
+
|
|
303
|
+
# Normalize from range [0,1] to [-1,1] as expected by siglip
|
|
304
|
+
img = img * 2.0 - 1.0
|
|
305
|
+
|
|
306
|
+
bsize = img.shape[0]
|
|
307
|
+
device = img.device
|
|
308
|
+
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
|
309
|
+
images.append(img)
|
|
310
|
+
img_masks.append(mask)
|
|
311
|
+
|
|
312
|
+
# Create image features not present in the batch
|
|
313
|
+
# as fully 0 padded images.
|
|
314
|
+
for num_empty_cameras in range(len(missing_img_keys)):
|
|
315
|
+
if num_empty_cameras >= self.config.empty_cameras:
|
|
316
|
+
break
|
|
317
|
+
img = torch.ones_like(img) * -1
|
|
318
|
+
mask = torch.zeros_like(mask)
|
|
319
|
+
images.append(img)
|
|
320
|
+
img_masks.append(mask)
|
|
321
|
+
|
|
322
|
+
return images, img_masks
|
|
323
|
+
|
|
324
|
+
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
|
325
|
+
"""Tokenizes the text input for the model.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
batch: Dictionary containing batch data, including "prompt".
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
tuple: A tuple containing language token tensors and attention mask tensors.
|
|
332
|
+
"""
|
|
333
|
+
device = batch.get("state", list(batch.values())[0]).device
|
|
334
|
+
tasks = batch["prompt"]
|
|
335
|
+
|
|
336
|
+
# PaliGemma prompt has to end with a new line
|
|
337
|
+
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
|
338
|
+
|
|
339
|
+
tokenized_prompt = self.language_tokenizer.__call__(
|
|
340
|
+
tasks,
|
|
341
|
+
padding="max_length",
|
|
342
|
+
padding_side="right",
|
|
343
|
+
max_length=self.config.tokenizer_max_length,
|
|
344
|
+
return_tensors="pt",
|
|
345
|
+
)
|
|
346
|
+
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
|
347
|
+
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
|
348
|
+
|
|
349
|
+
return lang_tokens, lang_masks
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
class ValueModel(nn.Module):
|
|
353
|
+
"""
|
|
354
|
+
Value Function Model using SIGLIP and Gemma 3 270M
|
|
355
|
+
|
|
356
|
+
Estimates state values for reinforcement learning by processing:
|
|
357
|
+
- Images through SIGLIP vision encoder
|
|
358
|
+
- Language tokens through Gemma 3 270M
|
|
359
|
+
- Optional state information
|
|
360
|
+
|
|
361
|
+
┌──────────────────────────────┐
|
|
362
|
+
│ value │
|
|
363
|
+
│ ▲ │
|
|
364
|
+
│ ┌┴─────┐ │
|
|
365
|
+
│ │Gemma │ │
|
|
366
|
+
│ │3 270M│ │
|
|
367
|
+
│ │ │ │
|
|
368
|
+
│ ┌──────────┐ └▲──▲──┘ │
|
|
369
|
+
│ │ │ │ │ │
|
|
370
|
+
│ │ SIGLIP ├──┘ │ │
|
|
371
|
+
│ │ │ language │
|
|
372
|
+
│ └────▲─────┘ │
|
|
373
|
+
│ │ │
|
|
374
|
+
│ image(s) │
|
|
375
|
+
│ │
|
|
376
|
+
└──────────────────────────────┘
|
|
377
|
+
"""
|
|
378
|
+
|
|
379
|
+
CLASSIFICATION_TOKEN_ID = 6 # unused token id in Gemma 3 270M that we repurpose for classification
|
|
380
|
+
|
|
381
|
+
def __init__(self, config):
|
|
382
|
+
"""Initializes the ValueModel.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
config: Configuration object for the model.
|
|
386
|
+
"""
|
|
387
|
+
super().__init__()
|
|
388
|
+
self.config = config
|
|
389
|
+
|
|
390
|
+
siglip_gemma_value_config = SiglipGemmaValueConfig(
|
|
391
|
+
num_value_bins=self.config.reward_config.number_of_bins
|
|
392
|
+
)
|
|
393
|
+
self.siglip_gemma_value = SiglipGemmaValueModel(siglip_gemma_value_config)
|
|
394
|
+
|
|
395
|
+
# Projection for state if provided
|
|
396
|
+
self.state_proj = nn.Linear(self.config.max_state_dim, 640)
|
|
397
|
+
self.multi_modal_proj = nn.Linear(1152, 640)
|
|
398
|
+
self.bins = config.reward_config.number_of_bins
|
|
399
|
+
self.c_neg = config.reward_config.C_neg
|
|
400
|
+
|
|
401
|
+
def embed_sequence(
|
|
402
|
+
self, images, img_masks, lang_tokens, lang_masks, state
|
|
403
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
404
|
+
"""Embeds sequence of images and language tokens.
|
|
405
|
+
|
|
406
|
+
Prepares embeddings for SiglipGemmaValueModel transformer processing.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
images: List of image tensors.
|
|
410
|
+
img_masks: List of image mask tensors.
|
|
411
|
+
lang_tokens: Language token tensor.
|
|
412
|
+
lang_masks: Language mask tensor.
|
|
413
|
+
state: State tensor.
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
tuple: A tuple containing embeddings, padding masks, and attention masks.
|
|
417
|
+
"""
|
|
418
|
+
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
|
419
|
+
embs = []
|
|
420
|
+
pad_masks = []
|
|
421
|
+
att_masks = []
|
|
422
|
+
|
|
423
|
+
# TODO: remove for loop
|
|
424
|
+
for (
|
|
425
|
+
img,
|
|
426
|
+
img_mask,
|
|
427
|
+
) in zip(images, img_masks, strict=False):
|
|
428
|
+
img_emb = self.siglip_gemma_value.embed_image(img)
|
|
429
|
+
img_emb = img_emb.to(dtype=torch.bfloat16)
|
|
430
|
+
img_emb = self.multi_modal_proj(img_emb)
|
|
431
|
+
|
|
432
|
+
# image embeddings don't need to be unnormalized because they were not normalized in the first place
|
|
433
|
+
pass
|
|
434
|
+
|
|
435
|
+
bsize, num_img_embs = img_emb.shape[:2]
|
|
436
|
+
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
|
437
|
+
|
|
438
|
+
embs.append(img_emb)
|
|
439
|
+
pad_masks.append(img_mask)
|
|
440
|
+
|
|
441
|
+
# Create attention masks so that image tokens attend to each other
|
|
442
|
+
att_masks += [0] * num_img_embs
|
|
443
|
+
|
|
444
|
+
# Gemma3 already scales by sqrt(d)
|
|
445
|
+
lang_emb = self.siglip_gemma_value.embed_language_tokens(lang_tokens)
|
|
446
|
+
|
|
447
|
+
embs.append(lang_emb)
|
|
448
|
+
pad_masks.append(lang_masks)
|
|
449
|
+
|
|
450
|
+
# full attention between image and language inputs
|
|
451
|
+
num_lang_embs = lang_emb.shape[1]
|
|
452
|
+
att_masks += [0] * num_lang_embs
|
|
453
|
+
|
|
454
|
+
# embed state
|
|
455
|
+
state_emb = self.state_proj(state)
|
|
456
|
+
state_emb = state_emb.to(dtype=torch.bfloat16)
|
|
457
|
+
embs.append(state_emb[:, None, :])
|
|
458
|
+
|
|
459
|
+
state_mask = torch.ones(state_emb.shape[0], 1, dtype=torch.bool, device=state_emb.device)
|
|
460
|
+
pad_masks.append(state_mask)
|
|
461
|
+
|
|
462
|
+
# full attention between state and image and language inputs
|
|
463
|
+
att_masks += [0]
|
|
464
|
+
|
|
465
|
+
# add classification token
|
|
466
|
+
cls_token = torch.full(
|
|
467
|
+
(bsize, 1), self.CLASSIFICATION_TOKEN_ID, device=state_emb.device, dtype=torch.long
|
|
468
|
+
)
|
|
469
|
+
cls_token_emb = self.siglip_gemma_value.gemma.embed_tokens(cls_token)
|
|
470
|
+
embs.append(cls_token_emb)
|
|
471
|
+
pad_masks.append(torch.ones(bsize, 1, dtype=torch.bool, device=state_emb.device))
|
|
472
|
+
att_masks += [0]
|
|
473
|
+
|
|
474
|
+
embs = torch.cat(embs, dim=1)
|
|
475
|
+
pad_masks = torch.cat(pad_masks, dim=1)
|
|
476
|
+
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
|
477
|
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
|
478
|
+
|
|
479
|
+
return embs, pad_masks, att_masks
|
|
480
|
+
|
|
481
|
+
def forward(
|
|
482
|
+
self,
|
|
483
|
+
images: list[torch.Tensor],
|
|
484
|
+
img_masks: list[torch.Tensor],
|
|
485
|
+
lang_tokens: torch.Tensor,
|
|
486
|
+
lang_masks: torch.Tensor,
|
|
487
|
+
state: torch.Tensor | None = None,
|
|
488
|
+
) -> torch.Tensor:
|
|
489
|
+
"""Predict value estimates given observations.
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
images: List of image tensors
|
|
493
|
+
img_masks: List of image masks
|
|
494
|
+
lang_tokens: Language token IDs
|
|
495
|
+
lang_masks: Language attention masks
|
|
496
|
+
state: Optional state tensor
|
|
497
|
+
|
|
498
|
+
Returns:
|
|
499
|
+
Tensor of shape [batch_size, 1] containing value estimates
|
|
500
|
+
"""
|
|
501
|
+
embs, pad_masks, att_masks = self.embed_sequence(images, img_masks, lang_tokens, lang_masks, state)
|
|
502
|
+
|
|
503
|
+
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
|
504
|
+
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
|
505
|
+
|
|
506
|
+
logits = self.siglip_gemma_value.forward(
|
|
507
|
+
inputs_embeds=embs,
|
|
508
|
+
attention_mask=att_2d_masks,
|
|
509
|
+
position_ids=position_ids,
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
return logits
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Reward calculation utilities for Value Policy training.
|
|
15
|
+
|
|
16
|
+
This module contains functions to calculate returns and discretize them into bins
|
|
17
|
+
for value function training and advantage calculation.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def calculate_return_bins_with_equal_width(
|
|
22
|
+
success: bool,
|
|
23
|
+
b: int,
|
|
24
|
+
episode_end_idx: int,
|
|
25
|
+
reward_normalizer: int,
|
|
26
|
+
current_idx: int,
|
|
27
|
+
c_neg: float = -100.0,
|
|
28
|
+
) -> tuple[int, float]:
|
|
29
|
+
"""Defines sparse Reward function for the pi0.6 policy to train value function network.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
success: Defines if the episode was successful or failed.
|
|
33
|
+
b: Number of bins to discretize the reward into, including the special bin 0.
|
|
34
|
+
episode_end_idx: Index of the end of the episode, exclusive to the last step.
|
|
35
|
+
reward_normalizer: Maximum length of the episode for normalization.
|
|
36
|
+
current_idx: Current index of the episode.
|
|
37
|
+
c_neg: Negative reward for failed episodes. Defaults to -100.0.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
tuple[int, float]: A tuple containing:
|
|
41
|
+
- bin_idx: The index of the reward bin.
|
|
42
|
+
- return_normalized: The normalized return value in range [-1, 0].
|
|
43
|
+
"""
|
|
44
|
+
# calculate the reward for each step ie -1 till the end of episode and exclude the last step
|
|
45
|
+
return_value = current_idx - episode_end_idx + 1
|
|
46
|
+
# add negative reward for last step if episode is a failure, else add nothing for a successful episode
|
|
47
|
+
if not success:
|
|
48
|
+
return_value += c_neg
|
|
49
|
+
|
|
50
|
+
# normalize the reward to the range of -1 to 0
|
|
51
|
+
return_normalized = return_value / reward_normalizer
|
|
52
|
+
# mapping normalized reward [-1,0) to bin index [0,b-1]
|
|
53
|
+
bin_idx = int((return_normalized + 1) * (b - 1))
|
|
54
|
+
return bin_idx, return_normalized
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def calculate_n_step_return(
|
|
58
|
+
success: bool,
|
|
59
|
+
n_steps_look_ahead: int,
|
|
60
|
+
episode_end_idx: int,
|
|
61
|
+
reward_normalizer: int,
|
|
62
|
+
current_idx: int,
|
|
63
|
+
c_neg: float = -100.0,
|
|
64
|
+
) -> float:
|
|
65
|
+
"""Defines sparse Reward function for the pi0.6 policy to calculate advantage.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
success: Defines if the episode was successful or failed.
|
|
69
|
+
n_steps_look_ahead: Number of steps to look ahead for calculating reward.
|
|
70
|
+
episode_end_idx: Index of the end of the episode.
|
|
71
|
+
reward_normalizer: Maximum length of the episode for normalization.
|
|
72
|
+
current_idx: Current index of the episode.
|
|
73
|
+
c_neg: Negative reward for failed episodes. Defaults to -100.0.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
float: The normalized continuous reward for the n-step lookahead.
|
|
77
|
+
"""
|
|
78
|
+
# calculate the reward till the next n_steps_look_ahead steps
|
|
79
|
+
return_value = max(current_idx - episode_end_idx + 1, -1 * n_steps_look_ahead)
|
|
80
|
+
# add negative reward for last step if episode is a failure, else add nothing for a successful episode. also check if
|
|
81
|
+
if not success and current_idx + n_steps_look_ahead >= episode_end_idx:
|
|
82
|
+
return_value += c_neg
|
|
83
|
+
|
|
84
|
+
# normalize the reward to the range of -1 to 0
|
|
85
|
+
return_normalized = return_value / reward_normalizer
|
|
86
|
+
|
|
87
|
+
return return_normalized
|