opentau 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -42,6 +42,7 @@ from opentau.policies.pi05.paligemma_with_expert import (
42
42
  PaliGemmaWithExpertModel,
43
43
  )
44
44
  from opentau.policies.pretrained import PreTrainedPolicy, T
45
+ from opentau.utils.accelerate_utils import get_proc_accelerator
45
46
  from opentau.utils.utils import get_safe_dtype
46
47
 
47
48
 
@@ -151,8 +152,8 @@ def make_att_2d_masks(
151
152
 
152
153
  # Apply padding masks: pad_masks for rows, cross_att_pad_masks for columns
153
154
  cross_att_mask = cross_att_mask & pad_masks[:, :, None] & cross_att_pad_masks[:, None, :]
154
-
155
- att_2d_masks = torch.cat((att_2d_masks, cross_att_mask), dim=2)
155
+ # The cross_att_masks are concatenated before the att_2d_masks
156
+ att_2d_masks = torch.cat((cross_att_mask, att_2d_masks), dim=2)
156
157
 
157
158
  return att_2d_masks
158
159
 
@@ -351,9 +352,12 @@ class PI05Policy(PreTrainedPolicy):
351
352
  model = cls(config, **kwargs)
352
353
 
353
354
  # Now manually load and remap the state dict
355
+ acc = get_proc_accelerator()
356
+ is_main_process = acc.is_main_process if acc else True
354
357
  try:
355
358
  # Try to load the pytorch_model.bin or model.safetensors file
356
- print(f"Loading model from: {pretrained_name_or_path}")
359
+ if is_main_process:
360
+ print(f"Loading model from: {pretrained_name_or_path}")
357
361
  try:
358
362
  from transformers.utils import cached_file
359
363
 
@@ -372,10 +376,12 @@ class PI05Policy(PreTrainedPolicy):
372
376
  from safetensors.torch import load_file
373
377
 
374
378
  original_state_dict = load_file(resolved_file)
375
- print("✓ Loaded state dict from model.safetensors")
379
+ if is_main_process:
380
+ print("✓ Loaded state dict from model.safetensors")
376
381
  except Exception as e:
377
- print(f"Could not load state dict from remote files: {e}")
378
- print("Returning model without loading pretrained weights")
382
+ if is_main_process:
383
+ print(f"Could not load state dict from remote files: {e}")
384
+ print("Returning model without loading pretrained weights")
379
385
  return model
380
386
 
381
387
  # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
@@ -390,18 +396,18 @@ class PI05Policy(PreTrainedPolicy):
390
396
  new_key = f"model.{key}"
391
397
  remapped_state_dict[new_key] = value
392
398
  remap_count += 1
393
- if remap_count <= 10: # Only print first 10 to avoid spam
399
+ if remap_count <= 10 and is_main_process: # Only print first 10 to avoid spam
394
400
  print(f"Remapped: {key} -> {new_key}")
395
401
  else:
396
402
  remapped_state_dict[key] = value
397
403
 
398
- if remap_count > 0:
404
+ if remap_count > 0 and is_main_process:
399
405
  print(f"Remapped {remap_count} state dict keys")
400
406
 
401
407
  # Load the remapped state dict into the model
402
408
  missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False)
403
409
 
404
- if missing_keys:
410
+ if missing_keys and is_main_process:
405
411
  print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
406
412
  if len(missing_keys) <= 20:
407
413
  for key in missing_keys:
@@ -411,7 +417,7 @@ class PI05Policy(PreTrainedPolicy):
411
417
  print(f" - {key}")
412
418
  print(f" ... and {len(missing_keys) - 20} more")
413
419
 
414
- if unexpected_keys:
420
+ if unexpected_keys and is_main_process:
415
421
  print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
416
422
  if len(unexpected_keys) <= 20:
417
423
  for key in unexpected_keys:
@@ -421,11 +427,12 @@ class PI05Policy(PreTrainedPolicy):
421
427
  print(f" - {key}")
422
428
  print(f" ... and {len(unexpected_keys) - 20} more")
423
429
 
424
- if not missing_keys and not unexpected_keys:
430
+ if not missing_keys and not unexpected_keys and is_main_process:
425
431
  print("All keys loaded successfully!")
426
432
 
427
433
  except Exception as e:
