robocandywrapper 0.2.2__tar.gz → 0.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {robocandywrapper-0.2.2/robocandywrapper.egg-info → robocandywrapper-0.2.4}/PKG-INFO +1 -1
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/pyproject.toml +1 -1
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/__init__.py +1 -1
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/dataformats/lerobot_21/dataset.py +11 -1
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/factory.py +21 -1
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/metadata_view.py +46 -6
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/utils.py +0 -4
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/wrapper.py +124 -19
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4/robocandywrapper.egg-info}/PKG-INFO +1 -1
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper.egg-info/SOURCES.txt +2 -1
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/setup.py +1 -1
- robocandywrapper-0.2.4/tests/test_key_rename_stats.py +394 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/LICENSE +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/MANIFEST.in +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/README.md +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/constants.py +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/dataformats/__init__.py +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/dataformats/lerobot_21/__init__.py +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/dataformats/lerobot_21/utils.py +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/plugin.py +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/plugins/__init__.py +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/plugins/affordance.py +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/plugins/episode_outcome.py +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/samplers/__init__.py +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/samplers/config.py +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/samplers/factory.py +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/samplers/weighted.py +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper.egg-info/dependency_links.txt +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper.egg-info/requires.txt +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper.egg-info/top_level.txt +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/setup.cfg +0 -0
- {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/tests/test_dataset_weights_integration.py +0 -0
{robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/dataformats/lerobot_21/dataset.py
RENAMED
|
@@ -93,6 +93,7 @@ class LeRobot21DatasetMetadata:
|
|
|
93
93
|
try:
|
|
94
94
|
if force_cache_sync:
|
|
95
95
|
raise FileNotFoundError
|
|
96
|
+
self.pull_from_repo()
|
|
96
97
|
self.load_metadata()
|
|
97
98
|
except (FileNotFoundError, NotADirectoryError):
|
|
98
99
|
if is_valid_version(self.revision):
|
|
@@ -728,7 +729,11 @@ class LeRobot21Dataset(torch.utils.data.Dataset):
|
|
|
728
729
|
item = {}
|
|
729
730
|
for vid_key, query_ts in query_timestamps.items():
|
|
730
731
|
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
|
731
|
-
|
|
732
|
+
try:
|
|
733
|
+
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
|
|
734
|
+
except Exception as e:
|
|
735
|
+
# fall back to trying to decode with pyav
|
|
736
|
+
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, "pyav")
|
|
732
737
|
item[vid_key] = frames.squeeze(0)
|
|
733
738
|
|
|
734
739
|
return item
|
|
@@ -768,6 +773,11 @@ class LeRobot21Dataset(torch.utils.data.Dataset):
|
|
|
768
773
|
task_idx = item["task_index"].item()
|
|
769
774
|
item["task"] = self.meta.tasks[task_idx]
|
|
770
775
|
|
|
776
|
+
# Hack - add gripper position to end
|
|
777
|
+
# only applies to a specific dataset
|
|
778
|
+
# if "observation.eef_6d_pose" in item and item["observation.eef_6d_pose"].shape[0] == 6:
|
|
779
|
+
# item["observation.eef_6d_pose"] = torch.cat([item["observation.eef_6d_pose"], item["observation.state"][-1:]], dim=0)
|
|
780
|
+
|
|
771
781
|
return item
|
|
772
782
|
|
|
773
783
|
def __repr__(self):
|
|
@@ -59,7 +59,7 @@ def resolve_delta_timestamps(
|
|
|
59
59
|
for key in ds_meta.features:
|
|
60
60
|
if key == REWARD and reward_indices is not None:
|
|
61
61
|
delta_timestamps[key] = _indices_to_times(reward_indices, ds_meta.fps)
|
|
62
|
-
if key
|
|
62
|
+
if key.startswith(ACTION) and action_indices is not None:
|
|
63
63
|
delta_timestamps[key] = _indices_to_times(action_indices, ds_meta.fps)
|
|
64
64
|
if key.startswith("observation.") and observation_indices is not None:
|
|
65
65
|
delta_timestamps[key] = _indices_to_times(observation_indices, ds_meta.fps)
|
|
@@ -76,6 +76,7 @@ def _create_datasets(
|
|
|
76
76
|
observation_delta_indices: Optional[List] = None,
|
|
77
77
|
reward_delta_indices: Optional[List] = None,
|
|
78
78
|
use_imagenet_stats: bool = True,
|
|
79
|
+
load_videos: bool = True,
|
|
79
80
|
) -> List[LeRobotDataset | LeRobot21Dataset]:
|
|
80
81
|
"""Private helper to create dataset instances from a list of repo IDs.
|
|
81
82
|
|
|
@@ -92,6 +93,8 @@ def _create_datasets(
|
|
|
92
93
|
observation_delta_indices: Frame indices for observations.
|
|
93
94
|
reward_delta_indices: Frame indices for rewards.
|
|
94
95
|
use_imagenet_stats: Whether to apply ImageNet normalization stats.
|
|
96
|
+
load_videos: Whether to download and load video files (default: True).
|
|
97
|
+
Set to False to skip video downloads when not needed.
|
|
95
98
|
|
|
96
99
|
Returns:
|
|
97
100
|
List of dataset instances.
|
|
@@ -170,6 +173,7 @@ def _create_datasets(
|
|
|
170
173
|
image_transforms=None, # Will be applied by WrappedRobotDataset
|
|
171
174
|
revision=revision,
|
|
172
175
|
video_backend=video_backend,
|
|
176
|
+
download_videos=load_videos,
|
|
173
177
|
)
|
|
174
178
|
|
|
175
179
|
# Apply ImageNet stats if needed
|
|
@@ -189,12 +193,18 @@ def _create_datasets(
|
|
|
189
193
|
def make_dataset(
|
|
190
194
|
cfg: TrainPipelineConfig,
|
|
191
195
|
plugins: Optional[list[DatasetPlugin]] = None,
|
|
196
|
+
key_rename_map: Optional[dict[str, str]] = None,
|
|
197
|
+
load_videos: bool = True,
|
|
192
198
|
) -> WrappedRobotDataset:
|
|
193
199
|
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
|
|
194
200
|
|
|
195
201
|
Args:
|
|
196
202
|
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
|
|
197
203
|
plugins (Optional[list[DatasetPlugin]]): Optional list of plugins to attach to the dataset(s).
|
|
204
|
+
key_rename_map (Optional[dict[str, str]]): Optional mapping from source keys to target keys
|
|
205
|
+
for unifying datasets with different naming conventions. Example: {"action.pos": "action"}
|
|
206
|
+
load_videos (bool): Whether to download and load video files (default: True).
|
|
207
|
+
Set to False to skip video downloads when not needed.
|
|
198
208
|
|
|
199
209
|
Returns:
|
|
200
210
|
WrappedRobotDataset: A wrapped dataset with plugin support.
|
|
@@ -221,6 +231,7 @@ def make_dataset(
|
|
|
221
231
|
observation_delta_indices=cfg.policy.observation_delta_indices,
|
|
222
232
|
reward_delta_indices=cfg.policy.reward_delta_indices,
|
|
223
233
|
use_imagenet_stats=cfg.dataset.use_imagenet_stats,
|
|
234
|
+
load_videos=load_videos,
|
|
224
235
|
)
|
|
225
236
|
|
|
226
237
|
# Wrap in WrappedRobotDataset with plugins
|
|
@@ -228,6 +239,7 @@ def make_dataset(
|
|
|
228
239
|
datasets=datasets,
|
|
229
240
|
plugins=plugins,
|
|
230
241
|
image_transforms=image_transforms,
|
|
242
|
+
key_rename_map=key_rename_map,
|
|
231
243
|
)
|
|
232
244
|
|
|
233
245
|
return wrapped_dataset
|
|
@@ -243,6 +255,8 @@ def make_dataset_without_config(
|
|
|
243
255
|
revision: str | None = None,
|
|
244
256
|
use_imagenet_stats: bool = True,
|
|
245
257
|
plugins: Optional[list[DatasetPlugin]] = None,
|
|
258
|
+
key_rename_map: Optional[dict[str, str]] = None,
|
|
259
|
+
load_videos: bool = True,
|
|
246
260
|
) -> WrappedRobotDataset:
|
|
247
261
|
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
|
|
248
262
|
|
|
@@ -259,6 +273,10 @@ def make_dataset_without_config(
|
|
|
259
273
|
revision (str, optional): Dataset revision
|
|
260
274
|
use_imagenet_stats (bool): Whether to use ImageNet normalization stats (default: True)
|
|
261
275
|
plugins (Optional[list[DatasetPlugin]]): Optional list of plugins to attach to the dataset(s)
|
|
276
|
+
key_rename_map (Optional[dict[str, str]]): Optional mapping from source keys to target keys
|
|
277
|
+
for unifying datasets with different naming conventions. Example: {"action.pos": "action"}
|
|
278
|
+
load_videos (bool): Whether to download and load video files (default: True).
|
|
279
|
+
Set to False to skip video downloads when not needed.
|
|
262
280
|
|
|
263
281
|
Returns:
|
|
264
282
|
WrappedRobotDataset: A wrapped dataset with plugin support.
|
|
@@ -283,12 +301,14 @@ def make_dataset_without_config(
|
|
|
283
301
|
action_delta_indices=action_delta_indices,
|
|
284
302
|
observation_delta_indices=observation_delta_indices,
|
|
285
303
|
use_imagenet_stats=use_imagenet_stats,
|
|
304
|
+
load_videos=load_videos,
|
|
286
305
|
)
|
|
287
306
|
|
|
288
307
|
# Wrap in WrappedRobotDataset with plugins
|
|
289
308
|
wrapped_dataset = WrappedRobotDataset(
|
|
290
309
|
datasets=datasets,
|
|
291
310
|
plugins=plugins,
|
|
311
|
+
key_rename_map=key_rename_map,
|
|
292
312
|
)
|
|
293
313
|
|
|
294
314
|
return wrapped_dataset
|
|
@@ -102,6 +102,7 @@ class WrappedRobotDatasetMetadataView:
|
|
|
102
102
|
datasets: list,
|
|
103
103
|
plugin_instances: list[list],
|
|
104
104
|
dataset_weights: Optional[dict[str, float]] = None,
|
|
105
|
+
dataset_renames: Optional[list[dict[str, str]]] = None,
|
|
105
106
|
):
|
|
106
107
|
"""
|
|
107
108
|
Initialize metadata view.
|
|
@@ -110,15 +111,32 @@ class WrappedRobotDatasetMetadataView:
|
|
|
110
111
|
datasets: List of LeRobotDataset instances
|
|
111
112
|
plugin_instances: List of plugin instances for each dataset
|
|
112
113
|
dataset_weights: Optional weights for each dataset (for weighted stats)
|
|
114
|
+
dataset_renames: Optional list of rename dicts for each dataset,
|
|
115
|
+
mapping source_key -> target_key. Used to unify keys across
|
|
116
|
+
datasets with different naming conventions.
|
|
113
117
|
"""
|
|
114
118
|
self._datasets = datasets
|
|
115
119
|
self._plugin_instances = plugin_instances
|
|
116
120
|
self._dataset_weights = dataset_weights or {}
|
|
121
|
+
self._dataset_renames = dataset_renames or [{} for _ in datasets]
|
|
117
122
|
|
|
118
123
|
# Cache computed properties
|
|
119
124
|
self._features = None
|
|
120
125
|
self._stats = None
|
|
121
126
|
|
|
127
|
+
def _get_renamed_features(self, dataset_idx: int) -> dict[str, dict]:
|
|
128
|
+
"""Get features from a dataset with key renames applied."""
|
|
129
|
+
dataset = self._datasets[dataset_idx]
|
|
130
|
+
renames = self._dataset_renames[dataset_idx]
|
|
131
|
+
|
|
132
|
+
renamed_features = {}
|
|
133
|
+
for key, value in dataset.meta.features.items():
|
|
134
|
+
# Apply rename if applicable
|
|
135
|
+
effective_key = renames.get(key, key)
|
|
136
|
+
renamed_features[effective_key] = value
|
|
137
|
+
|
|
138
|
+
return renamed_features
|
|
139
|
+
|
|
122
140
|
@property
|
|
123
141
|
def features(self) -> dict[str, dict]:
|
|
124
142
|
"""
|
|
@@ -127,6 +145,9 @@ class WrappedRobotDatasetMetadataView:
|
|
|
127
145
|
Returns intersection of:
|
|
128
146
|
1. Features from all datasets (taking intersection, not union)
|
|
129
147
|
2. Features provided by plugins (added to intersection)
|
|
148
|
+
|
|
149
|
+
Key renames are applied before computing the intersection, allowing
|
|
150
|
+
datasets with different naming conventions to be unified.
|
|
130
151
|
"""
|
|
131
152
|
if self._features is not None:
|
|
132
153
|
return self._features
|
|
@@ -135,12 +156,13 @@ class WrappedRobotDatasetMetadataView:
|
|
|
135
156
|
if not self._datasets:
|
|
136
157
|
all_features = {}
|
|
137
158
|
else:
|
|
138
|
-
# Start with all features from first dataset
|
|
139
|
-
all_features =
|
|
159
|
+
# Start with all features from first dataset (with renames applied)
|
|
160
|
+
all_features = self._get_renamed_features(0)
|
|
140
161
|
|
|
141
162
|
# Intersect with features from other datasets
|
|
142
|
-
for
|
|
143
|
-
|
|
163
|
+
for i in range(1, len(self._datasets)):
|
|
164
|
+
dataset_features = self._get_renamed_features(i)
|
|
165
|
+
dataset_feature_keys = set(dataset_features.keys())
|
|
144
166
|
all_feature_keys = set(all_features.keys())
|
|
145
167
|
|
|
146
168
|
# Keep only features that exist in both
|
|
@@ -166,6 +188,19 @@ class WrappedRobotDatasetMetadataView:
|
|
|
166
188
|
self._features = all_features
|
|
167
189
|
return self._features
|
|
168
190
|
|
|
191
|
+
def _get_renamed_stats(self, dataset_idx: int) -> dict[str, dict]:
|
|
192
|
+
"""Get stats from a dataset with key renames applied."""
|
|
193
|
+
dataset = self._datasets[dataset_idx]
|
|
194
|
+
renames = self._dataset_renames[dataset_idx]
|
|
195
|
+
|
|
196
|
+
renamed_stats = {}
|
|
197
|
+
for key, value in dataset.meta.stats.items():
|
|
198
|
+
# Apply rename if applicable
|
|
199
|
+
effective_key = renames.get(key, key)
|
|
200
|
+
renamed_stats[effective_key] = value
|
|
201
|
+
|
|
202
|
+
return renamed_stats
|
|
203
|
+
|
|
169
204
|
@property
|
|
170
205
|
def stats(self) -> dict:
|
|
171
206
|
"""
|
|
@@ -174,12 +209,17 @@ class WrappedRobotDatasetMetadataView:
|
|
|
174
209
|
If dataset_weights are provided, stats are computed as a weighted
|
|
175
210
|
average based on effective dataset sizes (size * weight).
|
|
176
211
|
Uses the correct statistical formula for combining variances.
|
|
212
|
+
|
|
213
|
+
Key renames are applied before aggregation, so different source keys
|
|
214
|
+
(e.g., "action.pos" and "trajectory") that map to the same target key
|
|
215
|
+
(e.g., "action") will have their stats combined as if they were the
|
|
216
|
+
same key across all datasets.
|
|
177
217
|
"""
|
|
178
218
|
if self._stats is not None:
|
|
179
219
|
return self._stats
|
|
180
220
|
|
|
181
|
-
# Collect stats and weights for each dataset
|
|
182
|
-
stats_list = [
|
|
221
|
+
# Collect stats (with renames applied) and weights for each dataset
|
|
222
|
+
stats_list = [self._get_renamed_stats(i) for i in range(len(self._datasets))]
|
|
183
223
|
|
|
184
224
|
# Get weight multiplier for each dataset
|
|
185
225
|
weights = []
|
|
@@ -10,11 +10,7 @@ from typing import Optional
|
|
|
10
10
|
from glob import glob
|
|
11
11
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
|
12
12
|
from lerobot.utils.constants import PRETRAINED_MODEL_DIR
|
|
13
|
-
from lerobot.configs.policies import PreTrainedConfig
|
|
14
13
|
from lerobot.configs.train import TrainPipelineConfig
|
|
15
|
-
from lerobot.configs.types import FeatureType
|
|
16
|
-
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
|
17
|
-
from lerobot.datasets.utils import dataset_to_policy_features
|
|
18
14
|
from termcolor import colored
|
|
19
15
|
|
|
20
16
|
|
|
@@ -23,6 +23,7 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
|
|
|
23
23
|
warn_on_key_conflicts: bool = True,
|
|
24
24
|
error_on_key_conflicts: bool = True,
|
|
25
25
|
dataset_weights: Optional[dict[str, float]] = None,
|
|
26
|
+
key_rename_map: Optional[dict[str, str]] = None,
|
|
26
27
|
**kwargs
|
|
27
28
|
):
|
|
28
29
|
"""
|
|
@@ -35,6 +36,14 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
|
|
|
35
36
|
warn_on_key_conflicts: Warn when plugins have overlapping keys (if not raising errors)
|
|
36
37
|
error_on_key_conflicts: Raise error on key conflicts (default: True)
|
|
37
38
|
dataset_weights: Optional weights for computing weighted stats (e.g., {"dataset_id": 2.0})
|
|
39
|
+
key_rename_map: Optional mapping from source keys to target keys for unifying
|
|
40
|
+
datasets with different naming conventions. Keys are renamed before the
|
|
41
|
+
intersection logic runs, allowing datasets with different key names to be
|
|
42
|
+
combined. Example: {"action.pos": "action", "trajectory": "action"}
|
|
43
|
+
|
|
44
|
+
Note: When a key is renamed, any corresponding "_is_pad" key (added by
|
|
45
|
+
LeRobot when using delta_timestamps) is automatically renamed as well.
|
|
46
|
+
E.g., "action.pos" -> "action" also renames "action.pos_is_pad" -> "action_is_pad".
|
|
38
47
|
"""
|
|
39
48
|
super().__init__()
|
|
40
49
|
|
|
@@ -64,6 +73,10 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
|
|
|
64
73
|
self._cumulative_lengths.append(self._cumulative_lengths[-1] + length)
|
|
65
74
|
self._total_length = self._cumulative_lengths[-1]
|
|
66
75
|
|
|
76
|
+
# Key rename mapping: unify differently-named keys across datasets
|
|
77
|
+
self.key_rename_map = key_rename_map or {}
|
|
78
|
+
self._dataset_renames = self._compute_dataset_renames()
|
|
79
|
+
|
|
67
80
|
# Plugin management: one plugin class, many instances (one per dataset)
|
|
68
81
|
self._plugins: list[DatasetPlugin] = plugins or []
|
|
69
82
|
self._plugin_instances: list[list[PluginInstance]] = []
|
|
@@ -84,6 +97,7 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
|
|
|
84
97
|
datasets=self._datasets,
|
|
85
98
|
plugin_instances=self._plugin_instances,
|
|
86
99
|
dataset_weights=dataset_weights,
|
|
100
|
+
dataset_renames=self._dataset_renames,
|
|
87
101
|
)
|
|
88
102
|
|
|
89
103
|
# ** MATCHING LeRobot MULTI-DATASET API DESIGN **
|
|
@@ -91,41 +105,51 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
|
|
|
91
105
|
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
|
92
106
|
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
|
93
107
|
# to use PyTorch's default DataLoader collate function.
|
|
108
|
+
#
|
|
109
|
+
# Key rename mapping is applied first (conceptually), so intersection is computed on
|
|
110
|
+
# "effective" features (post-rename). This allows datasets with different key names to be
|
|
111
|
+
# unified before the intersection check.
|
|
94
112
|
self.disabled_features = set()
|
|
95
|
-
intersection_features =
|
|
96
|
-
for
|
|
97
|
-
intersection_features.intersection_update(
|
|
113
|
+
intersection_features = self._get_effective_features(0)
|
|
114
|
+
for i in range(len(self._datasets)):
|
|
115
|
+
intersection_features.intersection_update(self._get_effective_features(i))
|
|
98
116
|
if len(intersection_features) == 0:
|
|
99
117
|
raise RuntimeError(
|
|
100
118
|
"Multiple datasets were provided but they had no keys common to all of them. "
|
|
101
119
|
"The multi-dataset functionality currently only keeps common keys."
|
|
102
120
|
)
|
|
103
|
-
for
|
|
104
|
-
|
|
121
|
+
for i, repo_id in enumerate(self.repo_ids):
|
|
122
|
+
effective_keys = self._get_effective_features(i)
|
|
123
|
+
extra_keys = effective_keys.difference(intersection_features)
|
|
105
124
|
if extra_keys:
|
|
106
125
|
logging.warning(
|
|
107
|
-
|
|
108
|
-
|
|
126
|
+
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
|
127
|
+
"other datasets."
|
|
109
128
|
)
|
|
110
129
|
self.disabled_features.update(extra_keys)
|
|
111
130
|
|
|
112
131
|
# Validate that common features have compatible shapes
|
|
132
|
+
# Note: We need to look up the original key name for renamed keys
|
|
113
133
|
for key in intersection_features:
|
|
114
134
|
shapes = []
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
135
|
+
shape_details = []
|
|
136
|
+
for i, ds in enumerate(self._datasets):
|
|
137
|
+
# Find the original key (may be renamed)
|
|
138
|
+
renames = self._dataset_renames[i]
|
|
139
|
+
reverse_renames = {v: k for k, v in renames.items()}
|
|
140
|
+
original_key = reverse_renames.get(key, key)
|
|
141
|
+
|
|
142
|
+
if original_key in ds.meta.features:
|
|
143
|
+
feature_shape = ds.meta.features[original_key].get('shape', [])
|
|
118
144
|
shapes.append(tuple(feature_shape))
|
|
145
|
+
if original_key != key:
|
|
146
|
+
shape_details.append(f"{ds.repo_id}: {feature_shape} (from '{original_key}')")
|
|
147
|
+
else:
|
|
148
|
+
shape_details.append(f"{ds.repo_id}: {feature_shape}")
|
|
119
149
|
|
|
120
150
|
# Check if all shapes are the same
|
|
121
151
|
unique_shapes = set(shapes)
|
|
122
152
|
if len(unique_shapes) > 1:
|
|
123
|
-
shape_details = []
|
|
124
|
-
for ds in self._datasets:
|
|
125
|
-
if key in ds.meta.features:
|
|
126
|
-
shape = ds.meta.features[key].get('shape', [])
|
|
127
|
-
shape_details.append(f"{ds.repo_id}: {shape}")
|
|
128
|
-
|
|
129
153
|
raise ValueError(
|
|
130
154
|
f"Incompatible shapes for feature '{key}' across datasets:\n" +
|
|
131
155
|
"\n".join(f" - {detail}" for detail in shape_details) +
|
|
@@ -208,8 +232,11 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
|
|
|
208
232
|
plugin_only_features = {}
|
|
209
233
|
for key, value in self._meta.features.items():
|
|
210
234
|
if key not in base_features:
|
|
211
|
-
|
|
212
|
-
|
|
235
|
+
if 'action' in key:
|
|
236
|
+
plugin_only_features[key] = PolicyFeature(type=FeatureType.ACTION, shape=value['shape'])
|
|
237
|
+
else:
|
|
238
|
+
plugin_only_features[key] = PolicyFeature(type=FeatureType.STATE, shape=value['shape'])
|
|
239
|
+
|
|
213
240
|
return plugin_only_features
|
|
214
241
|
|
|
215
242
|
@property
|
|
@@ -296,6 +323,77 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
|
|
|
296
323
|
# Also update the cached stats property
|
|
297
324
|
self.stats = self._meta.stats
|
|
298
325
|
|
|
326
|
+
def _compute_dataset_renames(self) -> list[dict[str, str]]:
|
|
327
|
+
"""
|
|
328
|
+
Pre-compute which key renames apply to each dataset.
|
|
329
|
+
|
|
330
|
+
For each dataset, determines which source keys from key_rename_map exist
|
|
331
|
+
and can be renamed (i.e., target key doesn't already exist).
|
|
332
|
+
|
|
333
|
+
Also automatically handles derived _is_pad keys that LeRobot adds when
|
|
334
|
+
delta_timestamps are used. For example, if renaming "action.pos" -> "action",
|
|
335
|
+
this will also rename "action.pos_is_pad" -> "action_is_pad".
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
List of dicts mapping source_key -> target_key for each dataset
|
|
339
|
+
"""
|
|
340
|
+
dataset_renames = []
|
|
341
|
+
for dataset in self._datasets:
|
|
342
|
+
ds_renames = {}
|
|
343
|
+
ds_keys = set(dataset.features)
|
|
344
|
+
|
|
345
|
+
for source, target in self.key_rename_map.items():
|
|
346
|
+
if source in ds_keys:
|
|
347
|
+
if target in ds_keys:
|
|
348
|
+
# Target already exists in this dataset - skip rename to avoid conflict
|
|
349
|
+
logging.warning(
|
|
350
|
+
f"Skipping rename '{source}' -> '{target}' for {dataset.repo_id}: "
|
|
351
|
+
f"target key already exists in dataset"
|
|
352
|
+
)
|
|
353
|
+
else:
|
|
354
|
+
ds_renames[source] = target
|
|
355
|
+
|
|
356
|
+
# Also handle the _is_pad suffix that LeRobot adds for delta_timestamps
|
|
357
|
+
# These keys are dynamically added during __getitem__ and may not be in
|
|
358
|
+
# dataset.features, but we still want to rename them consistently
|
|
359
|
+
is_pad_source = f"{source}_is_pad"
|
|
360
|
+
is_pad_target = f"{target}_is_pad"
|
|
361
|
+
|
|
362
|
+
# Check for conflicts on the _is_pad key as well
|
|
363
|
+
if is_pad_target in ds_keys:
|
|
364
|
+
logging.warning(
|
|
365
|
+
f"Skipping derived rename '{is_pad_source}' -> '{is_pad_target}' "
|
|
366
|
+
f"for {dataset.repo_id}: target key already exists in dataset"
|
|
367
|
+
)
|
|
368
|
+
else:
|
|
369
|
+
ds_renames[is_pad_source] = is_pad_target
|
|
370
|
+
|
|
371
|
+
dataset_renames.append(ds_renames)
|
|
372
|
+
|
|
373
|
+
return dataset_renames
|
|
374
|
+
|
|
375
|
+
def _get_effective_features(self, dataset_idx: int) -> set[str]:
|
|
376
|
+
"""
|
|
377
|
+
Get the effective feature keys for a dataset after applying renames.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
dataset_idx: Index of the dataset
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
Set of feature keys that would exist after renaming
|
|
384
|
+
"""
|
|
385
|
+
ds = self._datasets[dataset_idx]
|
|
386
|
+
renames = self._dataset_renames[dataset_idx]
|
|
387
|
+
|
|
388
|
+
effective = set()
|
|
389
|
+
for key in ds.features:
|
|
390
|
+
if key in renames:
|
|
391
|
+
effective.add(renames[key])
|
|
392
|
+
else:
|
|
393
|
+
effective.add(key)
|
|
394
|
+
|
|
395
|
+
return effective
|
|
396
|
+
|
|
299
397
|
def _validate_plugin_keys(self):
|
|
300
398
|
"""
|
|
301
399
|
Check for key conflicts between plugins.
|
|
@@ -483,7 +581,14 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
|
|
|
483
581
|
# Add dataset index
|
|
484
582
|
item["dataset_index"] = torch.tensor(dataset_idx)
|
|
485
583
|
|
|
486
|
-
#
|
|
584
|
+
# Apply key renaming for this dataset (before filtering disabled features)
|
|
585
|
+
# This unifies differently-named keys across datasets
|
|
586
|
+
renames = self._dataset_renames[dataset_idx]
|
|
587
|
+
for source, target in renames.items():
|
|
588
|
+
if source in item:
|
|
589
|
+
item[target] = item.pop(source)
|
|
590
|
+
|
|
591
|
+
# Remove disabled features (now operates on effective/renamed key names)
|
|
487
592
|
for data_key in self.disabled_features:
|
|
488
593
|
if data_key in item:
|
|
489
594
|
del item[data_key]
|
|
@@ -26,4 +26,5 @@ robocandywrapper/samplers/__init__.py
|
|
|
26
26
|
robocandywrapper/samplers/config.py
|
|
27
27
|
robocandywrapper/samplers/factory.py
|
|
28
28
|
robocandywrapper/samplers/weighted.py
|
|
29
|
-
tests/test_dataset_weights_integration.py
|
|
29
|
+
tests/test_dataset_weights_integration.py
|
|
30
|
+
tests/test_key_rename_stats.py
|
|
@@ -9,7 +9,7 @@ long_description = readme_file.read_text(encoding="utf-8") if readme_file.exists
|
|
|
9
9
|
|
|
10
10
|
setup(
|
|
11
11
|
name="robocandywrapper",
|
|
12
|
-
version="0.2.
|
|
12
|
+
version="0.2.4",
|
|
13
13
|
description="Sweet wrappers for extending and remixing LeRobot Datasets",
|
|
14
14
|
long_description=long_description,
|
|
15
15
|
long_description_content_type="text/markdown",
|
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Test for key_rename_map functionality in stats aggregation.
|
|
3
|
+
|
|
4
|
+
Tests that when datasets have differently-named keys that map to the same
|
|
5
|
+
target key, their stats are properly combined as if they were the same key.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import sys
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
# Add the local package to path before imports
|
|
12
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
from robocandywrapper.wrapper import WrappedRobotDataset
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MockLeRobotDataset:
|
|
19
|
+
"""Mock dataset for testing."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, repo_id, fps, features, num_frames, stats=None):
|
|
22
|
+
self.repo_id = repo_id
|
|
23
|
+
self._fps = fps
|
|
24
|
+
self._features = features
|
|
25
|
+
self._num_frames = num_frames
|
|
26
|
+
|
|
27
|
+
# Create mock metadata
|
|
28
|
+
self.meta = MockMetadata(repo_id, fps, features, stats)
|
|
29
|
+
self.hf_features = features
|
|
30
|
+
self.features = features
|
|
31
|
+
|
|
32
|
+
def __len__(self):
|
|
33
|
+
return self._num_frames
|
|
34
|
+
|
|
35
|
+
def __getitem__(self, idx):
|
|
36
|
+
# Minimal mock for dataset access
|
|
37
|
+
return {"action": np.array([0.0, 0.0])}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class MockMetadata:
|
|
41
|
+
"""Mock metadata object."""
|
|
42
|
+
|
|
43
|
+
def __init__(self, repo_id, fps, features, stats=None):
|
|
44
|
+
self.repo_id = repo_id
|
|
45
|
+
self._fps = fps
|
|
46
|
+
self._features = features
|
|
47
|
+
self.info = {"fps": fps}
|
|
48
|
+
|
|
49
|
+
# Use provided stats or default
|
|
50
|
+
if stats is None:
|
|
51
|
+
self.stats = {
|
|
52
|
+
"action": {
|
|
53
|
+
"mean": np.array([0.0, 0.0]),
|
|
54
|
+
"std": np.array([1.0, 1.0]),
|
|
55
|
+
"min": np.array([-1.0, -1.0]),
|
|
56
|
+
"max": np.array([1.0, 1.0]),
|
|
57
|
+
"count": np.array([1000]),
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
else:
|
|
61
|
+
self.stats = stats
|
|
62
|
+
|
|
63
|
+
self.camera_keys = [k for k in features if "image" in k or "video" in k]
|
|
64
|
+
self.image_keys = [k for k in features if "image" in k]
|
|
65
|
+
self.video_keys = [k for k in features if "video" in k]
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def fps(self):
|
|
69
|
+
return self._fps
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def features(self):
|
|
73
|
+
return self._features
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def test_key_rename_stats_aggregation():
|
|
77
|
+
"""
|
|
78
|
+
Test that keys are properly renamed in stats aggregation.
|
|
79
|
+
|
|
80
|
+
Scenario:
|
|
81
|
+
- Dataset 1 has "action.pos" key with certain stats
|
|
82
|
+
- Dataset 2 has "trajectory" key with certain stats
|
|
83
|
+
- key_rename_map maps both to "action"
|
|
84
|
+
- Result should have "action" stats that combine both sources
|
|
85
|
+
"""
|
|
86
|
+
print("\n" + "="*60)
|
|
87
|
+
print("Test: Key Rename Stats Aggregation")
|
|
88
|
+
print("="*60)
|
|
89
|
+
|
|
90
|
+
# Dataset 1: has "action.pos" key
|
|
91
|
+
stats1 = {
|
|
92
|
+
"action.pos": {
|
|
93
|
+
"mean": np.array([1.0, 2.0]),
|
|
94
|
+
"std": np.array([0.5, 0.5]),
|
|
95
|
+
"min": np.array([-1.0, -1.0]),
|
|
96
|
+
"max": np.array([3.0, 4.0]),
|
|
97
|
+
"count": np.array([1000]), # 1000 samples
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
# Dataset 2: has "trajectory" key
|
|
102
|
+
stats2 = {
|
|
103
|
+
"trajectory": {
|
|
104
|
+
"mean": np.array([5.0, 6.0]),
|
|
105
|
+
"std": np.array([1.0, 1.0]),
|
|
106
|
+
"min": np.array([0.0, 0.0]),
|
|
107
|
+
"max": np.array([10.0, 12.0]),
|
|
108
|
+
"count": np.array([1000]), # Same count for simpler math
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
dataset1 = MockLeRobotDataset(
|
|
113
|
+
repo_id="dataset_with_action_pos",
|
|
114
|
+
fps=20,
|
|
115
|
+
features={"action.pos": {"shape": [2]}},
|
|
116
|
+
num_frames=1000,
|
|
117
|
+
stats=stats1
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
dataset2 = MockLeRobotDataset(
|
|
121
|
+
repo_id="dataset_with_trajectory",
|
|
122
|
+
fps=20,
|
|
123
|
+
features={"trajectory": {"shape": [2]}},
|
|
124
|
+
num_frames=1000,
|
|
125
|
+
stats=stats2
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Create wrapped dataset with key rename map
|
|
129
|
+
key_rename_map = {
|
|
130
|
+
"action.pos": "action",
|
|
131
|
+
"trajectory": "action",
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
print(f"\n1. Creating wrapped dataset with key_rename_map: {key_rename_map}")
|
|
135
|
+
|
|
136
|
+
wrapped_dataset = WrappedRobotDataset(
|
|
137
|
+
datasets=[dataset1, dataset2],
|
|
138
|
+
plugins=None,
|
|
139
|
+
key_rename_map=key_rename_map,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# Check that features were renamed
|
|
143
|
+
print("\n2. Checking features")
|
|
144
|
+
assert "action" in wrapped_dataset.meta.features, \
|
|
145
|
+
"Renamed 'action' key should be in features"
|
|
146
|
+
assert "action.pos" not in wrapped_dataset.meta.features, \
|
|
147
|
+
"Original 'action.pos' key should not be in features"
|
|
148
|
+
assert "trajectory" not in wrapped_dataset.meta.features, \
|
|
149
|
+
"Original 'trajectory' key should not be in features"
|
|
150
|
+
print(" ✅ Features correctly renamed")
|
|
151
|
+
|
|
152
|
+
# Check that stats were combined
|
|
153
|
+
print("\n3. Checking stats aggregation")
|
|
154
|
+
combined_stats = wrapped_dataset.meta.stats
|
|
155
|
+
|
|
156
|
+
assert "action" in combined_stats, \
|
|
157
|
+
"Combined stats should have 'action' key"
|
|
158
|
+
assert "action.pos" not in combined_stats, \
|
|
159
|
+
"Original 'action.pos' should not be in combined stats"
|
|
160
|
+
assert "trajectory" not in combined_stats, \
|
|
161
|
+
"Original 'trajectory' should not be in combined stats"
|
|
162
|
+
|
|
163
|
+
# With equal counts (1000 each), the combined mean should be the average
|
|
164
|
+
# mean = (1000 * [1.0, 2.0] + 1000 * [5.0, 6.0]) / 2000 = [3.0, 4.0]
|
|
165
|
+
expected_mean = np.array([3.0, 4.0])
|
|
166
|
+
|
|
167
|
+
np.testing.assert_allclose(
|
|
168
|
+
combined_stats["action"]["mean"],
|
|
169
|
+
expected_mean,
|
|
170
|
+
rtol=1e-5,
|
|
171
|
+
err_msg="Combined mean should be average of both datasets' means"
|
|
172
|
+
)
|
|
173
|
+
print(f" Combined mean: {combined_stats['action']['mean']}")
|
|
174
|
+
print(f" Expected mean: {expected_mean}")
|
|
175
|
+
print(" ✅ Stats correctly combined")
|
|
176
|
+
|
|
177
|
+
# Check min/max
|
|
178
|
+
expected_min = np.array([-1.0, -1.0]) # min of both datasets
|
|
179
|
+
expected_max = np.array([10.0, 12.0]) # max of both datasets
|
|
180
|
+
|
|
181
|
+
np.testing.assert_allclose(
|
|
182
|
+
combined_stats["action"]["min"],
|
|
183
|
+
expected_min,
|
|
184
|
+
rtol=1e-5,
|
|
185
|
+
err_msg="Combined min should be minimum across both datasets"
|
|
186
|
+
)
|
|
187
|
+
np.testing.assert_allclose(
|
|
188
|
+
combined_stats["action"]["max"],
|
|
189
|
+
expected_max,
|
|
190
|
+
rtol=1e-5,
|
|
191
|
+
err_msg="Combined max should be maximum across both datasets"
|
|
192
|
+
)
|
|
193
|
+
print(f" Combined min: {combined_stats['action']['min']}, expected: {expected_min}")
|
|
194
|
+
print(f" Combined max: {combined_stats['action']['max']}, expected: {expected_max}")
|
|
195
|
+
print(" ✅ Min/max correctly combined")
|
|
196
|
+
|
|
197
|
+
# Check count
|
|
198
|
+
assert combined_stats["action"]["count"] == 2000, \
|
|
199
|
+
f"Combined count should be 2000, got {combined_stats['action']['count']}"
|
|
200
|
+
print(f" Combined count: {combined_stats['action']['count']}")
|
|
201
|
+
print(" ✅ Count correctly combined")
|
|
202
|
+
|
|
203
|
+
print("\n" + "="*60)
|
|
204
|
+
print("✅ KEY RENAME STATS TEST PASSED!")
|
|
205
|
+
print("="*60 + "\n")
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def test_key_rename_with_different_counts():
|
|
209
|
+
"""
|
|
210
|
+
Test key rename with datasets having different sample counts.
|
|
211
|
+
|
|
212
|
+
This ensures weighted aggregation works correctly with renamed keys.
|
|
213
|
+
"""
|
|
214
|
+
print("\n" + "="*60)
|
|
215
|
+
print("Test: Key Rename Stats with Different Counts")
|
|
216
|
+
print("="*60)
|
|
217
|
+
|
|
218
|
+
# Dataset 1: 1000 samples with "pos"
|
|
219
|
+
stats1 = {
|
|
220
|
+
"pos": {
|
|
221
|
+
"mean": np.array([0.0]),
|
|
222
|
+
"std": np.array([1.0]),
|
|
223
|
+
"min": np.array([-3.0]),
|
|
224
|
+
"max": np.array([3.0]),
|
|
225
|
+
"count": np.array([1000]),
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
# Dataset 2: 3000 samples with "position"
|
|
230
|
+
stats2 = {
|
|
231
|
+
"position": {
|
|
232
|
+
"mean": np.array([4.0]),
|
|
233
|
+
"std": np.array([2.0]),
|
|
234
|
+
"min": np.array([-2.0]),
|
|
235
|
+
"max": np.array([10.0]),
|
|
236
|
+
"count": np.array([3000]),
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
dataset1 = MockLeRobotDataset(
|
|
241
|
+
repo_id="ds1",
|
|
242
|
+
fps=20,
|
|
243
|
+
features={"pos": {"shape": [1]}},
|
|
244
|
+
num_frames=1000,
|
|
245
|
+
stats=stats1
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
dataset2 = MockLeRobotDataset(
|
|
249
|
+
repo_id="ds2",
|
|
250
|
+
fps=20,
|
|
251
|
+
features={"position": {"shape": [1]}},
|
|
252
|
+
num_frames=3000,
|
|
253
|
+
stats=stats2
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
key_rename_map = {
|
|
257
|
+
"pos": "state",
|
|
258
|
+
"position": "state",
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
print(f"\n1. Dataset 1: 1000 samples, mean=0.0")
|
|
262
|
+
print(f" Dataset 2: 3000 samples, mean=4.0")
|
|
263
|
+
print(f" key_rename_map: {key_rename_map}")
|
|
264
|
+
|
|
265
|
+
wrapped_dataset = WrappedRobotDataset(
|
|
266
|
+
datasets=[dataset1, dataset2],
|
|
267
|
+
plugins=None,
|
|
268
|
+
key_rename_map=key_rename_map,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
combined_stats = wrapped_dataset.meta.stats
|
|
272
|
+
|
|
273
|
+
# Expected weighted mean: (1000 * 0.0 + 3000 * 4.0) / 4000 = 3.0
|
|
274
|
+
expected_mean = np.array([3.0])
|
|
275
|
+
|
|
276
|
+
print(f"\n2. Combined stats for 'state':")
|
|
277
|
+
print(f" Mean: {combined_stats['state']['mean']} (expected: {expected_mean})")
|
|
278
|
+
|
|
279
|
+
np.testing.assert_allclose(
|
|
280
|
+
combined_stats["state"]["mean"],
|
|
281
|
+
expected_mean,
|
|
282
|
+
rtol=1e-5,
|
|
283
|
+
err_msg="Weighted mean should account for different counts"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
assert combined_stats["state"]["count"] == 4000, \
|
|
287
|
+
f"Total count should be 4000, got {combined_stats['state']['count']}"
|
|
288
|
+
print(f" Count: {combined_stats['state']['count']} (expected: 4000)")
|
|
289
|
+
|
|
290
|
+
print("\n" + "="*60)
|
|
291
|
+
print("✅ KEY RENAME WITH DIFFERENT COUNTS TEST PASSED!")
|
|
292
|
+
print("="*60 + "\n")
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def test_key_rename_partial_rename():
|
|
296
|
+
"""
|
|
297
|
+
Test that keys that don't need renaming are preserved.
|
|
298
|
+
|
|
299
|
+
Only keys in key_rename_map should be renamed; others should pass through.
|
|
300
|
+
"""
|
|
301
|
+
print("\n" + "="*60)
|
|
302
|
+
print("Test: Partial Key Rename")
|
|
303
|
+
print("="*60)
|
|
304
|
+
|
|
305
|
+
# Both datasets have "action" (no rename needed) but different secondary keys
|
|
306
|
+
stats1 = {
|
|
307
|
+
"action": {
|
|
308
|
+
"mean": np.array([1.0]),
|
|
309
|
+
"std": np.array([1.0]),
|
|
310
|
+
"min": np.array([0.0]),
|
|
311
|
+
"max": np.array([2.0]),
|
|
312
|
+
"count": np.array([1000]),
|
|
313
|
+
},
|
|
314
|
+
"observation.state": {
|
|
315
|
+
"mean": np.array([0.5]),
|
|
316
|
+
"std": np.array([0.1]),
|
|
317
|
+
"min": np.array([0.0]),
|
|
318
|
+
"max": np.array([1.0]),
|
|
319
|
+
"count": np.array([1000]),
|
|
320
|
+
}
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
stats2 = {
|
|
324
|
+
"action": {
|
|
325
|
+
"mean": np.array([3.0]),
|
|
326
|
+
"std": np.array([1.0]),
|
|
327
|
+
"min": np.array([2.0]),
|
|
328
|
+
"max": np.array([4.0]),
|
|
329
|
+
"count": np.array([1000]),
|
|
330
|
+
},
|
|
331
|
+
"observation.state": {
|
|
332
|
+
"mean": np.array([0.5]),
|
|
333
|
+
"std": np.array([0.1]),
|
|
334
|
+
"min": np.array([0.0]),
|
|
335
|
+
"max": np.array([1.0]),
|
|
336
|
+
"count": np.array([1000]),
|
|
337
|
+
}
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
dataset1 = MockLeRobotDataset(
|
|
341
|
+
repo_id="ds1",
|
|
342
|
+
fps=20,
|
|
343
|
+
features={"action": {"shape": [1]}, "observation.state": {"shape": [1]}},
|
|
344
|
+
num_frames=1000,
|
|
345
|
+
stats=stats1
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
dataset2 = MockLeRobotDataset(
|
|
349
|
+
repo_id="ds2",
|
|
350
|
+
fps=20,
|
|
351
|
+
features={"action": {"shape": [1]}, "observation.state": {"shape": [1]}},
|
|
352
|
+
num_frames=1000,
|
|
353
|
+
stats=stats2
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# No key rename map - should work normally
|
|
357
|
+
print("\n1. Creating wrapped dataset without key_rename_map")
|
|
358
|
+
|
|
359
|
+
wrapped_dataset = WrappedRobotDataset(
|
|
360
|
+
datasets=[dataset1, dataset2],
|
|
361
|
+
plugins=None,
|
|
362
|
+
key_rename_map=None,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
combined_stats = wrapped_dataset.meta.stats
|
|
366
|
+
|
|
367
|
+
# "action" should be combined normally
|
|
368
|
+
assert "action" in combined_stats, "action key should be present"
|
|
369
|
+
expected_action_mean = np.array([2.0]) # (1.0 + 3.0) / 2
|
|
370
|
+
|
|
371
|
+
np.testing.assert_allclose(
|
|
372
|
+
combined_stats["action"]["mean"],
|
|
373
|
+
expected_action_mean,
|
|
374
|
+
rtol=1e-5
|
|
375
|
+
)
|
|
376
|
+
print(f" action mean: {combined_stats['action']['mean']} (expected: {expected_action_mean})")
|
|
377
|
+
|
|
378
|
+
# "observation.state" should also be present
|
|
379
|
+
assert "observation.state" in combined_stats, "observation.state should be present"
|
|
380
|
+
print(f" observation.state mean: {combined_stats['observation.state']['mean']}")
|
|
381
|
+
|
|
382
|
+
print("\n" + "="*60)
|
|
383
|
+
print("✅ PARTIAL KEY RENAME TEST PASSED!")
|
|
384
|
+
print("="*60 + "\n")
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
if __name__ == "__main__":
|
|
388
|
+
test_key_rename_stats_aggregation()
|
|
389
|
+
test_key_rename_with_different_counts()
|
|
390
|
+
test_key_rename_partial_rename()
|
|
391
|
+
|
|
392
|
+
print("\n" + "="*60)
|
|
393
|
+
print("ALL KEY RENAME STATS TESTS PASSED!")
|
|
394
|
+
print("="*60 + "\n")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/dataformats/lerobot_21/utils.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/plugins/episode_outcome.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|