opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,460 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Utilities for recording and converting LIBERO datasets.
|
|
16
|
+
|
|
17
|
+
This module provides functions for converting LIBERO episode data into the standard
|
|
18
|
+
dataset format, including consolidation of task results, aggregation of multiple
|
|
19
|
+
results, and generation of dataset metadata.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import datetime
|
|
23
|
+
import io
|
|
24
|
+
import json
|
|
25
|
+
import logging
|
|
26
|
+
import math
|
|
27
|
+
import shutil
|
|
28
|
+
from itertools import count
|
|
29
|
+
from pathlib import Path
|
|
30
|
+
|
|
31
|
+
import numpy as np
|
|
32
|
+
import pandas as pd
|
|
33
|
+
import pyarrow as pa
|
|
34
|
+
import pyarrow.parquet as pq
|
|
35
|
+
import torch
|
|
36
|
+
from PIL import Image
|
|
37
|
+
|
|
38
|
+
LIBERO_TASKS = [
|
|
39
|
+
"put the white mug on the left plate and put the yellow and white mug on the right plate",
|
|
40
|
+
"put the white mug on the plate and put the chocolate pudding to the right of the plate",
|
|
41
|
+
"put the yellow and white mug in the microwave and close it",
|
|
42
|
+
"turn on the stove and put the moka pot on it",
|
|
43
|
+
"put both the alphabet soup and the cream cheese box in the basket",
|
|
44
|
+
"put both the alphabet soup and the tomato sauce in the basket",
|
|
45
|
+
"put both moka pots on the stove",
|
|
46
|
+
"put both the cream cheese box and the butter in the basket",
|
|
47
|
+
"put the black bowl in the bottom drawer of the cabinet and close it",
|
|
48
|
+
"pick up the book and place it in the back compartment of the caddy",
|
|
49
|
+
"put the bowl on the plate",
|
|
50
|
+
"put the wine bottle on the rack",
|
|
51
|
+
"open the top drawer and put the bowl inside",
|
|
52
|
+
"put the cream cheese in the bowl",
|
|
53
|
+
"put the wine bottle on top of the cabinet",
|
|
54
|
+
"push the plate to the front of the stove",
|
|
55
|
+
"turn on the stove",
|
|
56
|
+
"put the bowl on the stove",
|
|
57
|
+
"put the bowl on top of the cabinet",
|
|
58
|
+
"open the middle drawer of the cabinet",
|
|
59
|
+
"pick up the orange juice and place it in the basket",
|
|
60
|
+
"pick up the ketchup and place it in the basket",
|
|
61
|
+
"pick up the cream cheese and place it in the basket",
|
|
62
|
+
"pick up the bbq sauce and place it in the basket",
|
|
63
|
+
"pick up the alphabet soup and place it in the basket",
|
|
64
|
+
"pick up the milk and place it in the basket",
|
|
65
|
+
"pick up the salad dressing and place it in the basket",
|
|
66
|
+
"pick up the butter and place it in the basket",
|
|
67
|
+
"pick up the tomato sauce and place it in the basket",
|
|
68
|
+
"pick up the chocolate pudding and place it in the basket",
|
|
69
|
+
"pick up the black bowl next to the cookie box and place it on the plate",
|
|
70
|
+
"pick up the black bowl in the top drawer of the wooden cabinet and place it on the plate",
|
|
71
|
+
"pick up the black bowl on the ramekin and place it on the plate",
|
|
72
|
+
"pick up the black bowl on the stove and place it on the plate",
|
|
73
|
+
"pick up the black bowl between the plate and the ramekin and place it on the plate",
|
|
74
|
+
"pick up the black bowl on the cookie box and place it on the plate",
|
|
75
|
+
"pick up the black bowl next to the plate and place it on the plate",
|
|
76
|
+
"pick up the black bowl next to the ramekin and place it on the plate",
|
|
77
|
+
"pick up the black bowl from table center and place it on the plate",
|
|
78
|
+
"pick up the black bowl on the wooden cabinet and place it on the plate",
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
LIBERO_TASK_TO_IDX = {task: idx for idx, task in enumerate(LIBERO_TASKS)}
|
|
82
|
+
LIBERO_FPS = 10
|
|
83
|
+
LIBERO_CHUNK_SIZE = 1000
|
|
84
|
+
LIBERO_DEFAULT_STATS = {
|
|
85
|
+
"image": {
|
|
86
|
+
"mean": [
|
|
87
|
+
[[0.48068472743034363]],
|
|
88
|
+
[[0.4485854208469391]],
|
|
89
|
+
[[0.4106878638267517]],
|
|
90
|
+
],
|
|
91
|
+
"std": [
|
|
92
|
+
[[0.229267880320549]],
|
|
93
|
+
[[0.22267605364322662]],
|
|
94
|
+
[[0.21524138748645782]],
|
|
95
|
+
],
|
|
96
|
+
"max": [
|
|
97
|
+
[[1.0]],
|
|
98
|
+
[[1.0]],
|
|
99
|
+
[[1.0]],
|
|
100
|
+
],
|
|
101
|
+
"min": [
|
|
102
|
+
[[0.0]],
|
|
103
|
+
[[0.0]],
|
|
104
|
+
[[0.0]],
|
|
105
|
+
],
|
|
106
|
+
},
|
|
107
|
+
"wrist_image": {
|
|
108
|
+
"mean": [
|
|
109
|
+
[[0.5057708024978638]],
|
|
110
|
+
[[0.46449801325798035]],
|
|
111
|
+
[[0.42354270815849304]],
|
|
112
|
+
],
|
|
113
|
+
"std": [
|
|
114
|
+
[[0.2672027349472046]],
|
|
115
|
+
[[0.25637131929397583]],
|
|
116
|
+
[[0.24345873296260834]],
|
|
117
|
+
],
|
|
118
|
+
"max": [
|
|
119
|
+
[[1.0]],
|
|
120
|
+
[[1.0]],
|
|
121
|
+
[[1.0]],
|
|
122
|
+
],
|
|
123
|
+
"min": [
|
|
124
|
+
[[0.0]],
|
|
125
|
+
[[0.0]],
|
|
126
|
+
[[0.0]],
|
|
127
|
+
],
|
|
128
|
+
},
|
|
129
|
+
"state": {
|
|
130
|
+
"mean": [
|
|
131
|
+
-0.04651879519224167,
|
|
132
|
+
0.03440921753644943,
|
|
133
|
+
0.7645525336265564,
|
|
134
|
+
2.972202777862549,
|
|
135
|
+
-0.22047005593776703,
|
|
136
|
+
-0.1255796253681183,
|
|
137
|
+
0.026914266869425774,
|
|
138
|
+
-0.02719070389866829,
|
|
139
|
+
],
|
|
140
|
+
"std": [
|
|
141
|
+
0.10494378954172134,
|
|
142
|
+
0.15176637470722198,
|
|
143
|
+
0.3785160183906555,
|
|
144
|
+
0.34427398443222046,
|
|
145
|
+
0.9069469571113586,
|
|
146
|
+
0.3253920078277588,
|
|
147
|
+
0.014175914227962494,
|
|
148
|
+
0.014058894477784634,
|
|
149
|
+
],
|
|
150
|
+
"max": [
|
|
151
|
+
0.21031762659549713,
|
|
152
|
+
0.39128610491752625,
|
|
153
|
+
1.3660105466842651,
|
|
154
|
+
3.6714255809783936,
|
|
155
|
+
3.560650587081909,
|
|
156
|
+
1.386339545249939,
|
|
157
|
+
0.04233968257904053,
|
|
158
|
+
0.0013633022317662835,
|
|
159
|
+
],
|
|
160
|
+
"min": [
|
|
161
|
+
-0.4828203022480011,
|
|
162
|
+
-0.3255046010017395,
|
|
163
|
+
0.008128180168569088,
|
|
164
|
+
0.35277295112609863,
|
|
165
|
+
-3.641430377960205,
|
|
166
|
+
-1.842738389968872,
|
|
167
|
+
-0.0013586411951109767,
|
|
168
|
+
-0.042040832340717316,
|
|
169
|
+
],
|
|
170
|
+
},
|
|
171
|
+
"actions": {
|
|
172
|
+
"mean": [
|
|
173
|
+
0.06278137117624283,
|
|
174
|
+
0.0868409126996994,
|
|
175
|
+
-0.09037282317876816,
|
|
176
|
+
0.0005407406715676188,
|
|
177
|
+
0.005643361248075962,
|
|
178
|
+
-0.005229088477790356,
|
|
179
|
+
-0.04964079707860947,
|
|
180
|
+
],
|
|
181
|
+
"std": [
|
|
182
|
+
0.3355240225791931,
|
|
183
|
+
0.3784470558166504,
|
|
184
|
+
0.44472837448120117,
|
|
185
|
+
0.03924351558089256,
|
|
186
|
+
0.06339313089847565,
|
|
187
|
+
0.07797032594680786,
|
|
188
|
+
0.9987710118293762,
|
|
189
|
+
],
|
|
190
|
+
"max": [0.9375, 0.9375, 0.9375, 0.3557142913341522, 0.375, 0.375, 1.0],
|
|
191
|
+
"min": [-0.9375, -0.9375, -0.9375, -0.2582142949104309, -0.375, -0.3675000071525574, -1.0],
|
|
192
|
+
},
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def make_readme() -> str:
|
|
197
|
+
"""Generate a README string for automatically generated datasets.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
String containing information about when and how the dataset was generated.
|
|
201
|
+
"""
|
|
202
|
+
now_utc = datetime.datetime.now(datetime.timezone.utc)
|
|
203
|
+
return f"This dataset was automatically generated by {__file__} on {now_utc.isoformat()}"
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def make_libero_info(total_episodes: int, total_frames: int, total_chunks: int) -> dict:
|
|
207
|
+
"""Generate metadata dictionary for LIBERO dataset.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
total_episodes: Total number of episodes in the dataset.
|
|
211
|
+
total_frames: Total number of frames across all episodes.
|
|
212
|
+
total_chunks: Total number of data chunks.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Dictionary containing dataset metadata including codebase version, robot
|
|
216
|
+
type, episode/frame counts, task information, and feature specifications.
|
|
217
|
+
"""
|
|
218
|
+
return {
|
|
219
|
+
"codebase_version": "v2.0",
|
|
220
|
+
"robot_type": "panda",
|
|
221
|
+
"total_episodes": total_episodes,
|
|
222
|
+
"total_frames": total_frames,
|
|
223
|
+
"total_tasks": len(LIBERO_TASKS),
|
|
224
|
+
"total_videos": 0,
|
|
225
|
+
"total_chunks": total_chunks,
|
|
226
|
+
"chunks_size": LIBERO_CHUNK_SIZE,
|
|
227
|
+
"fps": LIBERO_FPS,
|
|
228
|
+
"splits": {"train": f"0:{total_episodes}"},
|
|
229
|
+
"data_path": "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet",
|
|
230
|
+
"video_path": "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4",
|
|
231
|
+
"features": {
|
|
232
|
+
"image": {"dtype": "image", "shape": [256, 256, 3], "names": ["height", "width", "channel"]},
|
|
233
|
+
"wrist_image": {
|
|
234
|
+
"dtype": "image",
|
|
235
|
+
"shape": [256, 256, 3],
|
|
236
|
+
"names": ["height", "width", "channel"],
|
|
237
|
+
},
|
|
238
|
+
"state": {"dtype": "float32", "shape": [8], "names": ["state"]},
|
|
239
|
+
"actions": {"dtype": "float32", "shape": [7], "names": ["actions"]},
|
|
240
|
+
"timestamp": {
|
|
241
|
+
"dtype": "float32",
|
|
242
|
+
"shape": [1],
|
|
243
|
+
"names": None,
|
|
244
|
+
},
|
|
245
|
+
"frame_index": {
|
|
246
|
+
"dtype": "int64",
|
|
247
|
+
"shape": [1],
|
|
248
|
+
"names": None,
|
|
249
|
+
},
|
|
250
|
+
"episode_index": {
|
|
251
|
+
"dtype": "int64",
|
|
252
|
+
"shape": [1],
|
|
253
|
+
"names": None,
|
|
254
|
+
},
|
|
255
|
+
"index": {
|
|
256
|
+
"dtype": "int64",
|
|
257
|
+
"shape": [1],
|
|
258
|
+
"names": None,
|
|
259
|
+
},
|
|
260
|
+
"task_index": {
|
|
261
|
+
"dtype": "int64",
|
|
262
|
+
"shape": [1],
|
|
263
|
+
"names": None,
|
|
264
|
+
},
|
|
265
|
+
},
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def uint8_arr_to_png_bytes(arr: np.ndarray) -> bytes:
|
|
270
|
+
"""Convert a uint8 image array to PNG bytes.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
arr: Image array in CHW format (channels, height, width) with uint8 values.
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
PNG-encoded image bytes.
|
|
277
|
+
"""
|
|
278
|
+
arr_hwc = np.transpose(arr, (1, 2, 0)) # shape: (244, 224, 3)
|
|
279
|
+
img = Image.fromarray(arr_hwc, mode="RGB")
|
|
280
|
+
buf = io.BytesIO()
|
|
281
|
+
img.save(buf, format="PNG")
|
|
282
|
+
return buf.getvalue()
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def consolidate_task_result(task_result: dict, output_dir: str | Path, allow_overwrite: bool = False) -> None:
|
|
286
|
+
"""Consolidate LIBERO task results into a standardized dataset format.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
task_result: Dictionary containing episode data with observations, actions,
|
|
290
|
+
and metadata.
|
|
291
|
+
output_dir: Directory path where the consolidated dataset will be saved.
|
|
292
|
+
allow_overwrite: If True, overwrite existing output directory. Defaults to False.
|
|
293
|
+
|
|
294
|
+
Raises:
|
|
295
|
+
FileExistsError: If output_dir already exists and allow_overwrite is False.
|
|
296
|
+
"""
|
|
297
|
+
output_dir = Path(output_dir)
|
|
298
|
+
if allow_overwrite and output_dir.is_dir():
|
|
299
|
+
shutil.rmtree(output_dir)
|
|
300
|
+
output_dir.mkdir(parents=True, exist_ok=False)
|
|
301
|
+
meta_dir = output_dir / "meta"
|
|
302
|
+
meta_dir.mkdir(parents=True, exist_ok=False)
|
|
303
|
+
|
|
304
|
+
with open(meta_dir / "tasks.jsonl", "w") as f:
|
|
305
|
+
for idx, task in enumerate(LIBERO_TASKS):
|
|
306
|
+
task_entry = json.dumps({"task_index": idx, "task": task})
|
|
307
|
+
f.write(task_entry + "\n")
|
|
308
|
+
|
|
309
|
+
with open(meta_dir / "stats.json", "w") as f:
|
|
310
|
+
json.dump(LIBERO_DEFAULT_STATS, f, indent=4)
|
|
311
|
+
|
|
312
|
+
all_lengths = [d["done_index"] + 1 for d in task_result["per_episode"]]
|
|
313
|
+
all_successes = [d["success"] for d in task_result["per_episode"]]
|
|
314
|
+
n_episodes = len(all_lengths)
|
|
315
|
+
|
|
316
|
+
with open(meta_dir / "info.json", "w") as f:
|
|
317
|
+
json.dump(
|
|
318
|
+
make_libero_info(
|
|
319
|
+
total_episodes=n_episodes,
|
|
320
|
+
total_frames=sum(all_lengths).item(),
|
|
321
|
+
total_chunks=int(math.ceil(n_episodes / 1000)),
|
|
322
|
+
),
|
|
323
|
+
f,
|
|
324
|
+
indent=4,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
with open(output_dir / "README.md", "w") as f:
|
|
328
|
+
f.write(make_readme())
|
|
329
|
+
|
|
330
|
+
global_idx_start = 0
|
|
331
|
+
for ep_idx, (ep_len, success) in enumerate(zip(all_lengths, all_successes, strict=False)):
|
|
332
|
+
ep_actions = task_result["episodes"]["action"][ep_idx][:ep_len]
|
|
333
|
+
ep_actions = ep_actions.to(torch.float32).numpy(force=True)
|
|
334
|
+
ep_images = task_result["episodes"]["observation"]["camera0"][ep_idx][:ep_len]
|
|
335
|
+
ep_images = (ep_images * 255).to(torch.uint8).numpy(force=True)
|
|
336
|
+
ep_wrist_images = task_result["episodes"]["observation"]["camera1"][ep_idx][:ep_len]
|
|
337
|
+
ep_wrist_images = (ep_wrist_images * 255).to(torch.uint8).numpy(force=True)
|
|
338
|
+
ep_states = task_result["episodes"]["observation"]["state"][ep_idx][:ep_len]
|
|
339
|
+
ep_states = ep_states.to(torch.float32).numpy(force=True)
|
|
340
|
+
|
|
341
|
+
prompt = task_result["episodes"]["observation"]["prompt"][ep_idx][0]
|
|
342
|
+
task_idx = LIBERO_TASK_TO_IDX[prompt]
|
|
343
|
+
df = pd.DataFrame(
|
|
344
|
+
{
|
|
345
|
+
"image": list(ep_images),
|
|
346
|
+
"wrist_image": list(ep_wrist_images),
|
|
347
|
+
"state": list(ep_states),
|
|
348
|
+
"actions": list(ep_actions),
|
|
349
|
+
"timestamp": np.arange(ep_len, dtype=np.float32) / np.float32(LIBERO_FPS),
|
|
350
|
+
"frame_index": np.arange(ep_len, dtype=int),
|
|
351
|
+
"episode_index": np.ones((ep_len,), dtype=int) * ep_idx,
|
|
352
|
+
"index": global_idx_start + np.arange(ep_len, dtype=int),
|
|
353
|
+
"task_index": [task_idx] * len(ep_actions),
|
|
354
|
+
}
|
|
355
|
+
)
|
|
356
|
+
global_idx_start += ep_len
|
|
357
|
+
|
|
358
|
+
cnt = count()
|
|
359
|
+
df.image = df.image.apply(
|
|
360
|
+
lambda x, counter=cnt: {
|
|
361
|
+
"bytes": uint8_arr_to_png_bytes(x),
|
|
362
|
+
"path": f"frame_{next(counter):06d}.png",
|
|
363
|
+
}
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
cnt = count()
|
|
367
|
+
df.wrist_image = df.wrist_image.apply(
|
|
368
|
+
lambda x, counter=cnt: {
|
|
369
|
+
"bytes": uint8_arr_to_png_bytes(x),
|
|
370
|
+
"path": f"frame_{next(counter):06d}.png",
|
|
371
|
+
}
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
df.actions = df.actions.apply(lambda x: x[:7]) # only keep first 7 dimensions of action
|
|
375
|
+
df.state = df.state.apply(lambda x: x[:8]) # only keep first 8 dimensions of state
|
|
376
|
+
|
|
377
|
+
chunk_idx = ep_idx // LIBERO_CHUNK_SIZE
|
|
378
|
+
parquet_path = output_dir / "data" / f"chunk-{chunk_idx:03d}" / f"episode_{ep_idx:06d}.parquet"
|
|
379
|
+
parquet_path.parent.mkdir(parents=True, exist_ok=True)
|
|
380
|
+
|
|
381
|
+
# Convert the pandas DataFrame to a PyArrow Table without the pandas index
|
|
382
|
+
table = pa.Table.from_pandas(df, preserve_index=False)
|
|
383
|
+
|
|
384
|
+
# Inject Hugging Face-style metadata so `datasets` will reconstruct the right features.
|
|
385
|
+
# Target:
|
|
386
|
+
# {b'huggingface': b'{"info": {"features": {"image": {"_type": "Image"}, ...}}}'}
|
|
387
|
+
hf_features = {
|
|
388
|
+
"image": {"_type": "Image"},
|
|
389
|
+
"wrist_image": {"_type": "Image"},
|
|
390
|
+
"state": {
|
|
391
|
+
"feature": {"dtype": "float32", "_type": "Value"},
|
|
392
|
+
"length": 8,
|
|
393
|
+
"_type": "Sequence",
|
|
394
|
+
},
|
|
395
|
+
"actions": {
|
|
396
|
+
"feature": {"dtype": "float32", "_type": "Value"},
|
|
397
|
+
"length": 7,
|
|
398
|
+
"_type": "Sequence",
|
|
399
|
+
},
|
|
400
|
+
"timestamp": {"dtype": "float32", "_type": "Value"},
|
|
401
|
+
"frame_index": {"dtype": "int64", "_type": "Value"},
|
|
402
|
+
"episode_index": {"dtype": "int64", "_type": "Value"},
|
|
403
|
+
"index": {"dtype": "int64", "_type": "Value"},
|
|
404
|
+
"task_index": {"dtype": "int64", "_type": "Value"},
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
hf_meta = {"info": {"features": hf_features}}
|
|
408
|
+
|
|
409
|
+
schema_metadata = dict(table.schema.metadata or {})
|
|
410
|
+
schema_metadata[b"huggingface"] = json.dumps(hf_meta).encode("utf-8")
|
|
411
|
+
table = table.replace_schema_metadata(schema_metadata)
|
|
412
|
+
|
|
413
|
+
# Finally, write the Parquet file with the patched schema metadata
|
|
414
|
+
pq.write_table(table, parquet_path)
|
|
415
|
+
|
|
416
|
+
with open(meta_dir / "episodes.jsonl", "a") as f:
|
|
417
|
+
ep_entry = json.dumps(
|
|
418
|
+
{
|
|
419
|
+
"episode_index": ep_idx,
|
|
420
|
+
"tasks": [prompt],
|
|
421
|
+
"length": ep_len.item(),
|
|
422
|
+
"success": success,
|
|
423
|
+
}
|
|
424
|
+
)
|
|
425
|
+
f.write(ep_entry + "\n")
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def aggregate_task_results(results: list[dict]) -> dict:
|
|
429
|
+
"""Aggregate multiple LIBERO task results into a single result dictionary.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
results: List of task result dictionaries to aggregate.
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
Dictionary containing aggregated episode data with all observations,
|
|
436
|
+
actions, and per-episode metadata combined.
|
|
437
|
+
"""
|
|
438
|
+
logging.info(f"Aggregating {len(results)} result(s)")
|
|
439
|
+
ret = {
|
|
440
|
+
"per_episode": [],
|
|
441
|
+
"episodes": {
|
|
442
|
+
"observation": {
|
|
443
|
+
"camera0": [],
|
|
444
|
+
"camera1": [],
|
|
445
|
+
"state": [],
|
|
446
|
+
"prompt": [],
|
|
447
|
+
},
|
|
448
|
+
"action": [],
|
|
449
|
+
},
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
for task_result in results:
|
|
453
|
+
ret["per_episode"].extend(task_result["per_episode"])
|
|
454
|
+
ret["episodes"]["observation"]["camera0"].extend(task_result["episodes"]["observation"]["camera0"])
|
|
455
|
+
ret["episodes"]["observation"]["camera1"].extend(task_result["episodes"]["observation"]["camera1"])
|
|
456
|
+
ret["episodes"]["observation"]["state"].extend(task_result["episodes"]["observation"]["state"])
|
|
457
|
+
ret["episodes"]["observation"]["prompt"].extend(task_result["episodes"]["observation"]["prompt"])
|
|
458
|
+
ret["episodes"]["action"].extend(task_result["episodes"]["action"])
|
|
459
|
+
|
|
460
|
+
return ret
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""Utilities for tracking and logging training metrics.
|
|
18
|
+
|
|
19
|
+
This module provides classes for tracking metrics during training, including
|
|
20
|
+
AverageMeter for computing running averages and MetricsTracker for managing
|
|
21
|
+
multiple metrics with step tracking.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from typing import Any
|
|
25
|
+
|
|
26
|
+
from opentau.utils.utils import format_big_number
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AverageMeter:
|
|
30
|
+
"""Computes and stores the average and current value.
|
|
31
|
+
|
|
32
|
+
Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
name: Name of the metric being tracked.
|
|
36
|
+
fmt: Format string for displaying the average value. Defaults to ":f".
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, name: str, fmt: str = ":f"):
|
|
40
|
+
self.name = name
|
|
41
|
+
self.fmt = fmt
|
|
42
|
+
self.reset()
|
|
43
|
+
|
|
44
|
+
def reset(self) -> None:
|
|
45
|
+
"""Reset all accumulated statistics to zero."""
|
|
46
|
+
self.val = 0.0
|
|
47
|
+
self.avg = 0.0
|
|
48
|
+
self.sum = 0.0
|
|
49
|
+
self.count = 0.0
|
|
50
|
+
|
|
51
|
+
def update(self, val: float, n: int = 1) -> None:
|
|
52
|
+
"""Update the meter with a new value.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
val: New value to add.
|
|
56
|
+
n: Number of samples this value represents. Defaults to 1.
|
|
57
|
+
"""
|
|
58
|
+
self.val = val
|
|
59
|
+
self.sum += val * n
|
|
60
|
+
self.count += n
|
|
61
|
+
self.avg = self.sum / self.count
|
|
62
|
+
|
|
63
|
+
def __str__(self):
|
|
64
|
+
fmtstr = "{name}:{avg" + self.fmt + "}"
|
|
65
|
+
return fmtstr.format(**self.__dict__)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class MetricsTracker:
|
|
69
|
+
"""
|
|
70
|
+
A helper class to track and log metrics over time.
|
|
71
|
+
|
|
72
|
+
Usage pattern::
|
|
73
|
+
|
|
74
|
+
# initialize, potentially with non-zero initial step (e.g. if resuming run)
|
|
75
|
+
metrics = {"loss": AverageMeter("loss", ":.3f")}
|
|
76
|
+
train_metrics = MetricsTracker(cfg, dataset, metrics, initial_step=step)
|
|
77
|
+
|
|
78
|
+
# update metrics derived from step (samples, episodes, epochs) at each training step
|
|
79
|
+
train_metrics.step()
|
|
80
|
+
|
|
81
|
+
# update various metrics
|
|
82
|
+
loss = policy.forward(batch)
|
|
83
|
+
train_metrics.loss = loss
|
|
84
|
+
|
|
85
|
+
# display current metrics
|
|
86
|
+
logging.info(train_metrics)
|
|
87
|
+
|
|
88
|
+
# export for wandb
|
|
89
|
+
wandb.log(train_metrics.to_dict())
|
|
90
|
+
|
|
91
|
+
# reset averages after logging
|
|
92
|
+
train_metrics.reset_averages()
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
__keys__ = [
|
|
96
|
+
"_batch_size",
|
|
97
|
+
"metrics",
|
|
98
|
+
"steps",
|
|
99
|
+
"samples",
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
batch_size: int,
|
|
105
|
+
metrics: dict[str, AverageMeter],
|
|
106
|
+
initial_step: int = 0,
|
|
107
|
+
):
|
|
108
|
+
"""Initialize the metrics tracker.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
batch_size: Number of samples per gradient update.
|
|
112
|
+
metrics: Dictionary of metric names to AverageMeter instances.
|
|
113
|
+
initial_step: Starting step number (useful when resuming training).
|
|
114
|
+
Defaults to 0.
|
|
115
|
+
"""
|
|
116
|
+
self.__dict__.update(dict.fromkeys(self.__keys__))
|
|
117
|
+
# This is the actual batch size, i.e., number of samples used to compute a gradient update;
|
|
118
|
+
# not to be confused with number of dataloader_batch_size.
|
|
119
|
+
self._batch_size = batch_size
|
|
120
|
+
self.metrics = metrics
|
|
121
|
+
|
|
122
|
+
# This is the actual step, i.e., number of gradient updates so far;
|
|
123
|
+
# not to be confused with number of passes (sub-steps)
|
|
124
|
+
self.steps = initial_step
|
|
125
|
+
# A sample is an (observation,action) pair, where observation and action
|
|
126
|
+
# can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
|
|
127
|
+
self.samples = self.steps * self._batch_size
|
|
128
|
+
|
|
129
|
+
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
|
130
|
+
if name in self.__dict__:
|
|
131
|
+
return self.__dict__[name]
|
|
132
|
+
elif name in self.metrics:
|
|
133
|
+
return self.metrics[name]
|
|
134
|
+
else:
|
|
135
|
+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
|
136
|
+
|
|
137
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
138
|
+
if name in self.__dict__:
|
|
139
|
+
super().__setattr__(name, value)
|
|
140
|
+
elif name in self.metrics:
|
|
141
|
+
self.metrics[name].update(value)
|
|
142
|
+
else:
|
|
143
|
+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
|
144
|
+
|
|
145
|
+
def step(self) -> None:
|
|
146
|
+
"""
|
|
147
|
+
Updates metrics that depend on 'step' for one step.
|
|
148
|
+
"""
|
|
149
|
+
self.steps += 1
|
|
150
|
+
self.samples += self._batch_size
|
|
151
|
+
|
|
152
|
+
def __str__(self) -> str:
|
|
153
|
+
display_list = [
|
|
154
|
+
f"step:{format_big_number(self.steps)}",
|
|
155
|
+
# number of samples seen during training
|
|
156
|
+
f"smpl:{format_big_number(self.samples)}",
|
|
157
|
+
*[str(m) for m in self.metrics.values()],
|
|
158
|
+
]
|
|
159
|
+
return " ".join(display_list)
|
|
160
|
+
|
|
161
|
+
def to_dict(self, use_avg: bool = True) -> dict[str, int | float]:
|
|
162
|
+
"""Convert current metrics to a dictionary.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
use_avg: If True, use average values; otherwise use current values.
|
|
166
|
+
Defaults to True.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
Dictionary containing steps, samples, and all metric values.
|
|
170
|
+
"""
|
|
171
|
+
return {
|
|
172
|
+
"steps": self.steps,
|
|
173
|
+
"samples": self.samples,
|
|
174
|
+
**{k: m.avg if use_avg else m.val for k, m in self.metrics.items()},
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
def reset_averages(self) -> None:
|
|
178
|
+
"""Resets average meters."""
|
|
179
|
+
for m in self.metrics.values():
|
|
180
|
+
m.reset()
|