opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,994 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """π0: A Vision-Language-Action Flow Model for General Robot Control
19
+
20
+ [Paper](https://www.physicalintelligence.company/download/pi0.pdf)
21
+ """
22
+
23
+ import math
24
+ from collections import deque
25
+
26
+ import torch
27
+ import torch.nn.functional as F # noqa: N812
28
+ from torch import Tensor, nn
29
+ from transformers import AutoTokenizer
30
+
31
+ from opentau.policies.normalize import Normalize, Unnormalize
32
+ from opentau.policies.pi0.configuration_pi0 import PI0Config
33
+ from opentau.policies.pi0.paligemma_with_expert import (
34
+ PaliGemmaWithExpertConfig,
35
+ PaliGemmaWithExpertModel,
36
+ )
37
+ from opentau.policies.pretrained import PreTrainedPolicy
38
+ from opentau.policies.utils import log_model_loading_keys
39
+ from opentau.utils.utils import get_safe_dtype
40
+
41
+
42
+ def create_sinusoidal_pos_embedding(
43
+ time: Tensor, dimension: int, min_period: float, max_period: float, device: torch.device | str = "cpu"
44
+ ) -> Tensor:
45
+ """Computes sine-cosine positional embedding vectors for scalar positions.
46
+
47
+ Args:
48
+ time: A 1-D tensor of shape (batch_size,).
49
+ dimension: The dimension of the embedding vectors. Must be divisible by 2.
50
+ min_period: The minimum period of the sinusoidal functions.
51
+ max_period: The maximum period of the sinusoidal functions.
52
+ device: The device to create the tensors on. Defaults to "cpu".
53
+
54
+ Returns:
55
+ A tensor of shape (batch_size, dimension) containing the positional embeddings.
56
+
57
+ Raises:
58
+ ValueError: If dimension is not divisible by 2 or if time tensor is not 1-D.
59
+ """
60
+ if dimension % 2 != 0:
61
+ raise ValueError(f"dimension ({dimension}) must be divisible by 2")
62
+
63
+ if time.ndim != 1:
64
+ raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
65
+
66
+ dtype = (
67
+ get_safe_dtype(torch.float64, device.type)
68
+ if isinstance(device, torch.device)
69
+ else get_safe_dtype(torch.float64, device)
70
+ )
71
+ fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
72
+ period = min_period * (max_period / min_period) ** fraction
73
+
74
+ # Compute the outer product
75
+ scaling_factor = 1.0 / period * 2 * math.pi
76
+ sin_input = scaling_factor[None, :] * time[:, None]
77
+ pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
78
+ return pos_emb
79
+
80
+
81
+ def make_att_2d_masks(pad_masks: Tensor, att_masks: Tensor) -> Tensor:
82
+ """Creates a 2-D attention mask given padding and 1-D attention masks.
83
+
84
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
85
+ smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
86
+ setup several types of attention, for example:
87
+
88
+ [[1 1 1 1 1 1]]: pure causal attention.
89
+
90
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
91
+ themselves and the last 3 tokens have a causal attention. The first
92
+ entry could also be a 1 without changing behaviour.
93
+
94
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
95
+ block can attend all previous blocks and all tokens on the same block.
96
+
97
+ Args:
98
+ pad_masks: bool[B, N] true if its part of the input, false if padding.
99
+ att_masks: int32[B, N] mask that's 1 where previous tokens cannot depend on
100
+ it and 0 where it shares the same attention mask as the previous token.
101
+
102
+ Returns:
103
+ A 2D attention mask tensor of shape (B, N, N).
104
+
105
+ Raises:
106
+ ValueError: If att_masks or pad_masks are not 2D.
107
+ """
108
+ if att_masks.ndim != 2:
109
+ raise ValueError(att_masks.ndim)
110
+ if pad_masks.ndim != 2:
111
+ raise ValueError(pad_masks.ndim)
112
+
113
+ cumsum = torch.cumsum(att_masks, dim=1)
114
+ att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
115
+ pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
116
+ att_2d_masks = att_2d_masks & pad_2d_masks
117
+ return att_2d_masks
118
+
119
+
120
+ def resize_with_pad(img: Tensor, width: int, height: int, pad_value: int = -1) -> Tensor:
121
+ """Resizes an image to fit within the specified dimensions while maintaining aspect ratio,
122
+ and pads the remaining area with the specified value.
123
+
124
+ Args:
125
+ img: Input image tensor of shape (batch_size, channels, current_height, current_width).
126
+ width: Target width.
127
+ height: Target height.
128
+ pad_value: Value to use for padding. Defaults to -1.
129
+
130
+ Returns:
131
+ The resized and padded image tensor of shape (batch_size, channels, height, width).
132
+
133
+ Raises:
134
+ ValueError: If the input image tensor does not have 4 dimensions.
135
+ """
136
+ # assume no-op when width height fits already
137
+ if img.ndim != 4:
138
+ raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
139
+
140
+ cur_height, cur_width = img.shape[2:]
141
+
142
+ ratio = max(cur_width / width, cur_height / height)
143
+ resized_height = int(cur_height / ratio)
144
+ resized_width = int(cur_width / ratio)
145
+ resized_img = F.interpolate(
146
+ img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
147
+ )
148
+
149
+ pad_height = max(0, int(height - resized_height))
150
+ pad_width = max(0, int(width - resized_width))
151
+
152
+ # pad on left and top of image
153
+ padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
154
+ return padded_img
155
+
156
+
157
+ def pad_vector(vector: Tensor, new_dim: int) -> Tensor:
158
+ """Pads the last dimension of a vector to a new size with zeros.
159
+
160
+ Args:
161
+ vector: Input tensor. Can be (batch_size x sequence_length x features_dimension)
162
+ or (batch_size x features_dimension).
163
+ new_dim: The new size for the last dimension.
164
+
165
+ Returns:
166
+ The padded tensor.
167
+ """
168
+ if vector.shape[-1] == new_dim:
169
+ return vector
170
+ shape = list(vector.shape)
171
+ current_dim = shape[-1]
172
+ shape[-1] = new_dim
173
+ new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
174
+ new_vector[..., :current_dim] = vector
175
+ return new_vector
176
+
177
+
178
+ class PI0Policy(PreTrainedPolicy):
179
+ """Wrapper class around PI0FlowMatching model to train and run inference within OpenTau."""
180
+
181
+ config_class = PI0Config
182
+ name = "pi0"
183
+
184
+ def __init__(
185
+ self,
186
+ config: PI0Config,
187
+ dataset_stats: dict[str, dict[str, Tensor]] | None = None,
188
+ ):
189
+ """Initializes the PI0Policy.
190
+
191
+ Args:
192
+ config: Policy configuration class instance.
193
+ dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
194
+ that they will be passed with a call to `load_state_dict` before the policy is used.
195
+ """
196
+
197
+ super().__init__(config)
198
+ config.validate_features()
199
+ self.config = config
200
+ self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
201
+ self.normalize_targets = Normalize(
202
+ config.output_features, config.normalization_mapping, dataset_stats
203
+ )
204
+ self.unnormalize_outputs = Unnormalize(
205
+ config.output_features, config.normalization_mapping, dataset_stats
206
+ )
207
+
208
+ self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
209
+ self.model = PI0FlowMatching(config)
210
+
211
+ self.reset()
212
+
213
+ def reset(self) -> None:
214
+ """This should be called whenever the environment is reset."""
215
+ self._action_queue = deque([], maxlen=self.config.n_action_steps)
216
+
217
+ @classmethod
218
+ def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
219
+ """
220
+ Transform state dict keys to match expected model structure.
221
+
222
+ Transformations:
223
+ - model.paligemma_with_expert.paligemma.language_model.lm_head ->
224
+ model.paligemma_with_expert.paligemma.lm_head
225
+ - model.paligemma_with_expert.paligemma.language_model.model ->
226
+ model.paligemma_with_expert.paligemma.model.language_model
227
+ - model.paligemma_with_expert.paligemma.vision_tower ->
228
+ model.paligemma_with_expert.paligemma.model.vision_tower
229
+ - model.paligemma_with_expert.paligemma.multi_modal_projector ->
230
+ model.paligemma_with_expert.paligemma.model.multi_modal_projector
231
+
232
+ Also handles tied weights between lm_head.weight and
233
+ embed_tokens.weight.
234
+
235
+ Args:
236
+ state_dict: The state dictionary to transform.
237
+
238
+ Returns:
239
+ The transformed state dictionary.
240
+ """
241
+ import re
242
+
243
+ transformed_dict = {}
244
+
245
+ transformations = [
246
+ (
247
+ re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
248
+ ".paligemma_with_expert.paligemma.lm_head",
249
+ ),
250
+ (
251
+ re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
252
+ ".paligemma_with_expert.paligemma.model.language_model",
253
+ ),
254
+ (
255
+ re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
256
+ ".paligemma_with_expert.paligemma.model.vision_tower",
257
+ ),
258
+ (
259
+ re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
260
+ ".paligemma_with_expert.paligemma.model.multi_modal_projector",
261
+ ),
262
+ ]
263
+
264
+ for key, value in state_dict.items():
265
+ new_key = key
266
+ for pattern, replacement in transformations:
267
+ new_key = pattern.sub(replacement, new_key)
268
+ transformed_dict[new_key] = value
269
+
270
+ # Handle tied weights: lm_head.weight and embed_tokens.weight share memory
271
+ lm_head_key = None
272
+ embed_tokens_key = None
273
+
274
+ for key in transformed_dict:
275
+ if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
276
+ lm_head_key = key
277
+ elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
278
+ embed_tokens_key = key
279
+ if lm_head_key and embed_tokens_key:
280
+ break
281
+
282
+ if lm_head_key and not embed_tokens_key:
283
+ embed_tokens_key = lm_head_key.replace(
284
+ ".lm_head.weight", ".model.language_model.embed_tokens.weight"
285
+ )
286
+ transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
287
+ elif embed_tokens_key and not lm_head_key:
288
+ lm_head_key = embed_tokens_key.replace(
289
+ ".model.language_model.embed_tokens.weight", ".lm_head.weight"
290
+ )
291
+ transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
292
+
293
+ return transformed_dict
294
+
295
+ @classmethod
296
+ def _load_as_safetensor(
297
+ cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
298
+ ) -> "PI0Policy":
299
+ """Override to apply key transformations before loading.
300
+
301
+ Args:
302
+ model: The model instance.
303
+ model_file: Path to the model file.
304
+ map_location: Device mapping location.
305
+ strict: Whether to strictly enforce state dict matching.
306
+
307
+ Returns:
308
+ The loaded model instance.
309
+ """
310
+ from safetensors.torch import load_file
311
+
312
+ # Load the state dict from file safely
313
+ state_dict = load_file(model_file, device=map_location)
314
+
315
+ # Apply key transformations
316
+ transformed_state_dict = cls._transform_state_dict_keys(state_dict)
317
+
318
+ # Apply tiling of linear input weights if needed
319
+ model._tile_linear_input_weight(transformed_state_dict)
320
+
321
+ # Load the transformed state dict
322
+ msg = model.load_state_dict(transformed_state_dict, strict=strict)
323
+
324
+ # Log message
325
+ log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
326
+ return model
327
+
328
+ def get_optim_params(self) -> dict:
329
+ """Returns the parameters to be optimized.
330
+
331
+ Returns:
332
+ A generator over the model parameters.
333
+ """
334
+ return self.parameters()
335
+
336
+ @classmethod
337
+ def from_pretrained(cls, *args, **kwargs):
338
+ """Override the from_pretrained method to display important disclaimer.
339
+
340
+ Args:
341
+ *args: Positional arguments passed to super().from_pretrained.
342
+ **kwargs: Keyword arguments passed to super().from_pretrained.
343
+
344
+ Returns:
345
+ The loaded model instance.
346
+ """
347
+ print(
348
+ "⚠️ DISCLAIMER: The PI0 model is ported from JAX by the Hugging Face team. \n"
349
+ " It is not expected to perform as well as the original implementation. \n"
350
+ " Original implementation: https://github.com/Physical-Intelligence/openpi"
351
+ )
352
+ return super().from_pretrained(*args, **kwargs)
353
+
354
+ @torch.no_grad()
355
+ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
356
+ """Predict a chunk of actions given environment observations.
357
+
358
+ Args:
359
+ batch: Batch of data containing environment observations.
360
+
361
+ Returns:
362
+ The predicted action chunk.
363
+
364
+ Raises:
365
+ NotImplementedError: Always, as this method is not implemented for PI0.
366
+ """
367
+ raise NotImplementedError("Currently not implemented for PI0")
368
+
369
+ @torch.no_grad()
370
+ def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
371
+ """Select a single action given environment observations.
372
+
373
+ This method wraps `select_actions` in order to return one action at a time for execution in the
374
+ environment. It works by managing the actions in a queue and only calling `select_actions` when the
375
+ queue is empty.
376
+
377
+ Args:
378
+ batch: Batch of data containing environment observations.
379
+ noise: Optional noise tensor to be used during sampling.
380
+
381
+ Returns:
382
+ The selected action tensor.
383
+ """
384
+ self.eval()
385
+
386
+ # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
387
+ # querying the policy.
388
+ if len(self._action_queue) <= self.config.safety_buffer:
389
+ actions = self.sample_actions(batch, noise=noise)
390
+ self._action_queue.extend(actions)
391
+ return self._action_queue.popleft()
392
+
393
+ @torch.no_grad()
394
+ def sample_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
395
+ """Sample actions from the policy given environment observations.
396
+
397
+ Args:
398
+ batch: Batch of data containing environment observations.
399
+ noise: Optional noise tensor.
400
+
401
+ Returns:
402
+ The sampled actions tensor of shape (batch_size, action_dim).
403
+ """
404
+ batch = self.normalize_inputs(batch)
405
+
406
+ images, img_masks = self.prepare_images(batch)
407
+ lang_tokens, lang_masks = self.prepare_language(batch)
408
+
409
+ state = batch["state"]
410
+ actions = self.model.sample_actions(
411
+ images,
412
+ img_masks,
413
+ lang_tokens,
414
+ lang_masks,
415
+ state,
416
+ noise=noise,
417
+ )
418
+
419
+ # Unpad actions
420
+ original_action_dim = self.config.action_feature.shape[0]
421
+ actions = actions[:, :, :original_action_dim]
422
+
423
+ actions = self.unnormalize_outputs({"actions": actions})["actions"]
424
+
425
+ # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
426
+ # effectively has shape (n_action_steps, batch_size, *), hence the transpose.
427
+ actions = actions.transpose(0, 1)
428
+ return actions
429
+
430
+ def forward(
431
+ self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None
432
+ ) -> dict[str, Tensor]:
433
+ """Do a full training forward pass to compute the loss.
434
+
435
+ Args:
436
+ batch: Batch of data containing environment observations, actions, and targets.
437
+ noise: Optional noise tensor.
438
+ time: Optional time tensor.
439
+
440
+ Returns:
441
+ A dictionary containing the loss components ("MSE" and "CE").
442
+ """
443
+ batch = self.normalize_inputs(batch)
444
+ batch = self.normalize_targets(batch)
445
+
446
+ images, img_masks = self.prepare_images(batch)
447
+ state = batch["state"]
448
+ lang_tokens, lang_masks = self.prepare_language(batch)
449
+ actions = batch["actions"]
450
+ actions_is_pad = batch.get("action_is_pad")
451
+
452
+ losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
453
+
454
+ if actions_is_pad is not None:
455
+ in_episode_bound = ~actions_is_pad
456
+ losses = losses * in_episode_bound.unsqueeze(-1)
457
+
458
+ # Remove padding
459
+ losses = losses[:, :, : self.config.max_action_dim]
460
+
461
+ # For backward pass
462
+ loss = losses.mean()
463
+
464
+ return {"MSE": loss, "CE": torch.zeros_like(loss, requires_grad=True)}
465
+
466
+ def prepare_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
467
+ """Apply Pi0 preprocessing to the images.
468
+
469
+ Resizes to 224x224 and padding to keep aspect ratio, and converts pixel range
470
+ from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
471
+
472
+ Args:
473
+ batch: Batch of data containing image tensors.
474
+
475
+ Returns:
476
+ A tuple containing:
477
+ - images: A list of processed image tensors.
478
+ - img_masks: A list of image mask tensors.
479
+
480
+ Raises:
481
+ ValueError: If no image features are present in the batch.
482
+ """
483
+ images = []
484
+ img_masks = []
485
+
486
+ present_img_keys = [key for key in self.config.image_features if key in batch]
487
+ missing_img_keys = [key for key in self.config.image_features if key not in batch]
488
+
489
+ if len(present_img_keys) == 0:
490
+ raise ValueError(
491
+ f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
492
+ )
493
+
494
+ # Preprocess image features present in the batch
495
+ for key in present_img_keys:
496
+ img = batch[key]
497
+
498
+ if self.config.resize_imgs_with_padding is not None:
499
+ img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
500
+
501
+ # Normalize from range [0,1] to [-1,1] as expected by siglip
502
+ img = img * 2.0 - 1.0
503
+
504
+ bsize = img.shape[0]
505
+ device = img.device
506
+ mask = torch.ones(bsize, dtype=torch.bool, device=device)
507
+ images.append(img)
508
+ img_masks.append(mask)
509
+
510
+ # Create image features not present in the batch
511
+ # as fully 0 padded images.
512
+ for num_empty_cameras in range(len(missing_img_keys)):
513
+ if num_empty_cameras >= self.config.empty_cameras:
514
+ break
515
+ img = torch.ones_like(img) * -1
516
+ mask = torch.zeros_like(mask)
517
+ images.append(img)
518
+ img_masks.append(mask)
519
+
520
+ return images, img_masks
521
+
522
+ def prepare_language(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
523
+ """Tokenize the text input.
524
+
525
+ Args:
526
+ batch: Batch of data containing "prompt" and potentially "advantage".
527
+
528
+ Returns:
529
+ A tuple containing:
530
+ - lang_tokens: Tensor of language tokens.
531
+ - lang_masks: Tensor of language attention masks.
532
+ """
533
+ device = batch["state"].device
534
+ tasks = batch["prompt"]
535
+
536
+ # PaliGemma prompt has to end with a new line
537
+ tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
538
+
539
+ for idx, task in enumerate(tasks):
540
+ if self.config.advantage == "on": # always add positive advantage
541
+ tasks[idx] = f"{task}Advantage: positive\n"
542
+ elif self.config.advantage == "use": # add advantage based on threshold
543
+ adv = batch["advantage"][idx] >= self.config.advantage_threshold
544
+ adv = "positive" if adv else "negative"
545
+ tasks[idx] = f"{task}Advantage: {adv}\n"
546
+
547
+ tokenized_prompt = self.language_tokenizer.__call__(
548
+ tasks,
549
+ padding="max_length",
550
+ padding_side="right",
551
+ max_length=self.config.tokenizer_max_length,
552
+ return_tensors="pt",
553
+ )
554
+ lang_tokens = tokenized_prompt["input_ids"].to(device=device)
555
+ lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
556
+
557
+ return lang_tokens, lang_masks
558
+
559
+
560
+ class PI0FlowMatching(nn.Module):
561
+ """
562
+ π0: A Vision-Language-Action Flow Model for General Robot Control
563
+
564
+ [Paper](https://www.physicalintelligence.company/download/pi0.pdf)
565
+
566
+ ┌──────────────────────────────┐
567
+ │ actions │
568
+ │ ▲ │
569
+ │ ┌┴─────┐ │
570
+ │ kv cache │Gemma │ │
571
+ │ ┌──────────►│Expert│ │
572
+ │ │ │ │ │
573
+ │ ┌┴────────┐ │x 10 │ │
574
+ │ │ │ └▲──▲──┘ │
575
+ │ │PaliGemma│ │ │ │
576
+ │ │ │ │ robot state │
577
+ │ │ │ noise │
578
+ │ └▲──▲─────┘ │
579
+ │ │ │ │
580
+ │ │ image(s) │
581
+ │ language tokens │
582
+ └──────────────────────────────┘
583
+ """
584
+
585
+ def __init__(self, config: PI0Config):
586
+ """Initializes the PI0FlowMatching model.
587
+
588
+ Args:
589
+ config: Model configuration.
590
+ """
591
+ super().__init__()
592
+ self.config = config
593
+
594
+ load_pretrained_paligemma = (
595
+ self.config.init_strategy == "expert_only_he_init"
596
+ ) # only load pretrained paligemma if we are He-initializing the expert only
597
+ paligemma_with_export_config = PaliGemmaWithExpertConfig(
598
+ freeze_vision_encoder=self.config.freeze_vision_encoder,
599
+ train_expert_only=self.config.train_expert_only,
600
+ attention_implementation=self.config.attention_implementation,
601
+ load_pretrained_paligemma=load_pretrained_paligemma,
602
+ dropout=self.config.dropout,
603
+ )
604
+ self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
605
+
606
+ # Projections are float32
607
+ self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
608
+ self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
609
+ self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
610
+
611
+ self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
612
+ self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
613
+
614
+ self.set_requires_grad()
615
+
616
+ self._init_model()
617
+
618
+ def set_requires_grad(self) -> None:
619
+ """Sets the requires_grad attribute for state projection parameters."""
620
+ for params in self.state_proj.parameters():
621
+ params.requires_grad = self.config.train_state_proj
622
+
623
+ def _init_weights(self, module: nn.Module) -> None:
624
+ """Initialize weights using He (Kaiming) initialization.
625
+
626
+ Args:
627
+ module: The module to initialize.
628
+ """
629
+ if isinstance(module, nn.Linear):
630
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
631
+ if module.bias is not None:
632
+ nn.init.zeros_(module.bias)
633
+ elif isinstance(module, nn.LayerNorm):
634
+ nn.init.ones_(module.weight)
635
+ nn.init.zeros_(module.bias)
636
+
637
+ def _init_model(self) -> None:
638
+ """Initialize the model weights based on the configuration."""
639
+ if self.config.init_strategy == "no_init":
640
+ return
641
+ elif self.config.init_strategy == "full_he_init":
642
+ for m in self.modules():
643
+ self._init_weights(m)
644
+ elif self.config.init_strategy == "expert_only_he_init":
645
+ for m in self.paligemma_with_expert.gemma_expert.modules():
646
+ self._init_weights(m)
647
+ else:
648
+ raise ValueError(f"Invalid init strategy: {self.config.init_strategy}")
649
+
650
+ def sample_noise(self, shape: tuple[int, ...], device: torch.device | str) -> Tensor:
651
+ """Samples Gaussian noise.
652
+
653
+ Args:
654
+ shape: The shape of the noise tensor.
655
+ device: The device to create the tensor on.
656
+
657
+ Returns:
658
+ A tensor containing the sampled noise.
659
+ """
660
+ noise = torch.normal(
661
+ mean=0.0,
662
+ std=1.0,
663
+ size=shape,
664
+ dtype=torch.float32,
665
+ device=device,
666
+ )
667
+ return noise
668
+
669
+ def sample_time(self, bsize: int, device: torch.device | str) -> Tensor:
670
+ """Samples time steps from a Beta distribution.
671
+
672
+ Args:
673
+ bsize: Batch size.
674
+ device: The device to create the tensor on.
675
+
676
+ Returns:
677
+ A tensor containing the sampled time steps.
678
+ """
679
+ beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
680
+ time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
681
+ time = time_beta * 0.999 + 0.001
682
+ return time
683
+
684
+ def embed_prefix(
685
+ self,
686
+ images: list[Tensor],
687
+ img_masks: list[Tensor],
688
+ lang_tokens: Tensor,
689
+ lang_masks: Tensor,
690
+ ) -> tuple[Tensor, Tensor, Tensor]:
691
+ """Embed images with SigLIP and language tokens with embedding layer to prepare
692
+ for PaliGemma transformer processing.
693
+
694
+ Args:
695
+ images: List of image tensors.
696
+ img_masks: List of image mask tensors.
697
+ lang_tokens: Language token tensor.
698
+ lang_masks: Language mask tensor.
699
+
700
+ Returns:
701
+ A tuple containing:
702
+ - embs: Concatenated embeddings tensor.
703
+ - pad_masks: Concatenated padding masks tensor.
704
+ - att_masks: Attention masks tensor.
705
+ """
706
+ # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
707
+ embs = []
708
+ pad_masks = []
709
+ att_masks = []
710
+
711
+ # TODO: remove for loop
712
+ for (
713
+ img,
714
+ img_mask,
715
+ ) in zip(images, img_masks, strict=False):
716
+ img_emb = self.paligemma_with_expert.embed_image(img)
717
+ img_emb = img_emb.to(dtype=torch.bfloat16)
718
+
719
+ # image embeddings don't need to be unnormalized because `fix/lerobot_openpi` branch of huggingface
720
+ # already removed the normalization inside PaliGemma
721
+ pass
722
+
723
+ bsize, num_img_embs = img_emb.shape[:2]
724
+ img_mask = img_mask[:, None].expand(bsize, num_img_embs)
725
+
726
+ embs.append(img_emb)
727
+ pad_masks.append(img_mask)
728
+
729
+ # Create attention masks so that image tokens attend to each other
730
+ att_masks += [0] * num_img_embs
731
+
732
+ lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
733
+
734
+ # Normalize language embeddings
735
+ lang_emb_dim = lang_emb.shape[-1]
736
+ lang_emb = lang_emb * math.sqrt(lang_emb_dim)
737
+
738
+ embs.append(lang_emb)
739
+ pad_masks.append(lang_masks)
740
+
741
+ # full attention between image and language inputs
742
+ num_lang_embs = lang_emb.shape[1]
743
+ att_masks += [0] * num_lang_embs
744
+
745
+ embs = torch.cat(embs, dim=1)
746
+ pad_masks = torch.cat(pad_masks, dim=1)
747
+ att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
748
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
749
+
750
+ return embs, pad_masks, att_masks
751
+
752
+ def embed_suffix(
753
+ self, state: Tensor, noisy_actions: Tensor, timestep: Tensor
754
+ ) -> tuple[Tensor, Tensor, Tensor]:
755
+ """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.
756
+
757
+ Args:
758
+ state: State tensor.
759
+ noisy_actions: Tensor containing noisy actions.
760
+ timestep: Tensor containing timesteps.
761
+
762
+ Returns:
763
+ A tuple containing:
764
+ - embs: Concatenated embeddings tensor.
765
+ - pad_masks: Concatenated padding masks tensor.
766
+ - att_masks: Attention masks tensor.
767
+ """
768
+ embs = []
769
+ pad_masks = []
770
+ att_masks = []
771
+
772
+ # Embed state
773
+ state_emb = self.state_proj(state)
774
+ state_emb = state_emb.to(dtype=torch.bfloat16)
775
+ embs.append(state_emb[:, None, :])
776
+ bsize = state_emb.shape[0]
777
+ dtype = state_emb.dtype
778
+ device = state_emb.device
779
+
780
+ state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
781
+ pad_masks.append(state_mask)
782
+
783
+ # Set attention masks so that image and language inputs do not attend to state or actions
784
+ att_masks += [1]
785
+
786
+ # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
787
+ time_emb = create_sinusoidal_pos_embedding(
788
+ timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
789
+ )
790
+ time_emb = time_emb.type(dtype=dtype)
791
+
792
+ # Fuse timestep + action information using an MLP
793
+ noisy_actions = noisy_actions.to(dtype=dtype)
794
+ action_emb = self.action_in_proj(noisy_actions)
795
+
796
+ time_emb = time_emb[:, None, :].expand_as(action_emb)
797
+ action_time_emb = torch.cat([action_emb, time_emb], dim=2)
798
+
799
+ action_time_emb = self.action_time_mlp_in(action_time_emb)
800
+ action_time_emb = F.silu(action_time_emb) # swish == silu
801
+ action_time_emb = self.action_time_mlp_out(action_time_emb)
802
+
803
+ # Add to input tokens
804
+ embs.append(action_time_emb)
805
+
806
+ bsize, action_time_dim = action_time_emb.shape[:2]
807
+ action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
808
+ pad_masks.append(action_time_mask)
809
+
810
+ # Set attention masks so that image, language and state inputs do not attend to action tokens
811
+ att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
812
+
813
+ embs = torch.cat(embs, dim=1)
814
+ pad_masks = torch.cat(pad_masks, dim=1)
815
+ att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
816
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
817
+
818
+ return embs, pad_masks, att_masks
819
+
820
+ def forward(
821
+ self,
822
+ images: list[Tensor],
823
+ img_masks: list[Tensor],
824
+ lang_tokens: Tensor,
825
+ lang_masks: Tensor,
826
+ state: Tensor,
827
+ actions: Tensor,
828
+ noise: Tensor | None = None,
829
+ time: Tensor | None = None,
830
+ ) -> Tensor:
831
+ """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors).
832
+
833
+ Args:
834
+ images: List of image tensors.
835
+ img_masks: List of image mask tensors.
836
+ lang_tokens: Language token tensor.
837
+ lang_masks: Language mask tensor.
838
+ state: State tensor.
839
+ actions: Action tensor.
840
+ noise: Optional noise tensor.
841
+ time: Optional time tensor.
842
+
843
+ Returns:
844
+ The computed loss tensor.
845
+ """
846
+ if noise is None:
847
+ noise = self.sample_noise(actions.shape, actions.device)
848
+
849
+ if time is None:
850
+ time = self.sample_time(actions.shape[0], actions.device)
851
+
852
+ time_expanded = time[:, None, None]
853
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
854
+ u_t = noise - actions
855
+
856
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
857
+ images, img_masks, lang_tokens, lang_masks
858
+ )
859
+ suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
860
+
861
+ pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
862
+ att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
863
+
864
+ att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
865
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1
866
+
867
+ (_, suffix_out), _ = self.paligemma_with_expert.forward(
868
+ attention_mask=att_2d_masks,
869
+ position_ids=position_ids,
870
+ past_key_values=None,
871
+ inputs_embeds=[prefix_embs, suffix_embs],
872
+ use_cache=False,
873
+ fill_kv_cache=False,
874
+ )
875
+ suffix_out = suffix_out[:, -self.config.n_action_steps :]
876
+ # Original openpi code, upcast attention output
877
+ v_t = self.action_out_proj(suffix_out)
878
+ v_t = v_t.to(dtype=torch.float32)
879
+
880
+ losses = F.mse_loss(u_t, v_t, reduction="none")
881
+ return losses
882
+
883
+ def sample_actions(
884
+ self,
885
+ images: list[Tensor],
886
+ img_masks: list[Tensor],
887
+ lang_tokens: Tensor,
888
+ lang_masks: Tensor,
889
+ state: Tensor,
890
+ noise: Tensor | None = None,
891
+ ) -> Tensor:
892
+ """Do a full inference forward and compute the action (batch_size x num_steps x num_motors).
893
+
894
+ Args:
895
+ images: List of image tensors.
896
+ img_masks: List of image mask tensors.
897
+ lang_tokens: Language token tensor.
898
+ lang_masks: Language mask tensor.
899
+ state: State tensor.
900
+ noise: Optional noise tensor.
901
+
902
+ Returns:
903
+ The sampled action tensor.
904
+ """
905
+ bsize = state.shape[0]
906
+ device = state.device
907
+
908
+ if noise is None:
909
+ actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
910
+ noise = self.sample_noise(actions_shape, device)
911
+
912
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
913
+ images, img_masks, lang_tokens, lang_masks
914
+ )
915
+ prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
916
+ prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
917
+
918
+ # Compute image and language key value cache
919
+ _, past_key_values = self.paligemma_with_expert.forward(
920
+ attention_mask=prefix_att_2d_masks,
921
+ position_ids=prefix_position_ids,
922
+ past_key_values=None,
923
+ inputs_embeds=[prefix_embs, None],
924
+ use_cache=self.config.use_cache,
925
+ fill_kv_cache=True,
926
+ )
927
+
928
+ dt = -1.0 / self.config.num_steps
929
+ dt = torch.tensor(dt, dtype=torch.float32, device=device)
930
+
931
+ x_t = noise
932
+ time = torch.tensor(1.0, dtype=torch.float32, device=device)
933
+ while time >= -dt / 2:
934
+ expanded_time = time.expand(bsize)
935
+ v_t = self.denoise_step(
936
+ state,
937
+ prefix_pad_masks,
938
+ past_key_values,
939
+ x_t,
940
+ expanded_time,
941
+ )
942
+
943
+ # Euler step
944
+ x_t += dt * v_t
945
+ time += dt
946
+ return x_t
947
+
948
+ def denoise_step(
949
+ self,
950
+ state: Tensor,
951
+ prefix_pad_masks: Tensor,
952
+ past_key_values: list[dict[str, Tensor]],
953
+ x_t: Tensor,
954
+ timestep: Tensor,
955
+ ) -> Tensor:
956
+ """Apply one denoising step of the noise `x_t` at a given timestep.
957
+
958
+ Args:
959
+ state: State tensor.
960
+ prefix_pad_masks: Prefix padding masks.
961
+ past_key_values: Past key values from the VLM.
962
+ x_t: Current noise tensor.
963
+ timestep: Current timestep.
964
+
965
+ Returns:
966
+ The predicted velocity tensor (v_t).
967
+ """
968
+ suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
969
+
970
+ suffix_len = suffix_pad_masks.shape[1]
971
+ batch_size = prefix_pad_masks.shape[0]
972
+ prefix_len = prefix_pad_masks.shape[1]
973
+ prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
974
+
975
+ suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
976
+
977
+ full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
978
+
979
+ prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
980
+ position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
981
+
982
+ outputs_embeds, _ = self.paligemma_with_expert.forward(
983
+ attention_mask=full_att_2d_masks,
984
+ position_ids=position_ids,
985
+ past_key_values=past_key_values,
986
+ inputs_embeds=[None, suffix_embs],
987
+ use_cache=self.config.use_cache,
988
+ fill_kv_cache=False,
989
+ )
990
+ suffix_out = outputs_embeds[1]
991
+ suffix_out = suffix_out[:, -self.config.n_action_steps :]
992
+ v_t = self.action_out_proj(suffix_out)
993
+ v_t = v_t.to(dtype=torch.float32)
994
+ return v_t