opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,572 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
PaliGemma with Expert Module.
|
|
18
|
+
|
|
19
|
+
This module implements the PaliGemma model with an additional expert module,
|
|
20
|
+
specifically designed for the Pi05 policy. It combines a pre-trained PaliGemma
|
|
21
|
+
Vision-Language Model (VLM) with a Gemma-based expert model to handle
|
|
22
|
+
action generation and conditioning.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
import torch.version
|
|
27
|
+
from pytest import Cache
|
|
28
|
+
from torch import nn
|
|
29
|
+
from transformers import (
|
|
30
|
+
AutoConfig,
|
|
31
|
+
GemmaForCausalLM,
|
|
32
|
+
PaliGemmaForConditionalGeneration,
|
|
33
|
+
PretrainedConfig,
|
|
34
|
+
PreTrainedModel,
|
|
35
|
+
)
|
|
36
|
+
from transformers.models.auto import CONFIG_MAPPING
|
|
37
|
+
from transformers.models.gemma import modeling_gemma
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def apply_rope(x: torch.Tensor, positions: torch.Tensor, max_wavelength: int = 10_000) -> torch.Tensor:
|
|
41
|
+
"""Applies RoPE positions to the input tensor.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
x: Input tensor of shape [B, L, H, D].
|
|
45
|
+
positions: Position tensor of shape [B, L].
|
|
46
|
+
max_wavelength: Maximum wavelength for RoPE. Defaults to 10_000.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Tensor: The input tensor with RoPE applied, of shape [B, L, H, D].
|
|
50
|
+
"""
|
|
51
|
+
d_half = x.shape[-1] // 2
|
|
52
|
+
device = x.device
|
|
53
|
+
dtype = x.dtype
|
|
54
|
+
x = x.to(torch.float32)
|
|
55
|
+
|
|
56
|
+
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
|
57
|
+
timescale = max_wavelength**freq_exponents
|
|
58
|
+
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
|
59
|
+
|
|
60
|
+
radians = radians[..., None, :]
|
|
61
|
+
|
|
62
|
+
sin = torch.sin(radians) # .to(dtype=dtype)
|
|
63
|
+
cos = torch.cos(radians) # .to(dtype=dtype)
|
|
64
|
+
|
|
65
|
+
x1, x2 = x.split(d_half, dim=-1)
|
|
66
|
+
res = torch.empty_like(x)
|
|
67
|
+
res[..., :d_half] = x1 * cos - x2 * sin
|
|
68
|
+
res[..., d_half:] = x2 * cos + x1 * sin
|
|
69
|
+
|
|
70
|
+
return res.to(dtype)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class PaliGemmaWithExpertConfig(PretrainedConfig):
|
|
74
|
+
"""Configuration class for PaliGemmaWithExpertModel."""
|
|
75
|
+
|
|
76
|
+
model_type = "PaliGemmaWithExpertModel"
|
|
77
|
+
sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
paligemma_config: dict | None = None,
|
|
82
|
+
gemma_expert_config: dict | None = None,
|
|
83
|
+
freeze_vision_encoder: bool = True,
|
|
84
|
+
train_expert_only: bool = True,
|
|
85
|
+
attention_implementation: str = "eager",
|
|
86
|
+
load_pretrained_paligemma: bool = False,
|
|
87
|
+
discrete_action_vocab_size: int | None = None,
|
|
88
|
+
dropout: float = 0.1,
|
|
89
|
+
**kwargs,
|
|
90
|
+
):
|
|
91
|
+
"""Initializes the configuration.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
paligemma_config: Configuration dictionary for the PaliGemma model.
|
|
95
|
+
gemma_expert_config: Configuration dictionary for the Gemma expert model.
|
|
96
|
+
freeze_vision_encoder: Whether to freeze the vision encoder. Defaults to True.
|
|
97
|
+
train_expert_only: Whether to train only the expert model. Defaults to True.
|
|
98
|
+
attention_implementation: Attention implementation to use ("eager" or "fa2"). Defaults to "eager".
|
|
99
|
+
load_pretrained_paligemma: Whether to load a pretrained PaliGemma model. Defaults to False.
|
|
100
|
+
discrete_action_vocab_size: Vocabulary size for discrete actions.
|
|
101
|
+
dropout: Dropout probability. Defaults to 0.1.
|
|
102
|
+
**kwargs: Additional keyword arguments passed to PretrainedConfig.
|
|
103
|
+
"""
|
|
104
|
+
self.freeze_vision_encoder = freeze_vision_encoder
|
|
105
|
+
self.train_expert_only = train_expert_only
|
|
106
|
+
self.attention_implementation = attention_implementation
|
|
107
|
+
self.load_pretrained_paligemma = load_pretrained_paligemma
|
|
108
|
+
self.discrete_action_vocab_size = discrete_action_vocab_size
|
|
109
|
+
self.dropout = dropout
|
|
110
|
+
|
|
111
|
+
if paligemma_config is None:
|
|
112
|
+
# Default config from Pi0
|
|
113
|
+
self.paligemma_config = CONFIG_MAPPING["paligemma"](
|
|
114
|
+
transformers_version="4.48.1",
|
|
115
|
+
_vocab_size=257152,
|
|
116
|
+
bos_token_id=2,
|
|
117
|
+
eos_token_id=1,
|
|
118
|
+
hidden_size=2048,
|
|
119
|
+
image_token_index=257152,
|
|
120
|
+
model_type="paligemma",
|
|
121
|
+
pad_token_id=0,
|
|
122
|
+
projection_dim=2048,
|
|
123
|
+
text_config={
|
|
124
|
+
"hidden_activation": "gelu_pytorch_tanh",
|
|
125
|
+
"hidden_size": 2048,
|
|
126
|
+
"intermediate_size": 16384,
|
|
127
|
+
"model_type": "gemma",
|
|
128
|
+
"num_attention_heads": 8,
|
|
129
|
+
"num_hidden_layers": 18,
|
|
130
|
+
"num_image_tokens": 256,
|
|
131
|
+
"num_key_value_heads": 1,
|
|
132
|
+
"torch_dtype": "float32",
|
|
133
|
+
"vocab_size": 257152,
|
|
134
|
+
"use_adarms": False,
|
|
135
|
+
"adarms_cond_dim": None,
|
|
136
|
+
},
|
|
137
|
+
vision_config={
|
|
138
|
+
"hidden_size": 1152,
|
|
139
|
+
"intermediate_size": 4304,
|
|
140
|
+
"model_type": "siglip_vision_model",
|
|
141
|
+
"num_attention_heads": 16,
|
|
142
|
+
"num_hidden_layers": 27,
|
|
143
|
+
"num_image_tokens": 256,
|
|
144
|
+
"patch_size": 14,
|
|
145
|
+
"projection_dim": 2048,
|
|
146
|
+
"projector_hidden_act": "gelu_fast",
|
|
147
|
+
"torch_dtype": "float32",
|
|
148
|
+
"vision_use_head": False,
|
|
149
|
+
},
|
|
150
|
+
)
|
|
151
|
+
elif isinstance(self.paligemma_config, dict):
|
|
152
|
+
# Override Pi0 default config for PaliGemma
|
|
153
|
+
if "model_type" not in gemma_expert_config:
|
|
154
|
+
paligemma_config["model_type"] = "paligemma"
|
|
155
|
+
|
|
156
|
+
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
|
157
|
+
self.paligemma_config = cfg_cls(**paligemma_config)
|
|
158
|
+
|
|
159
|
+
if gemma_expert_config is None:
|
|
160
|
+
# Default config from Pi0
|
|
161
|
+
self.gemma_expert_config = CONFIG_MAPPING["gemma"](
|
|
162
|
+
attention_bias=False,
|
|
163
|
+
attention_dropout=0.0,
|
|
164
|
+
bos_token_id=2,
|
|
165
|
+
eos_token_id=1,
|
|
166
|
+
head_dim=256,
|
|
167
|
+
hidden_act="gelu_pytorch_tanh",
|
|
168
|
+
hidden_activation="gelu_pytorch_tanh",
|
|
169
|
+
hidden_size=1024,
|
|
170
|
+
initializer_range=0.02,
|
|
171
|
+
intermediate_size=4096,
|
|
172
|
+
max_position_embeddings=8192,
|
|
173
|
+
model_type="gemma",
|
|
174
|
+
num_attention_heads=8,
|
|
175
|
+
num_hidden_layers=18,
|
|
176
|
+
num_key_value_heads=1,
|
|
177
|
+
pad_token_id=0,
|
|
178
|
+
rms_norm_eps=1e-06,
|
|
179
|
+
rope_theta=10000.0,
|
|
180
|
+
torch_dtype="float32",
|
|
181
|
+
use_adarms=True,
|
|
182
|
+
adarms_cond_dim=1024,
|
|
183
|
+
transformers_version="4.48.1",
|
|
184
|
+
use_cache=True,
|
|
185
|
+
vocab_size=257152,
|
|
186
|
+
)
|
|
187
|
+
elif isinstance(self.gemma_expert_config, dict):
|
|
188
|
+
# Override Pi0 default config for Gemma Expert
|
|
189
|
+
if "model_type" not in gemma_expert_config:
|
|
190
|
+
gemma_expert_config["model_type"] = "gemma"
|
|
191
|
+
|
|
192
|
+
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
|
193
|
+
self.gemma_expert_config = cfg_cls(**gemma_expert_config)
|
|
194
|
+
|
|
195
|
+
super().__init__(**kwargs)
|
|
196
|
+
|
|
197
|
+
def __post_init__(self):
|
|
198
|
+
"""Validates configuration parameters."""
|
|
199
|
+
super().__post_init__()
|
|
200
|
+
if self.train_expert_only and not self.freeze_vision_encoder:
|
|
201
|
+
raise ValueError(
|
|
202
|
+
"You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
if self.attention_implementation not in ["eager", "fa2"]:
|
|
206
|
+
raise ValueError(
|
|
207
|
+
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager' or 'fa2'."
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class PaliGemmaWithExpertModel(PreTrainedModel):
|
|
212
|
+
"""PaliGemma model with an additional expert module for action generation."""
|
|
213
|
+
|
|
214
|
+
config_class = PaliGemmaWithExpertConfig
|
|
215
|
+
|
|
216
|
+
def __init__(self, config: PaliGemmaWithExpertConfig):
|
|
217
|
+
"""Initializes the PaliGemmaWithExpertModel.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
config: Configuration object for the model.
|
|
221
|
+
"""
|
|
222
|
+
super().__init__(config=config)
|
|
223
|
+
self.config = config
|
|
224
|
+
|
|
225
|
+
if config.load_pretrained_paligemma:
|
|
226
|
+
self.paligemma = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-pt-224")
|
|
227
|
+
else:
|
|
228
|
+
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
|
|
229
|
+
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
|
|
230
|
+
# Remove unused embed_tokens
|
|
231
|
+
self.gemma_expert.model.embed_tokens = None
|
|
232
|
+
|
|
233
|
+
# Learned embedding layer for discrete actions
|
|
234
|
+
# Embedding dimension matches expert model hidden size
|
|
235
|
+
self.discrete_action_embedding = nn.Embedding(
|
|
236
|
+
num_embeddings=config.discrete_action_vocab_size,
|
|
237
|
+
embedding_dim=config.paligemma_config.text_config.hidden_size,
|
|
238
|
+
padding_idx=0, # 0 is used for padding in pad_fast_tokens
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# discrete action head that maps to action vocab size and not language vocab size
|
|
242
|
+
self.da_head = nn.Linear(
|
|
243
|
+
in_features=config.paligemma_config.text_config.hidden_size,
|
|
244
|
+
out_features=config.discrete_action_vocab_size,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
self.dropout = nn.Dropout(config.dropout)
|
|
248
|
+
|
|
249
|
+
self.to_bfloat16_like_physical_intelligence()
|
|
250
|
+
self.set_requires_grad()
|
|
251
|
+
|
|
252
|
+
def set_requires_grad(self) -> None:
|
|
253
|
+
"""Sets the requires_grad attribute for model parameters based on configuration."""
|
|
254
|
+
if self.config.freeze_vision_encoder:
|
|
255
|
+
self.paligemma.vision_tower.eval()
|
|
256
|
+
for params in self.paligemma.vision_tower.parameters():
|
|
257
|
+
params.requires_grad = False
|
|
258
|
+
|
|
259
|
+
if self.config.train_expert_only:
|
|
260
|
+
self.paligemma.eval()
|
|
261
|
+
for params in self.paligemma.parameters():
|
|
262
|
+
params.requires_grad = False
|
|
263
|
+
|
|
264
|
+
def train(self, mode: bool = True) -> None:
|
|
265
|
+
"""Sets the module in training mode.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
mode: whether to set training mode (True) or evaluation mode (False). Defaults to True.
|
|
269
|
+
"""
|
|
270
|
+
super().train(mode)
|
|
271
|
+
|
|
272
|
+
if self.config.freeze_vision_encoder:
|
|
273
|
+
self.paligemma.vision_tower.eval()
|
|
274
|
+
|
|
275
|
+
if self.config.train_expert_only:
|
|
276
|
+
self.paligemma.eval()
|
|
277
|
+
|
|
278
|
+
def to_bfloat16_like_physical_intelligence(self) -> None:
|
|
279
|
+
"""Casts specific model components to bfloat16 dtype."""
|
|
280
|
+
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
|
|
281
|
+
|
|
282
|
+
params_to_change_dtype = [
|
|
283
|
+
"language_model.model.layers",
|
|
284
|
+
"gemma_expert.model.layers",
|
|
285
|
+
"vision_tower",
|
|
286
|
+
"multi_modal",
|
|
287
|
+
]
|
|
288
|
+
for name, param in self.named_parameters():
|
|
289
|
+
if any(selector in name for selector in params_to_change_dtype):
|
|
290
|
+
param.data = param.data.to(dtype=torch.bfloat16)
|
|
291
|
+
|
|
292
|
+
def embed_image(self, image: torch.Tensor) -> torch.Tensor:
|
|
293
|
+
"""Computes image embeddings.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
image: Input image tensor.
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
torch.Tensor: Image embeddings.
|
|
300
|
+
"""
|
|
301
|
+
# Handle different transformers versions
|
|
302
|
+
if hasattr(self.paligemma, "get_image_features"):
|
|
303
|
+
return self.paligemma.get_image_features(image)
|
|
304
|
+
else:
|
|
305
|
+
return self.paligemma.model.get_image_features(image)
|
|
306
|
+
|
|
307
|
+
def embed_language_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
|
308
|
+
"""Embeds language tokens.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
tokens: Input token indices.
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
torch.Tensor: Token embeddings.
|
|
315
|
+
"""
|
|
316
|
+
return self.paligemma.language_model.embed_tokens(tokens)
|
|
317
|
+
|
|
318
|
+
def embed_discrete_actions(self, actions: torch.Tensor) -> torch.Tensor:
|
|
319
|
+
"""Embeds discrete action tokens.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
actions: Input discrete action indices.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
torch.Tensor: Action embeddings.
|
|
326
|
+
"""
|
|
327
|
+
# Ensure actions are long integers for embedding lookup
|
|
328
|
+
if actions.dtype != torch.long:
|
|
329
|
+
actions = actions.long()
|
|
330
|
+
|
|
331
|
+
# Apply embedding layer
|
|
332
|
+
embedded = self.discrete_action_embedding(actions)
|
|
333
|
+
|
|
334
|
+
return embedded
|
|
335
|
+
|
|
336
|
+
# TODO: break down this huge forward into modules or functions
|
|
337
|
+
def forward(
|
|
338
|
+
self,
|
|
339
|
+
attention_mask: torch.Tensor | None = None,
|
|
340
|
+
position_ids: torch.LongTensor | None = None,
|
|
341
|
+
past_key_values: list[torch.FloatTensor] | Cache | None = None,
|
|
342
|
+
inputs_embeds: list[torch.FloatTensor] = None,
|
|
343
|
+
n_cross_att_tokens: int | None = None,
|
|
344
|
+
use_cache: bool | None = None,
|
|
345
|
+
fill_kv_cache: bool | None = None,
|
|
346
|
+
adarms_cond: list[torch.Tensor] | None = None,
|
|
347
|
+
) -> tuple[list[torch.FloatTensor | None], list[torch.FloatTensor] | Cache | None]:
|
|
348
|
+
"""Forward pass of the model.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
attention_mask: Attention mask tensor.
|
|
352
|
+
position_ids: Position IDs tensor.
|
|
353
|
+
past_key_values: Past key values for caching.
|
|
354
|
+
inputs_embeds: List of input embeddings for the different model parts.
|
|
355
|
+
n_cross_att_tokens: Number of cross-attention tokens.
|
|
356
|
+
use_cache: Whether to use KV cache.
|
|
357
|
+
fill_kv_cache: Whether to fill the KV cache.
|
|
358
|
+
adarms_cond: List of AdaRMS conditioning tensors.
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
tuple: A tuple containing:
|
|
362
|
+
- outputs_embeds: List of output embeddings.
|
|
363
|
+
- past_key_values: Updated past key values.
|
|
364
|
+
|
|
365
|
+
Raises:
|
|
366
|
+
ValueError: If `n_cross_att_tokens` is not provided when `fill_kv_cache` is True.
|
|
367
|
+
"""
|
|
368
|
+
if adarms_cond is None:
|
|
369
|
+
adarms_cond = [None, None]
|
|
370
|
+
|
|
371
|
+
models = [self.paligemma.language_model, self.gemma_expert.model]
|
|
372
|
+
|
|
373
|
+
for hidden_states in inputs_embeds:
|
|
374
|
+
# TODO this is very inefficient
|
|
375
|
+
# dtype is always the same, batch size too (if > 1 len)
|
|
376
|
+
# device could be trickier in multi gpu edge cases but that's it
|
|
377
|
+
if hidden_states is None:
|
|
378
|
+
continue
|
|
379
|
+
batch_size = hidden_states.shape[0]
|
|
380
|
+
|
|
381
|
+
# RMSNorm
|
|
382
|
+
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
|
383
|
+
head_dim = self.paligemma.config.text_config.head_dim
|
|
384
|
+
for layer_idx in range(num_layers):
|
|
385
|
+
query_states = []
|
|
386
|
+
key_states = []
|
|
387
|
+
value_states = []
|
|
388
|
+
gates = []
|
|
389
|
+
for i, hidden_states in enumerate(inputs_embeds):
|
|
390
|
+
if hidden_states is None:
|
|
391
|
+
gates.append(None)
|
|
392
|
+
continue
|
|
393
|
+
layer = models[i].layers[layer_idx]
|
|
394
|
+
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
|
|
395
|
+
# hidden_states = hidden_states * normalizer
|
|
396
|
+
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i])
|
|
397
|
+
gates.append(gate)
|
|
398
|
+
input_shape = hidden_states.shape[:-1]
|
|
399
|
+
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
|
400
|
+
|
|
401
|
+
hidden_states = hidden_states.to(dtype=torch.bfloat16)
|
|
402
|
+
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
|
403
|
+
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
|
404
|
+
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
|
405
|
+
|
|
406
|
+
query_states.append(query_state)
|
|
407
|
+
key_states.append(key_state)
|
|
408
|
+
value_states.append(value_state)
|
|
409
|
+
|
|
410
|
+
# B,L,H,D with L sequence length, H number of heads, D head dim
|
|
411
|
+
# concatenate on the number of embeddings/tokens
|
|
412
|
+
query_states = torch.cat(query_states, dim=1)
|
|
413
|
+
key_states = torch.cat(key_states, dim=1)
|
|
414
|
+
value_states = torch.cat(value_states, dim=1)
|
|
415
|
+
|
|
416
|
+
query_states = apply_rope(query_states, position_ids)
|
|
417
|
+
key_states = apply_rope(key_states, position_ids)
|
|
418
|
+
|
|
419
|
+
if use_cache and past_key_values is None:
|
|
420
|
+
past_key_values = {}
|
|
421
|
+
|
|
422
|
+
if use_cache:
|
|
423
|
+
if fill_kv_cache:
|
|
424
|
+
if n_cross_att_tokens is None:
|
|
425
|
+
raise ValueError("n_cross_att_tokens must be provided when fill_kv_cache is True")
|
|
426
|
+
past_key_values[layer_idx] = {
|
|
427
|
+
# save the first n_cross_att_tokens for action expert cross attention
|
|
428
|
+
"key_states": key_states[:, :n_cross_att_tokens, :, :],
|
|
429
|
+
"value_states": value_states[:, :n_cross_att_tokens, :, :],
|
|
430
|
+
}
|
|
431
|
+
else:
|
|
432
|
+
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
|
433
|
+
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
|
434
|
+
# the max len, then we (for instance) double the cache size. This implementation already exists
|
|
435
|
+
# in `transformers`. (molbap)
|
|
436
|
+
key_states = torch.cat([key_states, past_key_values[layer_idx]["key_states"]], dim=1)
|
|
437
|
+
value_states = torch.cat(
|
|
438
|
+
[value_states, past_key_values[layer_idx]["value_states"]], dim=1
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
attention_interface = self.get_attention_interface()
|
|
442
|
+
att_output = attention_interface(
|
|
443
|
+
attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
|
444
|
+
)
|
|
445
|
+
att_output = att_output.to(dtype=torch.bfloat16)
|
|
446
|
+
|
|
447
|
+
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
|
|
448
|
+
outputs_embeds = []
|
|
449
|
+
start = 0
|
|
450
|
+
for i, hidden_states in enumerate(inputs_embeds):
|
|
451
|
+
layer = models[i].layers[layer_idx]
|
|
452
|
+
|
|
453
|
+
if hidden_states is not None:
|
|
454
|
+
end = start + hidden_states.shape[1]
|
|
455
|
+
|
|
456
|
+
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
|
457
|
+
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
|
458
|
+
out_emb = layer.self_attn.o_proj(att_output[:, start:end])
|
|
459
|
+
|
|
460
|
+
out_emb = self.dropout(out_emb)
|
|
461
|
+
|
|
462
|
+
# first residual
|
|
463
|
+
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
|
464
|
+
after_first_residual = out_emb.clone()
|
|
465
|
+
|
|
466
|
+
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
|
467
|
+
out_emb = layer.mlp(out_emb)
|
|
468
|
+
|
|
469
|
+
out_emb = self.dropout(out_emb)
|
|
470
|
+
|
|
471
|
+
# second residual
|
|
472
|
+
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
|
473
|
+
|
|
474
|
+
outputs_embeds.append(out_emb)
|
|
475
|
+
|
|
476
|
+
start = end
|
|
477
|
+
else:
|
|
478
|
+
outputs_embeds.append(None)
|
|
479
|
+
|
|
480
|
+
inputs_embeds = outputs_embeds
|
|
481
|
+
|
|
482
|
+
# final norm
|
|
483
|
+
outputs_embeds = []
|
|
484
|
+
for i, hidden_states in enumerate(inputs_embeds):
|
|
485
|
+
if hidden_states is not None:
|
|
486
|
+
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
|
487
|
+
outputs_embeds.append(out_emb)
|
|
488
|
+
else:
|
|
489
|
+
outputs_embeds.append(None)
|
|
490
|
+
|
|
491
|
+
return outputs_embeds, past_key_values
|
|
492
|
+
|
|
493
|
+
def get_attention_interface(self):
|
|
494
|
+
"""Returns the attention implementation function based on config.
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
callable: The attention function to use.
|
|
498
|
+
"""
|
|
499
|
+
return self.eager_attention_forward
|
|
500
|
+
|
|
501
|
+
def eager_attention_forward(
|
|
502
|
+
self,
|
|
503
|
+
attention_mask: torch.Tensor,
|
|
504
|
+
batch_size: int,
|
|
505
|
+
head_dim: int,
|
|
506
|
+
query_states: torch.Tensor,
|
|
507
|
+
key_states: torch.Tensor,
|
|
508
|
+
value_states: torch.Tensor,
|
|
509
|
+
) -> torch.Tensor:
|
|
510
|
+
"""Eager attention forward pass using standard matrix multiplications.
|
|
511
|
+
|
|
512
|
+
Args:
|
|
513
|
+
attention_mask: Attention mask tensor.
|
|
514
|
+
batch_size: Batch size.
|
|
515
|
+
head_dim: Head dimension.
|
|
516
|
+
query_states: Query states tensor.
|
|
517
|
+
key_states: Key states tensor.
|
|
518
|
+
value_states: Value states tensor.
|
|
519
|
+
|
|
520
|
+
Returns:
|
|
521
|
+
torch.Tensor: Attention output.
|
|
522
|
+
"""
|
|
523
|
+
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
|
|
524
|
+
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
|
|
525
|
+
num_key_value_groups = num_att_heads // num_key_value_heads
|
|
526
|
+
|
|
527
|
+
# query_states: batch_size, sequence_length, num_att_head, head_dim
|
|
528
|
+
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
|
|
529
|
+
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
|
|
530
|
+
sequence_length = key_states.shape[1]
|
|
531
|
+
|
|
532
|
+
key_states = key_states[:, :, :, None, :].expand(
|
|
533
|
+
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
|
534
|
+
)
|
|
535
|
+
key_states = key_states.reshape(
|
|
536
|
+
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
value_states = value_states[:, :, :, None, :].expand(
|
|
540
|
+
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
|
541
|
+
)
|
|
542
|
+
value_states = value_states.reshape(
|
|
543
|
+
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Attention here is upcasted to float32 to match the original eager implementation.
|
|
547
|
+
|
|
548
|
+
query_states = query_states.to(dtype=torch.float32)
|
|
549
|
+
key_states = key_states.to(dtype=torch.float32)
|
|
550
|
+
|
|
551
|
+
query_states = query_states.transpose(1, 2)
|
|
552
|
+
key_states = key_states.transpose(1, 2)
|
|
553
|
+
|
|
554
|
+
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
|
555
|
+
att_weights *= head_dim**-0.5
|
|
556
|
+
big_neg = -2.3819763e38 # See gemma/modules.py
|
|
557
|
+
|
|
558
|
+
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
|
559
|
+
|
|
560
|
+
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
|
561
|
+
probs = probs.to(dtype=value_states.dtype)
|
|
562
|
+
|
|
563
|
+
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
|
|
564
|
+
# value_states: batch_size, sequence_length, num_att_heads, head_dim
|
|
565
|
+
|
|
566
|
+
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
|
567
|
+
|
|
568
|
+
att_output = att_output.permute(0, 2, 1, 3)
|
|
569
|
+
# we use -1 because sequence length can change
|
|
570
|
+
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
|
571
|
+
|
|
572
|
+
return att_output
|