opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
"""Backward compatibility error handling for dataset format versions.
|
|
17
|
+
|
|
18
|
+
This module provides exception classes and error messages for handling dataset
|
|
19
|
+
format version incompatibilities. It supports detection and user-friendly
|
|
20
|
+
reporting of backward and forward compatibility issues when loading datasets
|
|
21
|
+
that use older or newer format versions than the current OpenTau version
|
|
22
|
+
supports.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import packaging.version
|
|
26
|
+
|
|
27
|
+
V2_MESSAGE = """
|
|
28
|
+
The dataset you requested ({repo_id}) is in {version} format.
|
|
29
|
+
|
|
30
|
+
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
|
31
|
+
Please, use our conversion script. Modify the following command with your own task description:
|
|
32
|
+
```
|
|
33
|
+
python src/opentau/datasets/v2/convert_dataset_v1_to_v2.py \\
|
|
34
|
+
--repo-id {repo_id} \\
|
|
35
|
+
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
|
|
39
|
+
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
|
|
40
|
+
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
|
|
41
|
+
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
|
|
42
|
+
sweatshirt.", ...
|
|
43
|
+
|
|
44
|
+
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
|
45
|
+
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
V21_MESSAGE = """
|
|
49
|
+
The dataset you requested ({repo_id}) is in {version} format.
|
|
50
|
+
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
|
|
51
|
+
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
|
|
52
|
+
```
|
|
53
|
+
python src/opentau/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
|
57
|
+
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
FUTURE_MESSAGE = """
|
|
61
|
+
The dataset you requested ({repo_id}) is only available in {version} format.
|
|
62
|
+
As we cannot ensure forward compatibility with it, please update your current version of opentau.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class CompatibilityError(Exception): ...
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class BackwardCompatibilityError(CompatibilityError):
|
|
70
|
+
def __init__(self, repo_id: str, version: packaging.version.Version):
|
|
71
|
+
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
|
72
|
+
super().__init__(message)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ForwardCompatibilityError(CompatibilityError):
|
|
76
|
+
def __init__(self, repo_id: str, version: packaging.version.Version):
|
|
77
|
+
message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
|
|
78
|
+
super().__init__(message)
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
"""Statistics computation and aggregation for dataset features.
|
|
19
|
+
|
|
20
|
+
This module provides functionality to compute statistical measures (min, max,
|
|
21
|
+
mean, standard deviation, and count) for dataset features, with special
|
|
22
|
+
handling for image and video data. It supports per-episode statistics
|
|
23
|
+
computation and aggregation across multiple episodes or datasets using
|
|
24
|
+
weighted averaging.
|
|
25
|
+
|
|
26
|
+
The module handles two main use cases:
|
|
27
|
+
1. Computing statistics for individual episodes: Samples images efficiently,
|
|
28
|
+
downsamples large images to reduce memory usage, and computes statistics
|
|
29
|
+
for all feature types (images, vectors, etc.).
|
|
30
|
+
2. Aggregating statistics across multiple episodes/datasets: Combines
|
|
31
|
+
statistics using weighted mean and variance computation, taking global
|
|
32
|
+
min/max values.
|
|
33
|
+
|
|
34
|
+
Key Features:
|
|
35
|
+
- Memory-efficient image sampling: Uses heuristic-based sampling to
|
|
36
|
+
estimate optimal number of samples based on dataset size.
|
|
37
|
+
- Automatic image downsampling: Reduces large images (>300px) to ~150px
|
|
38
|
+
for faster processing.
|
|
39
|
+
- Weighted aggregation: Supports custom weights or uses episode counts
|
|
40
|
+
as weights for aggregating statistics.
|
|
41
|
+
- Parallel variance algorithm: Uses efficient algorithm for computing
|
|
42
|
+
weighted variance across multiple statistics.
|
|
43
|
+
|
|
44
|
+
Functions:
|
|
45
|
+
estimate_num_samples
|
|
46
|
+
Heuristic to estimate optimal number of samples based on dataset size.
|
|
47
|
+
sample_indices
|
|
48
|
+
Generate evenly spaced sample indices from a dataset.
|
|
49
|
+
auto_downsample_height_width
|
|
50
|
+
Automatically downsample large images.
|
|
51
|
+
sample_images
|
|
52
|
+
Load and downsample a subset of images from file paths.
|
|
53
|
+
get_feature_stats
|
|
54
|
+
Compute statistical measures for an array.
|
|
55
|
+
compute_episode_stats
|
|
56
|
+
Compute statistics for a single episode.
|
|
57
|
+
aggregate_feature_stats
|
|
58
|
+
Aggregate statistics for a feature across multiple episodes.
|
|
59
|
+
aggregate_stats
|
|
60
|
+
Aggregate statistics from multiple episodes/datasets.
|
|
61
|
+
|
|
62
|
+
Example:
|
|
63
|
+
Compute statistics for a single episode:
|
|
64
|
+
>>> episode_data = {"state": state_array, "camera0": image_paths}
|
|
65
|
+
>>> features = {"state": {"dtype": "float32"}, "camera0": {"dtype": "image"}}
|
|
66
|
+
>>> stats = compute_episode_stats(episode_data, features)
|
|
67
|
+
|
|
68
|
+
Aggregate statistics across multiple episodes:
|
|
69
|
+
>>> stats_list = [episode1_stats, episode2_stats, episode3_stats]
|
|
70
|
+
>>> weights = [100, 200, 150] # Optional: custom weights
|
|
71
|
+
>>> aggregated = aggregate_stats(stats_list, weights=weights)
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
from typing import Optional
|
|
75
|
+
|
|
76
|
+
import numpy as np
|
|
77
|
+
|
|
78
|
+
from opentau.datasets.utils import load_image_as_numpy
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def estimate_num_samples(
|
|
82
|
+
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
|
83
|
+
) -> int:
|
|
84
|
+
"""Heuristic to estimate the number of samples based on dataset size.
|
|
85
|
+
The power controls the sample growth relative to dataset size.
|
|
86
|
+
Lower the power for less number of samples.
|
|
87
|
+
|
|
88
|
+
For default arguments, we have:
|
|
89
|
+
- from 1 to ~500, num_samples=100
|
|
90
|
+
- at 1000, num_samples=177
|
|
91
|
+
- at 2000, num_samples=299
|
|
92
|
+
- at 5000, num_samples=594
|
|
93
|
+
- at 10000, num_samples=1000
|
|
94
|
+
- at 20000, num_samples=1681
|
|
95
|
+
"""
|
|
96
|
+
if dataset_len < min_num_samples:
|
|
97
|
+
min_num_samples = dataset_len
|
|
98
|
+
return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def sample_indices(data_len: int) -> list[int]:
|
|
102
|
+
"""Generate evenly spaced sample indices from a dataset.
|
|
103
|
+
|
|
104
|
+
Uses estimate_num_samples to determine how many samples to take,
|
|
105
|
+
then returns evenly spaced indices across the dataset length.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
data_len: Total length of the dataset.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
List of evenly spaced integer indices.
|
|
112
|
+
"""
|
|
113
|
+
num_samples = estimate_num_samples(data_len)
|
|
114
|
+
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def auto_downsample_height_width(
|
|
118
|
+
img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300
|
|
119
|
+
) -> np.ndarray:
|
|
120
|
+
"""Automatically downsample an image if it exceeds size threshold.
|
|
121
|
+
|
|
122
|
+
If the image's maximum dimension is below the threshold, returns it unchanged.
|
|
123
|
+
Otherwise, downsamples by an integer factor to bring the larger dimension
|
|
124
|
+
close to the target size.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
img: Input image array of shape (C, H, W).
|
|
128
|
+
target_size: Target size for the larger dimension after downsampling.
|
|
129
|
+
Defaults to 150.
|
|
130
|
+
max_size_threshold: Maximum size before downsampling is applied.
|
|
131
|
+
Defaults to 300.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Downsampled image array, or original if no downsampling needed.
|
|
135
|
+
"""
|
|
136
|
+
_, height, width = img.shape
|
|
137
|
+
|
|
138
|
+
if max(width, height) < max_size_threshold:
|
|
139
|
+
# no downsampling needed
|
|
140
|
+
return img
|
|
141
|
+
|
|
142
|
+
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
|
|
143
|
+
return img[:, ::downsample_factor, ::downsample_factor]
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def sample_images(image_paths: list[str]) -> np.ndarray:
|
|
147
|
+
"""Load and downsample a subset of images from file paths.
|
|
148
|
+
|
|
149
|
+
Samples images using evenly spaced indices, loads them as uint8 arrays,
|
|
150
|
+
and automatically downsamples large images to reduce memory usage.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
image_paths: List of file paths to image files.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Array of shape (num_samples, C, H, W) containing sampled images as uint8.
|
|
157
|
+
"""
|
|
158
|
+
sampled_indices = sample_indices(len(image_paths))
|
|
159
|
+
|
|
160
|
+
images = None
|
|
161
|
+
for i, idx in enumerate(sampled_indices):
|
|
162
|
+
path = image_paths[idx]
|
|
163
|
+
# we load as uint8 to reduce memory usage
|
|
164
|
+
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
|
165
|
+
img = auto_downsample_height_width(img)
|
|
166
|
+
|
|
167
|
+
if images is None:
|
|
168
|
+
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
|
169
|
+
|
|
170
|
+
images[i] = img
|
|
171
|
+
|
|
172
|
+
return images
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
|
176
|
+
"""Compute statistical measures (min, max, mean, std, count) for an array.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
array: Input numpy array to compute statistics over.
|
|
180
|
+
axis: Axes along which to compute statistics.
|
|
181
|
+
keepdims: Whether to keep reduced dimensions.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Dictionary containing 'min', 'max', 'mean', 'std', and 'count' statistics.
|
|
185
|
+
"""
|
|
186
|
+
return {
|
|
187
|
+
"min": np.min(array, axis=axis, keepdims=keepdims),
|
|
188
|
+
"max": np.max(array, axis=axis, keepdims=keepdims),
|
|
189
|
+
"mean": np.mean(array, axis=axis, keepdims=keepdims),
|
|
190
|
+
"std": np.std(array, axis=axis, keepdims=keepdims),
|
|
191
|
+
"count": np.array([len(array)]),
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
|
196
|
+
"""Compute statistics for a single episode.
|
|
197
|
+
|
|
198
|
+
For image/video features, samples and downsamples images before computing stats.
|
|
199
|
+
For other features, computes stats directly on the array data.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
episode_data: Dictionary mapping feature names to their data (arrays or image paths).
|
|
203
|
+
features: Dictionary of feature specifications with 'dtype' keys.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Dictionary mapping feature names to their statistics (min, max, mean, std, count).
|
|
207
|
+
Image statistics are normalized to [0, 1] range.
|
|
208
|
+
"""
|
|
209
|
+
ep_stats = {}
|
|
210
|
+
for key, data in episode_data.items():
|
|
211
|
+
if features[key]["dtype"] == "string":
|
|
212
|
+
continue # HACK: we should receive np.arrays of strings
|
|
213
|
+
elif features[key]["dtype"] in ["image", "video"]:
|
|
214
|
+
ep_ft_array = sample_images(data) # data is a list of image paths
|
|
215
|
+
axes_to_reduce = (0, 2, 3) # keep channel dim
|
|
216
|
+
keepdims = True
|
|
217
|
+
else:
|
|
218
|
+
ep_ft_array = data # data is already a np.ndarray
|
|
219
|
+
axes_to_reduce = 0 # compute stats over the first axis
|
|
220
|
+
keepdims = data.ndim == 1 # keep as np.array
|
|
221
|
+
|
|
222
|
+
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
|
223
|
+
|
|
224
|
+
# finally, we normalize and remove batch dim for images
|
|
225
|
+
if features[key]["dtype"] in ["image", "video"]:
|
|
226
|
+
ep_stats[key] = {
|
|
227
|
+
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
return ep_stats
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _assert_type_and_shape(stats_list: list[dict[str, dict]]) -> None:
|
|
234
|
+
"""Validate that statistics dictionaries have correct types and shapes.
|
|
235
|
+
|
|
236
|
+
Checks that all values are numpy arrays, have at least 1 dimension,
|
|
237
|
+
count has shape (1,), and image stats have shape (3, 1, 1).
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
stats_list: List of statistics dictionaries to validate.
|
|
241
|
+
|
|
242
|
+
Raises:
|
|
243
|
+
ValueError: If any statistic has incorrect type or shape.
|
|
244
|
+
"""
|
|
245
|
+
for i in range(len(stats_list)):
|
|
246
|
+
for fkey in stats_list[i]:
|
|
247
|
+
for k, v in stats_list[i][fkey].items():
|
|
248
|
+
if not isinstance(v, np.ndarray):
|
|
249
|
+
raise ValueError(
|
|
250
|
+
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
|
251
|
+
)
|
|
252
|
+
if v.ndim == 0:
|
|
253
|
+
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
|
254
|
+
if k == "count" and v.shape != (1,):
|
|
255
|
+
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
|
256
|
+
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
|
257
|
+
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def aggregate_feature_stats(
|
|
261
|
+
stats_ft_list: list[dict[str, dict]], weights: Optional[list[float]] = None
|
|
262
|
+
) -> dict[str, dict[str, np.ndarray]]:
|
|
263
|
+
"""Aggregate statistics for a single feature across multiple episodes/datasets.
|
|
264
|
+
|
|
265
|
+
Computes weighted mean and variance using the parallel algorithm for variance
|
|
266
|
+
computation. Min and max are taken as the global min/max across all stats.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
stats_ft_list: List of statistics dictionaries for the same feature.
|
|
270
|
+
weights: Optional weights for each statistics entry. If None, uses
|
|
271
|
+
count values as weights.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
Aggregated statistics dictionary with min, max, mean, std, and count.
|
|
275
|
+
"""
|
|
276
|
+
means = np.stack([s["mean"] for s in stats_ft_list])
|
|
277
|
+
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
|
278
|
+
|
|
279
|
+
# if weights are provided, use them to compute the weighted mean and variance
|
|
280
|
+
# otherwise, use episode counts as weights
|
|
281
|
+
if weights is not None:
|
|
282
|
+
counts = np.stack(weights)
|
|
283
|
+
total_count = counts.sum(axis=0)
|
|
284
|
+
else:
|
|
285
|
+
counts = np.stack([s["count"] for s in stats_ft_list])
|
|
286
|
+
total_count = counts.sum(axis=0)
|
|
287
|
+
|
|
288
|
+
# Prepare weighted mean by matching number of dimensions
|
|
289
|
+
while counts.ndim < means.ndim:
|
|
290
|
+
counts = np.expand_dims(counts, axis=-1)
|
|
291
|
+
|
|
292
|
+
# Compute the weighted mean
|
|
293
|
+
weighted_means = means * counts
|
|
294
|
+
total_mean = weighted_means.sum(axis=0) / total_count
|
|
295
|
+
|
|
296
|
+
# Compute the variance using the parallel algorithm
|
|
297
|
+
delta_means = means - total_mean
|
|
298
|
+
weighted_variances = (variances + delta_means**2) * counts
|
|
299
|
+
total_variance = weighted_variances.sum(axis=0) / total_count
|
|
300
|
+
|
|
301
|
+
return {
|
|
302
|
+
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
|
303
|
+
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
|
304
|
+
"mean": total_mean,
|
|
305
|
+
"std": np.sqrt(total_variance),
|
|
306
|
+
"count": total_count,
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def aggregate_stats(
|
|
311
|
+
stats_list: list[dict[str, dict]], weights: Optional[list[float]] = None
|
|
312
|
+
) -> dict[str, dict[str, np.ndarray]]:
|
|
313
|
+
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
|
314
|
+
|
|
315
|
+
The final stats will have the union of all data keys from each of the stats dicts.
|
|
316
|
+
|
|
317
|
+
For instance:
|
|
318
|
+
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
|
319
|
+
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
|
320
|
+
- new_mean = (mean of all data, weighted by counts)
|
|
321
|
+
- new_std = (std of all data)
|
|
322
|
+
"""
|
|
323
|
+
|
|
324
|
+
_assert_type_and_shape(stats_list)
|
|
325
|
+
|
|
326
|
+
data_keys = {key for stats in stats_list for key in stats}
|
|
327
|
+
aggregated_stats = {key: {} for key in data_keys}
|
|
328
|
+
|
|
329
|
+
for key in data_keys:
|
|
330
|
+
stats_with_key = [stats[key] for stats in stats_list if key in stats]
|
|
331
|
+
aggregated_stats[key] = aggregate_feature_stats(stats_with_key, weights)
|
|
332
|
+
|
|
333
|
+
return aggregated_stats
|