428
- print(f"Warning: Could not remap state dict keys: {e}")
434
+ if is_main_process:
435
+ print(f"Warning: Could not remap state dict keys: {e}")
429
436
 
430
437
  return model
431
438
 
@@ -596,6 +603,11 @@ class PI05Policy(PreTrainedPolicy):
596
603
  lang_tokens, lang_masks = self.prepare_language(
597
604
  batch
598
605
  ) # in lang_masks we have True for real tokens and False for padded tokens
606
+ # response prediction is to predict the response . It will attend to image and language inputs.
607
+ response_tokens, response_masks = self.prepare_response(
608
+ batch
609
+ ) # in response_masks we have True for real tokens and False for padded tokens
610
+ # discrete actions are to predict actions using autoregressive technique and not flow matching. It will attend to image, language and response inputs.
599
611
  discrete_actions, discrete_action_masks = self.prepare_discrete_actions(
600
612
  batch
601
613
  ) # in discrete_action_masks we have True for real tokens and False for padded tokens
@@ -610,6 +622,8 @@ class PI05Policy(PreTrainedPolicy):
610
622
  lang_tokens,
611
623
  lang_masks,
612
624
  actions,
625
+ response_tokens,
626
+ response_masks,
613
627
  noise,
614
628
  time,
615
629
  discrete_actions,
@@ -752,18 +766,25 @@ class PI05Policy(PreTrainedPolicy):
752
766
  device = batch["state"].device
753
767
  tasks = batch["prompt"]
754
768
 
755
- # PaliGemma prompt has to end with a new line
756
- tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
757
-
758
769
  # add state to the prompt
759
770
  state = self.prepare_discrete_state(batch)
760
- prompt = [f"Task: {task}State: {state}\nActions:" for task, state in zip(tasks, state, strict=False)]
771
+ # using <eos> to separate each modality
772
+ if self.config.predict_response:
773
+ prompt = [
774
+ f"Task: {task}<eos>State: {state}<eos>Response:"
775
+ for task, state in zip(tasks, state, strict=False)
776
+ ]
777
+ else:
778
+ prompt = [
779
+ f"Task: {task}<eos>State: {state}<eos>Actions:"
780
+ for task, state in zip(tasks, state, strict=False)
781
+ ]
761
782
 
762
783
  tokenized_prompt = self.language_tokenizer.__call__(
763
784
  prompt,
764
785
  padding="max_length",
765
786
  padding_side="right",
766
- max_length=self.config.tokenizer_max_length,
787
+ max_length=self.config.prompt_max_length,
767
788
  return_tensors="pt",
768
789
  truncation=True,
769
790
  )
@@ -772,6 +793,39 @@ class PI05Policy(PreTrainedPolicy):
772
793
 
773
794
  return lang_tokens, lang_masks
774
795
 
