robocandywrapper 0.2.2__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {robocandywrapper-0.2.2/robocandywrapper.egg-info → robocandywrapper-0.2.4}/PKG-INFO +1 -1
  2. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/pyproject.toml +1 -1
  3. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/__init__.py +1 -1
  4. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/dataformats/lerobot_21/dataset.py +11 -1
  5. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/factory.py +21 -1
  6. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/metadata_view.py +46 -6
  7. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/utils.py +0 -4
  8. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/wrapper.py +124 -19
  9. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4/robocandywrapper.egg-info}/PKG-INFO +1 -1
  10. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper.egg-info/SOURCES.txt +2 -1
  11. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/setup.py +1 -1
  12. robocandywrapper-0.2.4/tests/test_key_rename_stats.py +394 -0
  13. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/LICENSE +0 -0
  14. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/MANIFEST.in +0 -0
  15. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/README.md +0 -0
  16. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/constants.py +0 -0
  17. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/dataformats/__init__.py +0 -0
  18. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/dataformats/lerobot_21/__init__.py +0 -0
  19. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/dataformats/lerobot_21/utils.py +0 -0
  20. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/plugin.py +0 -0
  21. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/plugins/__init__.py +0 -0
  22. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/plugins/affordance.py +0 -0
  23. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/plugins/episode_outcome.py +0 -0
  24. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/samplers/__init__.py +0 -0
  25. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/samplers/config.py +0 -0
  26. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/samplers/factory.py +0 -0
  27. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper/samplers/weighted.py +0 -0
  28. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper.egg-info/dependency_links.txt +0 -0
  29. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper.egg-info/requires.txt +0 -0
  30. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/robocandywrapper.egg-info/top_level.txt +0 -0
  31. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/setup.cfg +0 -0
  32. {robocandywrapper-0.2.2 → robocandywrapper-0.2.4}/tests/test_dataset_weights_integration.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: robocandywrapper
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: Sweet wrappers for extending and remixing LeRobot Datasets
5
5
  Author: RoboCandyWrapper Contributors
6
6
  License: MIT License
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "robocandywrapper"
7
- version = "0.2.2"
7
+ version = "0.2.4"
8
8
  description = "Sweet wrappers for extending and remixing LeRobot Datasets"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -17,7 +17,7 @@ from robocandywrapper.constants import (
17
17
  EPISODE_OUTCOME_PLUGIN_NAME,
18
18
  )
19
19
 
20
- __version__ = "0.2.1"
20
+ __version__ = "0.2.4"
21
21
 
