opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,366 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from typing import Optional
17
+
18
+ import torch
19
+ from openai import OpenAI
20
+ from PIL import Image
21
+ from transformers import (
22
+ AutoModelForCausalLM,
23
+ AutoModelForImageTextToText,
24
+ AutoModelForVision2Seq,
25
+ AutoProcessor,
26
+ LlamaTokenizer,
27
+ )
28
+
29
+ from opentau.planner.utils.memory import Memory
30
+ from opentau.planner.utils.utils import load_prompt_library, tensor_to_base64
31
+
32
+
33
+ class BaseHighLevelPlanner(ABC):
34
+ """
35
+ Represents a High level planner which has ability to infer various open source models and closed models like gpt-4o.
36
+ Generates a list of low level plans given a high level plan
37
+ """
38
+
39
+ def __init__(self):
40
+ self.prompts_dict = load_prompt_library("src/opentau/prompts/planner/prompts.yaml")
41
+
42
+ @abstractmethod
43
+ def inference(
44
+ self, image_dict: dict[str, torch.Tensor], model_name: str, task: str, mem: Optional[Memory] = None
45
+ ) -> str:
46
+ """
47
+ Handles inferencing the planner given images and language inputs
48
+ """
49
+
50
+ pass
51
+
52
+ def calculate_usage(self, response) -> float:
53
+ """
54
+ Calculates cost for each call to gpt-4o
55
+
56
+ Args:
57
+
58
+ response : a response object from gpt chat compeletion method
59
+
60
+ Returns:
61
+
62
+ cost (float) : cost for one call
63
+ """
64
+
65
+ prompt_tokens = response.usage.prompt_tokens
66
+ completion_tokens = response.usage.completion_tokens
67
+
68
+ cost = (prompt_tokens / 1000) * 0.0025 + (completion_tokens / 1000) * 0.01
69
+
70
+ return cost
71
+
72
+
73
+ class HighLevelPlanner(BaseHighLevelPlanner):
74
+ """
75
+ Represents a High level planner which has ability to infer various open source models and closed models like gpt-4o.
76
+ Generates a list of low level plans given a high level plan
77
+ """
78
+
79
+ def __init__(self):
80
+ super().__init__()
81
+
82
+ def model_and_tokenizer(self, model_name: str, device: str):
83
+ if model_name == "cogvlm-chat-hf":
84
+ processor = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
85
+ model = AutoModelForCausalLM.from_pretrained(
86
+ "THUDM/cogvlm-chat-hf",
87
+ torch_dtype=torch.float16,
88
+ low_cpu_mem_usage=True,
89
+ trust_remote_code=True,
90
+ ).eval()
91
+ elif model_name == "SmolVLM-256M-Instruct":
92
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct")
93
+ model = AutoModelForVision2Seq.from_pretrained(
94
+ "HuggingFaceTB/SmolVLM-256M-Instruct", torch_dtype=torch.float16
95
+ ).to(device)
96
+ elif model_name == "SmolVLM-500M-Instruct":
97
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct")
98
+ model = AutoModelForVision2Seq.from_pretrained(
99
+ "HuggingFaceTB/SmolVLM-500M-Instruct", torch_dtype=torch.float16
100
+ ).to(device)
101
+ elif model_name == "SmolVLM2-2.2B-Instruct":
102
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
103
+ model = AutoModelForImageTextToText.from_pretrained(
104
+ "HuggingFaceTB/SmolVLM2-2.2B-Instruct", torch_dtype=torch.float32
105
+ ).to(device)
106
+ else:
107
+ raise RuntimeError(f"The specified model {model_name} is not supported")
108
+
109
+ return model, processor
110
+
111
+ def generate_prompt(self, task: str, mem: Memory | None) -> str:
112
+ """
113
+ Generates prompt for gpt-4o model based on memory.
114
+
115
+ Args:
116
+
117
+ task (str): a high level command as a language
118
+ mem (Memory|None): instance of Memory class from utils file, which stores all the previous conservations of relevant task in hand. Its set to None if memory is no required while inferencing.
119
+
120
+ Returns:
121
+
122
+ prompt (str): a prompt
123
+ """
124
+
125
+ if mem is None:
126
+ prompt = (
127
+ f"Look at the image <image>. The task is {task}."
128
+ + "\n"
129
+ + self.prompts_dict["prompts"]["robot_user_without_memory_manipulation"]["template"]
130
+ )
131
+
132
+ else:
133
+ prompt = (
134
+ f"Look at the image <image>. The task is {task}."
135
+ + "\n"
136
+ + self.prompts_dict["prompts"]["robot_user_with_memory_manipulation"]["template"]
137
+ )
138
+
139
+ return prompt
140
+
141
+ def gpt_inference(self, image_dict: dict[str, torch.Tensor], task: str, mem: Memory | None) -> str:
142
+ """
143
+ Calls openai Api and passes high level plan and memory
144
+
145
+ Args:
146
+
147
+ image_dict(dict[str , torch.Tensor]) : dict of tensors of images in base64 format
148
+ task (str): a high level command as a language
149
+ mem (Memory|None): instance of Memory class from utils file, which stores all the previous conservations of relevant task in hand. Its set to None if memory is no required while inferencing.
150
+
151
+ Returns:
152
+
153
+ response (str): a low level language command that can be understood by low level planner
154
+ """
155
+
156
+ images = tensor_to_base64(image_dict, "mani")
157
+
158
+ client = OpenAI()
159
+
160
+ prompt = self.generate_prompt(task, mem)
161
+
162
+ content = [
163
+ {"type": "text", "text": f"{prompt}"},
164
+ ]
165
+
166
+ for image_base64 in images:
167
+ content.append(
168
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
169
+ )
170
+
171
+ if mem is None:
172
+ message = [
173
+ {
174
+ "role": "system",
175
+ "content": [
176
+ {
177
+ "type": "text",
178
+ "text": f"{self.prompts_dict['prompts']['robot_system_manipulation']['template']}",
179
+ }
180
+ ],
181
+ },
182
+ {
183
+ "role": "user",
184
+ "content": content,
185
+ },
186
+ ]
187
+ else:
188
+ mem.add_conversation("user", content)
189
+ message = mem.get_conversation()
190
+
191
+ response = client.chat.completions.create(
192
+ model="gpt-4o",
193
+ messages=message,
194
+ max_tokens=500,
195
+ temperature=0.0,
196
+ tool_choice=None,
197
+ )
198
+
199
+ res = response.choices[0].message.content
200
+
201
+ return res
202
+
203
+ def opensource_inference(self, image_path, task, model_name):
204
+ device = "cuda" if torch.cuda.is_available() else "cpu"
205
+
206
+ model, processor = self.model_and_tokenizer(model_name, device)
207
+
208
+ prompt = self.generate_prompt(task)
209
+
210
+ image1 = Image.open(image_path).convert("RGB")
211
+
212
+ inputs = processor(text=prompt, images=[image1], return_tensors="pt")
213
+ inputs = inputs.to(device)
214
+
215
+ # Generate outputs
216
+ generated_ids = model.generate(**inputs, max_new_tokens=500)
217
+ generated_texts = processor.batch_decode(
218
+ generated_ids,
219
+ skip_special_tokens=True,
220
+ )
221
+
222
+ return generated_texts[0]
223
+
224
+ def inference(
225
+ self, image_dict: dict[str, torch.Tensor], model_name: str, task: str, mem: Optional[Memory] = None
226
+ ) -> str:
227
+ """
228
+ Handles calling of open source models and gpt-4o models
229
+
230
+ Args:
231
+
232
+ image_dict (dict[str , torch.Tensor]) : dict of tensors of images in base64 format
233
+ task (str): a high level command as a language
234
+ mem (Memory|None): instance of Memory class from utils file, which stores all the previous conservations of relevant task in hand. Its set to None if memory is no required while inferencing.
235
+ model_name (str) : Name of open source model to be inferenced. To use gpt models, pass a gpt4o
236
+ image_path (str) : Path to image file for opensource models.
237
+
238
+ Returns:
239
+
240
+ response (str): a low level language command that can be understood by low level planner
241
+ """
242
+
243
+ if model_name == "gpt4o":
244
+ actions = self.gpt_inference(image_dict, task, mem)
245
+ actions = actions.split("```json")[1].split("[")[1].split("]")[0]
246
+ else:
247
+ actions = self.opensource_inference(model_name, task)
248
+
249
+ return actions
250
+
251
+
252
+ class NavHighLevelPlanner(BaseHighLevelPlanner):
253
+ """
254
+ Represents a High level planner which has ability to infer various open source models and closed models like gpt-4o.
255
+ Generates a list of low level plans given a high level plan
256
+ """
257
+
258
+ def __init__(self):
259
+ super().__init__()
260
+
261
+ def generate_prompt(self, task: str, mem: Memory | None) -> str:
262
+ """
263
+ Generates prompt for gpt-4o model based on memory.
264
+
265
+ Args:
266
+
267
+ task (str): a high level command as a language
268
+ mem (Memory|None): instance of Memory class from utils file, which stores all the previous conservations of relevant task in hand. Its set to None if memory is no required while inferencing.
269
+
270
+ Returns:
271
+
272
+ prompt (str): a prompt
273
+ """
274
+
275
+ prompt = (
276
+ self.prompts_dict["prompts"]["robot_user_navigation"]["template"]
277
+ + f"Look at the given images {' '.join(['<image>'] * 21)} from starting point. The task is {task}."
278
+ )
279
+
280
+ return prompt
281
+
282
+ def gpt_inference(self, image_dict: dict[str, torch.Tensor], task: str, mem: Memory | None) -> str:
283
+ """
284
+ Calls openai Api and passes high level plan and memory
285
+
286
+ Args:
287
+
288
+ image_dict(dict[str , torch.Tensor]) : dict of tensors of images in base64 format
289
+ task (str): a high level command as a language
290
+ mem (Memory|None): instance of Memory class from utils file, which stores all the previous conservations of relevant task in hand. Its set to None if memory is no required while inferencing.
291
+
292
+ Returns:
293
+
294
+ response (str): a low level language command that can be understood by low level planner
295
+ """
296
+
297
+ images = tensor_to_base64(image_dict, "nav")
298
+
299
+ client = OpenAI()
300
+
301
+ prompt = self.generate_prompt(task, mem)
302
+
303
+ content = [
304
+ {"type": "text", "text": f"{prompt}"},
305
+ ]
306
+
307
+ for image_base64 in images:
308
+ content.append(
309
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
310
+ )
311
+
312
+ if mem is None:
313
+ message = [
314
+ {
315
+ "role": "system",
316
+ "content": [
317
+ {
318
+ "type": "text",
319
+ "text": f"{self.prompts_dict['prompts']['robot_system_navigation']['template']}",
320
+ }
321
+ ],
322
+ },
323
+ {
324
+ "role": "user",
325
+ "content": content,
326
+ },
327
+ ]
328
+ else:
329
+ mem.add_conversation("user", content)
330
+ message = mem.get_conversation()
331
+
332
+ response = client.chat.completions.create(
333
+ model="gpt-4o",
334
+ messages=message,
335
+ max_tokens=500,
336
+ temperature=0.0,
337
+ tool_choice=None,
338
+ )
339
+
340
+ res = response.choices[0].message.content
341
+
342
+ return res
343
+
344
+ def inference(
345
+ self, image_dict: dict[str, torch.Tensor], model_name: str, task: str, mem: Optional[Memory] = None
346
+ ) -> str:
347
+ """
348
+ Handles calling of open source models and gpt-4o models
349
+
350
+ Args:
351
+
352
+ image_dict (dict[str , torch.Tensor]) : dict of tensors of images in base64 format
353
+ task (str): a high level command as a language
354
+ mem (Memory|None): instance of Memory class from utils file, which stores all the previous conservations of relevant task in hand. Its set to None if memory is no required while inferencing.
355
+ model_name (str) : Name of open source model to be inferenced. To use gpt models, pass a gpt4o
356
+ image_path (str) : Path to image file for opensource models.
357
+
358
+ Returns:
359
+
360
+ response (str): a low level language command that can be understood by low level planner
361
+ """
362
+
363
+ if model_name == "gpt4o":
364
+ actions = self.gpt_inference(image_dict, task, mem)
365
+
366
+ return actions
@@ -0,0 +1,64 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import deque
16
+
17
+ from opentau.planner.utils.utils import load_prompt_library
18
+
19
+
20
+ class Memory:
21
+ """
22
+ Generates a memory like class to store, retrieve all the conversations for a particular task between user and LLM assistant
23
+ """
24
+
25
+ def __init__(self, conversation=None, len=1000):
26
+ """
27
+ Initializes conservation variable if any earlier conservation needs to be append before initilaizing the memory object.
28
+ len: fixed buffer size of conversation
29
+ """
30
+
31
+ self.len = len
32
+ self.prompts_dict = load_prompt_library("src/opentau/planner/prompts.yaml")
33
+
34
+ if conversation:
35
+ self.conversation = conversation
36
+ else:
37
+ self.conversation = deque()
38
+ context = [
39
+ {
40
+ "type": "text",
41
+ "text": f"{self.prompts_dict['prompts']['robot_system_manipulation']['template']}",
42
+ }
43
+ ]
44
+ self.add_conversation("system", context)
45
+
46
+ def add_conversation(self, role: str, message: list[dict[str, str]]) -> None:
47
+ """
48
+ Adds new conversations to history of conversations.
49
+
50
+ Args:
51
+ role (str): The message given by (system, user, assistant)
52
+ message (list[dict[str, str]]): message containing text and/or images.
53
+ """
54
+
55
+ if len(self.conversation) >= self.len:
56
+ self.conversation.popleft()
57
+
58
+ self.conversation.append({"role": role, "content": message})
59
+
60
+ def get_conversation(self) -> list[dict[str, str]]:
61
+ """
62
+ Returns the stored conversation or history of conversations
63
+ """
64
+ return list(self.conversation)
@@ -0,0 +1,65 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+ import io
17
+
18
+ import torch
19
+ import yaml
20
+ from PIL import Image
21
+
22
+
23
+ def tensor_to_base64(image_dict: dict[str, torch.Tensor], task: str) -> list[bytes]:
24
+ """
25
+ converts dictionary of tensors into list of base64
26
+
27
+ Args :
28
+ image_dict (dict[str, torch.Tensor]) : Dictionary of tensors (camera images)
29
+ task : two supported tasks, i.e , 'mani' and 'nav'
30
+
31
+ Return:
32
+ images (list[bytes]) : List of base64
33
+ """
34
+
35
+ images = []
36
+
37
+ for img in image_dict.values():
38
+ if task == "mani":
39
+ img_tensor = img.squeeze(0)
40
+ img_tensor = img_tensor.to(dtype=torch.float32, device="cpu")
41
+ img_tensor = img_tensor.clamp(0, 1) * 255.0
42
+ img = Image.fromarray(img_tensor.to(torch.uint8).permute(1, 2, 0).numpy())
43
+
44
+ buffered = io.BytesIO()
45
+ img.save(buffered, format="PNG")
46
+
47
+ images.append(base64.b64encode(buffered.getvalue()).decode("utf-8"))
48
+
49
+ return images
50
+
51
+
52
+ def load_prompt_library(filepath: str) -> dict:
53
+ """
54
+ Loads a YAML file and returns its content as a dictionary.
55
+ """
56
+ try:
57
+ with open(filepath) as file:
58
+ # Use yaml.safe_load() to parse the YAML file
59
+ return yaml.safe_load(file)
60
+ except FileNotFoundError:
61
+ print(f"Error: The file at '{filepath}' was not found.")
62
+ return None
63
+ except yaml.YAMLError as e:
64
+ print(f"Error parsing the YAML file: {e}")
65
+ return None
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Policies module for OpenTau.
17
+
18
+ This module exports the configuration classes for available policies,
19
+ such as PI0, PI05, and Value policy.
20
+ """
21
+
22
+ from .pi0.configuration_pi0 import PI0Config as PI0Config
23
+ from .pi05.configuration_pi05 import PI05Config as PI05Config
24
+ from .value.configuration_value import ValueConfig as ValueConfig
@@ -0,0 +1,172 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Factory functions for creating policy instances and configurations.
19
+
20
+ This module provides utility functions to instantiate policy classes and their
21
+ corresponding configurations based on policy names and types. It handles the
22
+ logic for creating fresh policies or loading pretrained ones, as well as
23
+ parsing features from datasets or environments to properly configure the policies.
24
+ """
25
+
26
+ from typing import Optional
27
+
28
+ import numpy as np
29
+ from torch import nn
30
+
31
+ from opentau.configs.policies import PreTrainedConfig
32
+ from opentau.configs.types import FeatureType
33
+ from opentau.datasets.lerobot_dataset import LeRobotDatasetMetadata
34
+ from opentau.datasets.utils import dataset_to_policy_features
35
+ from opentau.policies.pi0.configuration_pi0 import PI0Config
36
+ from opentau.policies.pi05.configuration_pi05 import PI05Config
37
+ from opentau.policies.pretrained import PreTrainedPolicy
38
+ from opentau.policies.value.configuration_value import ValueConfig
39
+
40
+
41
+ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
42
+ """Get the policy's class given a name.
43
+
44
+ Args:
45
+ name: The name of the policy (e.g., "pi0", "pi05", "value").
46
+ Must match the policy class's `name` attribute.
47
+
48
+ Returns:
49
+ type[PreTrainedPolicy]: The policy class corresponding to the given name.
50
+
51
+ Raises:
52
+ NotImplementedError: If the policy with the given name is not implemented.
53
+ """
54
+ if name == "pi0":
55
+ from opentau.policies.pi0.modeling_pi0 import PI0Policy
56
+
57
+ return PI0Policy
58
+ elif name == "pi05":
59
+ from opentau.policies.pi05.modeling_pi05 import PI05Policy
60
+
61
+ return PI05Policy
62
+ elif name == "value":
63
+ from opentau.policies.value.modeling_value import ValueFunction
64
+
65
+ return ValueFunction
66
+ else:
67
+ raise NotImplementedError(f"Policy with name {name} is not implemented.")
68
+
69
+
70
+ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
71
+ """Creates a policy configuration object based on the policy type.
72
+
73
+ Args:
74
+ policy_type: The type of the policy (e.g., "pi0", "pi05", "value").
75
+ **kwargs: Keyword arguments to be passed to the configuration class constructor.
76
+
77
+ Returns:
78
+ PreTrainedConfig: An instance of the corresponding policy configuration class.
79
+
80
+ Raises:
81
+ ValueError: If the policy type is not available.
82
+ """
83
+ if policy_type == "pi0":
84
+ return PI0Config(**kwargs)
85
+ elif policy_type == "pi05":
86
+ return PI05Config(**kwargs)
87
+ elif policy_type == "value":
88
+ return ValueConfig(**kwargs)
89
+ else:
90
+ raise ValueError(f"Policy type '{policy_type}' is not available.")
91
+
92
+
93
+ def make_policy(
94
+ cfg: PreTrainedConfig,
95
+ ds_meta: LeRobotDatasetMetadata | None = None,
96
+ features: dict[str, FeatureType] | None = None,
97
+ stats: dict[str, dict[str, np.ndarray]] | None = None,
98
+ execution_target: Optional[
99
+ str
100
+ ] = None, # None for unified training, "robot" for robot action decoder inference, "cloud" for VLM on cloud inference
101
+ ) -> PreTrainedPolicy:
102
+ """Make an instance of a policy class.
103
+
104
+ This function exists because (for now) we need to parse features from either a dataset or an environment
105
+ in order to properly dimension and instantiate a policy for that dataset or environment.
106
+
107
+ Args:
108
+ cfg: The config of the policy to make. If `pretrained_path` is set, the policy will
109
+ be loaded with the weights from that path.
110
+ ds_meta: Dataset metadata to take input/output shapes and statistics to use for
111
+ (un)normalization of inputs/outputs in the policy. Defaults to None.
112
+ features: Input and output features. Defaults to None.
113
+ stats: Dictionary of statistics for normalization. Defaults to None.
114
+ execution_target: Target for execution. Can be "robot", "cloud", or None.
115
+ None implies unified training. "robot" implies robot action decoder inference.
116
+ "cloud" implies VLM on cloud inference. Defaults to None.
117
+
118
+ Returns:
119
+ PreTrainedPolicy: An instance of the created policy.
120
+
121
+ Raises:
122
+ ValueError: If neither or both `ds_meta` and `features` are provided when features are not already set in config.
123
+ ValueError: If `execution_target` is invalid.
124
+ """
125
+ features_already_set = (
126
+ isinstance(cfg.input_features, dict)
127
+ and cfg.input_features
128
+ and isinstance(cfg.output_features, dict)
129
+ and cfg.output_features
130
+ )
131
+ if (bool(ds_meta) + (features is not None) != 1) and not features_already_set:
132
+ raise ValueError("Exactly one of ds_meta or features must be provided.")
133
+
134
+ if execution_target not in ["robot", "cloud", None]:
135
+ raise ValueError(
136
+ f"execution_target must be one of ['robot', 'cloud', None], but got {execution_target}."
137
+ )
138
+
139
+ policy_cls = get_policy_class(cfg.type)
140
+
141
+ kwargs = {}
142
+
143
+ if ds_meta is not None:
144
+ features = dataset_to_policy_features(ds_meta.features)
145
+ kwargs["dataset_stats"] = ds_meta.stats
146
+
147
+ if not features_already_set:
148
+ cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
149
+ cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
150
+
151
+ if stats is not None:
152
+ kwargs["dataset_stats"] = stats
153
+
154
+ if execution_target is not None:
155
+ kwargs["execution_target"] = execution_target
156
+
157
+ kwargs["config"] = cfg
158
+
159
+ if cfg.pretrained_path:
160
+ # Load a pretrained policy and override the config if needed (for example, if there are inference-time
161
+ # hyperparameters that we want to vary).
162
+ kwargs["pretrained_name_or_path"] = cfg.pretrained_path
163
+ policy = policy_cls.from_pretrained(**kwargs)
164
+ else:
165
+ # Make a fresh policy.
166
+ policy = policy_cls(**kwargs)
167
+
168
+ assert isinstance(policy, nn.Module)
169
+
170
+ # policy = torch.compile(policy, mode="reduce-overhead")
171
+
172
+ return policy