796
+ def prepare_response(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
797
+ """Tokenize the response input.
798
+
799
+ Args:
800
+ batch: Batch of data containing the key "response".
801
+
802
+ Returns:
803
+ A tuple containing:
804
+ - response_tokens: Tensor of response language tokens.
805
+ - response_masks: Tensor of response language attention masks.
806
+ """
807
+
808
+ if not self.config.predict_response:
809
+ return None, None
810
+ device = batch["state"].device
811
+ responses = batch["response"]
812
+
813
+ # if '' is found in response then response is not for loss calculation (used for robotic dataset with no subtask), so add pad token to the response.
814
+ response_prompt = [f"{response}<eos>Actions:" for response in responses]
815
+
816
+ tokenized_response = self.language_tokenizer.__call__(
817
+ response_prompt,
818
+ padding="max_length",
819
+ padding_side="right",
820
+ max_length=self.config.response_max_length,
821
+ return_tensors="pt",
822
+ truncation=True,
823
+ )
824
+ response_tokens = tokenized_response["input_ids"].to(device=device)
825
+ response_masks = tokenized_response["attention_mask"].to(device=device, dtype=torch.bool)
826
+
827
+ return response_tokens, response_masks
828
+
775
829
 
776
830
  class PI05FlowMatching(nn.Module):
777
831
  """
@@ -828,6 +882,8 @@ class PI05FlowMatching(nn.Module):
828
882
  self.time_mlp_in = nn.Linear(self.config.proj_width, self.config.proj_width)
829
883
  self.time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
830
884
 
885
+ self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
886
+
831
887
  self._init_model()
832
888
 
833
889
  def _init_weights(self, module: nn.Module) -> None:
@@ -897,6 +953,8 @@ class PI05FlowMatching(nn.Module):
897
953
  img_masks: list[Tensor],
898
954
  lang_tokens: Tensor,
899
955
  lang_masks: Tensor,
956
+ response_tokens: Tensor | None = None,
957
+ response_masks: Tensor | None = None,
900
958
  discrete_actions: Tensor | None = None,
901
959
  discrete_action_masks: Tensor | None = None,
902
960
  ) -> tuple[Tensor, Tensor, Tensor]:
@@ -908,6 +966,8 @@ class PI05FlowMatching(nn.Module):
908
966
  img_masks: List of image mask tensors.
909
967
  lang_tokens: Language token tensor.
910
968
  lang_masks: Language mask tensor.
969
+ response_tokens: Optional Response language token tensor.
970
+ response_masks: Optional Response language mask tensor.
911
971
  discrete_actions: Optional discrete action tensor.
912
972
  discrete_action_masks: Optional discrete action mask tensor.
913
973
 
@@ -956,6 +1016,20 @@ class PI05FlowMatching(nn.Module):
956
1016
  num_lang_embs = lang_emb.shape[1]
957
1017
  att_masks += [0] * num_lang_embs
958
1018
 
1019
+ if response_tokens is not None:
1020
+ response_emb = self.paligemma_with_expert.embed_language_tokens(response_tokens)
1021
+
1022
+ # Normalize response language embeddings
1023
+ response_emb_dim = response_emb.shape[-1]
1024
+ response_emb = response_emb * math.sqrt(response_emb_dim)
1025
+
1026
+ embs.append(response_emb)
1027
+ pad_masks.append(response_masks)
1028
+
1029
+ # full attention between image, language and response inputs
1030
+ num_response_embs = response_emb.shape[1]
1031
+ att_masks += [1] * num_response_embs
1032
+
959
1033
  if discrete_actions is not None:
960
1034
  discrete_action_emb = self.paligemma_with_expert.embed_discrete_actions(discrete_actions)
961
1035
  embs.append(discrete_action_emb.to(dtype=torch.bfloat16))
@@ -1033,6 +1107,8 @@ class PI05FlowMatching(nn.Module):
1033
1107
  lang_tokens: Tensor,
1034
1108
  lang_masks: Tensor,
1035
1109
  actions: Tensor,
1110
+ response_tokens: Tensor | None = None,
1111
+ response_masks: Tensor | None = None,
1036
1112
  noise: Tensor | None = None,
1037
1113
  time: Tensor | None = None,
1038
1114
  discrete_actions: Tensor | None = None,
@@ -1045,6 +1121,8 @@ class PI05FlowMatching(nn.Module):
1045
1121
  img_masks: List of image mask tensors.
1046
1122
  lang_tokens: Language token tensor.
1047
1123
  lang_masks: Language mask tensor.
1124
+ response_tokens: Response language token tensor.
1125
+ response_masks: Response language mask tensor.
1048
1126
  actions: Action tensor.
1049
1127
  noise: Optional noise tensor.
1050
1128
  time: Optional time tensor.
@@ -1056,12 +1134,20 @@ class PI05FlowMatching(nn.Module):
1056
1134
  """
1057
1135
  # Run VLM first to get key value cache
1058
1136
  prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
1059
- images, img_masks, lang_tokens, lang_masks, discrete_actions, discrete_action_masks
1137
+ images,
1138
+ img_masks,
1139
+ lang_tokens,
1140
+ lang_masks,
1141
+ response_tokens,
1142
+ response_masks,
1143
+ discrete_actions,
1144
+ discrete_action_masks,
1060
1145
  )
1061
1146
 
1062
1147
  vlm_2d_attention_mask = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
1063
1148
  vlm_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
1064
1149
 
1150
+ # avoids using discrete action for predicting continuous flow matching action
1065
1151
  num_cross_att_tokens = prefix_embs.shape[1] - self.config.discrete_action_max_length
1066
1152
 
