opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,460 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utilities for recording and converting LIBERO datasets.
16
+
17
+ This module provides functions for converting LIBERO episode data into the standard
18
+ dataset format, including consolidation of task results, aggregation of multiple
19
+ results, and generation of dataset metadata.
20
+ """
21
+
22
+ import datetime
23
+ import io
24
+ import json
25
+ import logging
26
+ import math
27
+ import shutil
28
+ from itertools import count
29
+ from pathlib import Path
30
+
31
+ import numpy as np
32
+ import pandas as pd
33
+ import pyarrow as pa
34
+ import pyarrow.parquet as pq
35
+ import torch
36
+ from PIL import Image
37
+
38
+ LIBERO_TASKS = [
39
+ "put the white mug on the left plate and put the yellow and white mug on the right plate",
40
+ "put the white mug on the plate and put the chocolate pudding to the right of the plate",
41
+ "put the yellow and white mug in the microwave and close it",
42
+ "turn on the stove and put the moka pot on it",
43
+ "put both the alphabet soup and the cream cheese box in the basket",
44
+ "put both the alphabet soup and the tomato sauce in the basket",
45
+ "put both moka pots on the stove",
46
+ "put both the cream cheese box and the butter in the basket",
47
+ "put the black bowl in the bottom drawer of the cabinet and close it",
48
+ "pick up the book and place it in the back compartment of the caddy",
49
+ "put the bowl on the plate",
50
+ "put the wine bottle on the rack",
51
+ "open the top drawer and put the bowl inside",
52
+ "put the cream cheese in the bowl",
53
+ "put the wine bottle on top of the cabinet",
54
+ "push the plate to the front of the stove",
55
+ "turn on the stove",
56
+ "put the bowl on the stove",
57
+ "put the bowl on top of the cabinet",
58
+ "open the middle drawer of the cabinet",
59
+ "pick up the orange juice and place it in the basket",
60
+ "pick up the ketchup and place it in the basket",
61
+ "pick up the cream cheese and place it in the basket",
62
+ "pick up the bbq sauce and place it in the basket",
63
+ "pick up the alphabet soup and place it in the basket",
64
+ "pick up the milk and place it in the basket",
65
+ "pick up the salad dressing and place it in the basket",
66
+ "pick up the butter and place it in the basket",
67
+ "pick up the tomato sauce and place it in the basket",
68
+ "pick up the chocolate pudding and place it in the basket",
69
+ "pick up the black bowl next to the cookie box and place it on the plate",
70
+ "pick up the black bowl in the top drawer of the wooden cabinet and place it on the plate",
71
+ "pick up the black bowl on the ramekin and place it on the plate",
72
+ "pick up the black bowl on the stove and place it on the plate",
73
+ "pick up the black bowl between the plate and the ramekin and place it on the plate",
74
+ "pick up the black bowl on the cookie box and place it on the plate",
75
+ "pick up the black bowl next to the plate and place it on the plate",
76
+ "pick up the black bowl next to the ramekin and place it on the plate",
77
+ "pick up the black bowl from table center and place it on the plate",
78
+ "pick up the black bowl on the wooden cabinet and place it on the plate",
79
+ ]
80
+
81
+ LIBERO_TASK_TO_IDX = {task: idx for idx, task in enumerate(LIBERO_TASKS)}
82
+ LIBERO_FPS = 10
83
+ LIBERO_CHUNK_SIZE = 1000
84
+ LIBERO_DEFAULT_STATS = {
85
+ "image": {
86
+ "mean": [
87
+ [[0.48068472743034363]],
88
+ [[0.4485854208469391]],
89
+ [[0.4106878638267517]],
90
+ ],
91
+ "std": [
92
+ [[0.229267880320549]],
93
+ [[0.22267605364322662]],
94
+ [[0.21524138748645782]],
95
+ ],
96
+ "max": [
97
+ [[1.0]],
98
+ [[1.0]],
99
+ [[1.0]],
100
+ ],
101
+ "min": [
102
+ [[0.0]],
103
+ [[0.0]],
104
+ [[0.0]],
105
+ ],
106
+ },
107
+ "wrist_image": {
108
+ "mean": [
109
+ [[0.5057708024978638]],
110
+ [[0.46449801325798035]],
111
+ [[0.42354270815849304]],
112
+ ],
113
+ "std": [
114
+ [[0.2672027349472046]],
115
+ [[0.25637131929397583]],
116
+ [[0.24345873296260834]],
117
+ ],
118
+ "max": [
119
+ [[1.0]],
120
+ [[1.0]],
121
+ [[1.0]],
122
+ ],
123
+ "min": [
124
+ [[0.0]],
125
+ [[0.0]],
126
+ [[0.0]],
127
+ ],
128
+ },
129
+ "state": {
130
+ "mean": [
131
+ -0.04651879519224167,
132
+ 0.03440921753644943,
133
+ 0.7645525336265564,
134
+ 2.972202777862549,
135
+ -0.22047005593776703,
136
+ -0.1255796253681183,
137
+ 0.026914266869425774,
138
+ -0.02719070389866829,
139
+ ],
140
+ "std": [
141
+ 0.10494378954172134,
142
+ 0.15176637470722198,
143
+ 0.3785160183906555,
144
+ 0.34427398443222046,
145
+ 0.9069469571113586,
146
+ 0.3253920078277588,
147
+ 0.014175914227962494,
148
+ 0.014058894477784634,
149
+ ],
150
+ "max": [
151
+ 0.21031762659549713,
152
+ 0.39128610491752625,
153
+ 1.3660105466842651,
154
+ 3.6714255809783936,
155
+ 3.560650587081909,
156
+ 1.386339545249939,
157
+ 0.04233968257904053,
158
+ 0.0013633022317662835,
159
+ ],
160
+ "min": [
161
+ -0.4828203022480011,
162
+ -0.3255046010017395,
163
+ 0.008128180168569088,
164
+ 0.35277295112609863,
165
+ -3.641430377960205,
166
+ -1.842738389968872,
167
+ -0.0013586411951109767,
168
+ -0.042040832340717316,
169
+ ],
170
+ },
171
+ "actions": {
172
+ "mean": [
173
+ 0.06278137117624283,
174
+ 0.0868409126996994,
175
+ -0.09037282317876816,
176
+ 0.0005407406715676188,
177
+ 0.005643361248075962,
178
+ -0.005229088477790356,
179
+ -0.04964079707860947,
180
+ ],
181
+ "std": [
182
+ 0.3355240225791931,
183
+ 0.3784470558166504,
184
+ 0.44472837448120117,
185
+ 0.03924351558089256,
186
+ 0.06339313089847565,
187
+ 0.07797032594680786,
188
+ 0.9987710118293762,
189
+ ],
190
+ "max": [0.9375, 0.9375, 0.9375, 0.3557142913341522, 0.375, 0.375, 1.0],
191
+ "min": [-0.9375, -0.9375, -0.9375, -0.2582142949104309, -0.375, -0.3675000071525574, -1.0],
192
+ },
193
+ }
194
+
195
+
196
+ def make_readme() -> str:
197
+ """Generate a README string for automatically generated datasets.
198
+
199
+ Returns:
200
+ String containing information about when and how the dataset was generated.
201
+ """
202
+ now_utc = datetime.datetime.now(datetime.timezone.utc)
203
+ return f"This dataset was automatically generated by {__file__} on {now_utc.isoformat()}"
204
+
205
+
206
+ def make_libero_info(total_episodes: int, total_frames: int, total_chunks: int) -> dict:
207
+ """Generate metadata dictionary for LIBERO dataset.
208
+
209
+ Args:
210
+ total_episodes: Total number of episodes in the dataset.
211
+ total_frames: Total number of frames across all episodes.
212
+ total_chunks: Total number of data chunks.
213
+
214
+ Returns:
215
+ Dictionary containing dataset metadata including codebase version, robot
216
+ type, episode/frame counts, task information, and feature specifications.
217
+ """
218
+ return {
219
+ "codebase_version": "v2.0",
220
+ "robot_type": "panda",
221
+ "total_episodes": total_episodes,
222
+ "total_frames": total_frames,
223
+ "total_tasks": len(LIBERO_TASKS),
224
+ "total_videos": 0,
225
+ "total_chunks": total_chunks,
226
+ "chunks_size": LIBERO_CHUNK_SIZE,
227
+ "fps": LIBERO_FPS,
228
+ "splits": {"train": f"0:{total_episodes}"},
229
+ "data_path": "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet",
230
+ "video_path": "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4",
231
+ "features": {
232
+ "image": {"dtype": "image", "shape": [256, 256, 3], "names": ["height", "width", "channel"]},
233
+ "wrist_image": {
234
+ "dtype": "image",
235
+ "shape": [256, 256, 3],
236
+ "names": ["height", "width", "channel"],
237
+ },
238
+ "state": {"dtype": "float32", "shape": [8], "names": ["state"]},
239
+ "actions": {"dtype": "float32", "shape": [7], "names": ["actions"]},
240
+ "timestamp": {
241
+ "dtype": "float32",
242
+ "shape": [1],
243
+ "names": None,
244
+ },
245
+ "frame_index": {
246
+ "dtype": "int64",
247
+ "shape": [1],
248
+ "names": None,
249
+ },
250
+ "episode_index": {
251
+ "dtype": "int64",
252
+ "shape": [1],
253
+ "names": None,
254
+ },
255
+ "index": {
256
+ "dtype": "int64",
257
+ "shape": [1],
258
+ "names": None,
259
+ },
260
+ "task_index": {
261
+ "dtype": "int64",
262
+ "shape": [1],
263
+ "names": None,
264
+ },
265
+ },
266
+ }
267
+
268
+
269
+ def uint8_arr_to_png_bytes(arr: np.ndarray) -> bytes:
270
+ """Convert a uint8 image array to PNG bytes.
271
+
272
+ Args:
273
+ arr: Image array in CHW format (channels, height, width) with uint8 values.
274
+
275
+ Returns:
276
+ PNG-encoded image bytes.
277
+ """
278
+ arr_hwc = np.transpose(arr, (1, 2, 0)) # shape: (244, 224, 3)
279
+ img = Image.fromarray(arr_hwc, mode="RGB")
280
+ buf = io.BytesIO()
281
+ img.save(buf, format="PNG")
282
+ return buf.getvalue()
283
+
284
+
285
+ def consolidate_task_result(task_result: dict, output_dir: str | Path, allow_overwrite: bool = False) -> None:
286
+ """Consolidate LIBERO task results into a standardized dataset format.
287
+
288
+ Args:
289
+ task_result: Dictionary containing episode data with observations, actions,
290
+ and metadata.
291
+ output_dir: Directory path where the consolidated dataset will be saved.
292
+ allow_overwrite: If True, overwrite existing output directory. Defaults to False.
293
+
294
+ Raises:
295
+ FileExistsError: If output_dir already exists and allow_overwrite is False.
296
+ """
297
+ output_dir = Path(output_dir)
298
+ if allow_overwrite and output_dir.is_dir():
299
+ shutil.rmtree(output_dir)
300
+ output_dir.mkdir(parents=True, exist_ok=False)
301
+ meta_dir = output_dir / "meta"
302
+ meta_dir.mkdir(parents=True, exist_ok=False)
303
+
304
+ with open(meta_dir / "tasks.jsonl", "w") as f:
305
+ for idx, task in enumerate(LIBERO_TASKS):
306
+ task_entry = json.dumps({"task_index": idx, "task": task})
307
+ f.write(task_entry + "\n")
308
+
309
+ with open(meta_dir / "stats.json", "w") as f:
310
+ json.dump(LIBERO_DEFAULT_STATS, f, indent=4)
311
+
312
+ all_lengths = [d["done_index"] + 1 for d in task_result["per_episode"]]
313
+ all_successes = [d["success"] for d in task_result["per_episode"]]
314
+ n_episodes = len(all_lengths)
315
+
316
+ with open(meta_dir / "info.json", "w") as f:
317
+ json.dump(
318
+ make_libero_info(
319
+ total_episodes=n_episodes,
320
+ total_frames=sum(all_lengths).item(),
321
+ total_chunks=int(math.ceil(n_episodes / 1000)),
322
+ ),
323
+ f,
324
+ indent=4,
325
+ )
326
+
327
+ with open(output_dir / "README.md", "w") as f:
328
+ f.write(make_readme())
329
+
330
+ global_idx_start = 0
331
+ for ep_idx, (ep_len, success) in enumerate(zip(all_lengths, all_successes, strict=False)):
332
+ ep_actions = task_result["episodes"]["action"][ep_idx][:ep_len]
333
+ ep_actions = ep_actions.to(torch.float32).numpy(force=True)
334
+ ep_images = task_result["episodes"]["observation"]["camera0"][ep_idx][:ep_len]
335
+ ep_images = (ep_images * 255).to(torch.uint8).numpy(force=True)
336
+ ep_wrist_images = task_result["episodes"]["observation"]["camera1"][ep_idx][:ep_len]
337
+ ep_wrist_images = (ep_wrist_images * 255).to(torch.uint8).numpy(force=True)
338
+ ep_states = task_result["episodes"]["observation"]["state"][ep_idx][:ep_len]
339
+ ep_states = ep_states.to(torch.float32).numpy(force=True)
340
+
341
+ prompt = task_result["episodes"]["observation"]["prompt"][ep_idx][0]
342
+ task_idx = LIBERO_TASK_TO_IDX[prompt]
343
+ df = pd.DataFrame(
344
+ {
345
+ "image": list(ep_images),
346
+ "wrist_image": list(ep_wrist_images),
347
+ "state": list(ep_states),
348
+ "actions": list(ep_actions),
349
+ "timestamp": np.arange(ep_len, dtype=np.float32) / np.float32(LIBERO_FPS),
350
+ "frame_index": np.arange(ep_len, dtype=int),
351
+ "episode_index": np.ones((ep_len,), dtype=int) * ep_idx,
352
+ "index": global_idx_start + np.arange(ep_len, dtype=int),
353
+ "task_index": [task_idx] * len(ep_actions),
354
+ }
355
+ )
356
+ global_idx_start += ep_len
357
+
358
+ cnt = count()
359
+ df.image = df.image.apply(
360
+ lambda x, counter=cnt: {
361
+ "bytes": uint8_arr_to_png_bytes(x),
362
+ "path": f"frame_{next(counter):06d}.png",
363
+ }
364
+ )
365
+
366
+ cnt = count()
367
+ df.wrist_image = df.wrist_image.apply(
368
+ lambda x, counter=cnt: {
369
+ "bytes": uint8_arr_to_png_bytes(x),
370
+ "path": f"frame_{next(counter):06d}.png",
371
+ }
372
+ )
373
+
374
+ df.actions = df.actions.apply(lambda x: x[:7]) # only keep first 7 dimensions of action
375
+ df.state = df.state.apply(lambda x: x[:8]) # only keep first 8 dimensions of state
376
+
377
+ chunk_idx = ep_idx // LIBERO_CHUNK_SIZE
378
+ parquet_path = output_dir / "data" / f"chunk-{chunk_idx:03d}" / f"episode_{ep_idx:06d}.parquet"
379
+ parquet_path.parent.mkdir(parents=True, exist_ok=True)
380
+
381
+ # Convert the pandas DataFrame to a PyArrow Table without the pandas index
382
+ table = pa.Table.from_pandas(df, preserve_index=False)
383
+
384
+ # Inject Hugging Face-style metadata so `datasets` will reconstruct the right features.
385
+ # Target:
386
+ # {b'huggingface': b'{"info": {"features": {"image": {"_type": "Image"}, ...}}}'}
387
+ hf_features = {
388
+ "image": {"_type": "Image"},
389
+ "wrist_image": {"_type": "Image"},
390
+ "state": {
391
+ "feature": {"dtype": "float32", "_type": "Value"},
392
+ "length": 8,
393
+ "_type": "Sequence",
394
+ },
395
+ "actions": {
396
+ "feature": {"dtype": "float32", "_type": "Value"},
397
+ "length": 7,
398
+ "_type": "Sequence",
399
+ },
400
+ "timestamp": {"dtype": "float32", "_type": "Value"},
401
+ "frame_index": {"dtype": "int64", "_type": "Value"},
402
+ "episode_index": {"dtype": "int64", "_type": "Value"},
403
+ "index": {"dtype": "int64", "_type": "Value"},
404
+ "task_index": {"dtype": "int64", "_type": "Value"},
405
+ }
406
+
407
+ hf_meta = {"info": {"features": hf_features}}
408
+
409
+ schema_metadata = dict(table.schema.metadata or {})
410
+ schema_metadata[b"huggingface"] = json.dumps(hf_meta).encode("utf-8")
411
+ table = table.replace_schema_metadata(schema_metadata)
412
+
413
+ # Finally, write the Parquet file with the patched schema metadata
414
+ pq.write_table(table, parquet_path)
415
+
416
+ with open(meta_dir / "episodes.jsonl", "a") as f:
417
+ ep_entry = json.dumps(
418
+ {
419
+ "episode_index": ep_idx,
420
+ "tasks": [prompt],
421
+ "length": ep_len.item(),
422
+ "success": success,
423
+ }
424
+ )
425
+ f.write(ep_entry + "\n")
426
+
427
+
428
+ def aggregate_task_results(results: list[dict]) -> dict:
429
+ """Aggregate multiple LIBERO task results into a single result dictionary.
430
+
431
+ Args:
432
+ results: List of task result dictionaries to aggregate.
433
+
434
+ Returns:
435
+ Dictionary containing aggregated episode data with all observations,
436
+ actions, and per-episode metadata combined.
437
+ """
438
+ logging.info(f"Aggregating {len(results)} result(s)")
439
+ ret = {
440
+ "per_episode": [],
441
+ "episodes": {
442
+ "observation": {
443
+ "camera0": [],
444
+ "camera1": [],
445
+ "state": [],
446
+ "prompt": [],
447
+ },
448
+ "action": [],
449
+ },
450
+ }
451
+
452
+ for task_result in results:
453
+ ret["per_episode"].extend(task_result["per_episode"])
454
+ ret["episodes"]["observation"]["camera0"].extend(task_result["episodes"]["observation"]["camera0"])
455
+ ret["episodes"]["observation"]["camera1"].extend(task_result["episodes"]["observation"]["camera1"])
456
+ ret["episodes"]["observation"]["state"].extend(task_result["episodes"]["observation"]["state"])
457
+ ret["episodes"]["observation"]["prompt"].extend(task_result["episodes"]["observation"]["prompt"])
458
+ ret["episodes"]["action"].extend(task_result["episodes"]["action"])
459
+
460
+ return ret
@@ -0,0 +1,180 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Utilities for tracking and logging training metrics.
18
+
19
+ This module provides classes for tracking metrics during training, including
20
+ AverageMeter for computing running averages and MetricsTracker for managing
21
+ multiple metrics with step tracking.
22
+ """
23
+
24
+ from typing import Any
25
+
26
+ from opentau.utils.utils import format_big_number
27
+
28
+
29
+ class AverageMeter:
30
+ """Computes and stores the average and current value.
31
+
32
+ Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py
33
+
34
+ Args:
35
+ name: Name of the metric being tracked.
36
+ fmt: Format string for displaying the average value. Defaults to ":f".
37
+ """
38
+
39
+ def __init__(self, name: str, fmt: str = ":f"):
40
+ self.name = name
41
+ self.fmt = fmt
42
+ self.reset()
43
+
44
+ def reset(self) -> None:
45
+ """Reset all accumulated statistics to zero."""
46
+ self.val = 0.0
47
+ self.avg = 0.0
48
+ self.sum = 0.0
49
+ self.count = 0.0
50
+
51
+ def update(self, val: float, n: int = 1) -> None:
52
+ """Update the meter with a new value.
53
+
54
+ Args:
55
+ val: New value to add.
56
+ n: Number of samples this value represents. Defaults to 1.
57
+ """
58
+ self.val = val
59
+ self.sum += val * n
60
+ self.count += n
61
+ self.avg = self.sum / self.count
62
+
63
+ def __str__(self):
64
+ fmtstr = "{name}:{avg" + self.fmt + "}"
65
+ return fmtstr.format(**self.__dict__)
66
+
67
+
68
+ class MetricsTracker:
69
+ """
70
+ A helper class to track and log metrics over time.
71
+
72
+ Usage pattern::
73
+
74
+ # initialize, potentially with non-zero initial step (e.g. if resuming run)
75
+ metrics = {"loss": AverageMeter("loss", ":.3f")}
76
+ train_metrics = MetricsTracker(cfg, dataset, metrics, initial_step=step)
77
+
78
+ # update metrics derived from step (samples, episodes, epochs) at each training step
79
+ train_metrics.step()
80
+
81
+ # update various metrics
82
+ loss = policy.forward(batch)
83
+ train_metrics.loss = loss
84
+
85
+ # display current metrics
86
+ logging.info(train_metrics)
87
+
88
+ # export for wandb
89
+ wandb.log(train_metrics.to_dict())
90
+
91
+ # reset averages after logging
92
+ train_metrics.reset_averages()
93
+ """
94
+
95
+ __keys__ = [
96
+ "_batch_size",
97
+ "metrics",
98
+ "steps",
99
+ "samples",
100
+ ]
101
+
102
+ def __init__(
103
+ self,
104
+ batch_size: int,
105
+ metrics: dict[str, AverageMeter],
106
+ initial_step: int = 0,
107
+ ):
108
+ """Initialize the metrics tracker.
109
+
110
+ Args:
111
+ batch_size: Number of samples per gradient update.
112
+ metrics: Dictionary of metric names to AverageMeter instances.
113
+ initial_step: Starting step number (useful when resuming training).
114
+ Defaults to 0.
115
+ """
116
+ self.__dict__.update(dict.fromkeys(self.__keys__))
117
+ # This is the actual batch size, i.e., number of samples used to compute a gradient update;
118
+ # not to be confused with number of dataloader_batch_size.
119
+ self._batch_size = batch_size
120
+ self.metrics = metrics
121
+
122
+ # This is the actual step, i.e., number of gradient updates so far;
123
+ # not to be confused with number of passes (sub-steps)
124
+ self.steps = initial_step
125
+ # A sample is an (observation,action) pair, where observation and action
126
+ # can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
127
+ self.samples = self.steps * self._batch_size
128
+
129
+ def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
130
+ if name in self.__dict__:
131
+ return self.__dict__[name]
132
+ elif name in self.metrics:
133
+ return self.metrics[name]
134
+ else:
135
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
136
+
137
+ def __setattr__(self, name: str, value: Any) -> None:
138
+ if name in self.__dict__:
139
+ super().__setattr__(name, value)
140
+ elif name in self.metrics:
141
+ self.metrics[name].update(value)
142
+ else:
143
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
144
+
145
+ def step(self) -> None:
146
+ """
147
+ Updates metrics that depend on 'step' for one step.
148
+ """
149
+ self.steps += 1
150
+ self.samples += self._batch_size
151
+
152
+ def __str__(self) -> str:
153
+ display_list = [
154
+ f"step:{format_big_number(self.steps)}",
155
+ # number of samples seen during training
156
+ f"smpl:{format_big_number(self.samples)}",
157
+ *[str(m) for m in self.metrics.values()],
158
+ ]
159
+ return " ".join(display_list)
160
+
161
+ def to_dict(self, use_avg: bool = True) -> dict[str, int | float]:
162
+ """Convert current metrics to a dictionary.
163
+
164
+ Args:
165
+ use_avg: If True, use average values; otherwise use current values.
166
+ Defaults to True.
167
+
168
+ Returns:
169
+ Dictionary containing steps, samples, and all metric values.
170
+ """
171
+ return {
172
+ "steps": self.steps,
173
+ "samples": self.samples,
174
+ **{k: m.avg if use_avg else m.val for k, m in self.metrics.items()},
175
+ }
176
+
177
+ def reset_averages(self) -> None:
178
+ """Resets average meters."""
179
+ for m in self.metrics.values():
180
+ m.reset()