opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,180 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from pathlib import Path
17
+
18
+ import torch
19
+
20
+ from opentau.configs import parser
21
+ from opentau.configs.train import TrainPipelineConfig
22
+ from opentau.policies.factory import get_policy_class
23
+ from opentau.policies.pi0.modeling_pi0 import PI0Policy
24
+ from opentau.policies.pi05.modeling_pi05 import PI05Policy
25
+ from opentau.utils.monkey_patch import (
26
+ torch_cumsum_patch,
27
+ torch_full_patch,
28
+ torch_pow_patch,
29
+ )
30
+ from opentau.utils.utils import auto_torch_device
31
+
32
+ # Some patches are necessary only for dynamo export, which has current upstream bugs.
33
+ # Nonetheless, we apply them here to ensure future compatibility.
34
+ patches = [
35
+ torch_cumsum_patch, # This is always necessary to load the ONNX artifact without error.
36
+ torch_full_patch,
37
+ torch_pow_patch,
38
+ ]
39
+
40
+ KEY_STATES = "key_states"
41
+ VALUE_STATES = "value_states"
42
+
43
+
44
+ class InferenceWrapper(torch.nn.Module):
45
+ r"""Helper class to wrap the robot action decoder for ONNX export,
46
+ such that each input tensor is an individual argument to the `forward` method.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ decoder: PI0Policy | PI05Policy,
52
+ *,
53
+ prefix_pad_masks: torch.Tensor,
54
+ prefix_offsets: torch.Tensor,
55
+ num_cross_att_tokens: int,
56
+ layer_idx: int,
57
+ ):
58
+ super().__init__()
59
+ self.decoder = decoder
60
+ self.prefix_pad_masks = prefix_pad_masks
61
+ self.prefix_offsets = prefix_offsets
62
+ self.num_cross_att_tokens = num_cross_att_tokens
63
+ self.layer_idx = layer_idx
64
+
65
+ def forward(self, key_states, value_states, state):
66
+ vlm_tokens = (
67
+ {
68
+ self.layer_idx: {
69
+ KEY_STATES: key_states,
70
+ VALUE_STATES: value_states,
71
+ },
72
+ },
73
+ self.prefix_pad_masks,
74
+ self.prefix_offsets,
75
+ self.num_cross_att_tokens,
76
+ )
77
+ observation = {
78
+ "state": state,
79
+ }
80
+ actions = self.decoder.sample_actions(
81
+ observation,
82
+ vlm_token_cache_override=vlm_tokens,
83
+ )
84
+ return actions
85
+
86
+
87
+ # Get the VLM cache for the dummy observation. This guarantees consistency with post-loading usage.
88
+ def get_vlm_cache(cfg: TrainPipelineConfig, device: torch.device, dtype: torch.dtype):
89
+ logging.info("Getting VLM cache...")
90
+ policy_class = get_policy_class(cfg.policy.type)
91
+ cloud_vlm = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
92
+ cloud_vlm.set_execution_target("cloud")
93
+ cloud_vlm.to(device=device, dtype=torch.bfloat16)
94
+ cloud_vlm.eval()
95
+
96
+ vlm_camera_observation = {
97
+ f"camera{i}": torch.zeros((1, 3, *cfg.resolution), dtype=torch.bfloat16, device=device)
98
+ for i in range(cfg.num_cams)
99
+ }
100
+ vlm_observation = {
101
+ **vlm_camera_observation,
102
+ "prompt": ["Pick up yellow lego block and put it in the bin"],
103
+ "state": torch.zeros((1, cfg.max_state_dim), dtype=torch.bfloat16, device=device),
104
+ "img_is_pad": torch.zeros((1, cfg.num_cams), dtype=torch.bool, device=device),
105
+ }
106
+ cache, prefix_pad_masks, prefix_offsets, num_cross_att_tokens = cloud_vlm.get_vlm_tokens(vlm_observation)
107
+ assert len(cache) == 1, f"Expected only one cache entry for the dummy observation. Got {len(cache)}."
108
+ idx = list(cache)[0]
109
+ return (
110
+ cache[idx][KEY_STATES].to(dtype=dtype),
111
+ cache[idx][VALUE_STATES].to(dtype=dtype),
112
+ prefix_pad_masks,
113
+ prefix_offsets,
114
+ num_cross_att_tokens,
115
+ idx,
116
+ )
117
+
118
+
119
+ @parser.wrap()
120
+ def main(cfg: TrainPipelineConfig):
121
+ device = auto_torch_device()
122
+ dtype = torch.float32
123
+
124
+ # arguments for the dummy observation
125
+ (
126
+ key_states,
127
+ value_states,
128
+ prefix_pad_masks,
129
+ prefix_offsets,
130
+ num_cross_att_tokens,
131
+ layer_idx,
132
+ ) = get_vlm_cache(cfg, device, dtype)
133
+ state = torch.zeros((1, cfg.max_state_dim), device=device, dtype=dtype)
134
+ args = (key_states, value_states, state)
135
+ logging.info("Generated example args")
136
+
137
+ policy_class = get_policy_class(cfg.policy.type)
138
+ robot_action_decoder = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
139
+ robot_action_decoder.set_execution_target("robot")
140
+ robot_action_decoder.to(device)
141
+ robot_action_decoder.to(dtype=dtype)
142
+ robot_action_decoder.eval()
143
+ inference_wrapper = InferenceWrapper(
144
+ robot_action_decoder,
145
+ prefix_pad_masks=prefix_pad_masks,
146
+ prefix_offsets=prefix_offsets,
147
+ num_cross_att_tokens=num_cross_att_tokens,
148
+ layer_idx=layer_idx,
149
+ )
150
+ logging.info("Loaded policy")
151
+
152
+ logging.info("Applying monkey patches...")
153
+ for patch in patches:
154
+ patch()
155
+
156
+ logging.info("Exporting model to ONNX...")
157
+ with torch.inference_mode():
158
+ path = Path(cfg.policy.pretrained_path) / "robot_action_decoder.onnx"
159
+ path = path.resolve()
160
+ path.parent.mkdir(parents=True, exist_ok=True) # Should be a no-op
161
+ print("Exporting model to ONNX at path:", path)
162
+ print("Current directory:", Path.cwd())
163
+ print("Trying to write to:", path)
164
+ with open(path, "wb"):
165
+ print("Write permissions check passed for:", path)
166
+ print("Running torch.onnx.export...")
167
+ torch.onnx.export(
168
+ inference_wrapper.eval(),
169
+ args,
170
+ path,
171
+ input_names=[KEY_STATES, VALUE_STATES, "state"],
172
+ output_names=["action_chunk"],
173
+ opset_version=18,
174
+ do_constant_folding=False, # constant folding causes weird errors (getting dim -1 from a 0-dim scalar) after forward pass succeeds
175
+ )
176
+ logging.info(f"Successfully exported model to '{path}'.")
177
+
178
+
179
+ if __name__ == "__main__":
180
+ main()
@@ -0,0 +1,87 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+
17
+ import torch
18
+
19
+ from opentau.configs import parser
20
+ from opentau.utils.fake_tensor import FakeTensorContext
21
+ from opentau.utils.utils import auto_torch_device
22
+
23
+
24
+ @dataclass
25
+ class Args:
26
+ device: str | None = None
27
+ num_workers: int = 4
28
+ batch_size: int = 2
29
+ dim_in: int = 3
30
+ dim_out: int = 5
31
+ large_hidden_dim: int = 10**9
32
+
33
+
34
+ @parser.wrap()
35
+ def main(args: Args):
36
+ device = args.device or auto_torch_device()
37
+ data = [(torch.rand(args.dim_in), torch.rand(args.dim_out)) for _ in range(10)]
38
+ dataloader = torch.utils.data.DataLoader(
39
+ data,
40
+ num_workers=args.num_workers,
41
+ batch_size=args.batch_size,
42
+ )
43
+
44
+ # ---------------------------------------------------------------
45
+ # Everything inside this context will use FakeTensorMode
46
+ with FakeTensorContext():
47
+ # Create a model in FakeTensorContext shouldn't cost real memory for model parameters.
48
+ model = torch.nn.Sequential(
49
+ torch.nn.Linear(args.dim_in, args.large_hidden_dim),
50
+ torch.nn.Linear(args.large_hidden_dim, args.large_hidden_dim),
51
+ torch.nn.Linear(args.large_hidden_dim, args.dim_out),
52
+ ).to(device)
53
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
54
+ # End of FakeTensorContext
55
+ # ---------------------------------------------------------------
56
+
57
+ print("Model parameters: ")
58
+ for name, param in model.named_parameters():
59
+ print(f"{name}: {param})")
60
+ print()
61
+
62
+ losses = []
63
+ for i, (x, y) in enumerate(dataloader):
64
+ # ---------------------------------------------------------------
65
+ # Everything inside this context will use FakeTensorMode
66
+ # Ideally, we want to iterate the dataloader in FakeTensorMode as well.
67
+ # However, it does not work with multiple workers due to some serialization issue.
68
+ with FakeTensorContext():
69
+ x = x.to(device=device)
70
+ y = y.to(device=device)
71
+ optimizer.zero_grad()
72
+ loss = torch.nn.functional.mse_loss(model(x), y)
73
+ loss.backward()
74
+
75
+ optimizer.step()
76
+ losses.append(loss.item())
77
+ print(
78
+ f"Step {i}: symbolic loss = {loss.item()}, dummy numpy array = {loss.detach().cpu().numpy()}"
79
+ )
80
+ # End of FakeTensorContext
81
+ # ---------------------------------------------------------------
82
+
83
+ print("\nSymbolic mean loss:", sum(losses) / len(losses))
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main()
@@ -0,0 +1,220 @@
1
+ r"python src/opentau/scripts/get_advantage_and_percentiles.py \
2
+ --config_path=outputs/train/2025-11-29/00-38-59_value/checkpoints/00520000 \
3
+ --batch_size=20 \
4
+ --dataloader_batch_size=20 \
5
+ --dataset_mixture=examples/advantage_config.json"
6
+
7
+ #!/usr/bin/env python
8
+
9
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
10
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
11
+ #
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+ import json
24
+ import logging
25
+ import sys
26
+ from collections import defaultdict
27
+ from pathlib import Path
28
+
29
+ import draccus
30
+ import numpy as np
31
+ import torch
32
+ from torch.utils.data import DataLoader
33
+
34
+ from opentau.configs import parser
35
+ from opentau.configs.default import DatasetMixtureConfig
36
+ from opentau.configs.train import TrainPipelineConfig
37
+ from opentau.datasets.factory import make_dataset
38
+ from opentau.datasets.lerobot_dataset import LeRobotDataset
39
+ from opentau.datasets.utils import ADVANTAGES_PATH
40
+ from opentau.policies.factory import get_policy_class
41
+ from opentau.policies.value.reward import calculate_n_step_return
42
+ from opentau.utils.random_utils import set_seed
43
+ from opentau.utils.utils import (
44
+ auto_torch_device,
45
+ init_logging,
46
+ )
47
+
48
+
49
+ def ensure_primitive(maybe_tensor):
50
+ if isinstance(maybe_tensor, np.ndarray):
51
+ return ensure_primitive(torch.from_numpy(maybe_tensor))
52
+ if isinstance(maybe_tensor, torch.Tensor):
53
+ assert maybe_tensor.numel() == 1, f"Tensor must be a single value, got shape={maybe_tensor.numel()}"
54
+ return maybe_tensor
55
+
56
+
57
+ _default0 = defaultdict(int)
58
+
59
+ # Store dataset_mixture_path before filtering (needed for parsing inside main)
60
+ # Handle both --dataset_mixture_path=<path> and --dataset_mixture=<path> (without nested fields)
61
+ _dataset_mixture_path_value = None
62
+ for arg in sys.argv:
63
+ if arg.startswith("--dataset_mixture_path="):
64
+ _dataset_mixture_path_value = arg.split("=", 1)[1]
65
+ break
66
+ elif arg.startswith("--dataset_mixture=") and "." not in arg.split("=", 1)[0]:
67
+ # --dataset_mixture=<path> without nested fields (e.g., not --dataset_mixture.datasets.0.repo_id=...)
68
+ _dataset_mixture_path_value = arg.split("=", 1)[1]
69
+ break
70
+
71
+ # Create a wrapper that filters dataset_mixture path arguments before draccus parsing
72
+ _original_wrap = parser.wrap()
73
+
74
+
75
+ def _filter_dataset_mixture_path(fn):
76
+ """Wrapper that filters --dataset_mixture_path and --dataset_mixture=<path> from sys.argv before draccus sees it."""
77
+ wrapped_fn = _original_wrap(fn)
78
+
79
+ def filtered_wrapper(*args, **kwargs):
80
+ # If config is already provided, just call the function
81
+ if len(args) > 0:
82
+ return wrapped_fn(*args, **kwargs)
83
+
84
+ # Otherwise, filter dataset_mixture path arguments from sys.argv before draccus parses
85
+ original_argv = sys.argv.copy()
86
+ try:
87
+ filtered_args = []
88
+ for arg in sys.argv:
89
+ # Filter --dataset_mixture_path=<path>
90
+ if (
91
+ arg.startswith("--dataset_mixture_path=")
92
+ or arg.startswith("--dataset_mixture=")
93
+ and "." not in arg.split("=", 1)[0]
94
+ ):
95
+ continue
96
+ else:
97
+ filtered_args.append(arg)
98
+ sys.argv = filtered_args
99
+ return wrapped_fn(*args, **kwargs)
100
+ finally:
101
+ sys.argv = original_argv
102
+
103
+ return filtered_wrapper
104
+
105
+
106
+ @_filter_dataset_mixture_path
107
+ def main(cfg: TrainPipelineConfig):
108
+ dataset_mixture_path = _dataset_mixture_path_value
109
+
110
+ if dataset_mixture_path:
111
+ logging.info(f"Loading dataset config from separate file: {dataset_mixture_path}")
112
+ mixture_cfg = draccus.parse(
113
+ config_class=DatasetMixtureConfig, config_path=dataset_mixture_path, args=[]
114
+ )
115
+ else:
116
+ logging.info("Using the dataset mixture config from the TrainPipelineConfig")
117
+ mixture_cfg = cfg.dataset_mixture
118
+
119
+ device = auto_torch_device()
120
+ # torch.autograd.set_detect_anomaly(True)
121
+
122
+ # TODO(shuheng): Do we need the random seed here?
123
+ if cfg.seed is not None:
124
+ set_seed(cfg.seed)
125
+
126
+ logging.info("Creating policy")
127
+ policy_class = get_policy_class(cfg.policy.type)
128
+ policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
129
+ policy.to(device=device, dtype=torch.bfloat16)
130
+ policy.eval()
131
+
132
+ # Advantages across all datasets
133
+ advantages = []
134
+
135
+ for dataset_idx, dataset_cfg in enumerate(mixture_cfg.datasets):
136
+ logging.info(f"Creating dataset {dataset_idx}")
137
+ dataset = make_dataset(dataset_cfg, cfg, return_advantage_input=True)
138
+ assert isinstance(dataset, LeRobotDataset), (
139
+ f"Expected instance of LeRobotDataset, got {type(dataset)}"
140
+ )
141
+ dataloader = DataLoader(
142
+ dataset,
143
+ batch_size=cfg.batch_size,
144
+ shuffle=False,
145
+ drop_last=False,
146
+ num_workers=cfg.num_workers,
147
+ pin_memory=torch.cuda.is_available(),
148
+ prefetch_factor=cfg.prefetch_factor,
149
+ )
150
+
151
+ values = {}
152
+ ds_advantage = {} # per-dataset advantages
153
+ with torch.inference_mode():
154
+ # First pass to get the values
155
+ for batch in dataloader:
156
+ for key, value in batch.items():
157
+ if isinstance(value, torch.Tensor):
158
+ batch[key] = value.to(device)
159
+
160
+ for success, episode_index, episode_end_idx, current_idx, v0 in zip(
161
+ batch["success"],
162
+ batch["episode_index"],
163
+ batch["episode_end_idx"],
164
+ batch["current_idx"],
165
+ policy.predict_value(batch),
166
+ strict=True,
167
+ ):
168
+ success, episode_index, episode_end_idx, current_idx, v0 = map(
169
+ ensure_primitive, (success, episode_index, episode_end_idx, current_idx, v0)
170
+ )
171
+ reward = calculate_n_step_return(
172
+ success=success,
173
+ n_steps_look_ahead=cfg.policy.reward_config.N_steps_look_ahead,
174
+ episode_end_idx=episode_end_idx,
175
+ max_episode_length=cfg.policy.reward_config.reward_normalizer,
176
+ current_idx=current_idx,
177
+ c_neg=cfg.policy.reward_config.C_neg,
178
+ )
179
+
180
+ values[(episode_index, current_idx)] = {"v0": v0, "reward": reward}
181
+
182
+ # Second pass to compute the advantages
183
+ for batch in dataloader:
184
+ for episode_index, current_idx, timestamp in zip(
185
+ batch["episode_index"],
186
+ batch["current_idx"],
187
+ batch["timestamp"],
188
+ strict=True,
189
+ ):
190
+ episode_index, current_idx, timestamp = map(
191
+ ensure_primitive, (episode_index, current_idx, timestamp)
192
+ )
193
+ # check if the value for the next n_steps_look_ahead steps is available, else set it to 0
194
+ look_ahead_idx = current_idx + cfg.policy.reward_config.N_steps_look_ahead
195
+ vn = values.get((episode_index, look_ahead_idx), _default0)["v0"]
196
+ reward = values.get((episode_index, current_idx), _default0)["reward"]
197
+ v0 = values.get((episode_index, current_idx), _default0)["v0"]
198
+ advantage = ensure_primitive(reward + vn - v0)
199
+ advantages.append(advantage)
200
+ ds_advantage[(episode_index, timestamp)] = advantage
201
+
202
+ # Convert tuple keys to strings for JSON serialization
203
+ advantage_data_json = {f"{ep_idx},{ts}": val for (ep_idx, ts), val in ds_advantage.items()}
204
+
205
+ # TODO(shuheng) avoid overwriting existing advantage files.
206
+ with open(Path(dataset.root) / ADVANTAGES_PATH, "w") as f:
207
+ json.dump(advantage_data_json, f, indent=4)
208
+
209
+ # Calculate percentiles of advantages: 0th, 5th, 10th, ..., 100th
210
+ percentiles = list(range(0, 101, 5)) # [0, 5, 10, 15, ..., 100]
211
+ advantage_percentiles = np.percentile(np.array(advantages), percentiles)
212
+
213
+ print("Advantage percentiles for deciding epsilon threshold:")
214
+ for p, val in zip(percentiles, advantage_percentiles, strict=False):
215
+ print(f" {p:03d}th percentile: {val:.6f}")
216
+
217
+
218
+ if __name__ == "__main__":
219
+ init_logging()
220
+ main()
@@ -0,0 +1,114 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ #!/usr/bin/env python
16
+
17
+ import logging
18
+ from dataclasses import asdict
19
+ from pprint import pformat
20
+
21
+ import torch
22
+ from dotenv import load_dotenv
23
+
24
+ from opentau.configs import parser
25
+ from opentau.configs.train import TrainPipelineConfig
26
+ from opentau.planner import HighLevelPlanner, Memory
27
+ from opentau.policies.factory import get_policy_class
28
+ from opentau.utils.random_utils import set_seed
29
+ from opentau.utils.utils import (
30
+ init_logging,
31
+ )
32
+
33
+ load_dotenv()
34
+
35
+
36
+ @parser.wrap()
37
+ def inference_main(cfg: TrainPipelineConfig):
38
+ """
39
+ Reflects the whole pipeline from passing tasks to high level planner to generating actions from low level planner
40
+
41
+ Args:
42
+ cfg: configuration file. For example look at examples/dev_config.json
43
+ """
44
+
45
+ logging.info(pformat(asdict(cfg)))
46
+
47
+ # Check device is available
48
+ device = "cuda" if torch.cuda.is_available() else "cpu"
49
+
50
+ if cfg.seed is not None:
51
+ set_seed(cfg.seed)
52
+
53
+ logging.info("Creating policy")
54
+ policy_class = get_policy_class(cfg.policy.type)
55
+ policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
56
+ policy.to(device)
57
+ policy.to(dtype=torch.bfloat16)
58
+ policy.eval()
59
+
60
+ hlp = HighLevelPlanner()
61
+ mem = Memory(len=2)
62
+
63
+ # compile the model if possible
64
+ if hasattr(torch, "compile"):
65
+ logging.info("Attempting to compile the policy with torch.compile()...")
66
+ try:
67
+ # Other options: "default", "max-autotune" (longer compile time)
68
+ policy = torch.compile(policy)
69
+ logging.info("Policy compiled successfully.")
70
+ except Exception as e:
71
+ logging.warning(f"torch.compile failed with error: {e}. Proceeding without compilation.")
72
+ else:
73
+ logging.warning(
74
+ "torch.compile is not available. Requires PyTorch 2.0+. Proceeding without compilation."
75
+ )
76
+
77
+ # Always reset policy before episode to clear out action cache.
78
+ policy.reset()
79
+
80
+ for i in range(5):
81
+ # create dummy observation for pi05
82
+ camera_observations = {
83
+ f"camera{i}": torch.zeros((1, 3, *cfg.resolution), dtype=torch.bfloat16, device=device)
84
+ for i in range(cfg.num_cams)
85
+ }
86
+
87
+ task = "Pick up yellow lego block and put it in the bin"
88
+
89
+ sub_task = hlp.inference(camera_observations, "", task, "gpt4o", mem).split("\n")[1].split('"')[1]
90
+
91
+ if mem:
92
+ mem.add_conversation("assistant", [{"type": "text", "text": sub_task}])
93
+
94
+ logging.info(f"{sub_task}")
95
+ observation = {
96
+ **camera_observations,
97
+ "state": torch.zeros((1, cfg.max_state_dim), dtype=torch.bfloat16, device=device),
98
+ "prompt": [sub_task],
99
+ "img_is_pad": torch.zeros((1, cfg.num_cams), dtype=torch.bool, device=device),
100
+ "action_is_pad": torch.zeros((1, cfg.action_chunk), dtype=torch.bool, device=device),
101
+ }
102
+
103
+ with torch.inference_mode():
104
+ for _ in range(1000):
105
+ action = policy.select_action(observation)
106
+ action = action.to("cpu").numpy()
107
+ print(f"Output dummy action: {action}")
108
+
109
+ logging.info("End of inference")
110
+
111
+
112
+ if __name__ == "__main__":
113
+ init_logging()
114
+ inference_main()
@@ -0,0 +1,70 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ #!/usr/bin/env python
16
+
17
+ import logging
18
+ from dataclasses import asdict
19
+ from pprint import pformat
20
+
21
+ import torch
22
+
23
+ from opentau.configs import parser
24
+ from opentau.configs.train import TrainPipelineConfig
25
+ from opentau.policies.factory import get_policy_class
26
+ from opentau.utils.random_utils import set_seed
27
+ from opentau.utils.utils import (
28
+ attempt_torch_compile,
29
+ auto_torch_device,
30
+ create_dummy_observation,
31
+ init_logging,
32
+ )
33
+
34
+
35
+ @parser.wrap()
36
+ def inference_main(cfg: TrainPipelineConfig):
37
+ logging.info(pformat(asdict(cfg)))
38
+
39
+ # Check device is available
40
+ device = auto_torch_device()
41
+
42
+ if cfg.seed is not None:
43
+ set_seed(cfg.seed)
44
+
45
+ logging.info("Creating policy")
46
+ policy_class = get_policy_class(cfg.policy.type)
47
+ policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
48
+ policy.to(device=device, dtype=torch.bfloat16)
49
+ policy.eval()
50
+ policy = attempt_torch_compile(policy, device_hint=device)
51
+
52
+ # Always reset policy before episode to clear out action cache.
53
+ policy.reset()
54
+
55
+ observation = create_dummy_observation(cfg, device, dtype=torch.bfloat16)
56
+
57
+ print(observation.keys())
58
+
59
+ with torch.inference_mode():
60
+ for _ in range(1000):
61
+ action = policy.select_action(observation)
62
+ action = action.to("cpu", torch.float32).numpy()
63
+ print(f"Output shape: {action.shape}")
64
+
65
+ logging.info("End of inference")
66
+
67
+
68
+ if __name__ == "__main__":
69
+ init_logging()
70
+ inference_main()