1067
1153
  (prefix_out, _), past_key_values = self.paligemma_with_expert.forward(
@@ -1070,7 +1156,7 @@ class PI05FlowMatching(nn.Module):
1070
1156
  past_key_values=None,
1071
1157
  inputs_embeds=[prefix_embs, None],
1072
1158
  n_cross_att_tokens=num_cross_att_tokens,
1073
- use_cache=True,
1159
+ use_cache=False,
1074
1160
  fill_kv_cache=True,
1075
1161
  )
1076
1162
 
@@ -1093,7 +1179,7 @@ class PI05FlowMatching(nn.Module):
1093
1179
  n_cross_att_tokens=num_cross_att_tokens,
1094
1180
  cross_att_pad_masks=prefix_pad_masks[:, :num_cross_att_tokens],
1095
1181
  )
1096
- # We should skip the response tokens when numbering the position ids for the action expert
1182
+ # We should skip the discrete action tokens when numbering the position ids for the action expert
1097
1183
  prefix_offsets = torch.sum(prefix_pad_masks[:, : -self.config.discrete_action_max_length], dim=-1)[
1098
1184
  :, None
1099
1185
  ] # action expert position ids start after prefix
@@ -1124,24 +1210,60 @@ class PI05FlowMatching(nn.Module):
1124
1210
 
1125
1211
  # compute cross entropy loss for discrete actions
1126
1212
  batch_size, seq_len = discrete_actions.shape
1127
- discrete_action_out = prefix_out[:, -self.config.discrete_action_max_length - 1 : -1]
1213
+ discrete_token_start = -self.config.discrete_action_max_length
1214
+ # The last token of response will predict the first token of discrete actions , so we need to slice from discrete_token_start -1.
1215
+ # The predicted last token of discrete action is useless, so no need to include for loss calculation.
1216
+ discrete_action_slice_object = slice(discrete_token_start - 1, -1)
1217
+ discrete_action_out = prefix_out[:, discrete_action_slice_object]
1128
1218
  logits = self.paligemma_with_expert.da_head(discrete_action_out)
1129
1219
 
1130
1220
  logits = logits.to(dtype=torch.float32) # upcast to float32 for loss calculation
1131
1221
  logits = rearrange(logits, "b s d -> (b s) d")
1132
1222
  labels = rearrange(discrete_actions, "b s -> (b s)")
1133
- ce_loss = F.cross_entropy(logits, labels, reduction="none")
1223
+ discrete_action_ce_loss = F.cross_entropy(logits, labels, reduction="none")
1134
1224
 
1135
- ce_loss = rearrange(ce_loss, "(b s) -> b s", b=batch_size, s=seq_len)
1225
+ discrete_action_ce_loss = rearrange(discrete_action_ce_loss, "(b s) -> b s", b=batch_size, s=seq_len)
1136
1226
 
1137
1227
  # remove pad tokens
1138
1228
  discrete_action_is_pad = ~discrete_action_masks # convert into format where value for pad is True
1139
- ce_loss = ce_loss * ~discrete_action_is_pad
1229
+ discrete_action_ce_loss = discrete_action_ce_loss * ~discrete_action_is_pad
1230
+
1231
+ # compute mean
1232
+ discrete_action_ce_loss = discrete_action_ce_loss.mean()
1233
+
1234
+ # compute cross entropy loss for response language
1235
+ batch_size, seq_len = response_tokens.shape
1236
+ response_token_start = -self.config.response_max_length - self.config.discrete_action_max_length
1237
+ # The last token of language will predict <BOS> token of response, so no need to include for loss calculation. Hence slice starts from -self.config.discrete_action_max_length - self.config.response_max_length.
1238
+ # The last token of response predicts first token of discrete actions, so no need to include for loss calculation. Hence slice ends at -self.config.discrete_action_max_length - 1.
1239
+ response_token_end = -self.config.discrete_action_max_length - 1
1240
+ response_slice_object = slice(response_token_start, response_token_end)
1241
+ response_out = prefix_out[
1242
+ :,
1243
+ response_slice_object,
1244
+ ]
1245
+ response_logits = self.paligemma_with_expert.paligemma.lm_head(response_out)
1246
+ # response slice to exclude the <BOS> token from response while calculating loss.
1247
+ response_slice = slice(1, None)
1248
+ response_logits = response_logits.to(dtype=torch.float32) # upcast to float32 for loss calculation
1249
+ response_logits = rearrange(response_logits, "b s d -> (b s) d")
1250
+ response_labels = rearrange(response_tokens[:, response_slice], "b s -> (b s)")
1251
+ response_ce_loss = F.cross_entropy(response_logits, response_labels, reduction="none")
1252
+
1253
+ response_ce_loss = rearrange(response_ce_loss, "(b s) -> b s", b=batch_size, s=seq_len - 1)
1254
+
1255
+ # remove pad tokens
1256
+ response_is_pad = ~response_masks # convert into format where value for pad is True
1257
+ # helps to control loss for response tokens in case of robotic data and VQA data
1258
+ response_ce_loss = response_ce_loss * ~response_is_pad[:, response_slice]
1140
1259
 
