opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1257 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """π05: A Vision-Language-Action Flow Model for General Robot Control
19
+
20
+ [Paper](https://www.physicalintelligence.company/download/pi05.pdf)
21
+ """
22
+
23
+ import builtins
24
+ import logging
25
+ import math
26
+ from collections import deque
27
+ from pathlib import Path
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn.functional as F # noqa: N812
32
+ from einops import rearrange
33
+ from torch import Tensor, nn
34
+ from transformers import AutoProcessor, AutoTokenizer
35
+
36
+ from opentau.configs.policies import PreTrainedConfig
37
+ from opentau.configs.types import NormalizationMode
38
+ from opentau.policies.normalize import Normalize, Unnormalize
39
+ from opentau.policies.pi05.configuration_pi05 import PI05Config
40
+ from opentau.policies.pi05.paligemma_with_expert import (
41
+ PaliGemmaWithExpertConfig,
42
+ PaliGemmaWithExpertModel,
43
+ )
44
+ from opentau.policies.pretrained import PreTrainedPolicy, T
45
+ from opentau.utils.utils import get_safe_dtype
46
+
47
+
48
+ def create_sinusoidal_pos_embedding(
49
+ time: Tensor, dimension: int, min_period: float, max_period: float, device: torch.device | str = "cpu"
50
+ ) -> Tensor:
51
+ """Computes sine-cosine positional embedding vectors for scalar positions.
52
+
53
+ Args:
54
+ time: A 1-D tensor of shape (batch_size,).
55
+ dimension: The dimension of the embedding vectors. Must be divisible by 2.
56
+ min_period: The minimum period of the sinusoidal functions.
57
+ max_period: The maximum period of the sinusoidal functions.
58
+ device: The device to create the tensors on. Defaults to "cpu".
59
+
60
+ Returns:
61
+ A tensor of shape (batch_size, dimension) containing the positional embeddings.
62
+
63
+ Raises:
64
+ ValueError: If dimension is not divisible by 2 or if time tensor is not 1-D.
65
+ """
66
+ if dimension % 2 != 0:
67
+ raise ValueError(f"dimension ({dimension}) must be divisible by 2")
68
+
69
+ if time.ndim != 1:
70
+ raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
71
+
72
+ dtype = (
73
+ get_safe_dtype(torch.float64, device.type)
74
+ if isinstance(device, torch.device)
75
+ else get_safe_dtype(torch.float64, device)
76
+ )
77
+ fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
78
+ period = min_period * (max_period / min_period) ** fraction
79
+
80
+ # Compute the outer product
81
+ scaling_factor = 1.0 / period * 2 * math.pi
82
+ sin_input = scaling_factor[None, :] * time[:, None]
83
+ pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
84
+ return pos_emb
85
+
86
+
87
+ def make_att_2d_masks(
88
+ pad_masks: Tensor,
89
+ att_masks: Tensor,
90
+ n_cross_att_tokens: int | None = None,
91
+ cross_att_pad_masks: Tensor | None = None,
92
+ ) -> Tensor:
93
+ """Creates a 2-D attention mask given padding and 1-D attention masks.
94
+
95
+ Tokens can attend to valid inputs tokens which have a cumulative `att_masks`
96
+ smaller or equal to theirs. This way `att_masks` int[B, N] can be used to
97
+ setup several types of attention, for example:
98
+
99
+ [[1 1 1 1 1 1]]: pure causal attention.
100
+
101
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
102
+ themselves and the last 3 tokens have a causal attention. The first
103
+ entry could also be a 1 without changing behaviour.
104
+
105
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
106
+ block can attend all previous blocks and all tokens on the same block.
107
+
108
+ Args:
109
+ pad_masks: bool[B, N] true if its part of the input, false if padding.
110
+ att_masks: int32[B, N] mask that's 1 where previous tokens cannot depend on
111
+ it and 0 where it shares the same attention mask as the previous token.
112
+ n_cross_att_tokens: Add attention mask for cross-attention tokens if
113
+ `n_cross_att_tokens` is provided.
114
+ cross_att_pad_masks: Padding masks for cross attention tokens. Required if
115
+ `n_cross_att_tokens` is provided.
116
+
117
+ Returns:
118
+ A 2D attention mask tensor of shape (B, N + n_cross_att_tokens, N + n_cross_att_tokens)
119
+ if n_cross_att_tokens is provided, else (B, N, N).
120
+
121
+ Raises:
122
+ ValueError: If att_masks or pad_masks are not 2D (including batch dimension).
123
+ AssertionError: If cross_att_pad_masks is missing when n_cross_att_tokens is set,
124
+ or if its shape is incorrect.
125
+ """
126
+ if att_masks.ndim != 2:
127
+ raise ValueError(att_masks.ndim)
128
+ if pad_masks.ndim != 2:
129
+ raise ValueError(pad_masks.ndim)
130
+
131
+ cumsum = torch.cumsum(att_masks, dim=1)
132
+ att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
133
+ pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
134
+ att_2d_masks = att_2d_masks & pad_2d_masks
135
+
136
+ # If `n_cross_att_tokens` is provided, we add a mask for cross-attention tokens at the end of the sequence.
137
+ if n_cross_att_tokens is not None:
138
+ assert cross_att_pad_masks is not None, (
139
+ "cross_att_pad_masks must be provided if n_cross_att_tokens is provided"
140
+ )
141
+ assert cross_att_pad_masks.shape == (att_masks.size(0), n_cross_att_tokens), (
142
+ "cross_att_pad_masks must have shape (batch_size, n_cross_att_tokens)"
143
+ )
144
+
145
+ cross_att_mask = torch.full(
146
+ (att_masks.size(0), att_masks.size(1), n_cross_att_tokens),
147
+ True,
148
+ dtype=torch.bool,
149
+ device=att_masks.device,
150
+ )
151
+
152
+ # Apply padding masks: pad_masks for rows, cross_att_pad_masks for columns
153
+ cross_att_mask = cross_att_mask & pad_masks[:, :, None] & cross_att_pad_masks[:, None, :]
154
+
155
+ att_2d_masks = torch.cat((att_2d_masks, cross_att_mask), dim=2)
156
+
157
+ return att_2d_masks
158
+
159
+
160
+ def resize_with_pad(img: Tensor, width: int, height: int, pad_value: int = -1) -> Tensor:
161
+ """Resizes an image to fit within the specified dimensions while maintaining aspect ratio,
162
+ and pads the remaining area with the specified value.
163
+
164
+ Args:
165
+ img: Input image tensor of shape (batch_size, channels, current_height, current_width).
166
+ width: Target width.
167
+ height: Target height.
168
+ pad_value: Value to use for padding. Defaults to -1.
169
+
170
+ Returns:
171
+ The resized and padded image tensor of shape (batch_size, channels, height, width).
172
+
173
+ Raises:
174
+ ValueError: If the input image tensor does not have 4 dimensions.
175
+ """
176
+ # assume no-op when width height fits already
177
+ if img.ndim != 4:
178
+ raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
179
+
180
+ cur_height, cur_width = img.shape[2:]
181
+
182
+ ratio = max(cur_width / width, cur_height / height)
183
+ resized_height = int(cur_height / ratio)
184
+ resized_width = int(cur_width / ratio)
185
+ resized_img = F.interpolate(
186
+ img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
187
+ )
188
+
189
+ pad_height = max(0, int(height - resized_height))
190
+ pad_width = max(0, int(width - resized_width))
191
+
192
+ # pad on left and top of image
193
+ padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
194
+ return padded_img
195
+
196
+
197
+ def pad_vector(vector: Tensor, new_dim: int) -> Tensor:
198
+ """Pads the last dimension of a vector to a new size with zeros.
199
+
200
+ Args:
201
+ vector: Input tensor. Can be (batch_size x sequence_length x features_dimension)
202
+ or (batch_size x features_dimension).
203
+ new_dim: The new size for the last dimension.
204
+
205
+ Returns:
206
+ The padded tensor.
207
+ """
208
+ if vector.shape[-1] == new_dim:
209
+ return vector
210
+ shape = list(vector.shape)
211
+ current_dim = shape[-1]
212
+ shape[-1] = new_dim
213
+ new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
214
+ new_vector[..., :current_dim] = vector
215
+ return new_vector
216
+
217
+
218
+ def pad_discrete_tokens(tokens: list[list[int]], max_length: int) -> tuple[np.ndarray, np.ndarray]:
219
+ """Pads or truncates a list of discrete action token sequences to a fixed length.
220
+
221
+ Args:
222
+ tokens: A list of discrete action token sequences (lists of integers).
223
+ max_length: The target length for the discrete action token sequences.
224
+
225
+ Returns:
226
+ A tuple containing:
227
+ - discrete_action_tokens: A numpy array of shape (len(tokens), max_length) containing the padded discrete action tokens.
228
+ - discrete_action_masks: A boolean numpy array of shape (len(tokens), max_length) indicating valid discrete action tokens (True) and padding (False).
229
+ """
230
+ discrete_action_tokens = []
231
+ discrete_action_masks = []
232
+ for token in tokens:
233
+ if len(token) > max_length:
234
+ discrete_action_tokens.append(np.array(token[:max_length]))
235
+ discrete_action_masks.append(np.ones(max_length, dtype=bool))
236
+ else:
237
+ discrete_action_masks.append(
238
+ np.concatenate(
239
+ [np.ones(len(token), dtype=bool), np.zeros(max_length - len(token), dtype=bool)]
240
+ )
241
+ )
242
+ discrete_action_tokens.append(np.pad(token, (0, max_length - len(token)), constant_values=0))
243
+ return np.array(discrete_action_tokens), np.array(discrete_action_masks)
244
+
245
+
246
+ class PI05Policy(PreTrainedPolicy):
247
+ """Wrapper class around PI05FlowMatching model to train and run inference within OpenTau."""
248
+
249
+ config_class = PI05Config
250
+ name = "pi05"
251
+
252
+ def __init__(
253
+ self,
254
+ config: PI05Config,
255
+ dataset_stats: dict[str, dict[str, Tensor]] | None = None,
256
+ ):
257
+ """Initializes the PI05Policy.
258
+
259
+ Args:
260
+ config: Policy configuration class instance or None, in which case the default instantiation of
261
+ the configuration class is used.
262
+ dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
263
+ that they will be passed with a call to `load_state_dict` before the policy is used.
264
+ """
265
+
266
+ super().__init__(config)
267
+ config.validate_features()
268
+ self.config = config
269
+ self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
270
+ self.normalize_targets = Normalize(
271
+ config.output_features, config.normalization_mapping, dataset_stats
272
+ )
273
+ self.normalize_actions = Normalize(
274
+ config.output_features, {"ACTION": NormalizationMode.MIN_MAX}, dataset_stats
275
+ )
276
+ self.unnormalize_outputs = Unnormalize(
277
+ config.output_features, config.normalization_mapping, dataset_stats
278
+ )
279
+
280
+ self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
281
+
282
+ self.discrete_action_processor = AutoProcessor.from_pretrained(
283
+ "physical-intelligence/fast", trust_remote_code=True
284
+ )
285
+ # Get vocab size from processor
286
+ discrete_action_vocab_size = getattr(self.discrete_action_processor, "vocab_size", None)
287
+ self.model = PI05FlowMatching(config, discrete_action_vocab_size=discrete_action_vocab_size)
288
+
289
+ self.reset()
290
+
291
+ def reset(self) -> None:
292
+ """This should be called whenever the environment is reset."""
293
+ self._action_queue = deque([], maxlen=self.config.n_action_steps)
294
+
295
+ @classmethod
296
+ def from_pretrained(
297
+ cls: builtins.type[T],
298
+ pretrained_name_or_path: str | Path,
299
+ *,
300
+ config: PreTrainedConfig | None = None,
301
+ force_download: bool = False,
302
+ resume_download: bool | None = None,
303
+ proxies: dict | None = None,
304
+ token: str | bool | None = None,
305
+ cache_dir: str | Path | None = None,
306
+ local_files_only: bool = False,
307
+ revision: str | None = None,
308
+ strict: bool = True,
309
+ **kwargs,
310
+ ) -> T:
311
+ """Override the from_pretrained method to handle key remapping.
312
+
313
+ Args:
314
+ pretrained_name_or_path: Path to the pretrained model or its name on the Hub.
315
+ config: Configuration object.
316
+ force_download: Whether to force download the model weights.
317
+ resume_download: Whether to resume download.
318
+ proxies: Proxy configuration.
319
+ token: Authentication token.
320
+ cache_dir: Directory to cache downloaded files.
321
+ local_files_only: Whether to only look for files locally.
322
+ revision: Specific model revision.
323
+ strict: Whether to strictly enforce state dict matching.
324
+ **kwargs: Additional keyword arguments.
325
+
326
+ Returns:
327
+ The loaded model instance.
328
+
329
+ Raises:
330
+ ValueError: If pretrained_name_or_path is None.
331
+ """
332
+ if pretrained_name_or_path is None:
333
+ raise ValueError("pretrained_name_or_path is required")
334
+
335
+ # Use provided config if available, otherwise create default config
336
+ if config is None:
337
+ config = PreTrainedConfig.from_pretrained(
338
+ pretrained_name_or_path=pretrained_name_or_path,
339
+ force_download=force_download,
340
+ resume_download=resume_download,
341
+ proxies=proxies,
342
+ token=token,
343
+ cache_dir=cache_dir,
344
+ local_files_only=local_files_only,
345
+ revision=revision,
346
+ **kwargs,
347
+ )
348
+
349
+ # Initialize model without loading weights
350
+ # Check if dataset_stats were provided in kwargs
351
+ model = cls(config, **kwargs)
352
+
353
+ # Now manually load and remap the state dict
354
+ try:
355
+ # Try to load the pytorch_model.bin or model.safetensors file
356
+ print(f"Loading model from: {pretrained_name_or_path}")
357
+ try:
358
+ from transformers.utils import cached_file
359
+
360
+ # Try safetensors first
361
+ resolved_file = cached_file(
362
+ pretrained_name_or_path,
363
+ "model.safetensors",
364
+ cache_dir=kwargs.get("cache_dir"),
365
+ force_download=kwargs.get("force_download", False),
366
+ resume_download=kwargs.get("resume_download"),
367
+ proxies=kwargs.get("proxies"),
368
+ use_auth_token=kwargs.get("use_auth_token"),
369
+ revision=kwargs.get("revision"),
370
+ local_files_only=kwargs.get("local_files_only", False),
371
+ )
372
+ from safetensors.torch import load_file
373
+
374
+ original_state_dict = load_file(resolved_file)
375
+ print("✓ Loaded state dict from model.safetensors")
376
+ except Exception as e:
377
+ print(f"Could not load state dict from remote files: {e}")
378
+ print("Returning model without loading pretrained weights")
379
+ return model
380
+
381
+ # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
382
+ fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
383
+
384
+ # Then add "model." prefix for all keys that don't already have it
385
+ remapped_state_dict = {}
386
+ remap_count = 0
387
+
388
+ for key, value in fixed_state_dict.items():
389
+ if not key.startswith("model.") and "normalize" not in key:
390
+ new_key = f"model.{key}"
391
+ remapped_state_dict[new_key] = value
392
+ remap_count += 1
393
+ if remap_count <= 10: # Only print first 10 to avoid spam
394
+ print(f"Remapped: {key} -> {new_key}")
395
+ else:
396
+ remapped_state_dict[key] = value
397
+
398
+ if remap_count > 0:
399
+ print(f"Remapped {remap_count} state dict keys")
400
+
401
+ # Load the remapped state dict into the model
402
+ missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False)
403
+
404
+ if missing_keys:
405
+ print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
406
+ if len(missing_keys) <= 20:
407
+ for key in missing_keys:
408
+ print(f" - {key}")
409
+ else:
410
+ for key in missing_keys[:20]:
411
+ print(f" - {key}")
412
+ print(f" ... and {len(missing_keys) - 20} more")
413
+
414
+ if unexpected_keys:
415
+ print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
416
+ if len(unexpected_keys) <= 20:
417
+ for key in unexpected_keys:
418
+ print(f" - {key}")
419
+ else:
420
+ for key in unexpected_keys[:20]:
421
+ print(f" - {key}")
422
+ print(f" ... and {len(unexpected_keys) - 20} more")
423
+
424
+ if not missing_keys and not unexpected_keys:
425
+ print("All keys loaded successfully!")
426
+
427
+ except Exception as e:
428
+ print(f"Warning: Could not remap state dict keys: {e}")
429
+
430
+ return model
431
+
432
+ def _fix_pytorch_state_dict_keys(
433
+ self, state_dict: dict[str, Tensor], model_config: PreTrainedConfig
434
+ ) -> dict[str, Tensor]: # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys`
435
+ """Fix state dict keys to match current model architecture.
436
+
437
+ Args:
438
+ state_dict: The state dictionary to fix.
439
+ model_config: The model configuration.
440
+
441
+ Returns:
442
+ The fixed state dictionary.
443
+ """
444
+ import re
445
+
446
+ fixed_state_dict = {}
447
+
448
+ for key, value in state_dict.items():
449
+ new_key = key
450
+
451
+ # Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias
452
+ # For gemma expert layers
453
+ if re.match(
454
+ r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight",
455
+ key,
456
+ ):
457
+ # Check if the model actually has adaRMS enabled for the expert
458
+ expert_uses_adarms = getattr(
459
+ self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
460
+ )
461
+ if expert_uses_adarms:
462
+ logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}")
463
+ continue
464
+
465
+ if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key):
466
+ # Check if the model actually has adaRMS enabled for the expert
467
+ expert_uses_adarms = getattr(
468
+ self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
469
+ )
470
+ if expert_uses_adarms:
471
+ logging.warning(f"Skipping norm key (adaRMS mismatch): {key}")
472
+ continue
473
+
474
+ # Handle MLP naming changes for pi05
475
+ # pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_*
476
+ if key.startswith("action_time_mlp_in."):
477
+ new_key = key.replace("action_time_mlp_in.", "time_mlp_in.")
478
+ elif key.startswith("action_time_mlp_out."):
479
+ new_key = key.replace("action_time_mlp_out.", "time_mlp_out.")
480
+ # Also handle state_proj which shouldn't exist in pi05
481
+ if key.startswith("state_proj."):
482
+ logging.warning(f"Skipping state_proj key in pi05 mode: {key}")
483
+ continue
484
+
485
+ # Handle vision tower embedding layer potential differences
486
+ if "patch_embedding" in key:
487
+ # Some checkpoints might have this, but current model expects different structure
488
+ logging.warning(f"Vision embedding key might need handling: {key}")
489
+
490
+ fixed_state_dict[new_key] = value
491
+
492
+ return fixed_state_dict
493
+
494
+ def get_optim_params(self) -> dict:
495
+ """Returns the parameters to be optimized.
496
+
497
+ Returns:
498
+ A generator over the model parameters.
499
+ """
500
+ return self.parameters()
501
+
502
+ @torch.no_grad()
503
+ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
504
+ """Predict a chunk of actions given environment observations.
505
+
506
+ Args:
507
+ batch: Batch of data containing environment observations.
508
+
509
+ Returns:
510
+ The predicted action chunk.
511
+
512
+ Raises:
513
+ NotImplementedError: Always, as this method is not implemented for PI05.
514
+ """
515
+ raise NotImplementedError("Currently not implemented for PI05")
516
+
517
+ @torch.no_grad()
518
+ def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
519
+ """Select a single action given environment observations.
520
+
521
+ This method wraps `select_actions` in order to return one action at a time for execution in the
522
+ environment. It works by managing the actions in a queue and only calling `select_actions` when the
523
+ queue is empty.
524
+
525
+ Args:
526
+ batch: Batch of data containing environment observations.
527
+ noise: Optional noise tensor to be used during sampling.
528
+
529
+ Returns:
530
+ The selected action tensor.
531
+ """
532
+ self.eval()
533
+
534
+ # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
535
+ # querying the policy.
536
+ if len(self._action_queue) == 0:
537
+ actions = self.sample_actions(batch, noise=noise)
538
+ self._action_queue.extend(actions)
539
+ return self._action_queue.popleft()
540
+
541
+ @torch.no_grad()
542
+ def sample_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
543
+ """Sample actions from the policy given environment observations.
544
+
545
+ Args:
546
+ batch: Batch of data containing environment observations.
547
+ noise: Optional noise tensor.
548
+
549
+ Returns:
550
+ The sampled actions tensor of shape (batch_size, action_dim).
551
+ """
552
+ batch = self.normalize_inputs(batch)
553
+
554
+ images, img_masks = self.prepare_images(batch)
555
+ lang_tokens, lang_masks = self.prepare_language(batch)
556
+
557
+ actions = self.model.sample_actions(
558
+ images,
559
+ img_masks,
560
+ lang_tokens,
561
+ lang_masks,
562
+ noise=noise,
563
+ )
564
+
565
+ # Unpad actions
566
+ original_action_dim = self.config.action_feature.shape[0]
567
+ actions = actions[:, :, :original_action_dim]
568
+
569
+ actions = self.unnormalize_outputs({"actions": actions})["actions"]
570
+
571
+ # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
572
+ # effectively has shape (n_action_steps, batch_size, *), hence the transpose.
573
+ actions = actions.transpose(0, 1)
574
+ return actions
575
+
576
+ def forward(
577
+ self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None
578
+ ) -> dict[str, Tensor]:
579
+ """Do a full training forward pass to compute the loss.
580
+
581
+ Args:
582
+ batch: Batch of data containing environment observations, actions, and targets.
583
+ noise: Optional noise tensor.
584
+ time: Optional time tensor.
585
+
586
+ Returns:
587
+ A dictionary containing the loss components ("MSE" and "CE").
588
+ """
589
+ batch = self.normalize_inputs(batch)
590
+ batch["discrete_actions"] = self.normalize_actions(dict(batch))["actions"]
591
+ batch = self.normalize_targets(batch)
592
+
593
+ images, img_masks = self.prepare_images(
594
+ batch
595
+ ) # in img_masks we have True for real images and False for padded images
596
+ lang_tokens, lang_masks = self.prepare_language(
597
+ batch
598
+ ) # in lang_masks we have True for real tokens and False for padded tokens
599
+ discrete_actions, discrete_action_masks = self.prepare_discrete_actions(
600
+ batch
601
+ ) # in discrete_action_masks we have True for real tokens and False for padded tokens
602
+ actions = batch["actions"]
603
+ actions_is_pad = batch.get(
604
+ "action_is_pad"
605
+ ) # in actions_is_pad we have False for real actions and True for padded actions
606
+
607
+ losses = self.model.forward(
608
+ images,
609
+ img_masks,
610
+ lang_tokens,
611
+ lang_masks,
612
+ actions,
613
+ noise,
614
+ time,
615
+ discrete_actions,
616
+ discrete_action_masks,
617
+ )
618
+
619
+ mse_loss = losses["MSE"]
620
+ ce_loss = losses["CE"]
621
+ if actions_is_pad is not None:
622
+ in_episode_bound = ~actions_is_pad
623
+ mse_loss = mse_loss * in_episode_bound.unsqueeze(-1)
624
+
625
+ # Remove padding
626
+ mse_loss = mse_loss[:, :, : self.config.max_action_dim]
627
+
628
+ # For backward pass
629
+ loss = mse_loss.mean()
630
+
631
+ return {"MSE": loss, "CE": ce_loss}
632
+
633
+ def prepare_discrete_state(self, batch: dict[str, Tensor]) -> list[str]:
634
+ """Discretizes the state into bins and converts it to a string representation.
635
+
636
+ Each dimension of the state vector is discretized into 256 bins.
637
+ The values of each dimension of the state are expected to be in the range [-1, 1].
638
+ The discretization bins are linearly spaced between -1 and 1.
639
+ The index of the bin for each dimension is then concatenated into a space-separated string.
640
+
641
+ Args:
642
+ batch: Batch of data containing the "state" tensor.
643
+
644
+ Returns:
645
+ A list of strings, where each string is a space-separated list of discretized state values.
646
+
647
+ Raises:
648
+ ValueError: If the state values are not normalized between -1 and 1.
649
+ """
650
+ state = batch["state"]
651
+ state_np = state.to(device="cpu", dtype=torch.float32).numpy()
652
+ if np.any(state_np < -1.0) or np.any(state_np > 1.0):
653
+ logging.warning(
654
+ f"State values are not normalized between -1 and 1. Min: {state_np.min()}, Max: {state_np.max()}"
655
+ )
656
+ state_np = np.clip(state_np, -1.0, 1.0)
657
+ discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
658
+ return [
659
+ " ".join(map(str, row)) for row in discretized_states
660
+ ] # TODO: return a tensor instead of a list of strings?
661
+
662
+ def prepare_discrete_actions(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
663
+ """Prepares discrete actions for the model by tokenizing and padding them.
664
+
665
+ Args:
666
+ batch: Batch of data containing the key "discrete_actions".
667
+
668
+ Returns:
669
+ A tuple containing:
670
+ - discrete_action_tokens: A tensor of shape (batch_size, max_length) containing the tokenized actions.
671
+ - discrete_action_masks: A tensor of shape (batch_size, max_length) indicating valid tokens.
672
+ """
673
+ device = batch["discrete_actions"].device
674
+ discrete_actions = batch["discrete_actions"].to(device="cpu", dtype=torch.float32)
675
+ tokens = self.discrete_action_processor.__call__(discrete_actions)
676
+ discrete_action_tokens, discrete_action_masks = pad_discrete_tokens(
677
+ tokens, self.config.discrete_action_max_length
678
+ )
679
+ return torch.from_numpy(discrete_action_tokens).to(device=device, dtype=torch.long), torch.from_numpy(
680
+ discrete_action_masks
681
+ ).to(device=device, dtype=torch.bool)
682
+
683
+ def prepare_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
684
+ """Apply preprocessing to the images.
685
+
686
+ Resizes to 224x224 and padding to keep aspect ratio, and converts pixel range
687
+ from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
688
+
689
+ Args:
690
+ batch: Batch of data containing image tensors.
691
+
692
+ Returns:
693
+ A tuple containing:
694
+ - images: A list of processed image tensors.
695
+ - img_masks: A list of image mask tensors.
696
+
697
+ Raises:
698
+ ValueError: If no image features are present in the batch.
699
+ """
700
+ images = []
701
+ img_masks = []
702
+
703
+ present_img_keys = [key for key in self.config.image_features if key in batch]
704
+ missing_img_keys = [key for key in self.config.image_features if key not in batch]
705
+
706
+ if len(present_img_keys) == 0:
707
+ raise ValueError(
708
+ f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
709
+ )
710
+
711
+ # Preprocess image features present in the batch
712
+ for key in present_img_keys:
713
+ img = batch[key]
714
+
715
+ if self.config.resize_imgs_with_padding is not None:
716
+ img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
717
+
718
+ # Normalize from range [0,1] to [-1,1] as expected by siglip
719
+ img = img * 2.0 - 1.0
720
+
721
+ bsize = img.shape[0]
722
+ device = img.device
723
+ mask = torch.ones(bsize, dtype=torch.bool, device=device)
724
+ images.append(img)
725
+ img_masks.append(mask)
726
+
727
+ # Create image features not present in the batch
728
+ # as fully 0 padded images.
729
+ for num_empty_cameras in range(len(missing_img_keys)):
730
+ if num_empty_cameras >= self.config.empty_cameras:
731
+ break
732
+ img = torch.ones_like(img) * -1
733
+ mask = torch.zeros_like(mask)
734
+ images.append(img)
735
+ img_masks.append(mask)
736
+
737
+ return images, img_masks
738
+
739
+ def prepare_language(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
740
+ """Tokenize the text input.
741
+
742
+ The state is already expected to be discretized into a space-separated string.
743
+
744
+ Args:
745
+ batch: Batch of data containing the key "prompt" and "state".
746
+
747
+ Returns:
748
+ A tuple containing:
749
+ - lang_tokens: Tensor of language tokens.
750
+ - lang_masks: Tensor of language attention masks.
751
+ """
752
+ device = batch["state"].device
753
+ tasks = batch["prompt"]
754
+
755
+ # PaliGemma prompt has to end with a new line
756
+ tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
757
+
758
+ # add state to the prompt
759
+ state = self.prepare_discrete_state(batch)
760
+ prompt = [f"Task: {task}State: {state}\nActions:" for task, state in zip(tasks, state, strict=False)]
761
+
762
+ tokenized_prompt = self.language_tokenizer.__call__(
763
+ prompt,
764
+ padding="max_length",
765
+ padding_side="right",
766
+ max_length=self.config.tokenizer_max_length,
767
+ return_tensors="pt",
768
+ truncation=True,
769
+ )
770
+ lang_tokens = tokenized_prompt["input_ids"].to(device=device)
771
+ lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
772
+
773
+ return lang_tokens, lang_masks
774
+
775
+
776
+ class PI05FlowMatching(nn.Module):
777
+ """
778
+ π05: A Vision-Language-Action Flow Model for General Robot Control
779
+
780
+ [Paper](https://www.physicalintelligence.company/download/pi05.pdf)
781
+
782
+ ┌──────────────────────────────────────────┐
783
+ │ actions │
784
+ │ ▲ │
785
+ │ ┌┴─────┐ │
786
+ │ kv cache │Gemma │ │
787
+ │ ┌──────────►│Expert│ │
788
+ │ │ │ │ │
789
+ │ ┌┴─────────┐ │x 10 │ │
790
+ │ │ │ └▲─────┘ │
791
+ │ │PaliGemma │ │ │
792
+ │ │ │ noise │
793
+ │ └▲──▲──▲──▲ │
794
+ │ │ │ │ └── discrete actions │
795
+ │ │ │ └───── robot state │
796
+ │ │ └──────── language tokens │
797
+ │ └─────────── image(s) │
798
+ └──────────────────────────────────────────┘
799
+ """
800
+
801
+ def __init__(self, config: PI05Config, discrete_action_vocab_size: int | None = None):
802
+ """Initializes the PI05FlowMatching model.
803
+
804
+ Args:
805
+ config: Model configuration.
806
+ discrete_action_vocab_size: Size of the discrete action vocabulary.
807
+ """
808
+ super().__init__()
809
+ self.config = config
810
+
811
+ load_pretrained_paligemma = (
812
+ self.config.init_strategy == "expert_only_he_init"
813
+ ) # only load pretrained paligemma if we are He-initializing the expert only
814
+ paligemma_with_export_config = PaliGemmaWithExpertConfig(
815
+ freeze_vision_encoder=self.config.freeze_vision_encoder,
816
+ train_expert_only=self.config.train_expert_only,
817
+ attention_implementation=self.config.attention_implementation,
818
+ load_pretrained_paligemma=load_pretrained_paligemma,
819
+ discrete_action_vocab_size=discrete_action_vocab_size,
820
+ dropout=self.config.dropout,
821
+ )
822
+ self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
823
+
824
+ # Projections are float32
825
+ self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
826
+ self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
827
+
828
+ self.time_mlp_in = nn.Linear(self.config.proj_width, self.config.proj_width)
829
+ self.time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
830
+
831
+ self._init_model()
832
+
833
+ def _init_weights(self, module: nn.Module) -> None:
834
+ """Initialize weights using He (Kaiming) initialization.
835
+
836
+ Args:
837
+ module: The module to initialize.
838
+ """
839
+ if isinstance(module, nn.Linear):
840
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
841
+ if module.bias is not None:
842
+ nn.init.zeros_(module.bias)
843
+ elif isinstance(module, nn.LayerNorm):
844
+ nn.init.ones_(module.weight)
845
+ nn.init.zeros_(module.bias)
846
+
847
+ def _init_model(self) -> None:
848
+ """Initialize the model weights based on the configuration."""
849
+ if self.config.init_strategy == "no_init":
850
+ return
851
+ elif self.config.init_strategy == "full_he_init":
852
+ for m in self.modules():
853
+ self._init_weights(m)
854
+ elif self.config.init_strategy == "expert_only_he_init":
855
+ for m in self.paligemma_with_expert.gemma_expert.modules():
856
+ self._init_weights(m)
857
+ else:
858
+ raise ValueError(f"Invalid init strategy: {self.config.init_strategy}")
859
+
860
+ def sample_noise(self, shape: tuple[int, ...], device: torch.device | str) -> Tensor:
861
+ """Samples Gaussian noise.
862
+
863
+ Args:
864
+ shape: The shape of the noise tensor.
865
+ device: The device to create the tensor on.
866
+
867
+ Returns:
868
+ A tensor containing the sampled noise.
869
+ """
870
+ noise = torch.normal(
871
+ mean=0.0,
872
+ std=1.0,
873
+ size=shape,
874
+ dtype=torch.float32,
875
+ device=device,
876
+ )
877
+ return noise
878
+
879
+ def sample_time(self, bsize: int, device: torch.device | str) -> Tensor:
880
+ """Samples time steps from a Beta distribution.
881
+
882
+ Args:
883
+ bsize: Batch size.
884
+ device: The device to create the tensor on.
885
+
886
+ Returns:
887
+ A tensor containing the sampled time steps.
888
+ """
889
+ beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
890
+ time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
891
+ time = time_beta * 0.999 + 0.001
892
+ return time
893
+
894
+ def embed_prefix(
895
+ self,
896
+ images: list[Tensor],
897
+ img_masks: list[Tensor],
898
+ lang_tokens: Tensor,
899
+ lang_masks: Tensor,
900
+ discrete_actions: Tensor | None = None,
901
+ discrete_action_masks: Tensor | None = None,
902
+ ) -> tuple[Tensor, Tensor, Tensor]:
903
+ """Embed images with SigLIP and language tokens with embedding layer to prepare
904
+ for PaliGemma transformer processing.
905
+
906
+ Args:
907
+ images: List of image tensors.
908
+ img_masks: List of image mask tensors.
909
+ lang_tokens: Language token tensor.
910
+ lang_masks: Language mask tensor.
911
+ discrete_actions: Optional discrete action tensor.
912
+ discrete_action_masks: Optional discrete action mask tensor.
913
+
914
+ Returns:
915
+ A tuple containing:
916
+ - embs: Concatenated embeddings tensor.
917
+ - pad_masks: Concatenated padding masks tensor.
918
+ - att_masks: Attention masks tensor.
919
+ """
920
+ # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
921
+ embs = []
922
+ pad_masks = []
923
+ att_masks = []
924
+
925
+ # TODO: remove for loop
926
+ for (
927
+ img,
928
+ img_mask,
929
+ ) in zip(images, img_masks, strict=False):
930
+ img_emb = self.paligemma_with_expert.embed_image(img)
931
+ img_emb = img_emb.to(dtype=torch.bfloat16)
932
+
933
+ # image embeddings don't need to be unnormalized because `fix/lerobot_openpi` branch of huggingface
934
+ # already removed the normalization inside PaliGemma
935
+ pass
936
+
937
+ bsize, num_img_embs = img_emb.shape[:2]
938
+ img_mask = img_mask[:, None].expand(bsize, num_img_embs)
939
+
940
+ embs.append(img_emb)
941
+ pad_masks.append(img_mask)
942
+
943
+ # Create attention masks so that image tokens attend to each other
944
+ att_masks += [0] * num_img_embs
945
+
946
+ lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
947
+
948
+ # Normalize language embeddings
949
+ lang_emb_dim = lang_emb.shape[-1]
950
+ lang_emb = lang_emb * math.sqrt(lang_emb_dim)
951
+
952
+ embs.append(lang_emb)
953
+ pad_masks.append(lang_masks)
954
+
955
+ # full attention between image and language inputs
956
+ num_lang_embs = lang_emb.shape[1]
957
+ att_masks += [0] * num_lang_embs
958
+
959
+ if discrete_actions is not None:
960
+ discrete_action_emb = self.paligemma_with_expert.embed_discrete_actions(discrete_actions)
961
+ embs.append(discrete_action_emb.to(dtype=torch.bfloat16))
962
+ pad_masks.append(discrete_action_masks)
963
+ att_masks += [1] * discrete_action_emb.shape[1]
964
+
965
+ embs = torch.cat(embs, dim=1)
966
+ pad_masks = torch.cat(pad_masks, dim=1)
967
+ att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
968
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
969
+
970
+ return embs, pad_masks, att_masks
971
+
972
+ def embed_suffix(self, noisy_actions: Tensor, timestep: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
973
+ """Embed noisy_actions, timestep to prepare for Expert Gemma processing.
974
+
975
+ Args:
976
+ noisy_actions: Tensor containing noisy actions.
977
+ timestep: Tensor containing timesteps.
978
+
979
+ Returns:
980
+ A tuple containing:
981
+ - embs: Concatenated embeddings tensor.
982
+ - pad_masks: Concatenated padding masks tensor.
983
+ - att_masks: Attention masks tensor.
984
+ - adarms_cond: AdaRMS conditioning tensor.
985
+ """
986
+ embs = []
987
+ pad_masks = []
988
+ att_masks = []
989
+
990
+ bsize = noisy_actions.shape[0]
991
+ dtype = torch.bfloat16
992
+ device = noisy_actions.device
993
+
994
+ # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
995
+ time_emb = create_sinusoidal_pos_embedding(
996
+ timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
997
+ )
998
+
999
+ # Fuse timestep + action information using an MLP
1000
+ noisy_actions = noisy_actions.to(dtype=dtype)
1001
+ action_emb = self.action_in_proj(noisy_actions)
1002
+
1003
+ def time_mlp_func(time_emb):
1004
+ x = self.time_mlp_in(time_emb)
1005
+ x = F.silu(x)
1006
+ x = self.time_mlp_out(x)
1007
+ return F.silu(x)
1008
+
1009
+ time_emb = time_emb.to(dtype=dtype)
1010
+ adarms_cond = time_mlp_func(time_emb)
1011
+
1012
+ # Add to input tokens
1013
+ embs.append(action_emb)
1014
+
1015
+ bsize, action_dim = action_emb.shape[:2]
1016
+ action_mask = torch.ones(bsize, action_dim, dtype=torch.bool, device=device)
1017
+ pad_masks.append(action_mask)
1018
+
1019
+ # Set attention masks so that image, language and state inputs do not attend to action tokens
1020
+ att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
1021
+
1022
+ embs = torch.cat(embs, dim=1)
1023
+ pad_masks = torch.cat(pad_masks, dim=1)
1024
+ att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
1025
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
1026
+
1027
+ return embs, pad_masks, att_masks, adarms_cond
1028
+
1029
+ def forward(
1030
+ self,
1031
+ images: list[Tensor],
1032
+ img_masks: list[Tensor],
1033
+ lang_tokens: Tensor,
1034
+ lang_masks: Tensor,
1035
+ actions: Tensor,
1036
+ noise: Tensor | None = None,
1037
+ time: Tensor | None = None,
1038
+ discrete_actions: Tensor | None = None,
1039
+ discrete_action_masks: Tensor | None = None,
1040
+ ) -> dict[str, Tensor]:
1041
+ """Do a full training forward pass and compute the loss.
1042
+
1043
+ Args:
1044
+ images: List of image tensors.
1045
+ img_masks: List of image mask tensors.
1046
+ lang_tokens: Language token tensor.
1047
+ lang_masks: Language mask tensor.
1048
+ actions: Action tensor.
1049
+ noise: Optional noise tensor.
1050
+ time: Optional time tensor.
1051
+ discrete_actions: Optional discrete action tensor.
1052
+ discrete_action_masks: Optional discrete action mask tensor.
1053
+
1054
+ Returns:
1055
+ A dictionary containing the loss components ("MSE" and "CE").
1056
+ """
1057
+ # Run VLM first to get key value cache
1058
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
1059
+ images, img_masks, lang_tokens, lang_masks, discrete_actions, discrete_action_masks
1060
+ )
1061
+
1062
+ vlm_2d_attention_mask = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
1063
+ vlm_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
1064
+
1065
+ num_cross_att_tokens = prefix_embs.shape[1] - self.config.discrete_action_max_length
1066
+
1067
+ (prefix_out, _), past_key_values = self.paligemma_with_expert.forward(
1068
+ attention_mask=vlm_2d_attention_mask,
1069
+ position_ids=vlm_position_ids,
1070
+ past_key_values=None,
1071
+ inputs_embeds=[prefix_embs, None],
1072
+ n_cross_att_tokens=num_cross_att_tokens,
1073
+ use_cache=True,
1074
+ fill_kv_cache=True,
1075
+ )
1076
+
1077
+ # Now run action expert
1078
+ if noise is None:
1079
+ noise = self.sample_noise(actions.shape, actions.device)
1080
+
1081
+ if time is None:
1082
+ time = self.sample_time(actions.shape[0], actions.device)
1083
+
1084
+ time_expanded = time[:, None, None]
1085
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
1086
+ u_t = noise - actions
1087
+
1088
+ suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
1089
+
1090
+ action_expert_2d_attention_mask = make_att_2d_masks(
1091
+ suffix_pad_masks,
1092
+ suffix_att_masks,
1093
+ n_cross_att_tokens=num_cross_att_tokens,
1094
+ cross_att_pad_masks=prefix_pad_masks[:, :num_cross_att_tokens],
1095
+ )
1096
+ # We should skip the response tokens when numbering the position ids for the action expert
1097
+ prefix_offsets = torch.sum(prefix_pad_masks[:, : -self.config.discrete_action_max_length], dim=-1)[
1098
+ :, None
1099
+ ] # action expert position ids start after prefix
1100
+ action_expert_position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
1101
+
1102
+ # stop gradient to avoid backpropagating from action expert to VLM
1103
+ for layer_idx in past_key_values:
1104
+ past_key_values[layer_idx]["key_states"] = past_key_values[layer_idx]["key_states"].detach()
1105
+ past_key_values[layer_idx]["value_states"] = past_key_values[layer_idx]["value_states"].detach()
1106
+
1107
+ (_, suffix_out), _ = self.paligemma_with_expert.forward(
1108
+ attention_mask=action_expert_2d_attention_mask,
1109
+ position_ids=action_expert_position_ids,
1110
+ past_key_values=past_key_values,
1111
+ inputs_embeds=[None, suffix_embs],
1112
+ use_cache=True,
1113
+ fill_kv_cache=False,
1114
+ adarms_cond=[None, adarms_cond],
1115
+ )
1116
+
1117
+ # compute mse loss for velocity
1118
+ suffix_out = suffix_out[:, -self.config.n_action_steps :]
1119
+ # Original openpi code, upcast attention output
1120
+ v_t = self.action_out_proj(suffix_out)
1121
+ v_t = v_t.to(dtype=torch.float32)
1122
+
1123
+ losses = F.mse_loss(u_t, v_t, reduction="none")
1124
+
1125
+ # compute cross entropy loss for discrete actions
1126
+ batch_size, seq_len = discrete_actions.shape
1127
+ discrete_action_out = prefix_out[:, -self.config.discrete_action_max_length - 1 : -1]
1128
+ logits = self.paligemma_with_expert.da_head(discrete_action_out)
1129
+
1130
+ logits = logits.to(dtype=torch.float32) # upcast to float32 for loss calculation
1131
+ logits = rearrange(logits, "b s d -> (b s) d")
1132
+ labels = rearrange(discrete_actions, "b s -> (b s)")
1133
+ ce_loss = F.cross_entropy(logits, labels, reduction="none")
1134
+
1135
+ ce_loss = rearrange(ce_loss, "(b s) -> b s", b=batch_size, s=seq_len)
1136
+
1137
+ # remove pad tokens
1138
+ discrete_action_is_pad = ~discrete_action_masks # convert into format where value for pad is True
1139
+ ce_loss = ce_loss * ~discrete_action_is_pad
1140
+
1141
+ # compute mean
1142
+ ce_loss = ce_loss.mean()
1143
+
1144
+ return {"MSE": losses, "CE": ce_loss}
1145
+
1146
+ def sample_actions(
1147
+ self,
1148
+ images: list[Tensor],
1149
+ img_masks: list[Tensor],
1150
+ lang_tokens: Tensor,
1151
+ lang_masks: Tensor,
1152
+ noise: Tensor | None = None,
1153
+ ) -> Tensor:
1154
+ """Do a full inference forward and compute the action.
1155
+
1156
+ Args:
1157
+ images: List of image tensors.
1158
+ img_masks: List of image mask tensors.
1159
+ lang_tokens: Language token tensor.
1160
+ lang_masks: Language mask tensor.
1161
+ noise: Optional noise tensor.
1162
+
1163
+ Returns:
1164
+ The sampled action tensor.
1165
+ """
1166
+ bsize = lang_tokens.shape[0]
1167
+ device = lang_tokens.device
1168
+
1169
+ if noise is None:
1170
+ actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
1171
+ noise = self.sample_noise(actions_shape, device)
1172
+
1173
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
1174
+ images, img_masks, lang_tokens, lang_masks
1175
+ )
1176
+ prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
1177
+ prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
1178
+
1179
+ num_cross_att_tokens = prefix_embs.shape[1]
1180
+
1181
+ # Compute image and language key value cache
1182
+ _, past_key_values = self.paligemma_with_expert.forward(
1183
+ attention_mask=prefix_att_2d_masks,
1184
+ position_ids=prefix_position_ids,
1185
+ past_key_values=None,
1186
+ inputs_embeds=[prefix_embs, None],
1187
+ n_cross_att_tokens=num_cross_att_tokens,
1188
+ use_cache=self.config.use_cache,
1189
+ fill_kv_cache=True,
1190
+ )
1191
+
1192
+ dt = -1.0 / self.config.num_steps
1193
+ dt = torch.tensor(dt, dtype=torch.float32, device=device)
1194
+
1195
+ x_t = noise
1196
+ time = torch.tensor(1.0, dtype=torch.float32, device=device)
1197
+ while time >= -dt / 2:
1198
+ expanded_time = time.expand(bsize)
1199
+ v_t = self.denoise_step(
1200
+ prefix_pad_masks,
1201
+ past_key_values,
1202
+ x_t,
1203
+ expanded_time,
1204
+ )
1205
+
1206
+ # Euler step
1207
+ x_t += dt * v_t
1208
+ time += dt
1209
+ return x_t
1210
+
1211
+ def denoise_step(
1212
+ self,
1213
+ prefix_pad_masks: Tensor,
1214
+ past_key_values: list[dict[str, Tensor]],
1215
+ x_t: Tensor,
1216
+ timestep: Tensor,
1217
+ ) -> Tensor:
1218
+ """Apply one denoising step of the noise `x_t` at a given timestep.
1219
+
1220
+ Args:
1221
+ prefix_pad_masks: Prefix padding masks.
1222
+ past_key_values: Past key values from the VLM.
1223
+ x_t: Current noise tensor.
1224
+ timestep: Current timestep.
1225
+
1226
+ Returns:
1227
+ The predicted velocity tensor (v_t).
1228
+ """
1229
+ suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
1230
+
1231
+ num_cross_att_tokens = prefix_pad_masks.shape[1]
1232
+ action_expert_2d_attention_mask = make_att_2d_masks(
1233
+ suffix_pad_masks,
1234
+ suffix_att_masks,
1235
+ n_cross_att_tokens=num_cross_att_tokens,
1236
+ cross_att_pad_masks=prefix_pad_masks[:, :num_cross_att_tokens],
1237
+ )
1238
+ # We should skip the response tokens when numbering the position ids for the action expert
1239
+ prefix_offsets = torch.sum(prefix_pad_masks[:, : -self.config.discrete_action_max_length], dim=-1)[
1240
+ :, None
1241
+ ] # action expert position ids start after prefix
1242
+ action_expert_position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
1243
+
1244
+ outputs_embeds, _ = self.paligemma_with_expert.forward(
1245
+ attention_mask=action_expert_2d_attention_mask,
1246
+ position_ids=action_expert_position_ids,
1247
+ past_key_values=past_key_values,
1248
+ inputs_embeds=[None, suffix_embs],
1249
+ use_cache=True,
1250
+ fill_kv_cache=False,
1251
+ adarms_cond=[None, adarms_cond],
1252
+ )
1253
+ suffix_out = outputs_embeds[1]
1254
+ suffix_out = suffix_out[:, -self.config.n_action_steps :]
1255
+ v_t = self.action_out_proj(suffix_out)
1256
+ v_t = v_t.to(dtype=torch.float32)
1257
+ return v_t