opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,516 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
"""PaliGemma with Expert Module for PI0.
|
|
17
|
+
|
|
18
|
+
This module implements the PaliGemma model with an additional expert module,
|
|
19
|
+
specifically designed for the Pi0 policy. It combines a pre-trained PaliGemma
|
|
20
|
+
Vision-Language Model (VLM) with a Gemma-based expert model to handle
|
|
21
|
+
action generation and conditioning.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
import torch.version
|
|
26
|
+
from pytest import Cache
|
|
27
|
+
from torch import nn
|
|
28
|
+
from transformers import (
|
|
29
|
+
AutoConfig,
|
|
30
|
+
GemmaForCausalLM,
|
|
31
|
+
PaliGemmaForConditionalGeneration,
|
|
32
|
+
PretrainedConfig,
|
|
33
|
+
PreTrainedModel,
|
|
34
|
+
)
|
|
35
|
+
from transformers.models.auto import CONFIG_MAPPING
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def apply_rope(x: torch.Tensor, positions: torch.Tensor, max_wavelength: int = 10_000) -> torch.Tensor:
|
|
39
|
+
"""Applies RoPE positions to the input tensor.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
x: Input tensor of shape [B, L, H, D].
|
|
43
|
+
positions: Position tensor of shape [B, L].
|
|
44
|
+
max_wavelength: Maximum wavelength for RoPE. Defaults to 10_000.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Tensor: The input tensor with RoPE applied, of shape [B, L, H, D].
|
|
48
|
+
"""
|
|
49
|
+
d_half = x.shape[-1] // 2
|
|
50
|
+
device = x.device
|
|
51
|
+
dtype = x.dtype
|
|
52
|
+
x = x.to(torch.float32)
|
|
53
|
+
|
|
54
|
+
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
|
55
|
+
timescale = max_wavelength**freq_exponents
|
|
56
|
+
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
|
57
|
+
|
|
58
|
+
radians = radians[..., None, :]
|
|
59
|
+
|
|
60
|
+
sin = torch.sin(radians) # .to(dtype=dtype)
|
|
61
|
+
cos = torch.cos(radians) # .to(dtype=dtype)
|
|
62
|
+
|
|
63
|
+
x1, x2 = x.split(d_half, dim=-1)
|
|
64
|
+
res = torch.empty_like(x)
|
|
65
|
+
res[..., :d_half] = x1 * cos - x2 * sin
|
|
66
|
+
res[..., d_half:] = x2 * cos + x1 * sin
|
|
67
|
+
|
|
68
|
+
return res.to(dtype)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class PaliGemmaWithExpertConfig(PretrainedConfig):
|
|
72
|
+
"""Configuration class for PaliGemmaWithExpertModel."""
|
|
73
|
+
|
|
74
|
+
model_type = "PaliGemmaWithExpertModel"
|
|
75
|
+
sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
paligemma_config: dict | None = None,
|
|
80
|
+
gemma_expert_config: dict | None = None,
|
|
81
|
+
freeze_vision_encoder: bool = True,
|
|
82
|
+
train_expert_only: bool = True,
|
|
83
|
+
attention_implementation: str = "eager",
|
|
84
|
+
load_pretrained_paligemma: bool = False,
|
|
85
|
+
dropout: float = 0.1,
|
|
86
|
+
**kwargs,
|
|
87
|
+
):
|
|
88
|
+
"""Initializes the configuration.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
paligemma_config: Configuration dictionary for the PaliGemma model.
|
|
92
|
+
gemma_expert_config: Configuration dictionary for the Gemma expert model.
|
|
93
|
+
freeze_vision_encoder: Whether to freeze the vision encoder. Defaults to True.
|
|
94
|
+
train_expert_only: Whether to train only the expert model. Defaults to True.
|
|
95
|
+
attention_implementation: Attention implementation to use ("eager" or "fa2"). Defaults to "eager".
|
|
96
|
+
load_pretrained_paligemma: Whether to load a pretrained PaliGemma model. Defaults to False.
|
|
97
|
+
dropout: Dropout probability. Defaults to 0.1.
|
|
98
|
+
**kwargs: Additional keyword arguments passed to PretrainedConfig.
|
|
99
|
+
"""
|
|
100
|
+
self.freeze_vision_encoder = freeze_vision_encoder
|
|
101
|
+
self.train_expert_only = train_expert_only
|
|
102
|
+
self.attention_implementation = attention_implementation
|
|
103
|
+
self.load_pretrained_paligemma = load_pretrained_paligemma
|
|
104
|
+
self.dropout = dropout
|
|
105
|
+
|
|
106
|
+
if paligemma_config is None:
|
|
107
|
+
# Default config from Pi0
|
|
108
|
+
self.paligemma_config = CONFIG_MAPPING["paligemma"](
|
|
109
|
+
transformers_version="4.48.1",
|
|
110
|
+
_vocab_size=257152,
|
|
111
|
+
bos_token_id=2,
|
|
112
|
+
eos_token_id=1,
|
|
113
|
+
hidden_size=2048,
|
|
114
|
+
image_token_index=257152,
|
|
115
|
+
model_type="paligemma",
|
|
116
|
+
pad_token_id=0,
|
|
117
|
+
projection_dim=2048,
|
|
118
|
+
text_config={
|
|
119
|
+
"hidden_activation": "gelu_pytorch_tanh",
|
|
120
|
+
"hidden_size": 2048,
|
|
121
|
+
"intermediate_size": 16384,
|
|
122
|
+
"model_type": "gemma",
|
|
123
|
+
"num_attention_heads": 8,
|
|
124
|
+
"num_hidden_layers": 18,
|
|
125
|
+
"num_image_tokens": 256,
|
|
126
|
+
"num_key_value_heads": 1,
|
|
127
|
+
"torch_dtype": "float32",
|
|
128
|
+
"vocab_size": 257152,
|
|
129
|
+
},
|
|
130
|
+
vision_config={
|
|
131
|
+
"hidden_size": 1152,
|
|
132
|
+
"intermediate_size": 4304,
|
|
133
|
+
"model_type": "siglip_vision_model",
|
|
134
|
+
"num_attention_heads": 16,
|
|
135
|
+
"num_hidden_layers": 27,
|
|
136
|
+
"num_image_tokens": 256,
|
|
137
|
+
"patch_size": 14,
|
|
138
|
+
"projection_dim": 2048,
|
|
139
|
+
"projector_hidden_act": "gelu_fast",
|
|
140
|
+
"torch_dtype": "float32",
|
|
141
|
+
"vision_use_head": False,
|
|
142
|
+
},
|
|
143
|
+
)
|
|
144
|
+
elif isinstance(self.paligemma_config, dict):
|
|
145
|
+
# Override Pi0 default config for PaliGemma
|
|
146
|
+
if "model_type" not in gemma_expert_config:
|
|
147
|
+
paligemma_config["model_type"] = "paligemma"
|
|
148
|
+
|
|
149
|
+
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
|
150
|
+
self.paligemma_config = cfg_cls(**paligemma_config)
|
|
151
|
+
|
|
152
|
+
if gemma_expert_config is None:
|
|
153
|
+
# Default config from Pi0
|
|
154
|
+
self.gemma_expert_config = CONFIG_MAPPING["gemma"](
|
|
155
|
+
attention_bias=False,
|
|
156
|
+
attention_dropout=0.0,
|
|
157
|
+
bos_token_id=2,
|
|
158
|
+
eos_token_id=1,
|
|
159
|
+
head_dim=256,
|
|
160
|
+
hidden_act="gelu_pytorch_tanh",
|
|
161
|
+
hidden_activation="gelu_pytorch_tanh",
|
|
162
|
+
hidden_size=1024,
|
|
163
|
+
initializer_range=0.02,
|
|
164
|
+
intermediate_size=4096,
|
|
165
|
+
max_position_embeddings=8192,
|
|
166
|
+
model_type="gemma",
|
|
167
|
+
num_attention_heads=8,
|
|
168
|
+
num_hidden_layers=18,
|
|
169
|
+
num_key_value_heads=1,
|
|
170
|
+
pad_token_id=0,
|
|
171
|
+
rms_norm_eps=1e-06,
|
|
172
|
+
rope_theta=10000.0,
|
|
173
|
+
torch_dtype="float32",
|
|
174
|
+
transformers_version="4.48.1",
|
|
175
|
+
use_cache=True,
|
|
176
|
+
vocab_size=257152,
|
|
177
|
+
)
|
|
178
|
+
elif isinstance(self.gemma_expert_config, dict):
|
|
179
|
+
# Override Pi0 default config for Gemma Expert
|
|
180
|
+
if "model_type" not in gemma_expert_config:
|
|
181
|
+
gemma_expert_config["model_type"] = "gemma"
|
|
182
|
+
|
|
183
|
+
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
|
184
|
+
self.gemma_expert_config = cfg_cls(**gemma_expert_config)
|
|
185
|
+
|
|
186
|
+
super().__init__(**kwargs)
|
|
187
|
+
|
|
188
|
+
def __post_init__(self):
|
|
189
|
+
"""Validates configuration parameters."""
|
|
190
|
+
super().__post_init__()
|
|
191
|
+
if self.train_expert_only and not self.freeze_vision_encoder:
|
|
192
|
+
raise ValueError(
|
|
193
|
+
"You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
if self.attention_implementation not in ["eager", "fa2"]:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager' or 'fa2'."
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class PaliGemmaWithExpertModel(PreTrainedModel):
|
|
203
|
+
"""PaliGemma model with an additional expert module for action generation."""
|
|
204
|
+
|
|
205
|
+
config_class = PaliGemmaWithExpertConfig
|
|
206
|
+
|
|
207
|
+
def __init__(self, config: PaliGemmaWithExpertConfig):
|
|
208
|
+
"""Initializes the PaliGemmaWithExpertModel.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
config: Configuration object for the model.
|
|
212
|
+
"""
|
|
213
|
+
super().__init__(config=config)
|
|
214
|
+
self.config = config
|
|
215
|
+
|
|
216
|
+
if config.load_pretrained_paligemma:
|
|
217
|
+
self.paligemma = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-pt-224")
|
|
218
|
+
else:
|
|
219
|
+
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
|
|
220
|
+
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
|
221
|
+
# Remove unused embed_tokens
|
|
222
|
+
self.gemma_expert.model.embed_tokens = None
|
|
223
|
+
|
|
224
|
+
self.dropout = nn.Dropout(config.dropout)
|
|
225
|
+
|
|
226
|
+
self.to_bfloat16_like_physical_intelligence()
|
|
227
|
+
self.set_requires_grad()
|
|
228
|
+
|
|
229
|
+
def set_requires_grad(self) -> None:
|
|
230
|
+
"""Sets the requires_grad attribute for model parameters based on configuration."""
|
|
231
|
+
if self.config.freeze_vision_encoder:
|
|
232
|
+
self.paligemma.vision_tower.eval()
|
|
233
|
+
for params in self.paligemma.vision_tower.parameters():
|
|
234
|
+
params.requires_grad = False
|
|
235
|
+
|
|
236
|
+
if self.config.train_expert_only:
|
|
237
|
+
self.paligemma.eval()
|
|
238
|
+
for params in self.paligemma.parameters():
|
|
239
|
+
params.requires_grad = False
|
|
240
|
+
|
|
241
|
+
def train(self, mode: bool = True) -> None:
|
|
242
|
+
"""Sets the module in training mode.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
mode: whether to set training mode (True) or evaluation mode (False). Defaults to True.
|
|
246
|
+
"""
|
|
247
|
+
super().train(mode)
|
|
248
|
+
|
|
249
|
+
if self.config.freeze_vision_encoder:
|
|
250
|
+
self.paligemma.vision_tower.eval()
|
|
251
|
+
|
|
252
|
+
if self.config.train_expert_only:
|
|
253
|
+
self.paligemma.eval()
|
|
254
|
+
|
|
255
|
+
def to_bfloat16_like_physical_intelligence(self) -> None:
|
|
256
|
+
"""Casts specific model components to bfloat16 dtype."""
|
|
257
|
+
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
|
|
258
|
+
|
|
259
|
+
params_to_change_dtype = [
|
|
260
|
+
"language_model.model.layers",
|
|
261
|
+
"gemma_expert.model.layers",
|
|
262
|
+
"vision_tower",
|
|
263
|
+
"multi_modal",
|
|
264
|
+
]
|
|
265
|
+
for name, param in self.named_parameters():
|
|
266
|
+
if any(selector in name for selector in params_to_change_dtype):
|
|
267
|
+
param.data = param.data.to(dtype=torch.bfloat16)
|
|
268
|
+
|
|
269
|
+
def embed_image(self, image: torch.Tensor) -> torch.Tensor:
|
|
270
|
+
"""Computes image embeddings.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
image: Input image tensor.
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
torch.Tensor: Image embeddings.
|
|
277
|
+
"""
|
|
278
|
+
# Handle different transformers versions
|
|
279
|
+
if hasattr(self.paligemma, "get_image_features"):
|
|
280
|
+
return self.paligemma.get_image_features(image)
|
|
281
|
+
else:
|
|
282
|
+
return self.paligemma.model.get_image_features(image)
|
|
283
|
+
|
|
284
|
+
def embed_language_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
|
285
|
+
"""Embeds language tokens.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
tokens: Input token indices.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
torch.Tensor: Token embeddings.
|
|
292
|
+
"""
|
|
293
|
+
return self.paligemma.language_model.embed_tokens(tokens)
|
|
294
|
+
|
|
295
|
+
# TODO: break down this huge forward into modules or functions
|
|
296
|
+
def forward(
|
|
297
|
+
self,
|
|
298
|
+
attention_mask: torch.Tensor | None = None,
|
|
299
|
+
position_ids: torch.LongTensor | None = None,
|
|
300
|
+
past_key_values: list[torch.FloatTensor] | Cache | None = None,
|
|
301
|
+
inputs_embeds: list[torch.FloatTensor] = None,
|
|
302
|
+
use_cache: bool | None = None,
|
|
303
|
+
fill_kv_cache: bool | None = None,
|
|
304
|
+
) -> tuple[list[torch.FloatTensor | None], list[torch.FloatTensor] | Cache | None]:
|
|
305
|
+
"""Forward pass of the model.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
attention_mask: Attention mask tensor.
|
|
309
|
+
position_ids: Position IDs tensor.
|
|
310
|
+
past_key_values: Past key values for caching.
|
|
311
|
+
inputs_embeds: List of input embeddings for the different model parts.
|
|
312
|
+
use_cache: Whether to use KV cache.
|
|
313
|
+
fill_kv_cache: Whether to fill the KV cache.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
tuple: A tuple containing:
|
|
317
|
+
- outputs_embeds: List of output embeddings.
|
|
318
|
+
- past_key_values: Updated past key values.
|
|
319
|
+
"""
|
|
320
|
+
models = [self.paligemma.language_model, self.gemma_expert.model]
|
|
321
|
+
|
|
322
|
+
for hidden_states in inputs_embeds:
|
|
323
|
+
# TODO this is very inefficient
|
|
324
|
+
# dtype is always the same, batch size too (if > 1 len)
|
|
325
|
+
# device could be trickier in multi gpu edge cases but that's it
|
|
326
|
+
if hidden_states is None:
|
|
327
|
+
continue
|
|
328
|
+
batch_size = hidden_states.shape[0]
|
|
329
|
+
|
|
330
|
+
# RMSNorm
|
|
331
|
+
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
|
332
|
+
head_dim = self.paligemma.config.text_config.head_dim
|
|
333
|
+
for layer_idx in range(num_layers):
|
|
334
|
+
query_states = []
|
|
335
|
+
key_states = []
|
|
336
|
+
value_states = []
|
|
337
|
+
for i, hidden_states in enumerate(inputs_embeds):
|
|
338
|
+
if hidden_states is None:
|
|
339
|
+
continue
|
|
340
|
+
layer = models[i].layers[layer_idx]
|
|
341
|
+
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
|
|
342
|
+
# hidden_states = hidden_states * normalizer
|
|
343
|
+
hidden_states, _ = layer.input_layernorm(hidden_states)
|
|
344
|
+
|
|
345
|
+
input_shape = hidden_states.shape[:-1]
|
|
346
|
+
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
|
347
|
+
|
|
348
|
+
hidden_states = hidden_states.to(dtype=torch.bfloat16)
|
|
349
|
+
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
|
350
|
+
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
|
351
|
+
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
|
352
|
+
|
|
353
|
+
query_states.append(query_state)
|
|
354
|
+
key_states.append(key_state)
|
|
355
|
+
value_states.append(value_state)
|
|
356
|
+
|
|
357
|
+
# B,L,H,D with L sequence length, H number of heads, D head dim
|
|
358
|
+
# concatenate on the number of embeddings/tokens
|
|
359
|
+
query_states = torch.cat(query_states, dim=1)
|
|
360
|
+
key_states = torch.cat(key_states, dim=1)
|
|
361
|
+
value_states = torch.cat(value_states, dim=1)
|
|
362
|
+
|
|
363
|
+
query_states = apply_rope(query_states, position_ids)
|
|
364
|
+
key_states = apply_rope(key_states, position_ids)
|
|
365
|
+
|
|
366
|
+
if use_cache and past_key_values is None:
|
|
367
|
+
past_key_values = {}
|
|
368
|
+
|
|
369
|
+
if use_cache:
|
|
370
|
+
if fill_kv_cache:
|
|
371
|
+
past_key_values[layer_idx] = {
|
|
372
|
+
"key_states": key_states,
|
|
373
|
+
"value_states": value_states,
|
|
374
|
+
}
|
|
375
|
+
else:
|
|
376
|
+
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
|
377
|
+
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
|
378
|
+
# the max len, then we (for instance) double the cache size. This implementation already exists
|
|
379
|
+
# in `transformers`. (molbap)
|
|
380
|
+
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
|
381
|
+
value_states = torch.cat(
|
|
382
|
+
[past_key_values[layer_idx]["value_states"], value_states], dim=1
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
attention_interface = self.get_attention_interface()
|
|
386
|
+
att_output = attention_interface(
|
|
387
|
+
attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
|
388
|
+
)
|
|
389
|
+
att_output = att_output.to(dtype=torch.bfloat16)
|
|
390
|
+
|
|
391
|
+
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
|
|
392
|
+
outputs_embeds = []
|
|
393
|
+
start = 0
|
|
394
|
+
for i, hidden_states in enumerate(inputs_embeds):
|
|
395
|
+
layer = models[i].layers[layer_idx]
|
|
396
|
+
|
|
397
|
+
if hidden_states is not None:
|
|
398
|
+
end = start + hidden_states.shape[1]
|
|
399
|
+
|
|
400
|
+
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
|
401
|
+
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
|
402
|
+
out_emb = layer.self_attn.o_proj(att_output[:, start:end])
|
|
403
|
+
|
|
404
|
+
out_emb = self.dropout(out_emb)
|
|
405
|
+
|
|
406
|
+
# first residual
|
|
407
|
+
out_emb += hidden_states
|
|
408
|
+
after_first_residual = out_emb.clone()
|
|
409
|
+
|
|
410
|
+
out_emb, _ = layer.post_attention_layernorm(out_emb)
|
|
411
|
+
out_emb = layer.mlp(out_emb)
|
|
412
|
+
|
|
413
|
+
out_emb = self.dropout(out_emb)
|
|
414
|
+
|
|
415
|
+
# second residual
|
|
416
|
+
out_emb += after_first_residual
|
|
417
|
+
|
|
418
|
+
outputs_embeds.append(out_emb)
|
|
419
|
+
|
|
420
|
+
start = end
|
|
421
|
+
else:
|
|
422
|
+
outputs_embeds.append(None)
|
|
423
|
+
|
|
424
|
+
inputs_embeds = outputs_embeds
|
|
425
|
+
|
|
426
|
+
# final norm
|
|
427
|
+
outputs_embeds = []
|
|
428
|
+
for i, hidden_states in enumerate(inputs_embeds):
|
|
429
|
+
if hidden_states is not None:
|
|
430
|
+
out_emb, _ = models[i].norm(hidden_states)
|
|
431
|
+
outputs_embeds.append(out_emb)
|
|
432
|
+
else:
|
|
433
|
+
outputs_embeds.append(None)
|
|
434
|
+
|
|
435
|
+
return outputs_embeds, past_key_values
|
|
436
|
+
|
|
437
|
+
def get_attention_interface(self):
|
|
438
|
+
"""Returns the attention implementation function based on config.
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
callable: The attention function to use.
|
|
442
|
+
"""
|
|
443
|
+
return self.eager_attention_forward
|
|
444
|
+
|
|
445
|
+
def eager_attention_forward(
|
|
446
|
+
self,
|
|
447
|
+
attention_mask: torch.Tensor,
|
|
448
|
+
batch_size: int,
|
|
449
|
+
head_dim: int,
|
|
450
|
+
query_states: torch.Tensor,
|
|
451
|
+
key_states: torch.Tensor,
|
|
452
|
+
value_states: torch.Tensor,
|
|
453
|
+
) -> torch.Tensor:
|
|
454
|
+
"""Eager attention forward pass using standard matrix multiplications.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
attention_mask: Attention mask tensor.
|
|
458
|
+
batch_size: Batch size.
|
|
459
|
+
head_dim: Head dimension.
|
|
460
|
+
query_states: Query states tensor.
|
|
461
|
+
key_states: Key states tensor.
|
|
462
|
+
value_states: Value states tensor.
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
torch.Tensor: Attention output.
|
|
466
|
+
"""
|
|
467
|
+
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
|
|
468
|
+
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
|
|
469
|
+
num_key_value_groups = num_att_heads // num_key_value_heads
|
|
470
|
+
|
|
471
|
+
# query_states: batch_size, sequence_length, num_att_head, head_dim
|
|
472
|
+
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
|
|
473
|
+
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
|
|
474
|
+
sequence_length = key_states.shape[1]
|
|
475
|
+
|
|
476
|
+
key_states = key_states[:, :, :, None, :].expand(
|
|
477
|
+
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
|
478
|
+
)
|
|
479
|
+
key_states = key_states.reshape(
|
|
480
|
+
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
value_states = value_states[:, :, :, None, :].expand(
|
|
484
|
+
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
|
485
|
+
)
|
|
486
|
+
value_states = value_states.reshape(
|
|
487
|
+
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
# Attention here is upcasted to float32 to match the original eager implementation.
|
|
491
|
+
|
|
492
|
+
query_states = query_states.to(dtype=torch.float32)
|
|
493
|
+
key_states = key_states.to(dtype=torch.float32)
|
|
494
|
+
|
|
495
|
+
query_states = query_states.transpose(1, 2)
|
|
496
|
+
key_states = key_states.transpose(1, 2)
|
|
497
|
+
|
|
498
|
+
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
|
499
|
+
att_weights *= head_dim**-0.5
|
|
500
|
+
big_neg = -2.3819763e38 # See gemma/modules.py
|
|
501
|
+
|
|
502
|
+
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
|
503
|
+
|
|
504
|
+
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
|
505
|
+
probs = probs.to(dtype=value_states.dtype)
|
|
506
|
+
|
|
507
|
+
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
|
|
508
|
+
# value_states: batch_size, sequence_length, num_att_heads, head_dim
|
|
509
|
+
|
|
510
|
+
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
|
511
|
+
|
|
512
|
+
att_output = att_output.permute(0, 2, 1, 3)
|
|
513
|
+
# we use -1 because sequence length can change
|
|
514
|
+
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
|
515
|
+
|
|
516
|
+
return att_output
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""
|
|
15
|
+
PI05 Policy Module.
|
|
16
|
+
|
|
17
|
+
This module implements the π05 (Pi05) Vision-Language-Action Flow Model policy,
|
|
18
|
+
designed for general robot control. It includes the policy definition,
|
|
19
|
+
configuration, and model architecture.
|
|
20
|
+
"""
|