1141
1260
  # compute mean
1142
- ce_loss = ce_loss.mean()
1261
+ response_ce_loss = response_ce_loss.mean()
1143
1262
 
1144
- return {"MSE": losses, "CE": ce_loss}
1263
+ return {
1264
+ "MSE": losses,
1265
+ "CE": (discrete_action_ce_loss + response_ce_loss),
1266
+ }
1145
1267
 
1146
1268
  def sample_actions(
1147
1269
  self,
@@ -1176,19 +1298,48 @@ class PI05FlowMatching(nn.Module):
1176
1298
  prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
1177
1299
  prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
1178
1300
 
1301
+ prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] - 1
1302
+
1179
1303
  num_cross_att_tokens = prefix_embs.shape[1]
1180
1304
 
1181
1305
  # Compute image and language key value cache
1182
- _, past_key_values = self.paligemma_with_expert.forward(
1306
+ (prefix_out, _), past_key_values = self.paligemma_with_expert.forward(
1183
1307
  attention_mask=prefix_att_2d_masks,
1184
1308
  position_ids=prefix_position_ids,
1185
1309
  past_key_values=None,
1186
1310
  inputs_embeds=[prefix_embs, None],
1187
1311
  n_cross_att_tokens=num_cross_att_tokens,
1188
- use_cache=self.config.use_cache,
1312
+ use_cache=False,
1189
1313
  fill_kv_cache=True,
1190
1314
  )
1191
1315
 
1316
+ # initialize response tokens to empty tensor for storing response tokens during inference
1317
+ response_tokens = torch.empty((bsize, 0), device=device, dtype=torch.long)
1318
+ # if response prediction is enabled, then predict response tokens autoregressively
1319
+ if self.config.predict_response:
1320
+ for auto_step in range(self.config.response_max_length):
1321
+ (
1322
+ prefix_out,
1323
+ prefix_embs,
1324
+ prefix_pad_masks,
1325
+ prefix_att_masks,
1326
+ prefix_offsets,
1327
+ response_tokens,
1328
+ past_key_values,
1329
+ ) = self.infer_response(
1330
+ prefix_out,
1331
+ prefix_embs,
1332
+ prefix_pad_masks,
1333
+ prefix_att_masks,
1334
+ past_key_values,
1335
+ prefix_offsets,
1336
+ response_tokens,
1337
+ auto_step,
1338
+ bsize,
1339
+ device,
1340
+ )
1341
+
1342
+ # perform denoising steps to get the action
1192
1343
  dt = -1.0 / self.config.num_steps
1193
1344
  dt = torch.tensor(dt, dtype=torch.float32, device=device)
1194
1345
 
@@ -1235,8 +1386,7 @@ class PI05FlowMatching(nn.Module):
1235
1386
  n_cross_att_tokens=num_cross_att_tokens,
1236
1387
  cross_att_pad_masks=prefix_pad_masks[:, :num_cross_att_tokens],
1237
1388
  )
