opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,572 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ PaliGemma with Expert Module.
18
+
19
+ This module implements the PaliGemma model with an additional expert module,
20
+ specifically designed for the Pi05 policy. It combines a pre-trained PaliGemma
21
+ Vision-Language Model (VLM) with a Gemma-based expert model to handle
22
+ action generation and conditioning.
23
+ """
24
+
25
+ import torch
26
+ import torch.version
27
+ from pytest import Cache
28
+ from torch import nn
29
+ from transformers import (
30
+ AutoConfig,
31
+ GemmaForCausalLM,
32
+ PaliGemmaForConditionalGeneration,
33
+ PretrainedConfig,
34
+ PreTrainedModel,
35
+ )
36
+ from transformers.models.auto import CONFIG_MAPPING
37
+ from transformers.models.gemma import modeling_gemma
38
+
39
+
40
+ def apply_rope(x: torch.Tensor, positions: torch.Tensor, max_wavelength: int = 10_000) -> torch.Tensor:
41
+ """Applies RoPE positions to the input tensor.
42
+
43
+ Args:
44
+ x: Input tensor of shape [B, L, H, D].
45
+ positions: Position tensor of shape [B, L].
46
+ max_wavelength: Maximum wavelength for RoPE. Defaults to 10_000.
47
+
48
+ Returns:
49
+ Tensor: The input tensor with RoPE applied, of shape [B, L, H, D].
50
+ """
51
+ d_half = x.shape[-1] // 2
52
+ device = x.device
53
+ dtype = x.dtype
54
+ x = x.to(torch.float32)
55
+
56
+ freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
57
+ timescale = max_wavelength**freq_exponents
58
+ radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
59
+
60
+ radians = radians[..., None, :]
61
+
62
+ sin = torch.sin(radians) # .to(dtype=dtype)
63
+ cos = torch.cos(radians) # .to(dtype=dtype)
64
+
65
+ x1, x2 = x.split(d_half, dim=-1)
66
+ res = torch.empty_like(x)
67
+ res[..., :d_half] = x1 * cos - x2 * sin
68
+ res[..., d_half:] = x2 * cos + x1 * sin
69
+
70
+ return res.to(dtype)
71
+
72
+
73
+ class PaliGemmaWithExpertConfig(PretrainedConfig):
74
+ """Configuration class for PaliGemmaWithExpertModel."""
75
+
76
+ model_type = "PaliGemmaWithExpertModel"
77
+ sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
78
+
79
+ def __init__(
80
+ self,
81
+ paligemma_config: dict | None = None,
82
+ gemma_expert_config: dict | None = None,
83
+ freeze_vision_encoder: bool = True,
84
+ train_expert_only: bool = True,
85
+ attention_implementation: str = "eager",
86
+ load_pretrained_paligemma: bool = False,
87
+ discrete_action_vocab_size: int | None = None,
88
+ dropout: float = 0.1,
89
+ **kwargs,
90
+ ):
91
+ """Initializes the configuration.
92
+
93
+ Args:
94
+ paligemma_config: Configuration dictionary for the PaliGemma model.
95
+ gemma_expert_config: Configuration dictionary for the Gemma expert model.
96
+ freeze_vision_encoder: Whether to freeze the vision encoder. Defaults to True.
97
+ train_expert_only: Whether to train only the expert model. Defaults to True.
98
+ attention_implementation: Attention implementation to use ("eager" or "fa2"). Defaults to "eager".
99
+ load_pretrained_paligemma: Whether to load a pretrained PaliGemma model. Defaults to False.
100
+ discrete_action_vocab_size: Vocabulary size for discrete actions.
101
+ dropout: Dropout probability. Defaults to 0.1.
102
+ **kwargs: Additional keyword arguments passed to PretrainedConfig.
103
+ """
104
+ self.freeze_vision_encoder = freeze_vision_encoder
105
+ self.train_expert_only = train_expert_only
106
+ self.attention_implementation = attention_implementation
107
+ self.load_pretrained_paligemma = load_pretrained_paligemma
108
+ self.discrete_action_vocab_size = discrete_action_vocab_size
109
+ self.dropout = dropout
110
+
111
+ if paligemma_config is None:
112
+ # Default config from Pi0
113
+ self.paligemma_config = CONFIG_MAPPING["paligemma"](
114
+ transformers_version="4.48.1",
115
+ _vocab_size=257152,
116
+ bos_token_id=2,
117
+ eos_token_id=1,
118
+ hidden_size=2048,
119
+ image_token_index=257152,
120
+ model_type="paligemma",
121
+ pad_token_id=0,
122
+ projection_dim=2048,
123
+ text_config={
124
+ "hidden_activation": "gelu_pytorch_tanh",
125
+ "hidden_size": 2048,
126
+ "intermediate_size": 16384,
127
+ "model_type": "gemma",
128
+ "num_attention_heads": 8,
129
+ "num_hidden_layers": 18,
130
+ "num_image_tokens": 256,
131
+ "num_key_value_heads": 1,
132
+ "torch_dtype": "float32",
133
+ "vocab_size": 257152,
134
+ "use_adarms": False,
135
+ "adarms_cond_dim": None,
136
+ },
137
+ vision_config={
138
+ "hidden_size": 1152,
139
+ "intermediate_size": 4304,
140
+ "model_type": "siglip_vision_model",
141
+ "num_attention_heads": 16,
142
+ "num_hidden_layers": 27,
143
+ "num_image_tokens": 256,
144
+ "patch_size": 14,
145
+ "projection_dim": 2048,
146
+ "projector_hidden_act": "gelu_fast",
147
+ "torch_dtype": "float32",
148
+ "vision_use_head": False,
149
+ },
150
+ )
151
+ elif isinstance(self.paligemma_config, dict):
152
+ # Override Pi0 default config for PaliGemma
153
+ if "model_type" not in gemma_expert_config:
154
+ paligemma_config["model_type"] = "paligemma"
155
+
156
+ cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
157
+ self.paligemma_config = cfg_cls(**paligemma_config)
158
+
159
+ if gemma_expert_config is None:
160
+ # Default config from Pi0
161
+ self.gemma_expert_config = CONFIG_MAPPING["gemma"](
162
+ attention_bias=False,
163
+ attention_dropout=0.0,
164
+ bos_token_id=2,
165
+ eos_token_id=1,
166
+ head_dim=256,
167
+ hidden_act="gelu_pytorch_tanh",
168
+ hidden_activation="gelu_pytorch_tanh",
169
+ hidden_size=1024,
170
+ initializer_range=0.02,
171
+ intermediate_size=4096,
172
+ max_position_embeddings=8192,
173
+ model_type="gemma",
174
+ num_attention_heads=8,
175
+ num_hidden_layers=18,
176
+ num_key_value_heads=1,
177
+ pad_token_id=0,
178
+ rms_norm_eps=1e-06,
179
+ rope_theta=10000.0,
180
+ torch_dtype="float32",
181
+ use_adarms=True,
182
+ adarms_cond_dim=1024,
183
+ transformers_version="4.48.1",
184
+ use_cache=True,
185
+ vocab_size=257152,
186
+ )
187
+ elif isinstance(self.gemma_expert_config, dict):
188
+ # Override Pi0 default config for Gemma Expert
189
+ if "model_type" not in gemma_expert_config:
190
+ gemma_expert_config["model_type"] = "gemma"
191
+
192
+ cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
193
+ self.gemma_expert_config = cfg_cls(**gemma_expert_config)
194
+
195
+ super().__init__(**kwargs)
196
+
197
+ def __post_init__(self):
198
+ """Validates configuration parameters."""
199
+ super().__post_init__()
200
+ if self.train_expert_only and not self.freeze_vision_encoder:
201
+ raise ValueError(
202
+ "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
203
+ )
204
+
205
+ if self.attention_implementation not in ["eager", "fa2"]:
206
+ raise ValueError(
207
+ f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager' or 'fa2'."
208
+ )
209
+
210
+
211
+ class PaliGemmaWithExpertModel(PreTrainedModel):
212
+ """PaliGemma model with an additional expert module for action generation."""
213
+
214
+ config_class = PaliGemmaWithExpertConfig
215
+
216
+ def __init__(self, config: PaliGemmaWithExpertConfig):
217
+ """Initializes the PaliGemmaWithExpertModel.
218
+
219
+ Args:
220
+ config: Configuration object for the model.
221
+ """
222
+ super().__init__(config=config)
223
+ self.config = config
224
+
225
+ if config.load_pretrained_paligemma:
226
+ self.paligemma = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-pt-224")
227
+ else:
228
+ self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
229
+ self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
230
+ # Remove unused embed_tokens
231
+ self.gemma_expert.model.embed_tokens = None
232
+
233
+ # Learned embedding layer for discrete actions
234
+ # Embedding dimension matches expert model hidden size
235
+ self.discrete_action_embedding = nn.Embedding(
236
+ num_embeddings=config.discrete_action_vocab_size,
237
+ embedding_dim=config.paligemma_config.text_config.hidden_size,
238
+ padding_idx=0, # 0 is used for padding in pad_fast_tokens
239
+ )
240
+
241
+ # discrete action head that maps to action vocab size and not language vocab size
242
+ self.da_head = nn.Linear(
243
+ in_features=config.paligemma_config.text_config.hidden_size,
244
+ out_features=config.discrete_action_vocab_size,
245
+ )
246
+
247
+ self.dropout = nn.Dropout(config.dropout)
248
+
249
+ self.to_bfloat16_like_physical_intelligence()
250
+ self.set_requires_grad()
251
+
252
+ def set_requires_grad(self) -> None:
253
+ """Sets the requires_grad attribute for model parameters based on configuration."""
254
+ if self.config.freeze_vision_encoder:
255
+ self.paligemma.vision_tower.eval()
256
+ for params in self.paligemma.vision_tower.parameters():
257
+ params.requires_grad = False
258
+
259
+ if self.config.train_expert_only:
260
+ self.paligemma.eval()
261
+ for params in self.paligemma.parameters():
262
+ params.requires_grad = False
263
+
264
+ def train(self, mode: bool = True) -> None:
265
+ """Sets the module in training mode.
266
+
267
+ Args:
268
+ mode: whether to set training mode (True) or evaluation mode (False). Defaults to True.
269
+ """
270
+ super().train(mode)
271
+
272
+ if self.config.freeze_vision_encoder:
273
+ self.paligemma.vision_tower.eval()
274
+
275
+ if self.config.train_expert_only:
276
+ self.paligemma.eval()
277
+
278
+ def to_bfloat16_like_physical_intelligence(self) -> None:
279
+ """Casts specific model components to bfloat16 dtype."""
280
+ self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
281
+
282
+ params_to_change_dtype = [
283
+ "language_model.model.layers",
284
+ "gemma_expert.model.layers",
285
+ "vision_tower",
286
+ "multi_modal",
287
+ ]
288
+ for name, param in self.named_parameters():
289
+ if any(selector in name for selector in params_to_change_dtype):
290
+ param.data = param.data.to(dtype=torch.bfloat16)
291
+
292
+ def embed_image(self, image: torch.Tensor) -> torch.Tensor:
293
+ """Computes image embeddings.
294
+
295
+ Args:
296
+ image: Input image tensor.
297
+
298
+ Returns:
299
+ torch.Tensor: Image embeddings.
300
+ """
301
+ # Handle different transformers versions
302
+ if hasattr(self.paligemma, "get_image_features"):
303
+ return self.paligemma.get_image_features(image)
304
+ else:
305
+ return self.paligemma.model.get_image_features(image)
306
+
307
+ def embed_language_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
308
+ """Embeds language tokens.
309
+
310
+ Args:
311
+ tokens: Input token indices.
312
+
313
+ Returns:
314
+ torch.Tensor: Token embeddings.
315
+ """
316
+ return self.paligemma.language_model.embed_tokens(tokens)
317
+
318
+ def embed_discrete_actions(self, actions: torch.Tensor) -> torch.Tensor:
319
+ """Embeds discrete action tokens.
320
+
321
+ Args:
322
+ actions: Input discrete action indices.
323
+
324
+ Returns:
325
+ torch.Tensor: Action embeddings.
326
+ """
327
+ # Ensure actions are long integers for embedding lookup
328
+ if actions.dtype != torch.long:
329
+ actions = actions.long()
330
+
331
+ # Apply embedding layer
332
+ embedded = self.discrete_action_embedding(actions)
333
+
334
+ return embedded
335
+
336
+ # TODO: break down this huge forward into modules or functions
337
+ def forward(
338
+ self,
339
+ attention_mask: torch.Tensor | None = None,
340
+ position_ids: torch.LongTensor | None = None,
341
+ past_key_values: list[torch.FloatTensor] | Cache | None = None,
342
+ inputs_embeds: list[torch.FloatTensor] = None,
343
+ n_cross_att_tokens: int | None = None,
344
+ use_cache: bool | None = None,
345
+ fill_kv_cache: bool | None = None,
346
+ adarms_cond: list[torch.Tensor] | None = None,
347
+ ) -> tuple[list[torch.FloatTensor | None], list[torch.FloatTensor] | Cache | None]:
348
+ """Forward pass of the model.
349
+
350
+ Args:
351
+ attention_mask: Attention mask tensor.
352
+ position_ids: Position IDs tensor.
353
+ past_key_values: Past key values for caching.
354
+ inputs_embeds: List of input embeddings for the different model parts.
355
+ n_cross_att_tokens: Number of cross-attention tokens.
356
+ use_cache: Whether to use KV cache.
357
+ fill_kv_cache: Whether to fill the KV cache.
358
+ adarms_cond: List of AdaRMS conditioning tensors.
359
+
360
+ Returns:
361
+ tuple: A tuple containing:
362
+ - outputs_embeds: List of output embeddings.
363
+ - past_key_values: Updated past key values.
364
+
365
+ Raises:
366
+ ValueError: If `n_cross_att_tokens` is not provided when `fill_kv_cache` is True.
367
+ """
368
+ if adarms_cond is None:
369
+ adarms_cond = [None, None]
370
+
371
+ models = [self.paligemma.language_model, self.gemma_expert.model]
372
+
373
+ for hidden_states in inputs_embeds:
374
+ # TODO this is very inefficient
375
+ # dtype is always the same, batch size too (if > 1 len)
376
+ # device could be trickier in multi gpu edge cases but that's it
377
+ if hidden_states is None:
378
+ continue
379
+ batch_size = hidden_states.shape[0]
380
+
381
+ # RMSNorm
382
+ num_layers = self.paligemma.config.text_config.num_hidden_layers
383
+ head_dim = self.paligemma.config.text_config.head_dim
384
+ for layer_idx in range(num_layers):
385
+ query_states = []
386
+ key_states = []
387
+ value_states = []
388
+ gates = []
389
+ for i, hidden_states in enumerate(inputs_embeds):
390
+ if hidden_states is None:
391
+ gates.append(None)
392
+ continue
393
+ layer = models[i].layers[layer_idx]
394
+ # normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
395
+ # hidden_states = hidden_states * normalizer
396
+ hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i])
397
+ gates.append(gate)
398
+ input_shape = hidden_states.shape[:-1]
399
+ hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
400
+
401
+ hidden_states = hidden_states.to(dtype=torch.bfloat16)
402
+ query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
403
+ key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
404
+ value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
405
+
406
+ query_states.append(query_state)
407
+ key_states.append(key_state)
408
+ value_states.append(value_state)
409
+
410
+ # B,L,H,D with L sequence length, H number of heads, D head dim
411
+ # concatenate on the number of embeddings/tokens
412
+ query_states = torch.cat(query_states, dim=1)
413
+ key_states = torch.cat(key_states, dim=1)
414
+ value_states = torch.cat(value_states, dim=1)
415
+
416
+ query_states = apply_rope(query_states, position_ids)
417
+ key_states = apply_rope(key_states, position_ids)
418
+
419
+ if use_cache and past_key_values is None:
420
+ past_key_values = {}
421
+
422
+ if use_cache:
423
+ if fill_kv_cache:
424
+ if n_cross_att_tokens is None:
425
+ raise ValueError("n_cross_att_tokens must be provided when fill_kv_cache is True")
426
+ past_key_values[layer_idx] = {
427
+ # save the first n_cross_att_tokens for action expert cross attention
428
+ "key_states": key_states[:, :n_cross_att_tokens, :, :],
429
+ "value_states": value_states[:, :n_cross_att_tokens, :, :],
430
+ }
431
+ else:
432
+ # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
433
+ # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
434
+ # the max len, then we (for instance) double the cache size. This implementation already exists
435
+ # in `transformers`. (molbap)
436
+ key_states = torch.cat([key_states, past_key_values[layer_idx]["key_states"]], dim=1)
437
+ value_states = torch.cat(
438
+ [value_states, past_key_values[layer_idx]["value_states"]], dim=1
439
+ )
440
+
441
+ attention_interface = self.get_attention_interface()
442
+ att_output = attention_interface(
443
+ attention_mask, batch_size, head_dim, query_states, key_states, value_states
444
+ )
445
+ att_output = att_output.to(dtype=torch.bfloat16)
446
+
447
+ # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
448
+ outputs_embeds = []
449
+ start = 0
450
+ for i, hidden_states in enumerate(inputs_embeds):
451
+ layer = models[i].layers[layer_idx]
452
+
453
+ if hidden_states is not None:
454
+ end = start + hidden_states.shape[1]
455
+
456
+ if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
457
+ att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
458
+ out_emb = layer.self_attn.o_proj(att_output[:, start:end])
459
+
460
+ out_emb = self.dropout(out_emb)
461
+
462
+ # first residual
463
+ out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
464
+ after_first_residual = out_emb.clone()
465
+
466
+ out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
467
+ out_emb = layer.mlp(out_emb)
468
+
469
+ out_emb = self.dropout(out_emb)
470
+
471
+ # second residual
472
+ out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
473
+
474
+ outputs_embeds.append(out_emb)
475
+
476
+ start = end
477
+ else:
478
+ outputs_embeds.append(None)
479
+
480
+ inputs_embeds = outputs_embeds
481
+
482
+ # final norm
483
+ outputs_embeds = []
484
+ for i, hidden_states in enumerate(inputs_embeds):
485
+ if hidden_states is not None:
486
+ out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
487
+ outputs_embeds.append(out_emb)
488
+ else:
489
+ outputs_embeds.append(None)
490
+
491
+ return outputs_embeds, past_key_values
492
+
493
+ def get_attention_interface(self):
494
+ """Returns the attention implementation function based on config.
495
+
496
+ Returns:
497
+ callable: The attention function to use.
498
+ """
499
+ return self.eager_attention_forward
500
+
501
+ def eager_attention_forward(
502
+ self,
503
+ attention_mask: torch.Tensor,
504
+ batch_size: int,
505
+ head_dim: int,
506
+ query_states: torch.Tensor,
507
+ key_states: torch.Tensor,
508
+ value_states: torch.Tensor,
509
+ ) -> torch.Tensor:
510
+ """Eager attention forward pass using standard matrix multiplications.
511
+
512
+ Args:
513
+ attention_mask: Attention mask tensor.
514
+ batch_size: Batch size.
515
+ head_dim: Head dimension.
516
+ query_states: Query states tensor.
517
+ key_states: Key states tensor.
518
+ value_states: Value states tensor.
519
+
520
+ Returns:
521
+ torch.Tensor: Attention output.
522
+ """
523
+ num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
524
+ num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
525
+ num_key_value_groups = num_att_heads // num_key_value_heads
526
+
527
+ # query_states: batch_size, sequence_length, num_att_head, head_dim
528
+ # key_states: batch_size, sequence_length, num_key_value_head, head_dim
529
+ # value_states: batch_size, sequence_length, num_key_value_head, head_dim
530
+ sequence_length = key_states.shape[1]
531
+
532
+ key_states = key_states[:, :, :, None, :].expand(
533
+ batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
534
+ )
535
+ key_states = key_states.reshape(
536
+ batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
537
+ )
538
+
539
+ value_states = value_states[:, :, :, None, :].expand(
540
+ batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
541
+ )
542
+ value_states = value_states.reshape(
543
+ batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
544
+ )
545
+
546
+ # Attention here is upcasted to float32 to match the original eager implementation.
547
+
548
+ query_states = query_states.to(dtype=torch.float32)
549
+ key_states = key_states.to(dtype=torch.float32)
550
+
551
+ query_states = query_states.transpose(1, 2)
552
+ key_states = key_states.transpose(1, 2)
553
+
554
+ att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
555
+ att_weights *= head_dim**-0.5
556
+ big_neg = -2.3819763e38 # See gemma/modules.py
557
+
558
+ masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
559
+
560
+ probs = nn.functional.softmax(masked_att_weights, dim=-1)
561
+ probs = probs.to(dtype=value_states.dtype)
562
+
563
+ # probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
564
+ # value_states: batch_size, sequence_length, num_att_heads, head_dim
565
+
566
+ att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
567
+
568
+ att_output = att_output.permute(0, 2, 1, 3)
569
+ # we use -1 because sequence length can change
570
+ att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
571
+
572
+ return att_output