opentau 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/configs/default.py +16 -0
- opentau/configs/deployment.py +85 -0
- opentau/configs/train.py +5 -0
- opentau/datasets/factory.py +43 -10
- opentau/datasets/lerobot_dataset.py +19 -19
- opentau/datasets/video_utils.py +11 -6
- opentau/policies/pi05/configuration_pi05.py +9 -6
- opentau/policies/pi05/modeling_pi05.py +296 -30
- opentau/policies/pi05/paligemma_with_expert.py +20 -20
- opentau/scripts/grpc/__init__.py +19 -0
- opentau/scripts/grpc/client.py +601 -0
- opentau/scripts/grpc/robot_inference_pb2.py +61 -0
- opentau/scripts/grpc/robot_inference_pb2_grpc.py +210 -0
- opentau/scripts/grpc/server.py +313 -0
- opentau/scripts/launch.py +12 -4
- opentau/scripts/train.py +94 -17
- opentau/scripts/visualize_dataset.py +141 -38
- opentau/utils/transformers_patch.py +251 -20
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/METADATA +37 -17
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/RECORD +24 -21
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/WHEEL +1 -1
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/entry_points.txt +1 -0
- opentau/scripts/libero_simulation_parallel.py +0 -356
- opentau/scripts/libero_simulation_sequential.py +0 -122
- opentau/scripts/visualize_dataset_html.py +0 -507
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -42,6 +42,7 @@ from opentau.policies.pi05.paligemma_with_expert import (
|
|
|
42
42
|
PaliGemmaWithExpertModel,
|
|
43
43
|
)
|
|
44
44
|
from opentau.policies.pretrained import PreTrainedPolicy, T
|
|
45
|
+
from opentau.utils.accelerate_utils import get_proc_accelerator
|
|
45
46
|
from opentau.utils.utils import get_safe_dtype
|
|
46
47
|
|
|
47
48
|
|
|
@@ -151,8 +152,8 @@ def make_att_2d_masks(
|
|
|
151
152
|
|
|
152
153
|
# Apply padding masks: pad_masks for rows, cross_att_pad_masks for columns
|
|
153
154
|
cross_att_mask = cross_att_mask & pad_masks[:, :, None] & cross_att_pad_masks[:, None, :]
|
|
154
|
-
|
|
155
|
-
att_2d_masks = torch.cat((
|
|
155
|
+
# The cross_att_masks are concatenated before the att_2d_masks
|
|
156
|
+
att_2d_masks = torch.cat((cross_att_mask, att_2d_masks), dim=2)
|
|
156
157
|
|
|
157
158
|
return att_2d_masks
|
|
158
159
|
|
|
@@ -351,9 +352,12 @@ class PI05Policy(PreTrainedPolicy):
|
|
|
351
352
|
model = cls(config, **kwargs)
|
|
352
353
|
|
|
353
354
|
# Now manually load and remap the state dict
|
|
355
|
+
acc = get_proc_accelerator()
|
|
356
|
+
is_main_process = acc.is_main_process if acc else True
|
|
354
357
|
try:
|
|
355
358
|
# Try to load the pytorch_model.bin or model.safetensors file
|
|
356
|
-
|
|
359
|
+
if is_main_process:
|
|
360
|
+
print(f"Loading model from: {pretrained_name_or_path}")
|
|
357
361
|
try:
|
|
358
362
|
from transformers.utils import cached_file
|
|
359
363
|
|
|
@@ -372,10 +376,12 @@ class PI05Policy(PreTrainedPolicy):
|
|
|
372
376
|
from safetensors.torch import load_file
|
|
373
377
|
|
|
374
378
|
original_state_dict = load_file(resolved_file)
|
|
375
|
-
|
|
379
|
+
if is_main_process:
|
|
380
|
+
print("✓ Loaded state dict from model.safetensors")
|
|
376
381
|
except Exception as e:
|
|
377
|
-
|
|
378
|
-
|
|
382
|
+
if is_main_process:
|
|
383
|
+
print(f"Could not load state dict from remote files: {e}")
|
|
384
|
+
print("Returning model without loading pretrained weights")
|
|
379
385
|
return model
|
|
380
386
|
|
|
381
387
|
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
|
@@ -390,18 +396,18 @@ class PI05Policy(PreTrainedPolicy):
|
|
|
390
396
|
new_key = f"model.{key}"
|
|
391
397
|
remapped_state_dict[new_key] = value
|
|
392
398
|
remap_count += 1
|
|
393
|
-
if remap_count <= 10: # Only print first 10 to avoid spam
|
|
399
|
+
if remap_count <= 10 and is_main_process: # Only print first 10 to avoid spam
|
|
394
400
|
print(f"Remapped: {key} -> {new_key}")
|
|
395
401
|
else:
|
|
396
402
|
remapped_state_dict[key] = value
|
|
397
403
|
|
|
398
|
-
if remap_count > 0:
|
|
404
|
+
if remap_count > 0 and is_main_process:
|
|
399
405
|
print(f"Remapped {remap_count} state dict keys")
|
|
400
406
|
|
|
401
407
|
# Load the remapped state dict into the model
|
|
402
408
|
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False)
|
|
403
409
|
|
|
404
|
-
if missing_keys:
|
|
410
|
+
if missing_keys and is_main_process:
|
|
405
411
|
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
|
|
406
412
|
if len(missing_keys) <= 20:
|
|
407
413
|
for key in missing_keys:
|
|
@@ -411,7 +417,7 @@ class PI05Policy(PreTrainedPolicy):
|
|
|
411
417
|
print(f" - {key}")
|
|
412
418
|
print(f" ... and {len(missing_keys) - 20} more")
|
|
413
419
|
|
|
414
|
-
if unexpected_keys:
|
|
420
|
+
if unexpected_keys and is_main_process:
|
|
415
421
|
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
|
|
416
422
|
if len(unexpected_keys) <= 20:
|
|
417
423
|
for key in unexpected_keys:
|
|
@@ -421,11 +427,12 @@ class PI05Policy(PreTrainedPolicy):
|
|
|
421
427
|
print(f" - {key}")
|
|
422
428
|
print(f" ... and {len(unexpected_keys) - 20} more")
|
|
423
429
|
|
|
424
|
-
if not missing_keys and not unexpected_keys:
|
|
430
|
+
if not missing_keys and not unexpected_keys and is_main_process:
|
|
425
431
|
print("All keys loaded successfully!")
|
|
426
432
|
|
|
427
433
|
except Exception as e:
|
|
428
|
-
|
|
434
|
+
if is_main_process:
|
|
435
|
+
print(f"Warning: Could not remap state dict keys: {e}")
|
|
429
436
|
|
|
430
437
|
return model
|
|
431
438
|
|
|
@@ -596,6 +603,11 @@ class PI05Policy(PreTrainedPolicy):
|
|
|
596
603
|
lang_tokens, lang_masks = self.prepare_language(
|
|
597
604
|
batch
|
|
598
605
|
) # in lang_masks we have True for real tokens and False for padded tokens
|
|
606
|
+
# response prediction is to predict the response . It will attend to image and language inputs.
|
|
607
|
+
response_tokens, response_masks = self.prepare_response(
|
|
608
|
+
batch
|
|
609
|
+
) # in response_masks we have True for real tokens and False for padded tokens
|
|
610
|
+
# discrete actions are to predict actions using autoregressive technique and not flow matching. It will attend to image, language and response inputs.
|
|
599
611
|
discrete_actions, discrete_action_masks = self.prepare_discrete_actions(
|
|
600
612
|
batch
|
|
601
613
|
) # in discrete_action_masks we have True for real tokens and False for padded tokens
|
|
@@ -610,6 +622,8 @@ class PI05Policy(PreTrainedPolicy):
|
|
|
610
622
|
lang_tokens,
|
|
611
623
|
lang_masks,
|
|
612
624
|
actions,
|
|
625
|
+
response_tokens,
|
|
626
|
+
response_masks,
|
|
613
627
|
noise,
|
|
614
628
|
time,
|
|
615
629
|
discrete_actions,
|
|
@@ -752,18 +766,25 @@ class PI05Policy(PreTrainedPolicy):
|
|
|
752
766
|
device = batch["state"].device
|
|
753
767
|
tasks = batch["prompt"]
|
|
754
768
|
|
|
755
|
-
# PaliGemma prompt has to end with a new line
|
|
756
|
-
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
|
757
|
-
|
|
758
769
|
# add state to the prompt
|
|
759
770
|
state = self.prepare_discrete_state(batch)
|
|
760
|
-
|
|
771
|
+
# using <eos> to separate each modality
|
|
772
|
+
if self.config.predict_response:
|
|
773
|
+
prompt = [
|
|
774
|
+
f"Task: {task}<eos>State: {state}<eos>Response:"
|
|
775
|
+
for task, state in zip(tasks, state, strict=False)
|
|
776
|
+
]
|
|
777
|
+
else:
|
|
778
|
+
prompt = [
|
|
779
|
+
f"Task: {task}<eos>State: {state}<eos>Actions:"
|
|
780
|
+
for task, state in zip(tasks, state, strict=False)
|
|
781
|
+
]
|
|
761
782
|
|
|
762
783
|
tokenized_prompt = self.language_tokenizer.__call__(
|
|
763
784
|
prompt,
|
|
764
785
|
padding="max_length",
|
|
765
786
|
padding_side="right",
|
|
766
|
-
max_length=self.config.
|
|
787
|
+
max_length=self.config.prompt_max_length,
|
|
767
788
|
return_tensors="pt",
|
|
768
789
|
truncation=True,
|
|
769
790
|
)
|
|
@@ -772,6 +793,39 @@ class PI05Policy(PreTrainedPolicy):
|
|
|
772
793
|
|
|
773
794
|
return lang_tokens, lang_masks
|
|
774
795
|
|
|
796
|
+
def prepare_response(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
797
|
+
"""Tokenize the response input.
|
|
798
|
+
|
|
799
|
+
Args:
|
|
800
|
+
batch: Batch of data containing the key "response".
|
|
801
|
+
|
|
802
|
+
Returns:
|
|
803
|
+
A tuple containing:
|
|
804
|
+
- response_tokens: Tensor of response language tokens.
|
|
805
|
+
- response_masks: Tensor of response language attention masks.
|
|
806
|
+
"""
|
|
807
|
+
|
|
808
|
+
if not self.config.predict_response:
|
|
809
|
+
return None, None
|
|
810
|
+
device = batch["state"].device
|
|
811
|
+
responses = batch["response"]
|
|
812
|
+
|
|
813
|
+
# if '' is found in response then response is not for loss calculation (used for robotic dataset with no subtask), so add pad token to the response.
|
|
814
|
+
response_prompt = [f"{response}<eos>Actions:" for response in responses]
|
|
815
|
+
|
|
816
|
+
tokenized_response = self.language_tokenizer.__call__(
|
|
817
|
+
response_prompt,
|
|
818
|
+
padding="max_length",
|
|
819
|
+
padding_side="right",
|
|
820
|
+
max_length=self.config.response_max_length,
|
|
821
|
+
return_tensors="pt",
|
|
822
|
+
truncation=True,
|
|
823
|
+
)
|
|
824
|
+
response_tokens = tokenized_response["input_ids"].to(device=device)
|
|
825
|
+
response_masks = tokenized_response["attention_mask"].to(device=device, dtype=torch.bool)
|
|
826
|
+
|
|
827
|
+
return response_tokens, response_masks
|
|
828
|
+
|
|
775
829
|
|
|
776
830
|
class PI05FlowMatching(nn.Module):
|
|
777
831
|
"""
|
|
@@ -828,6 +882,8 @@ class PI05FlowMatching(nn.Module):
|
|
|
828
882
|
self.time_mlp_in = nn.Linear(self.config.proj_width, self.config.proj_width)
|
|
829
883
|
self.time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
|
|
830
884
|
|
|
885
|
+
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
|
886
|
+
|
|
831
887
|
self._init_model()
|
|
832
888
|
|
|
833
889
|
def _init_weights(self, module: nn.Module) -> None:
|
|
@@ -897,6 +953,8 @@ class PI05FlowMatching(nn.Module):
|
|
|
897
953
|
img_masks: list[Tensor],
|
|
898
954
|
lang_tokens: Tensor,
|
|
899
955
|
lang_masks: Tensor,
|
|
956
|
+
response_tokens: Tensor | None = None,
|
|
957
|
+
response_masks: Tensor | None = None,
|
|
900
958
|
discrete_actions: Tensor | None = None,
|
|
901
959
|
discrete_action_masks: Tensor | None = None,
|
|
902
960
|
) -> tuple[Tensor, Tensor, Tensor]:
|
|
@@ -908,6 +966,8 @@ class PI05FlowMatching(nn.Module):
|
|
|
908
966
|
img_masks: List of image mask tensors.
|
|
909
967
|
lang_tokens: Language token tensor.
|
|
910
968
|
lang_masks: Language mask tensor.
|
|
969
|
+
response_tokens: Optional Response language token tensor.
|
|
970
|
+
response_masks: Optional Response language mask tensor.
|
|
911
971
|
discrete_actions: Optional discrete action tensor.
|
|
912
972
|
discrete_action_masks: Optional discrete action mask tensor.
|
|
913
973
|
|
|
@@ -956,6 +1016,20 @@ class PI05FlowMatching(nn.Module):
|
|
|
956
1016
|
num_lang_embs = lang_emb.shape[1]
|
|
957
1017
|
att_masks += [0] * num_lang_embs
|
|
958
1018
|
|
|
1019
|
+
if response_tokens is not None:
|
|
1020
|
+
response_emb = self.paligemma_with_expert.embed_language_tokens(response_tokens)
|
|
1021
|
+
|
|
1022
|
+
# Normalize response language embeddings
|
|
1023
|
+
response_emb_dim = response_emb.shape[-1]
|
|
1024
|
+
response_emb = response_emb * math.sqrt(response_emb_dim)
|
|
1025
|
+
|
|
1026
|
+
embs.append(response_emb)
|
|
1027
|
+
pad_masks.append(response_masks)
|
|
1028
|
+
|
|
1029
|
+
# full attention between image, language and response inputs
|
|
1030
|
+
num_response_embs = response_emb.shape[1]
|
|
1031
|
+
att_masks += [1] * num_response_embs
|
|
1032
|
+
|
|
959
1033
|
if discrete_actions is not None:
|
|
960
1034
|
discrete_action_emb = self.paligemma_with_expert.embed_discrete_actions(discrete_actions)
|
|
961
1035
|
embs.append(discrete_action_emb.to(dtype=torch.bfloat16))
|
|
@@ -1033,6 +1107,8 @@ class PI05FlowMatching(nn.Module):
|
|
|
1033
1107
|
lang_tokens: Tensor,
|
|
1034
1108
|
lang_masks: Tensor,
|
|
1035
1109
|
actions: Tensor,
|
|
1110
|
+
response_tokens: Tensor | None = None,
|
|
1111
|
+
response_masks: Tensor | None = None,
|
|
1036
1112
|
noise: Tensor | None = None,
|
|
1037
1113
|
time: Tensor | None = None,
|
|
1038
1114
|
discrete_actions: Tensor | None = None,
|
|
@@ -1045,6 +1121,8 @@ class PI05FlowMatching(nn.Module):
|
|
|
1045
1121
|
img_masks: List of image mask tensors.
|
|
1046
1122
|
lang_tokens: Language token tensor.
|
|
1047
1123
|
lang_masks: Language mask tensor.
|
|
1124
|
+
response_tokens: Response language token tensor.
|
|
1125
|
+
response_masks: Response language mask tensor.
|
|
1048
1126
|
actions: Action tensor.
|
|
1049
1127
|
noise: Optional noise tensor.
|
|
1050
1128
|
time: Optional time tensor.
|
|
@@ -1056,12 +1134,20 @@ class PI05FlowMatching(nn.Module):
|
|
|
1056
1134
|
"""
|
|
1057
1135
|
# Run VLM first to get key value cache
|
|
1058
1136
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
|
1059
|
-
images,
|
|
1137
|
+
images,
|
|
1138
|
+
img_masks,
|
|
1139
|
+
lang_tokens,
|
|
1140
|
+
lang_masks,
|
|
1141
|
+
response_tokens,
|
|
1142
|
+
response_masks,
|
|
1143
|
+
discrete_actions,
|
|
1144
|
+
discrete_action_masks,
|
|
1060
1145
|
)
|
|
1061
1146
|
|
|
1062
1147
|
vlm_2d_attention_mask = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
|
1063
1148
|
vlm_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
|
1064
1149
|
|
|
1150
|
+
# avoids using discrete action for predicting continuous flow matching action
|
|
1065
1151
|
num_cross_att_tokens = prefix_embs.shape[1] - self.config.discrete_action_max_length
|
|
1066
1152
|
|
|
1067
1153
|
(prefix_out, _), past_key_values = self.paligemma_with_expert.forward(
|
|
@@ -1070,7 +1156,7 @@ class PI05FlowMatching(nn.Module):
|
|
|
1070
1156
|
past_key_values=None,
|
|
1071
1157
|
inputs_embeds=[prefix_embs, None],
|
|
1072
1158
|
n_cross_att_tokens=num_cross_att_tokens,
|
|
1073
|
-
use_cache=
|
|
1159
|
+
use_cache=False,
|
|
1074
1160
|
fill_kv_cache=True,
|
|
1075
1161
|
)
|
|
1076
1162
|
|
|
@@ -1093,7 +1179,7 @@ class PI05FlowMatching(nn.Module):
|
|
|
1093
1179
|
n_cross_att_tokens=num_cross_att_tokens,
|
|
1094
1180
|
cross_att_pad_masks=prefix_pad_masks[:, :num_cross_att_tokens],
|
|
1095
1181
|
)
|
|
1096
|
-
# We should skip the
|
|
1182
|
+
# We should skip the discrete action tokens when numbering the position ids for the action expert
|
|
1097
1183
|
prefix_offsets = torch.sum(prefix_pad_masks[:, : -self.config.discrete_action_max_length], dim=-1)[
|
|
1098
1184
|
:, None
|
|
1099
1185
|
] # action expert position ids start after prefix
|
|
@@ -1124,24 +1210,60 @@ class PI05FlowMatching(nn.Module):
|
|
|
1124
1210
|
|
|
1125
1211
|
# compute cross entropy loss for discrete actions
|
|
1126
1212
|
batch_size, seq_len = discrete_actions.shape
|
|
1127
|
-
|
|
1213
|
+
discrete_token_start = -self.config.discrete_action_max_length
|
|
1214
|
+
# The last token of response will predict the first token of discrete actions , so we need to slice from discrete_token_start -1.
|
|
1215
|
+
# The predicted last token of discrete action is useless, so no need to include for loss calculation.
|
|
1216
|
+
discrete_action_slice_object = slice(discrete_token_start - 1, -1)
|
|
1217
|
+
discrete_action_out = prefix_out[:, discrete_action_slice_object]
|
|
1128
1218
|
logits = self.paligemma_with_expert.da_head(discrete_action_out)
|
|
1129
1219
|
|
|
1130
1220
|
logits = logits.to(dtype=torch.float32) # upcast to float32 for loss calculation
|
|
1131
1221
|
logits = rearrange(logits, "b s d -> (b s) d")
|
|
1132
1222
|
labels = rearrange(discrete_actions, "b s -> (b s)")
|
|
1133
|
-
|
|
1223
|
+
discrete_action_ce_loss = F.cross_entropy(logits, labels, reduction="none")
|
|
1134
1224
|
|
|
1135
|
-
|
|
1225
|
+
discrete_action_ce_loss = rearrange(discrete_action_ce_loss, "(b s) -> b s", b=batch_size, s=seq_len)
|
|
1136
1226
|
|
|
1137
1227
|
# remove pad tokens
|
|
1138
1228
|
discrete_action_is_pad = ~discrete_action_masks # convert into format where value for pad is True
|
|
1139
|
-
|
|
1229
|
+
discrete_action_ce_loss = discrete_action_ce_loss * ~discrete_action_is_pad
|
|
1230
|
+
|
|
1231
|
+
# compute mean
|
|
1232
|
+
discrete_action_ce_loss = discrete_action_ce_loss.mean()
|
|
1233
|
+
|
|
1234
|
+
# compute cross entropy loss for response language
|
|
1235
|
+
batch_size, seq_len = response_tokens.shape
|
|
1236
|
+
response_token_start = -self.config.response_max_length - self.config.discrete_action_max_length
|
|
1237
|
+
# The last token of language will predict <BOS> token of response, so no need to include for loss calculation. Hence slice starts from -self.config.discrete_action_max_length - self.config.response_max_length.
|
|
1238
|
+
# The last token of response predicts first token of discrete actions, so no need to include for loss calculation. Hence slice ends at -self.config.discrete_action_max_length - 1.
|
|
1239
|
+
response_token_end = -self.config.discrete_action_max_length - 1
|
|
1240
|
+
response_slice_object = slice(response_token_start, response_token_end)
|
|
1241
|
+
response_out = prefix_out[
|
|
1242
|
+
:,
|
|
1243
|
+
response_slice_object,
|
|
1244
|
+
]
|
|
1245
|
+
response_logits = self.paligemma_with_expert.paligemma.lm_head(response_out)
|
|
1246
|
+
# response slice to exclude the <BOS> token from response while calculating loss.
|
|
1247
|
+
response_slice = slice(1, None)
|
|
1248
|
+
response_logits = response_logits.to(dtype=torch.float32) # upcast to float32 for loss calculation
|
|
1249
|
+
response_logits = rearrange(response_logits, "b s d -> (b s) d")
|
|
1250
|
+
response_labels = rearrange(response_tokens[:, response_slice], "b s -> (b s)")
|
|
1251
|
+
response_ce_loss = F.cross_entropy(response_logits, response_labels, reduction="none")
|
|
1252
|
+
|
|
1253
|
+
response_ce_loss = rearrange(response_ce_loss, "(b s) -> b s", b=batch_size, s=seq_len - 1)
|
|
1254
|
+
|
|
1255
|
+
# remove pad tokens
|
|
1256
|
+
response_is_pad = ~response_masks # convert into format where value for pad is True
|
|
1257
|
+
# helps to control loss for response tokens in case of robotic data and VQA data
|
|
1258
|
+
response_ce_loss = response_ce_loss * ~response_is_pad[:, response_slice]
|
|
1140
1259
|
|
|
1141
1260
|
# compute mean
|
|
1142
|
-
|
|
1261
|
+
response_ce_loss = response_ce_loss.mean()
|
|
1143
1262
|
|
|
1144
|
-
return {
|
|
1263
|
+
return {
|
|
1264
|
+
"MSE": losses,
|
|
1265
|
+
"CE": (discrete_action_ce_loss + response_ce_loss),
|
|
1266
|
+
}
|
|
1145
1267
|
|
|
1146
1268
|
def sample_actions(
|
|
1147
1269
|
self,
|
|
@@ -1176,19 +1298,48 @@ class PI05FlowMatching(nn.Module):
|
|
|
1176
1298
|
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
|
1177
1299
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
|
1178
1300
|
|
|
1301
|
+
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] - 1
|
|
1302
|
+
|
|
1179
1303
|
num_cross_att_tokens = prefix_embs.shape[1]
|
|
1180
1304
|
|
|
1181
1305
|
# Compute image and language key value cache
|
|
1182
|
-
_, past_key_values = self.paligemma_with_expert.forward(
|
|
1306
|
+
(prefix_out, _), past_key_values = self.paligemma_with_expert.forward(
|
|
1183
1307
|
attention_mask=prefix_att_2d_masks,
|
|
1184
1308
|
position_ids=prefix_position_ids,
|
|
1185
1309
|
past_key_values=None,
|
|
1186
1310
|
inputs_embeds=[prefix_embs, None],
|
|
1187
1311
|
n_cross_att_tokens=num_cross_att_tokens,
|
|
1188
|
-
use_cache=
|
|
1312
|
+
use_cache=False,
|
|
1189
1313
|
fill_kv_cache=True,
|
|
1190
1314
|
)
|
|
1191
1315
|
|
|
1316
|
+
# initialize response tokens to empty tensor for storing response tokens during inference
|
|
1317
|
+
response_tokens = torch.empty((bsize, 0), device=device, dtype=torch.long)
|
|
1318
|
+
# if response prediction is enabled, then predict response tokens autoregressively
|
|
1319
|
+
if self.config.predict_response:
|
|
1320
|
+
for auto_step in range(self.config.response_max_length):
|
|
1321
|
+
(
|
|
1322
|
+
prefix_out,
|
|
1323
|
+
prefix_embs,
|
|
1324
|
+
prefix_pad_masks,
|
|
1325
|
+
prefix_att_masks,
|
|
1326
|
+
prefix_offsets,
|
|
1327
|
+
response_tokens,
|
|
1328
|
+
past_key_values,
|
|
1329
|
+
) = self.infer_response(
|
|
1330
|
+
prefix_out,
|
|
1331
|
+
prefix_embs,
|
|
1332
|
+
prefix_pad_masks,
|
|
1333
|
+
prefix_att_masks,
|
|
1334
|
+
past_key_values,
|
|
1335
|
+
prefix_offsets,
|
|
1336
|
+
response_tokens,
|
|
1337
|
+
auto_step,
|
|
1338
|
+
bsize,
|
|
1339
|
+
device,
|
|
1340
|
+
)
|
|
1341
|
+
|
|
1342
|
+
# perform denoising steps to get the action
|
|
1192
1343
|
dt = -1.0 / self.config.num_steps
|
|
1193
1344
|
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
|
1194
1345
|
|
|
@@ -1235,8 +1386,7 @@ class PI05FlowMatching(nn.Module):
|
|
|
1235
1386
|
n_cross_att_tokens=num_cross_att_tokens,
|
|
1236
1387
|
cross_att_pad_masks=prefix_pad_masks[:, :num_cross_att_tokens],
|
|
1237
1388
|
)
|
|
1238
|
-
|
|
1239
|
-
prefix_offsets = torch.sum(prefix_pad_masks[:, : -self.config.discrete_action_max_length], dim=-1)[
|
|
1389
|
+
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[
|
|
1240
1390
|
:, None
|
|
1241
1391
|
] # action expert position ids start after prefix
|
|
1242
1392
|
action_expert_position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
|
@@ -1255,3 +1405,119 @@ class PI05FlowMatching(nn.Module):
|
|
|
1255
1405
|
v_t = self.action_out_proj(suffix_out)
|
|
1256
1406
|
v_t = v_t.to(dtype=torch.float32)
|
|
1257
1407
|
return v_t
|
|
1408
|
+
|
|
1409
|
+
def infer_response(
|
|
1410
|
+
self,
|
|
1411
|
+
prefix_out: Tensor,
|
|
1412
|
+
prefix_embs: Tensor,
|
|
1413
|
+
prefix_pad_masks: Tensor,
|
|
1414
|
+
prefix_att_masks: Tensor,
|
|
1415
|
+
past_key_values: list[dict[str, Tensor]],
|
|
1416
|
+
prefix_offsets: Tensor,
|
|
1417
|
+
response_tokens: Tensor,
|
|
1418
|
+
auto_step: int,
|
|
1419
|
+
bsize: int,
|
|
1420
|
+
device: torch.device,
|
|
1421
|
+
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, list[dict[str, Tensor]], Tensor]:
|
|
1422
|
+
"""Perform autoregressive inference for response generation.
|
|
1423
|
+
|
|
1424
|
+
This method generates the next token in the response sequence, updating the
|
|
1425
|
+
various state tensors required for maintaining the generation context. It handles
|
|
1426
|
+
the initial BOS token generation as well as subsequent tokens, and manages
|
|
1427
|
+
padding masks to handle variable-length sequences properly.
|
|
1428
|
+
|
|
1429
|
+
Args:
|
|
1430
|
+
prefix_out: Output tensor from the previous step.
|
|
1431
|
+
prefix_embs: Embeddings for the current prefix context.
|
|
1432
|
+
prefix_pad_masks: Boolean mask indicating valid (non-padding) tokens in the prefix.
|
|
1433
|
+
prefix_att_masks: Attention mask for the prefix.
|
|
1434
|
+
past_key_values: KV cache from previous transformer steps.
|
|
1435
|
+
prefix_offsets: Position offsets for the current generation step.
|
|
1436
|
+
response_tokens: Accumulated tokens generated so far.
|
|
1437
|
+
auto_step: Current autoregressive step index (0 for first token).
|
|
1438
|
+
bsize: Batch size.
|
|
1439
|
+
device: Device to run the computation on.
|
|
1440
|
+
|
|
1441
|
+
Returns:
|
|
1442
|
+
A tuple containing updated tensors for the next step:
|
|
1443
|
+
(prefix_out, prefix_embs, prefix_pad_masks, prefix_att_masks,
|
|
1444
|
+
prefix_offsets, response_tokens, past_key_values, response_token)
|
|
1445
|
+
"""
|
|
1446
|
+
EOS_TOKEN = self.language_tokenizer.convert_tokens_to_ids(self.language_tokenizer.eos_token) # noqa: N806
|
|
1447
|
+
if auto_step == 0:
|
|
1448
|
+
# Start the autoregressive inference with <bos> token
|
|
1449
|
+
response_token = torch.full(
|
|
1450
|
+
(bsize, 1),
|
|
1451
|
+
self.language_tokenizer.bos_token_id,
|
|
1452
|
+
device=device,
|
|
1453
|
+
dtype=torch.long,
|
|
1454
|
+
)
|
|
1455
|
+
else:
|
|
1456
|
+
# get the last predicted token from the prefix output which is predicted response
|
|
1457
|
+
response_token = prefix_out[:, -1:]
|
|
1458
|
+
response_token = self.paligemma_with_expert.paligemma.lm_head(response_token).argmax(dim=-1)
|
|
1459
|
+
|
|
1460
|
+
PAD_TOKEN = self.language_tokenizer.pad_token_id # noqa: N806
|
|
1461
|
+
# Create pad masks: False if previous token was EOS or PAD
|
|
1462
|
+
if response_tokens.shape[1] > 1:
|
|
1463
|
+
prev_tokens = response_tokens
|
|
1464
|
+
has_eos = (prev_tokens == EOS_TOKEN).any(dim=1, keepdim=True)
|
|
1465
|
+
has_pad = (prev_tokens == PAD_TOKEN).any(dim=1, keepdim=True)
|
|
1466
|
+
# check if the previous token was EOS or PAD. If so, then the current token should be padded, so its not attended by flow matching action expert.
|
|
1467
|
+
response_pad_masks = ~(has_eos | has_pad)
|
|
1468
|
+
response_token = torch.where(
|
|
1469
|
+
response_pad_masks,
|
|
1470
|
+
response_token,
|
|
1471
|
+
torch.tensor(PAD_TOKEN, device=device, dtype=response_token.dtype),
|
|
1472
|
+
)
|
|
1473
|
+
else:
|
|
1474
|
+
response_pad_masks = torch.ones((bsize, 1), device=device, dtype=torch.bool)
|
|
1475
|
+
|
|
1476
|
+
# Updating response tokens with current predicted token
|
|
1477
|
+
response_tokens = torch.cat([response_tokens, response_token], dim=1)
|
|
1478
|
+
|
|
1479
|
+
# Embed the current predicted token
|
|
1480
|
+
response_emb = self.paligemma_with_expert.embed_language_tokens(response_token)
|
|
1481
|
+
|
|
1482
|
+
# Normalize response language embeddings
|
|
1483
|
+
response_emb_dim = response_emb.shape[-1]
|
|
1484
|
+
response_emb = response_emb * math.sqrt(response_emb_dim)
|
|
1485
|
+
|
|
1486
|
+
response_att_masks = torch.ones((bsize, 1), device=device, dtype=response_emb.dtype)
|
|
1487
|
+
|
|
1488
|
+
# update the prefix embs, pad masks and att masks, so it can be used by action experts
|
|
1489
|
+
prefix_embs = torch.cat([prefix_embs, response_emb], dim=1)
|
|
1490
|
+
prefix_pad_masks = torch.cat([prefix_pad_masks, response_pad_masks], dim=1)
|
|
1491
|
+
prefix_att_masks = torch.cat([prefix_att_masks, response_att_masks], dim=1)
|
|
1492
|
+
|
|
1493
|
+
num_cross_att_tokens = prefix_pad_masks.shape[1]
|
|
1494
|
+
# create the attention mask for the response tokens
|
|
1495
|
+
response_att_2d_masks = make_att_2d_masks(
|
|
1496
|
+
response_pad_masks,
|
|
1497
|
+
response_att_masks,
|
|
1498
|
+
n_cross_att_tokens=num_cross_att_tokens - 1,
|
|
1499
|
+
cross_att_pad_masks=prefix_pad_masks[:, : num_cross_att_tokens - 1],
|
|
1500
|
+
)
|
|
1501
|
+
prefix_offsets = prefix_offsets + response_pad_masks.long()
|
|
1502
|
+
prefix_position_ids = prefix_offsets
|
|
1503
|
+
|
|
1504
|
+
# Compute image and language key value cache
|
|
1505
|
+
(prefix_out, _), past_key_values = self.paligemma_with_expert.forward(
|
|
1506
|
+
attention_mask=response_att_2d_masks,
|
|
1507
|
+
position_ids=prefix_position_ids,
|
|
1508
|
+
past_key_values=past_key_values,
|
|
1509
|
+
inputs_embeds=[response_emb, None],
|
|
1510
|
+
n_cross_att_tokens=num_cross_att_tokens,
|
|
1511
|
+
use_cache=True,
|
|
1512
|
+
fill_kv_cache=True,
|
|
1513
|
+
)
|
|
1514
|
+
|
|
1515
|
+
return (
|
|
1516
|
+
prefix_out,
|
|
1517
|
+
prefix_embs,
|
|
1518
|
+
prefix_pad_masks,
|
|
1519
|
+
prefix_att_masks,
|
|
1520
|
+
prefix_offsets,
|
|
1521
|
+
response_tokens,
|
|
1522
|
+
past_key_values,
|
|
1523
|
+
)
|
|
@@ -260,6 +260,10 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
|
|
260
260
|
self.paligemma.eval()
|
|
261
261
|
for params in self.paligemma.parameters():
|
|
262
262
|
params.requires_grad = False
|
|
263
|
+
for param in self.da_head.parameters():
|
|
264
|
+
param.requires_grad = False
|
|
265
|
+
for param in self.discrete_action_embedding.parameters():
|
|
266
|
+
param.requires_grad = False
|
|
263
267
|
|
|
264
268
|
def train(self, mode: bool = True) -> None:
|
|
265
269
|
"""Sets the module in training mode.
|
|
@@ -416,27 +420,23 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
|
|
416
420
|
query_states = apply_rope(query_states, position_ids)
|
|
417
421
|
key_states = apply_rope(key_states, position_ids)
|
|
418
422
|
|
|
419
|
-
if use_cache and past_key_values is None:
|
|
420
|
-
past_key_values = {}
|
|
421
|
-
|
|
422
423
|
if use_cache:
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
#
|
|
436
|
-
key_states
|
|
437
|
-
value_states
|
|
438
|
-
|
|
439
|
-
)
|
|
424
|
+
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
|
425
|
+
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
|
426
|
+
# the max len, then we (for instance) double the cache size. This implementation already exists
|
|
427
|
+
# in `transformers`. (molbap)
|
|
428
|
+
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
|
429
|
+
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
|
|
430
|
+
if fill_kv_cache:
|
|
431
|
+
if past_key_values is None:
|
|
432
|
+
past_key_values = {}
|
|
433
|
+
if n_cross_att_tokens is None:
|
|
434
|
+
raise ValueError("n_cross_att_tokens must be provided when fill_kv_cache is True")
|
|
435
|
+
past_key_values[layer_idx] = {
|
|
436
|
+
# save the first n_cross_att_tokens for action expert cross attention
|
|
437
|
+
"key_states": key_states[:, :n_cross_att_tokens, :, :],
|
|
438
|
+
"value_states": value_states[:, :n_cross_att_tokens, :, :],
|
|
439
|
+
}
|
|
440
440
|
|
|
441
441
|
attention_interface = self.get_attention_interface()
|
|
442
442
|
att_output = attention_interface(
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""gRPC server and client for remote robot policy inference.
|
|
16
|
+
|
|
17
|
+
This module provides gRPC-based communication between a robot running ROS 2
|
|
18
|
+
and a remote server running ML model inference on GPU.
|
|
19
|
+
"""
|