1238
- # We should skip the response tokens when numbering the position ids for the action expert
1239
- prefix_offsets = torch.sum(prefix_pad_masks[:, : -self.config.discrete_action_max_length], dim=-1)[
1389
+ prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[
1240
1390
  :, None
1241
1391
  ] # action expert position ids start after prefix
1242
1392
  action_expert_position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
@@ -1255,3 +1405,119 @@ class PI05FlowMatching(nn.Module):
1255
1405
  v_t = self.action_out_proj(suffix_out)
1256
1406
  v_t = v_t.to(dtype=torch.float32)
1257
1407
  return v_t
1408
+
1409
+ def infer_response(
1410
+ self,
1411
+ prefix_out: Tensor,
1412
+ prefix_embs: Tensor,
1413
+ prefix_pad_masks: Tensor,
1414
+ prefix_att_masks: Tensor,
1415
+ past_key_values: list[dict[str, Tensor]],
1416
+ prefix_offsets: Tensor,
1417
+ response_tokens: Tensor,
1418
+ auto_step: int,
1419
+ bsize: int,
1420
+ device: torch.device,
1421
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, list[dict[str, Tensor]], Tensor]:
1422
+ """Perform autoregressive inference for response generation.
1423
+
1424
+ This method generates the next token in the response sequence, updating the
1425
+ various state tensors required for maintaining the generation context. It handles
1426
+ the initial BOS token generation as well as subsequent tokens, and manages
1427
+ padding masks to handle variable-length sequences properly.
1428
+
1429
+ Args:
1430
+ prefix_out: Output tensor from the previous step.
1431
+ prefix_embs: Embeddings for the current prefix context.
1432
+ prefix_pad_masks: Boolean mask indicating valid (non-padding) tokens in the prefix.
1433
+ prefix_att_masks: Attention mask for the prefix.
1434
+ past_key_values: KV cache from previous transformer steps.
1435
+ prefix_offsets: Position offsets for the current generation step.
1436
+ response_tokens: Accumulated tokens generated so far.
1437
+ auto_step: Current autoregressive step index (0 for first token).
1438
+ bsize: Batch size.
1439
+ device: Device to run the computation on.
1440
+
1441
+ Returns:
1442
+ A tuple containing updated tensors for the next step:
1443
+ (prefix_out, prefix_embs, prefix_pad_masks, prefix_att_masks,
1444
+ prefix_offsets, response_tokens, past_key_values, response_token)
1445
+ """
1446
+ EOS_TOKEN = self.language_tokenizer.convert_tokens_to_ids(self.language_tokenizer.eos_token) # noqa: N806
1447
+ if auto_step == 0:
1448
+ # Start the autoregressive inference with <bos> token
1449
+ response_token = torch.full(
1450
+ (bsize, 1),
1451
+ self.language_tokenizer.bos_token_id,
1452
+ device=device,
1453
+ dtype=torch.long,
1454
+ )
1455
+ else:
1456
+ # get the last predicted token from the prefix output which is predicted response
1457
+ response_token = prefix_out[:, -1:]
1458
+ response_token = self.paligemma_with_expert.paligemma.lm_head(response_token).argmax(dim=-1)
1459
+
1460
+ PAD_TOKEN = self.language_tokenizer.pad_token_id # noqa: N806
1461
+ # Create pad masks: False if previous token was EOS or PAD
1462
+ if response_tokens.shape[1] > 1:
1463
+ prev_tokens = response_tokens
1464
+ has_eos = (prev_tokens == EOS_TOKEN).any(dim=1, keepdim=True)
1465
+ has_pad = (prev_tokens == PAD_TOKEN).any(dim=1, keepdim=True)
1466
+ # check if the previous token was EOS or PAD. If so, then the current token should be padded, so its not attended by flow matching action expert.
1467
+ response_pad_masks = ~(has_eos | has_pad)
1468
+ response_token = torch.where(
1469
+ response_pad_masks,
1470
+ response_token,
1471
+ torch.tensor(PAD_TOKEN, device=device, dtype=response_token.dtype),
1472
+ )
1473
+ else:
1474
+ response_pad_masks = torch.ones((bsize, 1), device=device, dtype=torch.bool)
1475
+
1476
+ # Updating response tokens with current predicted token
1477
+ response_tokens = torch.cat([response_tokens, response_token], dim=1)
1478
+
1479
+ # Embed the current predicted token
1480
+ response_emb = self.paligemma_with_expert.embed_language_tokens(response_token)
1481
+
1482
+ # Normalize response language embeddings
1483
+ response_emb_dim = response_emb.shape[-1]
1484
+ response_emb = response_emb * math.sqrt(response_emb_dim)
1485
+
1486
+ response_att_masks = torch.ones((bsize, 1), device=device, dtype=response_emb.dtype)
1487
+
1488
+ # update the prefix embs, pad masks and att masks, so it can be used by action experts
1489
+ prefix_embs = torch.cat([prefix_embs, response_emb], dim=1)
1490
+ prefix_pad_masks = torch.cat([prefix_pad_masks, response_pad_masks], dim=1)
1491
+ prefix_att_masks = torch.cat([prefix_att_masks, response_att_masks], dim=1)
1492
+
1493
+ num_cross_att_tokens = prefix_pad_masks.shape[1]
1494
+ # create the attention mask for the response tokens
1495
+ response_att_2d_masks = make_att_2d_masks(
1496
+ response_pad_masks,
1497
+ response_att_masks,
1498
+ n_cross_att_tokens=num_cross_att_tokens - 1,
1499
+ cross_att_pad_masks=prefix_pad_masks[:, : num_cross_att_tokens - 1],
1500
+ )
1501
+ prefix_offsets = prefix_offsets + response_pad_masks.long()
1502
+ prefix_position_ids = prefix_offsets
1503
+
1504
+ # Compute image and language key value cache
1505
+ (prefix_out, _), past_key_values = self.paligemma_with_expert.forward(
1506
+ attention_mask=response_att_2d_masks,
1507
+ position_ids=prefix_position_ids,
1508
+ past_key_values=past_key_values,
1509
+ inputs_embeds=[response_emb, None],
1510
+ n_cross_att_tokens=num_cross_att_tokens,
1511
+ use_cache=True,
1512
+ fill_kv_cache=True,
1513
+ )
1514
+
1515
+ return (
1516
+ prefix_out,
1517
+ prefix_embs,
1518
+ prefix_pad_masks,
1519
+ prefix_att_masks,
1520
+ prefix_offsets,
1521
+ response_tokens,
1522
+ past_key_values,
1523
+ )
@@ -260,6 +260,10 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
260
260
  self.paligemma.eval()
261
261
  for params in self.paligemma.parameters():
262
262
  params.requires_grad = False
263
+ for param in self.da_head.parameters():
264
+ param.requires_grad = False
265
+ for param in self.discrete_action_embedding.parameters():
266
+ param.requires_grad = False
263
267
 
264
268
  def train(self, mode: bool = True) -> None:
265
269
  """Sets the module in training mode.
@@ -416,27 +420,23 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
416
420
  query_states = apply_rope(query_states, position_ids)
417
421
  key_states = apply_rope(key_states, position_ids)
418
422
 
419
- if use_cache and past_key_values is None:
420
- past_key_values = {}
421
-
422
423
  if use_cache:
423
- if fill_kv_cache:
424
- if n_cross_att_tokens is None:
425
- raise ValueError("n_cross_att_tokens must be provided when fill_kv_cache is True")
426
- past_key_values[layer_idx] = {
427
- # save the first n_cross_att_tokens for action expert cross attention
428
- "key_states": key_states[:, :n_cross_att_tokens, :, :],
429
- "value_states": value_states[:, :n_cross_att_tokens, :, :],
430
- }
431
- else:
432
- # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
433
- # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
434
- # the max len, then we (for instance) double the cache size. This implementation already exists
435
- # in `transformers`. (molbap)
436
- key_states = torch.cat([key_states, past_key_values[layer_idx]["key_states"]], dim=1)
437
- value_states = torch.cat(
438
- [value_states, past_key_values[layer_idx]["value_states"]], dim=1
439
- )
424
+ # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
425
+ # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
426
+ # the max len, then we (for instance) double the cache size. This implementation already exists
427
+ # in `transformers`. (molbap)
428
+ key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
429
+ value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
430
+ if fill_kv_cache:
431
+ if past_key_values is None:
432
+ past_key_values = {}
433
+ if n_cross_att_tokens is None:
434
+ raise ValueError("n_cross_att_tokens must be provided when fill_kv_cache is True")
435
+ past_key_values[layer_idx] = {
436
+ # save the first n_cross_att_tokens for action expert cross attention
437
+ "key_states": key_states[:, :n_cross_att_tokens, :, :],
438
+ "value_states": value_states[:, :n_cross_att_tokens, :, :],
439
+ }
440
440
 
441
441
  attention_interface = self.get_attention_interface()
442
442
  att_output = attention_interface(
@@ -0,0 +1,19 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """gRPC server and client for remote robot policy inference.
16
+
17
+ This module provides gRPC-based communication between a robot running ROS 2
18
+ and a remote server running ML model inference on GPU.
19
+ """