opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,78 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Backward compatibility error handling for dataset format versions.
17
+
18
+ This module provides exception classes and error messages for handling dataset
19
+ format version incompatibilities. It supports detection and user-friendly
20
+ reporting of backward and forward compatibility issues when loading datasets
21
+ that use older or newer format versions than the current OpenTau version
22
+ supports.
23
+ """
24
+
25
+ import packaging.version
26
+
27
+ V2_MESSAGE = """
28
+ The dataset you requested ({repo_id}) is in {version} format.
29
+
30
+ We introduced a new format since v2.0 which is not backward compatible with v1.x.
31
+ Please, use our conversion script. Modify the following command with your own task description:
32
+ ```
33
+ python src/opentau/datasets/v2/convert_dataset_v1_to_v2.py \\
34
+ --repo-id {repo_id} \\
35
+ --single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
36
+ ```
37
+
38
+ A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
39
+ peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
40
+ cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
41
+ target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
42
+ sweatshirt.", ...
43
+
44
+ If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
45
+ or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
46
+ """
47
+
48
+ V21_MESSAGE = """
49
+ The dataset you requested ({repo_id}) is in {version} format.
50
+ While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
51
+ stats instead of per-episode stats. Update your dataset stats to the new format using this command:
52
+ ```
53
+ python src/opentau/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
54
+ ```
55
+
56
+ If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
57
+ or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
58
+ """
59
+
60
+ FUTURE_MESSAGE = """
61
+ The dataset you requested ({repo_id}) is only available in {version} format.
62
+ As we cannot ensure forward compatibility with it, please update your current version of opentau.
63
+ """
64
+
65
+
66
+ class CompatibilityError(Exception): ...
67
+
68
+
69
+ class BackwardCompatibilityError(CompatibilityError):
70
+ def __init__(self, repo_id: str, version: packaging.version.Version):
71
+ message = V2_MESSAGE.format(repo_id=repo_id, version=version)
72
+ super().__init__(message)
73
+
74
+
75
+ class ForwardCompatibilityError(CompatibilityError):
76
+ def __init__(self, repo_id: str, version: packaging.version.Version):
77
+ message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
78
+ super().__init__(message)
@@ -0,0 +1,333 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Statistics computation and aggregation for dataset features.
19
+
20
+ This module provides functionality to compute statistical measures (min, max,
21
+ mean, standard deviation, and count) for dataset features, with special
22
+ handling for image and video data. It supports per-episode statistics
23
+ computation and aggregation across multiple episodes or datasets using
24
+ weighted averaging.
25
+
26
+ The module handles two main use cases:
27
+ 1. Computing statistics for individual episodes: Samples images efficiently,
28
+ downsamples large images to reduce memory usage, and computes statistics
29
+ for all feature types (images, vectors, etc.).
30
+ 2. Aggregating statistics across multiple episodes/datasets: Combines
31
+ statistics using weighted mean and variance computation, taking global
32
+ min/max values.
33
+
34
+ Key Features:
35
+ - Memory-efficient image sampling: Uses heuristic-based sampling to
36
+ estimate optimal number of samples based on dataset size.
37
+ - Automatic image downsampling: Reduces large images (>300px) to ~150px
38
+ for faster processing.
39
+ - Weighted aggregation: Supports custom weights or uses episode counts
40
+ as weights for aggregating statistics.
41
+ - Parallel variance algorithm: Uses efficient algorithm for computing
42
+ weighted variance across multiple statistics.
43
+
44
+ Functions:
45
+ estimate_num_samples
46
+ Heuristic to estimate optimal number of samples based on dataset size.
47
+ sample_indices
48
+ Generate evenly spaced sample indices from a dataset.
49
+ auto_downsample_height_width
50
+ Automatically downsample large images.
51
+ sample_images
52
+ Load and downsample a subset of images from file paths.
53
+ get_feature_stats
54
+ Compute statistical measures for an array.
55
+ compute_episode_stats
56
+ Compute statistics for a single episode.
57
+ aggregate_feature_stats
58
+ Aggregate statistics for a feature across multiple episodes.
59
+ aggregate_stats
60
+ Aggregate statistics from multiple episodes/datasets.
61
+
62
+ Example:
63
+ Compute statistics for a single episode:
64
+ >>> episode_data = {"state": state_array, "camera0": image_paths}
65
+ >>> features = {"state": {"dtype": "float32"}, "camera0": {"dtype": "image"}}
66
+ >>> stats = compute_episode_stats(episode_data, features)
67
+
68
+ Aggregate statistics across multiple episodes:
69
+ >>> stats_list = [episode1_stats, episode2_stats, episode3_stats]
70
+ >>> weights = [100, 200, 150] # Optional: custom weights
71
+ >>> aggregated = aggregate_stats(stats_list, weights=weights)
72
+ """
73
+
74
+ from typing import Optional
75
+
76
+ import numpy as np
77
+
78
+ from opentau.datasets.utils import load_image_as_numpy
79
+
80
+
81
+ def estimate_num_samples(
82
+ dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
83
+ ) -> int:
84
+ """Heuristic to estimate the number of samples based on dataset size.
85
+ The power controls the sample growth relative to dataset size.
86
+ Lower the power for less number of samples.
87
+
88
+ For default arguments, we have:
89
+ - from 1 to ~500, num_samples=100
90
+ - at 1000, num_samples=177
91
+ - at 2000, num_samples=299
92
+ - at 5000, num_samples=594
93
+ - at 10000, num_samples=1000
94
+ - at 20000, num_samples=1681
95
+ """
96
+ if dataset_len < min_num_samples:
97
+ min_num_samples = dataset_len
98
+ return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
99
+
100
+
101
+ def sample_indices(data_len: int) -> list[int]:
102
+ """Generate evenly spaced sample indices from a dataset.
103
+
104
+ Uses estimate_num_samples to determine how many samples to take,
105
+ then returns evenly spaced indices across the dataset length.
106
+
107
+ Args:
108
+ data_len: Total length of the dataset.
109
+
110
+ Returns:
111
+ List of evenly spaced integer indices.
112
+ """
113
+ num_samples = estimate_num_samples(data_len)
114
+ return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
115
+
116
+
117
+ def auto_downsample_height_width(
118
+ img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300
119
+ ) -> np.ndarray:
120
+ """Automatically downsample an image if it exceeds size threshold.
121
+
122
+ If the image's maximum dimension is below the threshold, returns it unchanged.
123
+ Otherwise, downsamples by an integer factor to bring the larger dimension
124
+ close to the target size.
125
+
126
+ Args:
127
+ img: Input image array of shape (C, H, W).
128
+ target_size: Target size for the larger dimension after downsampling.
129
+ Defaults to 150.
130
+ max_size_threshold: Maximum size before downsampling is applied.
131
+ Defaults to 300.
132
+
133
+ Returns:
134
+ Downsampled image array, or original if no downsampling needed.
135
+ """
136
+ _, height, width = img.shape
137
+
138
+ if max(width, height) < max_size_threshold:
139
+ # no downsampling needed
140
+ return img
141
+
142
+ downsample_factor = int(width / target_size) if width > height else int(height / target_size)
143
+ return img[:, ::downsample_factor, ::downsample_factor]
144
+
145
+
146
+ def sample_images(image_paths: list[str]) -> np.ndarray:
147
+ """Load and downsample a subset of images from file paths.
148
+
149
+ Samples images using evenly spaced indices, loads them as uint8 arrays,
150
+ and automatically downsamples large images to reduce memory usage.
151
+
152
+ Args:
153
+ image_paths: List of file paths to image files.
154
+
155
+ Returns:
156
+ Array of shape (num_samples, C, H, W) containing sampled images as uint8.
157
+ """
158
+ sampled_indices = sample_indices(len(image_paths))
159
+
160
+ images = None
161
+ for i, idx in enumerate(sampled_indices):
162
+ path = image_paths[idx]
163
+ # we load as uint8 to reduce memory usage
164
+ img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
165
+ img = auto_downsample_height_width(img)
166
+
167
+ if images is None:
168
+ images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
169
+
170
+ images[i] = img
171
+
172
+ return images
173
+
174
+
175
+ def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
176
+ """Compute statistical measures (min, max, mean, std, count) for an array.
177
+
178
+ Args:
179
+ array: Input numpy array to compute statistics over.
180
+ axis: Axes along which to compute statistics.
181
+ keepdims: Whether to keep reduced dimensions.
182
+
183
+ Returns:
184
+ Dictionary containing 'min', 'max', 'mean', 'std', and 'count' statistics.
185
+ """
186
+ return {
187
+ "min": np.min(array, axis=axis, keepdims=keepdims),
188
+ "max": np.max(array, axis=axis, keepdims=keepdims),
189
+ "mean": np.mean(array, axis=axis, keepdims=keepdims),
190
+ "std": np.std(array, axis=axis, keepdims=keepdims),
191
+ "count": np.array([len(array)]),
192
+ }
193
+
194
+
195
+ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
196
+ """Compute statistics for a single episode.
197
+
198
+ For image/video features, samples and downsamples images before computing stats.
199
+ For other features, computes stats directly on the array data.
200
+
201
+ Args:
202
+ episode_data: Dictionary mapping feature names to their data (arrays or image paths).
203
+ features: Dictionary of feature specifications with 'dtype' keys.
204
+
205
+ Returns:
206
+ Dictionary mapping feature names to their statistics (min, max, mean, std, count).
207
+ Image statistics are normalized to [0, 1] range.
208
+ """
209
+ ep_stats = {}
210
+ for key, data in episode_data.items():
211
+ if features[key]["dtype"] == "string":
212
+ continue # HACK: we should receive np.arrays of strings
213
+ elif features[key]["dtype"] in ["image", "video"]:
214
+ ep_ft_array = sample_images(data) # data is a list of image paths
215
+ axes_to_reduce = (0, 2, 3) # keep channel dim
216
+ keepdims = True
217
+ else:
218
+ ep_ft_array = data # data is already a np.ndarray
219
+ axes_to_reduce = 0 # compute stats over the first axis
220
+ keepdims = data.ndim == 1 # keep as np.array
221
+
222
+ ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
223
+
224
+ # finally, we normalize and remove batch dim for images
225
+ if features[key]["dtype"] in ["image", "video"]:
226
+ ep_stats[key] = {
227
+ k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
228
+ }
229
+
230
+ return ep_stats
231
+
232
+
233
+ def _assert_type_and_shape(stats_list: list[dict[str, dict]]) -> None:
234
+ """Validate that statistics dictionaries have correct types and shapes.
235
+
236
+ Checks that all values are numpy arrays, have at least 1 dimension,
237
+ count has shape (1,), and image stats have shape (3, 1, 1).
238
+
239
+ Args:
240
+ stats_list: List of statistics dictionaries to validate.
241
+
242
+ Raises:
243
+ ValueError: If any statistic has incorrect type or shape.
244
+ """
245
+ for i in range(len(stats_list)):
246
+ for fkey in stats_list[i]:
247
+ for k, v in stats_list[i][fkey].items():
248
+ if not isinstance(v, np.ndarray):
249
+ raise ValueError(
250
+ f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
251
+ )
252
+ if v.ndim == 0:
253
+ raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
254
+ if k == "count" and v.shape != (1,):
255
+ raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
256
+ if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
257
+ raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
258
+
259
+
260
+ def aggregate_feature_stats(
261
+ stats_ft_list: list[dict[str, dict]], weights: Optional[list[float]] = None
262
+ ) -> dict[str, dict[str, np.ndarray]]:
263
+ """Aggregate statistics for a single feature across multiple episodes/datasets.
264
+
265
+ Computes weighted mean and variance using the parallel algorithm for variance
266
+ computation. Min and max are taken as the global min/max across all stats.
267
+
268
+ Args:
269
+ stats_ft_list: List of statistics dictionaries for the same feature.
270
+ weights: Optional weights for each statistics entry. If None, uses
271
+ count values as weights.
272
+
273
+ Returns:
274
+ Aggregated statistics dictionary with min, max, mean, std, and count.
275
+ """
276
+ means = np.stack([s["mean"] for s in stats_ft_list])
277
+ variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
278
+
279
+ # if weights are provided, use them to compute the weighted mean and variance
280
+ # otherwise, use episode counts as weights
281
+ if weights is not None:
282
+ counts = np.stack(weights)
283
+ total_count = counts.sum(axis=0)
284
+ else:
285
+ counts = np.stack([s["count"] for s in stats_ft_list])
286
+ total_count = counts.sum(axis=0)
287
+
288
+ # Prepare weighted mean by matching number of dimensions
289
+ while counts.ndim < means.ndim:
290
+ counts = np.expand_dims(counts, axis=-1)
291
+
292
+ # Compute the weighted mean
293
+ weighted_means = means * counts
294
+ total_mean = weighted_means.sum(axis=0) / total_count
295
+
296
+ # Compute the variance using the parallel algorithm
297
+ delta_means = means - total_mean
298
+ weighted_variances = (variances + delta_means**2) * counts
299
+ total_variance = weighted_variances.sum(axis=0) / total_count
300
+
301
+ return {
302
+ "min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
303
+ "max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
304
+ "mean": total_mean,
305
+ "std": np.sqrt(total_variance),
306
+ "count": total_count,
307
+ }
308
+
309
+
310
+ def aggregate_stats(
311
+ stats_list: list[dict[str, dict]], weights: Optional[list[float]] = None
312
+ ) -> dict[str, dict[str, np.ndarray]]:
313
+ """Aggregate stats from multiple compute_stats outputs into a single set of stats.
314
+
315
+ The final stats will have the union of all data keys from each of the stats dicts.
316
+
317
+ For instance:
318
+ - new_min = min(min_dataset_0, min_dataset_1, ...)
319
+ - new_max = max(max_dataset_0, max_dataset_1, ...)
320
+ - new_mean = (mean of all data, weighted by counts)
321
+ - new_std = (std of all data)
322
+ """
323
+
324
+ _assert_type_and_shape(stats_list)
325
+
326
+ data_keys = {key for stats in stats_list for key in stats}
327
+ aggregated_stats = {key: {} for key in data_keys}
328
+
329
+ for key in data_keys:
330
+ stats_with_key = [stats[key] for stats in stats_list if key in stats]
331
+ aggregated_stats[key] = aggregate_feature_stats(stats_with_key, weights)
332
+
333
+ return aggregated_stats