opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,516 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """PaliGemma with Expert Module for PI0.
17
+
18
+ This module implements the PaliGemma model with an additional expert module,
19
+ specifically designed for the Pi0 policy. It combines a pre-trained PaliGemma
20
+ Vision-Language Model (VLM) with a Gemma-based expert model to handle
21
+ action generation and conditioning.
22
+ """
23
+
24
+ import torch
25
+ import torch.version
26
+ from pytest import Cache
27
+ from torch import nn
28
+ from transformers import (
29
+ AutoConfig,
30
+ GemmaForCausalLM,
31
+ PaliGemmaForConditionalGeneration,
32
+ PretrainedConfig,
33
+ PreTrainedModel,
34
+ )
35
+ from transformers.models.auto import CONFIG_MAPPING
36
+
37
+
38
+ def apply_rope(x: torch.Tensor, positions: torch.Tensor, max_wavelength: int = 10_000) -> torch.Tensor:
39
+ """Applies RoPE positions to the input tensor.
40
+
41
+ Args:
42
+ x: Input tensor of shape [B, L, H, D].
43
+ positions: Position tensor of shape [B, L].
44
+ max_wavelength: Maximum wavelength for RoPE. Defaults to 10_000.
45
+
46
+ Returns:
47
+ Tensor: The input tensor with RoPE applied, of shape [B, L, H, D].
48
+ """
49
+ d_half = x.shape[-1] // 2
50
+ device = x.device
51
+ dtype = x.dtype
52
+ x = x.to(torch.float32)
53
+
54
+ freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
55
+ timescale = max_wavelength**freq_exponents
56
+ radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
57
+
58
+ radians = radians[..., None, :]
59
+
60
+ sin = torch.sin(radians) # .to(dtype=dtype)
61
+ cos = torch.cos(radians) # .to(dtype=dtype)
62
+
63
+ x1, x2 = x.split(d_half, dim=-1)
64
+ res = torch.empty_like(x)
65
+ res[..., :d_half] = x1 * cos - x2 * sin
66
+ res[..., d_half:] = x2 * cos + x1 * sin
67
+
68
+ return res.to(dtype)
69
+
70
+
71
+ class PaliGemmaWithExpertConfig(PretrainedConfig):
72
+ """Configuration class for PaliGemmaWithExpertModel."""
73
+
74
+ model_type = "PaliGemmaWithExpertModel"
75
+ sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
76
+
77
+ def __init__(
78
+ self,
79
+ paligemma_config: dict | None = None,
80
+ gemma_expert_config: dict | None = None,
81
+ freeze_vision_encoder: bool = True,
82
+ train_expert_only: bool = True,
83
+ attention_implementation: str = "eager",
84
+ load_pretrained_paligemma: bool = False,
85
+ dropout: float = 0.1,
86
+ **kwargs,
87
+ ):
88
+ """Initializes the configuration.
89
+
90
+ Args:
91
+ paligemma_config: Configuration dictionary for the PaliGemma model.
92
+ gemma_expert_config: Configuration dictionary for the Gemma expert model.
93
+ freeze_vision_encoder: Whether to freeze the vision encoder. Defaults to True.
94
+ train_expert_only: Whether to train only the expert model. Defaults to True.
95
+ attention_implementation: Attention implementation to use ("eager" or "fa2"). Defaults to "eager".
96
+ load_pretrained_paligemma: Whether to load a pretrained PaliGemma model. Defaults to False.
97
+ dropout: Dropout probability. Defaults to 0.1.
98
+ **kwargs: Additional keyword arguments passed to PretrainedConfig.
99
+ """
100
+ self.freeze_vision_encoder = freeze_vision_encoder
101
+ self.train_expert_only = train_expert_only
102
+ self.attention_implementation = attention_implementation
103
+ self.load_pretrained_paligemma = load_pretrained_paligemma
104
+ self.dropout = dropout
105
+
106
+ if paligemma_config is None:
107
+ # Default config from Pi0
108
+ self.paligemma_config = CONFIG_MAPPING["paligemma"](
109
+ transformers_version="4.48.1",
110
+ _vocab_size=257152,
111
+ bos_token_id=2,
112
+ eos_token_id=1,
113
+ hidden_size=2048,
114
+ image_token_index=257152,
115
+ model_type="paligemma",
116
+ pad_token_id=0,
117
+ projection_dim=2048,
118
+ text_config={
119
+ "hidden_activation": "gelu_pytorch_tanh",
120
+ "hidden_size": 2048,
121
+ "intermediate_size": 16384,
122
+ "model_type": "gemma",
123
+ "num_attention_heads": 8,
124
+ "num_hidden_layers": 18,
125
+ "num_image_tokens": 256,
126
+ "num_key_value_heads": 1,
127
+ "torch_dtype": "float32",
128
+ "vocab_size": 257152,
129
+ },
130
+ vision_config={
131
+ "hidden_size": 1152,
132
+ "intermediate_size": 4304,
133
+ "model_type": "siglip_vision_model",
134
+ "num_attention_heads": 16,
135
+ "num_hidden_layers": 27,
136
+ "num_image_tokens": 256,
137
+ "patch_size": 14,
138
+ "projection_dim": 2048,
139
+ "projector_hidden_act": "gelu_fast",
140
+ "torch_dtype": "float32",
141
+ "vision_use_head": False,
142
+ },
143
+ )
144
+ elif isinstance(self.paligemma_config, dict):
145
+ # Override Pi0 default config for PaliGemma
146
+ if "model_type" not in gemma_expert_config:
147
+ paligemma_config["model_type"] = "paligemma"
148
+
149
+ cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
150
+ self.paligemma_config = cfg_cls(**paligemma_config)
151
+
152
+ if gemma_expert_config is None:
153
+ # Default config from Pi0
154
+ self.gemma_expert_config = CONFIG_MAPPING["gemma"](
155
+ attention_bias=False,
156
+ attention_dropout=0.0,
157
+ bos_token_id=2,
158
+ eos_token_id=1,
159
+ head_dim=256,
160
+ hidden_act="gelu_pytorch_tanh",
161
+ hidden_activation="gelu_pytorch_tanh",
162
+ hidden_size=1024,
163
+ initializer_range=0.02,
164
+ intermediate_size=4096,
165
+ max_position_embeddings=8192,
166
+ model_type="gemma",
167
+ num_attention_heads=8,
168
+ num_hidden_layers=18,
169
+ num_key_value_heads=1,
170
+ pad_token_id=0,
171
+ rms_norm_eps=1e-06,
172
+ rope_theta=10000.0,
173
+ torch_dtype="float32",
174
+ transformers_version="4.48.1",
175
+ use_cache=True,
176
+ vocab_size=257152,
177
+ )
178
+ elif isinstance(self.gemma_expert_config, dict):
179
+ # Override Pi0 default config for Gemma Expert
180
+ if "model_type" not in gemma_expert_config:
181
+ gemma_expert_config["model_type"] = "gemma"
182
+
183
+ cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
184
+ self.gemma_expert_config = cfg_cls(**gemma_expert_config)
185
+
186
+ super().__init__(**kwargs)
187
+
188
+ def __post_init__(self):
189
+ """Validates configuration parameters."""
190
+ super().__post_init__()
191
+ if self.train_expert_only and not self.freeze_vision_encoder:
192
+ raise ValueError(
193
+ "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
194
+ )
195
+
196
+ if self.attention_implementation not in ["eager", "fa2"]:
197
+ raise ValueError(
198
+ f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager' or 'fa2'."
199
+ )
200
+
201
+
202
+ class PaliGemmaWithExpertModel(PreTrainedModel):
203
+ """PaliGemma model with an additional expert module for action generation."""
204
+
205
+ config_class = PaliGemmaWithExpertConfig
206
+
207
+ def __init__(self, config: PaliGemmaWithExpertConfig):
208
+ """Initializes the PaliGemmaWithExpertModel.
209
+
210
+ Args:
211
+ config: Configuration object for the model.
212
+ """
213
+ super().__init__(config=config)
214
+ self.config = config
215
+
216
+ if config.load_pretrained_paligemma:
217
+ self.paligemma = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-pt-224")
218
+ else:
219
+ self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
220
+ self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
221
+ # Remove unused embed_tokens
222
+ self.gemma_expert.model.embed_tokens = None
223
+
224
+ self.dropout = nn.Dropout(config.dropout)
225
+
226
+ self.to_bfloat16_like_physical_intelligence()
227
+ self.set_requires_grad()
228
+
229
+ def set_requires_grad(self) -> None:
230
+ """Sets the requires_grad attribute for model parameters based on configuration."""
231
+ if self.config.freeze_vision_encoder:
232
+ self.paligemma.vision_tower.eval()
233
+ for params in self.paligemma.vision_tower.parameters():
234
+ params.requires_grad = False
235
+
236
+ if self.config.train_expert_only:
237
+ self.paligemma.eval()
238
+ for params in self.paligemma.parameters():
239
+ params.requires_grad = False
240
+
241
+ def train(self, mode: bool = True) -> None:
242
+ """Sets the module in training mode.
243
+
244
+ Args:
245
+ mode: whether to set training mode (True) or evaluation mode (False). Defaults to True.
246
+ """
247
+ super().train(mode)
248
+
249
+ if self.config.freeze_vision_encoder:
250
+ self.paligemma.vision_tower.eval()
251
+
252
+ if self.config.train_expert_only:
253
+ self.paligemma.eval()
254
+
255
+ def to_bfloat16_like_physical_intelligence(self) -> None:
256
+ """Casts specific model components to bfloat16 dtype."""
257
+ self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
258
+
259
+ params_to_change_dtype = [
260
+ "language_model.model.layers",
261
+ "gemma_expert.model.layers",
262
+ "vision_tower",
263
+ "multi_modal",
264
+ ]
265
+ for name, param in self.named_parameters():
266
+ if any(selector in name for selector in params_to_change_dtype):
267
+ param.data = param.data.to(dtype=torch.bfloat16)
268
+
269
+ def embed_image(self, image: torch.Tensor) -> torch.Tensor:
270
+ """Computes image embeddings.
271
+
272
+ Args:
273
+ image: Input image tensor.
274
+
275
+ Returns:
276
+ torch.Tensor: Image embeddings.
277
+ """
278
+ # Handle different transformers versions
279
+ if hasattr(self.paligemma, "get_image_features"):
280
+ return self.paligemma.get_image_features(image)
281
+ else:
282
+ return self.paligemma.model.get_image_features(image)
283
+
284
+ def embed_language_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
285
+ """Embeds language tokens.
286
+
287
+ Args:
288
+ tokens: Input token indices.
289
+
290
+ Returns:
291
+ torch.Tensor: Token embeddings.
292
+ """
293
+ return self.paligemma.language_model.embed_tokens(tokens)
294
+
295
+ # TODO: break down this huge forward into modules or functions
296
+ def forward(
297
+ self,
298
+ attention_mask: torch.Tensor | None = None,
299
+ position_ids: torch.LongTensor | None = None,
300
+ past_key_values: list[torch.FloatTensor] | Cache | None = None,
301
+ inputs_embeds: list[torch.FloatTensor] = None,
302
+ use_cache: bool | None = None,
303
+ fill_kv_cache: bool | None = None,
304
+ ) -> tuple[list[torch.FloatTensor | None], list[torch.FloatTensor] | Cache | None]:
305
+ """Forward pass of the model.
306
+
307
+ Args:
308
+ attention_mask: Attention mask tensor.
309
+ position_ids: Position IDs tensor.
310
+ past_key_values: Past key values for caching.
311
+ inputs_embeds: List of input embeddings for the different model parts.
312
+ use_cache: Whether to use KV cache.
313
+ fill_kv_cache: Whether to fill the KV cache.
314
+
315
+ Returns:
316
+ tuple: A tuple containing:
317
+ - outputs_embeds: List of output embeddings.
318
+ - past_key_values: Updated past key values.
319
+ """
320
+ models = [self.paligemma.language_model, self.gemma_expert.model]
321
+
322
+ for hidden_states in inputs_embeds:
323
+ # TODO this is very inefficient
324
+ # dtype is always the same, batch size too (if > 1 len)
325
+ # device could be trickier in multi gpu edge cases but that's it
326
+ if hidden_states is None:
327
+ continue
328
+ batch_size = hidden_states.shape[0]
329
+
330
+ # RMSNorm
331
+ num_layers = self.paligemma.config.text_config.num_hidden_layers
332
+ head_dim = self.paligemma.config.text_config.head_dim
333
+ for layer_idx in range(num_layers):
334
+ query_states = []
335
+ key_states = []
336
+ value_states = []
337
+ for i, hidden_states in enumerate(inputs_embeds):
338
+ if hidden_states is None:
339
+ continue
340
+ layer = models[i].layers[layer_idx]
341
+ # normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
342
+ # hidden_states = hidden_states * normalizer
343
+ hidden_states, _ = layer.input_layernorm(hidden_states)
344
+
345
+ input_shape = hidden_states.shape[:-1]
346
+ hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
347
+
348
+ hidden_states = hidden_states.to(dtype=torch.bfloat16)
349
+ query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
350
+ key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
351
+ value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
352
+
353
+ query_states.append(query_state)
354
+ key_states.append(key_state)
355
+ value_states.append(value_state)
356
+
357
+ # B,L,H,D with L sequence length, H number of heads, D head dim
358
+ # concatenate on the number of embeddings/tokens
359
+ query_states = torch.cat(query_states, dim=1)
360
+ key_states = torch.cat(key_states, dim=1)
361
+ value_states = torch.cat(value_states, dim=1)
362
+
363
+ query_states = apply_rope(query_states, position_ids)
364
+ key_states = apply_rope(key_states, position_ids)
365
+
366
+ if use_cache and past_key_values is None:
367
+ past_key_values = {}
368
+
369
+ if use_cache:
370
+ if fill_kv_cache:
371
+ past_key_values[layer_idx] = {
372
+ "key_states": key_states,
373
+ "value_states": value_states,
374
+ }
375
+ else:
376
+ # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
377
+ # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
378
+ # the max len, then we (for instance) double the cache size. This implementation already exists
379
+ # in `transformers`. (molbap)
380
+ key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
381
+ value_states = torch.cat(
382
+ [past_key_values[layer_idx]["value_states"], value_states], dim=1
383
+ )
384
+
385
+ attention_interface = self.get_attention_interface()
386
+ att_output = attention_interface(
387
+ attention_mask, batch_size, head_dim, query_states, key_states, value_states
388
+ )
389
+ att_output = att_output.to(dtype=torch.bfloat16)
390
+
391
+ # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
392
+ outputs_embeds = []
393
+ start = 0
394
+ for i, hidden_states in enumerate(inputs_embeds):
395
+ layer = models[i].layers[layer_idx]
396
+
397
+ if hidden_states is not None:
398
+ end = start + hidden_states.shape[1]
399
+
400
+ if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
401
+ att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
402
+ out_emb = layer.self_attn.o_proj(att_output[:, start:end])
403
+
404
+ out_emb = self.dropout(out_emb)
405
+
406
+ # first residual
407
+ out_emb += hidden_states
408
+ after_first_residual = out_emb.clone()
409
+
410
+ out_emb, _ = layer.post_attention_layernorm(out_emb)
411
+ out_emb = layer.mlp(out_emb)
412
+
413
+ out_emb = self.dropout(out_emb)
414
+
415
+ # second residual
416
+ out_emb += after_first_residual
417
+
418
+ outputs_embeds.append(out_emb)
419
+
420
+ start = end
421
+ else:
422
+ outputs_embeds.append(None)
423
+
424
+ inputs_embeds = outputs_embeds
425
+
426
+ # final norm
427
+ outputs_embeds = []
428
+ for i, hidden_states in enumerate(inputs_embeds):
429
+ if hidden_states is not None:
430
+ out_emb, _ = models[i].norm(hidden_states)
431
+ outputs_embeds.append(out_emb)
432
+ else:
433
+ outputs_embeds.append(None)
434
+
435
+ return outputs_embeds, past_key_values
436
+
437
+ def get_attention_interface(self):
438
+ """Returns the attention implementation function based on config.
439
+
440
+ Returns:
441
+ callable: The attention function to use.
442
+ """
443
+ return self.eager_attention_forward
444
+
445
+ def eager_attention_forward(
446
+ self,
447
+ attention_mask: torch.Tensor,
448
+ batch_size: int,
449
+ head_dim: int,
450
+ query_states: torch.Tensor,
451
+ key_states: torch.Tensor,
452
+ value_states: torch.Tensor,
453
+ ) -> torch.Tensor:
454
+ """Eager attention forward pass using standard matrix multiplications.
455
+
456
+ Args:
457
+ attention_mask: Attention mask tensor.
458
+ batch_size: Batch size.
459
+ head_dim: Head dimension.
460
+ query_states: Query states tensor.
461
+ key_states: Key states tensor.
462
+ value_states: Value states tensor.
463
+
464
+ Returns:
465
+ torch.Tensor: Attention output.
466
+ """
467
+ num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
468
+ num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
469
+ num_key_value_groups = num_att_heads // num_key_value_heads
470
+
471
+ # query_states: batch_size, sequence_length, num_att_head, head_dim
472
+ # key_states: batch_size, sequence_length, num_key_value_head, head_dim
473
+ # value_states: batch_size, sequence_length, num_key_value_head, head_dim
474
+ sequence_length = key_states.shape[1]
475
+
476
+ key_states = key_states[:, :, :, None, :].expand(
477
+ batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
478
+ )
479
+ key_states = key_states.reshape(
480
+ batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
481
+ )
482
+
483
+ value_states = value_states[:, :, :, None, :].expand(
484
+ batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
485
+ )
486
+ value_states = value_states.reshape(
487
+ batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
488
+ )
489
+
490
+ # Attention here is upcasted to float32 to match the original eager implementation.
491
+
492
+ query_states = query_states.to(dtype=torch.float32)
493
+ key_states = key_states.to(dtype=torch.float32)
494
+
495
+ query_states = query_states.transpose(1, 2)
496
+ key_states = key_states.transpose(1, 2)
497
+
498
+ att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
499
+ att_weights *= head_dim**-0.5
500
+ big_neg = -2.3819763e38 # See gemma/modules.py
501
+
502
+ masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
503
+
504
+ probs = nn.functional.softmax(masked_att_weights, dim=-1)
505
+ probs = probs.to(dtype=value_states.dtype)
506
+
507
+ # probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
508
+ # value_states: batch_size, sequence_length, num_att_heads, head_dim
509
+
510
+ att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
511
+
512
+ att_output = att_output.permute(0, 2, 1, 3)
513
+ # we use -1 because sequence length can change
514
+ att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
515
+
516
+ return att_output
@@ -0,0 +1,20 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ PI05 Policy Module.
16
+
17
+ This module implements the π05 (Pi05) Vision-Language-Action Flow Model policy,
18
+ designed for general robot control. It includes the policy definition,
19
+ configuration, and model architecture.
20
+ """