22
22
  __all__ = [
23
23
  "DatasetPlugin",
@@ -93,6 +93,7 @@ class LeRobot21DatasetMetadata:
93
93
  try:
94
94
  if force_cache_sync:
95
95
  raise FileNotFoundError
96
+ self.pull_from_repo()
96
97
  self.load_metadata()
97
98
  except (FileNotFoundError, NotADirectoryError):
98
99
  if is_valid_version(self.revision):
@@ -728,7 +729,11 @@ class LeRobot21Dataset(torch.utils.data.Dataset):
728
729
  item = {}
729
730
  for vid_key, query_ts in query_timestamps.items():
730
731
  video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
731
- frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
732
+ try:
733
+ frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
734
+ except Exception as e:
735
+ # fall back to trying to decode with pyav
736
+ frames = decode_video_frames(video_path, query_ts, self.tolerance_s, "pyav")
732
737
  item[vid_key] = frames.squeeze(0)
733
738
 
734
739
  return item
@@ -768,6 +773,11 @@ class LeRobot21Dataset(torch.utils.data.Dataset):
768
773
  task_idx = item["task_index"].item()
769
774
  item["task"] = self.meta.tasks[task_idx]
770
775
 
776
+ # Hack - add gripper position to end
777
+ # only applies to a specific dataset
778
+ # if "observation.eef_6d_pose" in item and item["observation.eef_6d_pose"].shape[0] == 6:
779
+ # item["observation.eef_6d_pose"] = torch.cat([item["observation.eef_6d_pose"], item["observation.state"][-1:]], dim=0)
780
+
771
781
  return item
772
782
 
773
783
  def __repr__(self):
@@ -59,7 +59,7 @@ def resolve_delta_timestamps(
59
59
  for key in ds_meta.features:
60
60
  if key == REWARD and reward_indices is not None:
61
61
  delta_timestamps[key] = _indices_to_times(reward_indices, ds_meta.fps)
62
- if key == ACTION and action_indices is not None:
62
+ if key.startswith(ACTION) and action_indices is not None:
63
63
  delta_timestamps[key] = _indices_to_times(action_indices, ds_meta.fps)
64
64
  if key.startswith("observation.") and observation_indices is not None:
65
65
  delta_timestamps[key] = _indices_to_times(observation_indices, ds_meta.fps)
@@ -76,6 +76,7 @@ def _create_datasets(
76
76
  observation_delta_indices: Optional[List] = None,
77
77
  reward_delta_indices: Optional[List] = None,
78
78
  use_imagenet_stats: bool = True,
79
+ load_videos: bool = True,
79
80
  ) -> List[LeRobotDataset | LeRobot21Dataset]:
80
81
  """Private helper to create dataset instances from a list of repo IDs.
81
82
 
@@ -92,6 +93,8 @@ def _create_datasets(
92
93
  observation_delta_indices: Frame indices for observations.
93
94
  reward_delta_indices: Frame indices for rewards.
94
95
  use_imagenet_stats: Whether to apply ImageNet normalization stats.
96
+ load_videos: Whether to download and load video files (default: True).
97
+ Set to False to skip video downloads when not needed.
95
98
 
96
99
  Returns:
97
100
  List of dataset instances.
@@ -170,6 +173,7 @@ def _create_datasets(
170
173
  image_transforms=None, # Will be applied by WrappedRobotDataset
171
174
  revision=revision,
172
175
  video_backend=video_backend,
176
+ download_videos=load_videos,
173
177
  )
174
178
 
175
179
  # Apply ImageNet stats if needed
@@ -189,12 +193,18 @@ def _create_datasets(
189
193
  def make_dataset(
190
194
  cfg: TrainPipelineConfig,
191
195
  plugins: Optional[list[DatasetPlugin]] = None,
196
+ key_rename_map: Optional[dict[str, str]] = None,
197
+ load_videos: bool = True,
192
198
  ) -> WrappedRobotDataset:
193
199
  """Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
194
200
 
195
201
  Args:
196
202
  cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
197
203
  plugins (Optional[list[DatasetPlugin]]): Optional list of plugins to attach to the dataset(s).
204
+ key_rename_map (Optional[dict[str, str]]): Optional mapping from source keys to target keys
205
+ for unifying datasets with different naming conventions. Example: {"action.pos": "action"}
206
+ load_videos (bool): Whether to download and load video files (default: True).
207
+ Set to False to skip video downloads when not needed.
198
208
 
199
209
  Returns:
200
210
  WrappedRobotDataset: A wrapped dataset with plugin support.
@@ -221,6 +231,7 @@ def make_dataset(
221
231
  observation_delta_indices=cfg.policy.observation_delta_indices,
222
232
  reward_delta_indices=cfg.policy.reward_delta_indices,
223
233
  use_imagenet_stats=cfg.dataset.use_imagenet_stats,
234
+ load_videos=load_videos,
224
235
  )
225
236
 
226
237
  # Wrap in WrappedRobotDataset with plugins
@@ -228,6 +239,7 @@ def make_dataset(
228
239
  datasets=datasets,
229
240
  plugins=plugins,
230
241
  image_transforms=image_transforms,
242
+ key_rename_map=key_rename_map,
231
243
  )
232
244
 
233
245
  return wrapped_dataset
@@ -243,6 +255,8 @@ def make_dataset_without_config(
243
255
  revision: str | None = None,
244
256
  use_imagenet_stats: bool = True,
245
257
  plugins: Optional[list[DatasetPlugin]] = None,
258
+ key_rename_map: Optional[dict[str, str]] = None,
259
+ load_videos: bool = True,
246
260
  ) -> WrappedRobotDataset:
247
261
  """Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
248
262
 
@@ -259,6 +273,10 @@ def make_dataset_without_config(
259
273
  revision (str, optional): Dataset revision
260
274
  use_imagenet_stats (bool): Whether to use ImageNet normalization stats (default: True)
261
275
  plugins (Optional[list[DatasetPlugin]]): Optional list of plugins to attach to the dataset(s)
276
+ key_rename_map (Optional[dict[str, str]]): Optional mapping from source keys to target keys
277
+ for unifying datasets with different naming conventions. Example: {"action.pos": "action"}
278
+ load_videos (bool): Whether to download and load video files (default: True).
279
+ Set to False to skip video downloads when not needed.
262
280
 
263
281
  Returns:
264
282
  WrappedRobotDataset: A wrapped dataset with plugin support.
@@ -283,12 +301,14 @@ def make_dataset_without_config(
283
301
  action_delta_indices=action_delta_indices,
284
302
  observation_delta_indices=observation_delta_indices,
285
303
  use_imagenet_stats=use_imagenet_stats,
304
+ load_videos=load_videos,
286
305
  )
287
306
 
288
307
  # Wrap in WrappedRobotDataset with plugins
289
308
  wrapped_dataset = WrappedRobotDataset(
290
309
  datasets=datasets,
291
310
  plugins=plugins,
311
+ key_rename_map=key_rename_map,
292
312
  )
293
313
 
294
314
  return wrapped_dataset
@@ -102,6 +102,7 @@ class WrappedRobotDatasetMetadataView:
102
102
  datasets: list,
103
103
  plugin_instances: list[list],
104
104
  dataset_weights: Optional[dict[str, float]] = None,
105
+ dataset_renames: Optional[list[dict[str, str]]] = None,
105
106
  ):
106
107
  """
107
108
  Initialize metadata view.
@@ -110,15 +111,32 @@ class WrappedRobotDatasetMetadataView:
110
111
  datasets: List of LeRobotDataset instances
111
112
  plugin_instances: List of plugin instances for each dataset
112
113
  dataset_weights: Optional weights for each dataset (for weighted stats)
114
+ dataset_renames: Optional list of rename dicts for each dataset,
115
+ mapping source_key -> target_key. Used to unify keys across
116
+ datasets with different naming conventions.
113
117
  """
114
118
  self._datasets = datasets
115
119
  self._plugin_instances = plugin_instances
116
120
  self._dataset_weights = dataset_weights or {}
121
+ self._dataset_renames = dataset_renames or [{} for _ in datasets]
117
122
 
118
123
  # Cache computed properties
119
124
  self._features = None
120
125
  self._stats = None
121
126
 
127
+ def _get_renamed_features(self, dataset_idx: int) -> dict[str, dict]:
128
+ """Get features from a dataset with key renames applied."""
129
+ dataset = self._datasets[dataset_idx]
130
+ renames = self._dataset_renames[dataset_idx]
131
+
132
+ renamed_features = {}
133
+ for key, value in dataset.meta.features.items():
134
+ # Apply rename if applicable
135
+ effective_key = renames.get(key, key)
136
+ renamed_features[effective_key] = value
137
+
138
+ return renamed_features
139
+
122
140
  @property
123
141
  def features(self) -> dict[str, dict]:
124
142
  """
@@ -127,6 +145,9 @@ class WrappedRobotDatasetMetadataView:
127
145
  Returns intersection of:
128
146
  1. Features from all datasets (taking intersection, not union)
129
147
  2. Features provided by plugins (added to intersection)
148
+
149
+ Key renames are applied before computing the intersection, allowing
150
+ datasets with different naming conventions to be unified.
130
151
  """
131
152
  if self._features is not None:
132
153
  return self._features
@@ -135,12 +156,13 @@ class WrappedRobotDatasetMetadataView:
135
156
  if not self._datasets:
136
157
  all_features = {}
137
158
  else:
138
- # Start with all features from first dataset
139
- all_features = dict(self._datasets[0].meta.features)
159
+ # Start with all features from first dataset (with renames applied)
160
+ all_features = self._get_renamed_features(0)
140
161
 
141
162
  # Intersect with features from other datasets
142
- for dataset in self._datasets[1:]:
143
- dataset_feature_keys = set(dataset.meta.features.keys())
163
+ for i in range(1, len(self._datasets)):
164
+ dataset_features = self._get_renamed_features(i)
165
+ dataset_feature_keys = set(dataset_features.keys())
144
166
  all_feature_keys = set(all_features.keys())
145
167
 
146
168
  # Keep only features that exist in both
@@ -166,6 +188,19 @@ class WrappedRobotDatasetMetadataView:
166
188
  self._features = all_features
167
189
  return self._features
168
190
 
191
+ def _get_renamed_stats(self, dataset_idx: int) -> dict[str, dict]:
192
+ """Get stats from a dataset with key renames applied."""
193
+ dataset = self._datasets[dataset_idx]
194
+ renames = self._dataset_renames[dataset_idx]
195
+
196
+ renamed_stats = {}
197
+ for key, value in dataset.meta.stats.items():
198
+ # Apply rename if applicable
199
+ effective_key = renames.get(key, key)
200
+ renamed_stats[effective_key] = value
201
+
202
+ return renamed_stats
203
+
169
204
  @property
170
205
  def stats(self) -> dict:
171
206
  """
@@ -174,12 +209,17 @@ class WrappedRobotDatasetMetadataView:
174
209
  If dataset_weights are provided, stats are computed as a weighted
175
210
  average based on effective dataset sizes (size * weight).
176
211
  Uses the correct statistical formula for combining variances.
212
+
213
+ Key renames are applied before aggregation, so different source keys
214
+ (e.g., "action.pos" and "trajectory") that map to the same target key
215
+ (e.g., "action") will have their stats combined as if they were the
216
+ same key across all datasets.
177
217
  """
178
218
  if self._stats is not None:
179
219
  return self._stats
180
220
 
181
- # Collect stats and weights for each dataset
182
- stats_list = [dataset.meta.stats for dataset in self._datasets]
221
+ # Collect stats (with renames applied) and weights for each dataset
222
+ stats_list = [self._get_renamed_stats(i) for i in range(len(self._datasets))]
183
223
 
184
224
  # Get weight multiplier for each dataset
185
225
  weights = []
@@ -10,11 +10,7 @@ from typing import Optional
10
10
  from glob import glob
11
11
  from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
12
12
  from lerobot.utils.constants import PRETRAINED_MODEL_DIR
13
- from lerobot.configs.policies import PreTrainedConfig
14
13
  from lerobot.configs.train import TrainPipelineConfig
15
- from lerobot.configs.types import FeatureType
16
- from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
17
- from lerobot.datasets.utils import dataset_to_policy_features
18
14
  from termcolor import colored
19
15
 
20
16
 
@@ -23,6 +23,7 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
23
23
  warn_on_key_conflicts: bool = True,
24
24
  error_on_key_conflicts: bool = True,
25
25
  dataset_weights: Optional[dict[str, float]] = None,
26
+ key_rename_map: Optional[dict[str, str]] = None,
26
27
  **kwargs
27
28
  ):
28
29
  """
@@ -35,6 +36,14 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
35
36
  warn_on_key_conflicts: Warn when plugins have overlapping keys (if not raising errors)
36
37
  error_on_key_conflicts: Raise error on key conflicts (default: True)
37
38
  dataset_weights: Optional weights for computing weighted stats (e.g., {"dataset_id": 2.0})
39
+ key_rename_map: Optional mapping from source keys to target keys for unifying
40
+ datasets with different naming conventions. Keys are renamed before the
41
+ intersection logic runs, allowing datasets with different key names to be
42
+ combined. Example: {"action.pos": "action", "trajectory": "action"}
43
+
44
+ Note: When a key is renamed, any corresponding "_is_pad" key (added by
45
+ LeRobot when using delta_timestamps) is automatically renamed as well.
46
+ E.g., "action.pos" -> "action" also renames "action.pos_is_pad" -> "action_is_pad".
38
47
  """
39
48
  super().__init__()
40
49
 
@@ -64,6 +73,10 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
64
73
  self._cumulative_lengths.append(self._cumulative_lengths[-1] + length)
65
74
  self._total_length = self._cumulative_lengths[-1]
66
75
 
76
+ # Key rename mapping: unify differently-named keys across datasets
77
+ self.key_rename_map = key_rename_map or {}
78
+ self._dataset_renames = self._compute_dataset_renames()
79
+
67
80
  # Plugin management: one plugin class, many instances (one per dataset)
68
81
  self._plugins: list[DatasetPlugin] = plugins or []
69
82
  self._plugin_instances: list[list[PluginInstance]] = []
@@ -84,6 +97,7 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
84
97
  datasets=self._datasets,
85
98
  plugin_instances=self._plugin_instances,
86
99
  dataset_weights=dataset_weights,
100
+ dataset_renames=self._dataset_renames,
87
101
  )
88
102
 
89
103
  # ** MATCHING LeRobot MULTI-DATASET API DESIGN **
@@ -91,41 +105,51 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
91
105
  # Disable any data keys that are not common across all of the datasets. Note: we may relax this
92
106
  # restriction in future iterations of this class. For now, this is necessary at least for being able
93
107
  # to use PyTorch's default DataLoader collate function.
108
+ #
109
+ # Key rename mapping is applied first (conceptually), so intersection is computed on
110
+ # "effective" features (post-rename). This allows datasets with different key names to be
111
+ # unified before the intersection check.
94
112
  self.disabled_features = set()
95
- intersection_features = set(self._datasets[0].features)
96
- for ds in self._datasets:
97
- intersection_features.intersection_update(ds.features)
113
+ intersection_features = self._get_effective_features(0)
114
+ for i in range(len(self._datasets)):
115
+ intersection_features.intersection_update(self._get_effective_features(i))
98
116
  if len(intersection_features) == 0:
99
117
  raise RuntimeError(
100
118
  "Multiple datasets were provided but they had no keys common to all of them. "
101
119
  "The multi-dataset functionality currently only keeps common keys."
102
120
  )
103
- for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
104
- extra_keys = set(ds.features).difference(intersection_features)
121
+ for i, repo_id in enumerate(self.repo_ids):
122
+ effective_keys = self._get_effective_features(i)
123
+ extra_keys = effective_keys.difference(intersection_features)
105
124
  if extra_keys:
106
125
  logging.warning(
107
- f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
108
- "other datasets."
126
+ f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
127
+ "other datasets."
109
128
  )
110
129
  self.disabled_features.update(extra_keys)
111
130
 
112
131
  # Validate that common features have compatible shapes
132
+ # Note: We need to look up the original key name for renamed keys
113
133
  for key in intersection_features:
114
134
  shapes = []
115
- for ds in self._datasets:
116
- if key in ds.meta.features:
117
- feature_shape = ds.meta.features[key].get('shape', [])
135
+ shape_details = []
136
+ for i, ds in enumerate(self._datasets):
137
+ # Find the original key (may be renamed)
138
+ renames = self._dataset_renames[i]
139
+ reverse_renames = {v: k for k, v in renames.items()}
140
+ original_key = reverse_renames.get(key, key)
141
+
142
+ if original_key in ds.meta.features:
143
+ feature_shape = ds.meta.features[original_key].get('shape', [])
118
144
  shapes.append(tuple(feature_shape))
145
+ if original_key != key:
146
+ shape_details.append(f"{ds.repo_id}: {feature_shape} (from '{original_key}')")
147
+ else:
148
+ shape_details.append(f"{ds.repo_id}: {feature_shape}")
119
149
 
120
150
  # Check if all shapes are the same
121
151
  unique_shapes = set(shapes)
122
152
  if len(unique_shapes) > 1:
123
- shape_details = []
124
- for ds in self._datasets:
125
- if key in ds.meta.features:
126
- shape = ds.meta.features[key].get('shape', [])
127
- shape_details.append(f"{ds.repo_id}: {shape}")
128
-
129
153
  raise ValueError(
130
154
  f"Incompatible shapes for feature '{key}' across datasets:\n" +
131
155
  "\n".join(f" - {detail}" for detail in shape_details) +
@@ -208,8 +232,11 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
208
232
  plugin_only_features = {}
209
233
  for key, value in self._meta.features.items():
210
234
  if key not in base_features:
211
- plugin_only_features[key] = PolicyFeature(type=FeatureType.STATE, shape=value['shape'])
212
-
235
+ if 'action' in key:
236
+ plugin_only_features[key] = PolicyFeature(type=FeatureType.ACTION, shape=value['shape'])
237
+ else:
238
+ plugin_only_features[key] = PolicyFeature(type=FeatureType.STATE, shape=value['shape'])
239
+
213
240
  return plugin_only_features
214
241
 
215
242
  @property
@@ -296,6 +323,77 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
296
323
  # Also update the cached stats property
297
324
  self.stats = self._meta.stats
298
325
 
326
+ def _compute_dataset_renames(self) -> list[dict[str, str]]:
327
+ """
328
+ Pre-compute which key renames apply to each dataset.
329
+
330
+ For each dataset, determines which source keys from key_rename_map exist
331
+ and can be renamed (i.e., target key doesn't already exist).
332
+
333
+ Also automatically handles derived _is_pad keys that LeRobot adds when
334
+ delta_timestamps are used. For example, if renaming "action.pos" -> "action",
335
+ this will also rename "action.pos_is_pad" -> "action_is_pad".
336
+
337
+ Returns:
338
+ List of dicts mapping source_key -> target_key for each dataset
339
+ """
340
+ dataset_renames = []
341
+ for dataset in self._datasets:
342
+ ds_renames = {}
343
+ ds_keys = set(dataset.features)
344
+
345
+ for source, target in self.key_rename_map.items():
346
+ if source in ds_keys:
347
+ if target in ds_keys:
348
+ # Target already exists in this dataset - skip rename to avoid conflict
349
+ logging.warning(
350
+ f"Skipping rename '{source}' -> '{target}' for {dataset.repo_id}: "
351
+ f"target key already exists in dataset"
352
+ )
353
+ else:
354
+ ds_renames[source] = target
355
+
356
+ # Also handle the _is_pad suffix that LeRobot adds for delta_timestamps
357
+ # These keys are dynamically added during __getitem__ and may not be in
358
+ # dataset.features, but we still want to rename them consistently
359
+ is_pad_source = f"{source}_is_pad"
360
+ is_pad_target = f"{target}_is_pad"
361
+
362
+ # Check for conflicts on the _is_pad key as well
363
+ if is_pad_target in ds_keys:
364
+ logging.warning(
365
+ f"Skipping derived rename '{is_pad_source}' -> '{is_pad_target}' "
366
+ f"for {dataset.repo_id}: target key already exists in dataset"
367
+ )
368
+ else:
369
+ ds_renames[is_pad_source] = is_pad_target
370
+
371
+ dataset_renames.append(ds_renames)
372
+
373
+ return dataset_renames
374
+
375
+ def _get_effective_features(self, dataset_idx: int) -> set[str]:
376
+ """
377
+ Get the effective feature keys for a dataset after applying renames.
378
+
379
+ Args:
380
+ dataset_idx: Index of the dataset
381
+
382
+ Returns:
383
+ Set of feature keys that would exist after renaming
384
+ """
385
+ ds = self._datasets[dataset_idx]
386
+ renames = self._dataset_renames[dataset_idx]
387
+
388
+ effective = set()
389
+ for key in ds.features:
390
+ if key in renames:
391
+ effective.add(renames[key])
392
+ else:
393
+ effective.add(key)
394
+
395
+ return effective
396
+
299
397
  def _validate_plugin_keys(self):
300
398
  """
301
399
  Check for key conflicts between plugins.
@@ -483,7 +581,14 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
483
581
  # Add dataset index
484
582
  item["dataset_index"] = torch.tensor(dataset_idx)
485
583
 
486
- # Remove disabled features
584
+ # Apply key renaming for this dataset (before filtering disabled features)
585
+ # This unifies differently-named keys across datasets
586
+ renames = self._dataset_renames[dataset_idx]
587
+ for source, target in renames.items():
588
+ if source in item:
589
+ item[target] = item.pop(source)
590
+
591
+ # Remove disabled features (now operates on effective/renamed key names)
487
592
  for data_key in self.disabled_features:
488
593
  if data_key in item:
489
594
  del item[data_key]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: robocandywrapper
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: Sweet wrappers for extending and remixing LeRobot Datasets
5
5
  Author: RoboCandyWrapper Contributors
6
6
  License: MIT License
@@ -26,4 +26,5 @@ robocandywrapper/samplers/__init__.py
26
26
  robocandywrapper/samplers/config.py
27
27
  robocandywrapper/samplers/factory.py
28
28
  robocandywrapper/samplers/weighted.py
29
- tests/test_dataset_weights_integration.py
29
+ tests/test_dataset_weights_integration.py
30
+ tests/test_key_rename_stats.py
@@ -9,7 +9,7 @@ long_description = readme_file.read_text(encoding="utf-8") if readme_file.exists
9
9
 
10
10
  setup(
11
11
  name="robocandywrapper",
12
- version="0.2.1",
12
+ version="0.2.4",
13
13
  description="Sweet wrappers for extending and remixing LeRobot Datasets",
14
14
  long_description=long_description,
15
15
  long_description_content_type="text/markdown",
@@ -0,0 +1,394 @@
1
+ """
2
+ Test for key_rename_map functionality in stats aggregation.
3
+
4
+ Tests that when datasets have differently-named keys that map to the same
5
+ target key, their stats are properly combined as if they were the same key.
6
+ """
7
+
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ # Add the local package to path before imports
12
+ sys.path.insert(0, str(Path(__file__).parent.parent))
13
+
14
+ import numpy as np
15
+ from robocandywrapper.wrapper import WrappedRobotDataset
16
+
17
+
18
+ class MockLeRobotDataset:
19
+ """Mock dataset for testing."""
20
+
21
+ def __init__(self, repo_id, fps, features, num_frames, stats=None):
22
+ self.repo_id = repo_id
23
+ self._fps = fps
24
+ self._features = features
25
+ self._num_frames = num_frames
26
+
27
+ # Create mock metadata
28
+ self.meta = MockMetadata(repo_id, fps, features, stats)
29
+ self.hf_features = features
30
+ self.features = features
31
+
32
+ def __len__(self):
33
+ return self._num_frames
34
+
35
+ def __getitem__(self, idx):
36
+ # Minimal mock for dataset access
37
+ return {"action": np.array([0.0, 0.0])}
38
+
39
+
40
+ class MockMetadata:
41
+ """Mock metadata object."""
42
+
43
+ def __init__(self, repo_id, fps, features, stats=None):
44
+ self.repo_id = repo_id
45
+ self._fps = fps
46
+ self._features = features
47
+ self.info = {"fps": fps}
48
+
49
+ # Use provided stats or default
50
+ if stats is None:
51
+ self.stats = {
52
+ "action": {
53
+ "mean": np.array([0.0, 0.0]),
54
+ "std": np.array([1.0, 1.0]),
55
+ "min": np.array([-1.0, -1.0]),
56
+ "max": np.array([1.0, 1.0]),
57
+ "count": np.array([1000]),
58
+ }
59
+ }
60
+ else:
61
+ self.stats = stats
62
+
63
+ self.camera_keys = [k for k in features if "image" in k or "video" in k]
64
+ self.image_keys = [k for k in features if "image" in k]
65
+ self.video_keys = [k for k in features if "video" in k]
66
+
67
+ @property
68
+ def fps(self):
69
+ return self._fps
70
+
71
+ @property
72
+ def features(self):
73
+ return self._features
74
+
75
+
76
+ def test_key_rename_stats_aggregation():
77
+ """
78
+ Test that keys are properly renamed in stats aggregation.
79
+
80
+ Scenario:
81
+ - Dataset 1 has "action.pos" key with certain stats
82
+ - Dataset 2 has "trajectory" key with certain stats
83
+ - key_rename_map maps both to "action"
84
+ - Result should have "action" stats that combine both sources
85
+ """
86
+ print("\n" + "="*60)
87
+ print("Test: Key Rename Stats Aggregation")
88
+ print("="*60)
89
+
90
+ # Dataset 1: has "action.pos" key
91
+ stats1 = {
92
+ "action.pos": {
93
+ "mean": np.array([1.0, 2.0]),
94
+ "std": np.array([0.5, 0.5]),
95
+ "min": np.array([-1.0, -1.0]),
96
+ "max": np.array([3.0, 4.0]),
97
+ "count": np.array([1000]), # 1000 samples
98
+ }
99
+ }
100
+
101
+ # Dataset 2: has "trajectory" key
102
+ stats2 = {
103
+ "trajectory": {
104
+ "mean": np.array([5.0, 6.0]),
105
+ "std": np.array([1.0, 1.0]),
106
+ "min": np.array([0.0, 0.0]),
107
+ "max": np.array([10.0, 12.0]),
108
+ "count": np.array([1000]), # Same count for simpler math
109
+ }
110
+ }
111
+
112
+ dataset1 = MockLeRobotDataset(
113
+ repo_id="dataset_with_action_pos",
114
+ fps=20,
115
+ features={"action.pos": {"shape": [2]}},
116
+ num_frames=1000,
117
+ stats=stats1
118
+ )
119
+
120
+ dataset2 = MockLeRobotDataset(
121
+ repo_id="dataset_with_trajectory",
122
+ fps=20,
123
+ features={"trajectory": {"shape": [2]}},
124
+ num_frames=1000,
125
+ stats=stats2
126
+ )
127
+
128
+ # Create wrapped dataset with key rename map
129
+ key_rename_map = {
130
+ "action.pos": "action",
131
+ "trajectory": "action",
132
+ }
133
+
134
+ print(f"\n1. Creating wrapped dataset with key_rename_map: {key_rename_map}")
135
+
136
+ wrapped_dataset = WrappedRobotDataset(
137
+ datasets=[dataset1, dataset2],
138
+ plugins=None,
139
+ key_rename_map=key_rename_map,
140
+ )
141
+
142
+ # Check that features were renamed
143
+ print("\n2. Checking features")
144
+ assert "action" in wrapped_dataset.meta.features, \
145
+ "Renamed 'action' key should be in features"
146
+ assert "action.pos" not in wrapped_dataset.meta.features, \
147
+ "Original 'action.pos' key should not be in features"
148
+ assert "trajectory" not in wrapped_dataset.meta.features, \
149
+ "Original 'trajectory' key should not be in features"
150
+ print(" ✅ Features correctly renamed")
151
+
152
+ # Check that stats were combined
153
+ print("\n3. Checking stats aggregation")
154
+ combined_stats = wrapped_dataset.meta.stats
155
+
156
+ assert "action" in combined_stats, \
157
+ "Combined stats should have 'action' key"
158
+ assert "action.pos" not in combined_stats, \
159
+ "Original 'action.pos' should not be in combined stats"
160
+ assert "trajectory" not in combined_stats, \
161
+ "Original 'trajectory' should not be in combined stats"
162
+
163
+ # With equal counts (1000 each), the combined mean should be the average
164
+ # mean = (1000 * [1.0, 2.0] + 1000 * [5.0, 6.0]) / 2000 = [3.0, 4.0]
165
+ expected_mean = np.array([3.0, 4.0])
166
+
167
+ np.testing.assert_allclose(
168
+ combined_stats["action"]["mean"],
169
+ expected_mean,
170
+ rtol=1e-5,
171
+ err_msg="Combined mean should be average of both datasets' means"
172
+ )
173
+ print(f" Combined mean: {combined_stats['action']['mean']}")
174
+ print(f" Expected mean: {expected_mean}")
175
+ print(" ✅ Stats correctly combined")
176
+
177
+ # Check min/max
178
+ expected_min = np.array([-1.0, -1.0]) # min of both datasets
179
+ expected_max = np.array([10.0, 12.0]) # max of both datasets
180
+
181
+ np.testing.assert_allclose(
182
+ combined_stats["action"]["min"],
183
+ expected_min,
184
+ rtol=1e-5,
185
+ err_msg="Combined min should be minimum across both datasets"
186
+ )
187
+ np.testing.assert_allclose(
188
+ combined_stats["action"]["max"],
189
+ expected_max,
190
+ rtol=1e-5,
191
+ err_msg="Combined max should be maximum across both datasets"
192
+ )
193
+ print(f" Combined min: {combined_stats['action']['min']}, expected: {expected_min}")
194
+ print(f" Combined max: {combined_stats['action']['max']}, expected: {expected_max}")
195
+ print(" ✅ Min/max correctly combined")
196
+
197
+ # Check count
198
+ assert combined_stats["action"]["count"] == 2000, \
199
+ f"Combined count should be 2000, got {combined_stats['action']['count']}"
200
+ print(f" Combined count: {combined_stats['action']['count']}")
201
+ print(" ✅ Count correctly combined")
202
+
203
+ print("\n" + "="*60)
204
+ print("✅ KEY RENAME STATS TEST PASSED!")
205
+ print("="*60 + "\n")
206
+
207
+
208
+ def test_key_rename_with_different_counts():
209
+ """
210
+ Test key rename with datasets having different sample counts.
211
+
212
+ This ensures weighted aggregation works correctly with renamed keys.
213
+ """
214
+ print("\n" + "="*60)
215
+ print("Test: Key Rename Stats with Different Counts")
216
+ print("="*60)
217
+
218
+ # Dataset 1: 1000 samples with "pos"
219
+ stats1 = {
220
+ "pos": {
221
+ "mean": np.array([0.0]),
222
+ "std": np.array([1.0]),
223
+ "min": np.array([-3.0]),
224
+ "max": np.array([3.0]),
225
+ "count": np.array([1000]),
226
+ }
227
+ }
228
+
229
+ # Dataset 2: 3000 samples with "position"
230
+ stats2 = {
231
+ "position": {
232
+ "mean": np.array([4.0]),
233
+ "std": np.array([2.0]),
234
+ "min": np.array([-2.0]),
235
+ "max": np.array([10.0]),
236
+ "count": np.array([3000]),
237
+ }
238
+ }
239
+
240
+ dataset1 = MockLeRobotDataset(
241
+ repo_id="ds1",
242
+ fps=20,
243
+ features={"pos": {"shape": [1]}},
244
+ num_frames=1000,
245
+ stats=stats1
246
+ )
247
+
248
+ dataset2 = MockLeRobotDataset(
249
+ repo_id="ds2",
250
+ fps=20,
251
+ features={"position": {"shape": [1]}},
252
+ num_frames=3000,
253
+ stats=stats2
254
+ )
255
+
256
+ key_rename_map = {
257
+ "pos": "state",
258
+ "position": "state",
259
+ }
260
+
261
+ print(f"\n1. Dataset 1: 1000 samples, mean=0.0")
262
+ print(f" Dataset 2: 3000 samples, mean=4.0")
263
+ print(f" key_rename_map: {key_rename_map}")
264
+
265
+ wrapped_dataset = WrappedRobotDataset(
266
+ datasets=[dataset1, dataset2],
267
+ plugins=None,
268
+ key_rename_map=key_rename_map,
269
+ )
270
+
271
+ combined_stats = wrapped_dataset.meta.stats
272
+
273
+ # Expected weighted mean: (1000 * 0.0 + 3000 * 4.0) / 4000 = 3.0
274
+ expected_mean = np.array([3.0])
275
+
276
+ print(f"\n2. Combined stats for 'state':")
277
+ print(f" Mean: {combined_stats['state']['mean']} (expected: {expected_mean})")
278
+
279
+ np.testing.assert_allclose(
280
+ combined_stats["state"]["mean"],
281
+ expected_mean,
282
+ rtol=1e-5,
283
+ err_msg="Weighted mean should account for different counts"
284
+ )
285
+
286
+ assert combined_stats["state"]["count"] == 4000, \
287
+ f"Total count should be 4000, got {combined_stats['state']['count']}"
288
+ print(f" Count: {combined_stats['state']['count']} (expected: 4000)")
289
+
290
+ print("\n" + "="*60)
291
+ print("✅ KEY RENAME WITH DIFFERENT COUNTS TEST PASSED!")
292
+ print("="*60 + "\n")
293
+
294
+
295
+ def test_key_rename_partial_rename():
296
+ """
297
+ Test that keys that don't need renaming are preserved.
298
+
299
+ Only keys in key_rename_map should be renamed; others should pass through.
300
+ """
301
+ print("\n" + "="*60)
302
+ print("Test: Partial Key Rename")
303
+ print("="*60)
304
+
305
+ # Both datasets have "action" (no rename needed) but different secondary keys
306
+ stats1 = {
307
+ "action": {
308
+ "mean": np.array([1.0]),
309
+ "std": np.array([1.0]),
310
+ "min": np.array([0.0]),
311
+ "max": np.array([2.0]),
312
+ "count": np.array([1000]),
313
+ },
314
+ "observation.state": {
315
+ "mean": np.array([0.5]),
316
+ "std": np.array([0.1]),
317
+ "min": np.array([0.0]),
318
+ "max": np.array([1.0]),
319
+ "count": np.array([1000]),
320
+ }
321
+ }
322
+
323
+ stats2 = {
324
+ "action": {
325
+ "mean": np.array([3.0]),
326
+ "std": np.array([1.0]),
327
+ "min": np.array([2.0]),
328
+ "max": np.array([4.0]),
329
+ "count": np.array([1000]),
330
+ },
331
+ "observation.state": {
332
+ "mean": np.array([0.5]),
333
+ "std": np.array([0.1]),
334
+ "min": np.array([0.0]),
335
+ "max": np.array([1.0]),
336
+ "count": np.array([1000]),
337
+ }
338
+ }
339
+
340
+ dataset1 = MockLeRobotDataset(
341
+ repo_id="ds1",
342
+ fps=20,
343
+ features={"action": {"shape": [1]}, "observation.state": {"shape": [1]}},
344
+ num_frames=1000,
345
+ stats=stats1
346
+ )
347
+
348
+ dataset2 = MockLeRobotDataset(
349
+ repo_id="ds2",
350
+ fps=20,
351
+ features={"action": {"shape": [1]}, "observation.state": {"shape": [1]}},
352
+ num_frames=1000,
353
+ stats=stats2
354
+ )
355
+
356
+ # No key rename map - should work normally
357
+ print("\n1. Creating wrapped dataset without key_rename_map")
358
+
359
+ wrapped_dataset = WrappedRobotDataset(
360
+ datasets=[dataset1, dataset2],
361
+ plugins=None,
362
+ key_rename_map=None,
363
+ )
364
+
365
+ combined_stats = wrapped_dataset.meta.stats
366
+
367
+ # "action" should be combined normally
368
+ assert "action" in combined_stats, "action key should be present"
369
+ expected_action_mean = np.array([2.0]) # (1.0 + 3.0) / 2
370
+
371
+ np.testing.assert_allclose(
372
+ combined_stats["action"]["mean"],
373
+ expected_action_mean,
374
+ rtol=1e-5
375
+ )
376
+ print(f" action mean: {combined_stats['action']['mean']} (expected: {expected_action_mean})")
377
+
378
+ # "observation.state" should also be present
379
+ assert "observation.state" in combined_stats, "observation.state should be present"
380
+ print(f" observation.state mean: {combined_stats['observation.state']['mean']}")
381
+
382
+ print("\n" + "="*60)
383
+ print("✅ PARTIAL KEY RENAME TEST PASSED!")
384
+ print("="*60 + "\n")
385
+
386
+
387
+ if __name__ == "__main__":
388
+ test_key_rename_stats_aggregation()
389
+ test_key_rename_with_different_counts()
390
+ test_key_rename_partial_rename()
391
+
392
+ print("\n" + "="*60)
393
+ print("ALL KEY RENAME STATS TESTS PASSED!")
394
+ print("="*60 + "\n")