robocandywrapper 0.2.2__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {robocandywrapper-0.2.2/robocandywrapper.egg-info → robocandywrapper-0.2.3}/PKG-INFO +1 -1
  2. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/pyproject.toml +1 -1
  3. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/__init__.py +1 -1
  4. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/dataformats/lerobot_21/dataset.py +11 -1
  5. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/factory.py +20 -0
  6. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/wrapper.py +118 -17
  7. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3/robocandywrapper.egg-info}/PKG-INFO +1 -1
  8. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/setup.py +1 -1
  9. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/LICENSE +0 -0
  10. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/MANIFEST.in +0 -0
  11. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/README.md +0 -0
  12. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/constants.py +0 -0
  13. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/dataformats/__init__.py +0 -0
  14. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/dataformats/lerobot_21/__init__.py +0 -0
  15. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/dataformats/lerobot_21/utils.py +0 -0
  16. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/metadata_view.py +0 -0
  17. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/plugin.py +0 -0
  18. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/plugins/__init__.py +0 -0
  19. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/plugins/affordance.py +0 -0
  20. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/plugins/episode_outcome.py +0 -0
  21. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/samplers/__init__.py +0 -0
  22. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/samplers/config.py +0 -0
  23. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/samplers/factory.py +0 -0
  24. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/samplers/weighted.py +0 -0
  25. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper/utils.py +0 -0
  26. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper.egg-info/SOURCES.txt +0 -0
  27. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper.egg-info/dependency_links.txt +0 -0
  28. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper.egg-info/requires.txt +0 -0
  29. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/robocandywrapper.egg-info/top_level.txt +0 -0
  30. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/setup.cfg +0 -0
  31. {robocandywrapper-0.2.2 → robocandywrapper-0.2.3}/tests/test_dataset_weights_integration.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: robocandywrapper
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: Sweet wrappers for extending and remixing LeRobot Datasets
5
5
  Author: RoboCandyWrapper Contributors
6
6
  License: MIT License
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "robocandywrapper"
7
- version = "0.2.2"
7
+ version = "0.2.3"
8
8
  description = "Sweet wrappers for extending and remixing LeRobot Datasets"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -17,7 +17,7 @@ from robocandywrapper.constants import (
17
17
  EPISODE_OUTCOME_PLUGIN_NAME,
18
18
  )
19
19
 
20
- __version__ = "0.2.1"
20
+ __version__ = "0.2.3"
21
21
 
22
22
  __all__ = [
23
23
  "DatasetPlugin",
@@ -93,6 +93,7 @@ class LeRobot21DatasetMetadata:
93
93
  try:
94
94
  if force_cache_sync:
95
95
  raise FileNotFoundError
96
+ self.pull_from_repo()
96
97
  self.load_metadata()
97
98
  except (FileNotFoundError, NotADirectoryError):
98
99
  if is_valid_version(self.revision):
@@ -728,7 +729,11 @@ class LeRobot21Dataset(torch.utils.data.Dataset):
728
729
  item = {}
729
730
  for vid_key, query_ts in query_timestamps.items():
730
731
  video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
731
- frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
732
+ try:
733
+ frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
734
+ except Exception as e:
735
+ # fall back to trying to decode with pyav
736
+ frames = decode_video_frames(video_path, query_ts, self.tolerance_s, "pyav")
732
737
  item[vid_key] = frames.squeeze(0)
733
738
 
734
739
  return item
@@ -768,6 +773,11 @@ class LeRobot21Dataset(torch.utils.data.Dataset):
768
773
  task_idx = item["task_index"].item()
769
774
  item["task"] = self.meta.tasks[task_idx]
770
775
 
776
+ # Hack - add gripper position to end
777
+ # only applies to a specific dataset
778
+ # if "observation.eef_6d_pose" in item and item["observation.eef_6d_pose"].shape[0] == 6:
779
+ # item["observation.eef_6d_pose"] = torch.cat([item["observation.eef_6d_pose"], item["observation.state"][-1:]], dim=0)
780
+
771
781
  return item
772
782
 
773
783
  def __repr__(self):
@@ -76,6 +76,7 @@ def _create_datasets(
76
76
  observation_delta_indices: Optional[List] = None,
77
77
  reward_delta_indices: Optional[List] = None,
78
78
  use_imagenet_stats: bool = True,
79
+ load_videos: bool = True,
79
80
  ) -> List[LeRobotDataset | LeRobot21Dataset]:
80
81
  """Private helper to create dataset instances from a list of repo IDs.
81
82
 
@@ -92,6 +93,8 @@ def _create_datasets(
92
93
  observation_delta_indices: Frame indices for observations.
93
94
  reward_delta_indices: Frame indices for rewards.
94
95
  use_imagenet_stats: Whether to apply ImageNet normalization stats.
96
+ load_videos: Whether to download and load video files (default: True).
97
+ Set to False to skip video downloads when not needed.
95
98
 
96
99
  Returns:
97
100
  List of dataset instances.
@@ -170,6 +173,7 @@ def _create_datasets(
170
173
  image_transforms=None, # Will be applied by WrappedRobotDataset
171
174
  revision=revision,
172
175
  video_backend=video_backend,
176
+ download_videos=load_videos,
173
177
  )
174
178
 
175
179
  # Apply ImageNet stats if needed
@@ -189,12 +193,18 @@ def _create_datasets(
189
193
  def make_dataset(
190
194
  cfg: TrainPipelineConfig,
191
195
  plugins: Optional[list[DatasetPlugin]] = None,
196
+ key_rename_map: Optional[dict[str, str]] = None,
197
+ load_videos: bool = True,
192
198
  ) -> WrappedRobotDataset:
193
199
  """Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
194
200
 
195
201
  Args:
196
202
  cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
197
203
  plugins (Optional[list[DatasetPlugin]]): Optional list of plugins to attach to the dataset(s).
204
+ key_rename_map (Optional[dict[str, str]]): Optional mapping from source keys to target keys
205
+ for unifying datasets with different naming conventions. Example: {"action.pos": "action"}
206
+ load_videos (bool): Whether to download and load video files (default: True).
207
+ Set to False to skip video downloads when not needed.
198
208
 
199
209
  Returns:
200
210
  WrappedRobotDataset: A wrapped dataset with plugin support.
@@ -221,6 +231,7 @@ def make_dataset(
221
231
  observation_delta_indices=cfg.policy.observation_delta_indices,
222
232
  reward_delta_indices=cfg.policy.reward_delta_indices,
223
233
  use_imagenet_stats=cfg.dataset.use_imagenet_stats,
234
+ load_videos=load_videos,
224
235
  )
225
236
 
226
237
  # Wrap in WrappedRobotDataset with plugins
@@ -228,6 +239,7 @@ def make_dataset(
228
239
  datasets=datasets,
229
240
  plugins=plugins,
230
241
  image_transforms=image_transforms,
242
+ key_rename_map=key_rename_map,
231
243
  )
232
244
 
233
245
  return wrapped_dataset
@@ -243,6 +255,8 @@ def make_dataset_without_config(
243
255
  revision: str | None = None,
244
256
  use_imagenet_stats: bool = True,
245
257
  plugins: Optional[list[DatasetPlugin]] = None,
258
+ key_rename_map: Optional[dict[str, str]] = None,
259
+ load_videos: bool = True,
246
260
  ) -> WrappedRobotDataset:
247
261
  """Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
248
262
 
@@ -259,6 +273,10 @@ def make_dataset_without_config(
259
273
  revision (str, optional): Dataset revision
260
274
  use_imagenet_stats (bool): Whether to use ImageNet normalization stats (default: True)
261
275
  plugins (Optional[list[DatasetPlugin]]): Optional list of plugins to attach to the dataset(s)
276
+ key_rename_map (Optional[dict[str, str]]): Optional mapping from source keys to target keys
277
+ for unifying datasets with different naming conventions. Example: {"action.pos": "action"}
278
+ load_videos (bool): Whether to download and load video files (default: True).
279
+ Set to False to skip video downloads when not needed.
262
280
 
263
281
  Returns:
264
282
  WrappedRobotDataset: A wrapped dataset with plugin support.
@@ -283,12 +301,14 @@ def make_dataset_without_config(
283
301
  action_delta_indices=action_delta_indices,
284
302
  observation_delta_indices=observation_delta_indices,
285
303
  use_imagenet_stats=use_imagenet_stats,
304
+ load_videos=load_videos,
286
305
  )
287
306
 
288
307
  # Wrap in WrappedRobotDataset with plugins
289
308
  wrapped_dataset = WrappedRobotDataset(
290
309
  datasets=datasets,
291
310
  plugins=plugins,
311
+ key_rename_map=key_rename_map,
292
312
  )
293
313
 
294
314
  return wrapped_dataset
@@ -23,6 +23,7 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
23
23
  warn_on_key_conflicts: bool = True,
24
24
  error_on_key_conflicts: bool = True,
25
25
  dataset_weights: Optional[dict[str, float]] = None,
26
+ key_rename_map: Optional[dict[str, str]] = None,
26
27
  **kwargs
27
28
  ):
28
29
  """
@@ -35,6 +36,14 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
35
36
  warn_on_key_conflicts: Warn when plugins have overlapping keys (if not raising errors)
36
37
  error_on_key_conflicts: Raise error on key conflicts (default: True)
37
38
  dataset_weights: Optional weights for computing weighted stats (e.g., {"dataset_id": 2.0})
39
+ key_rename_map: Optional mapping from source keys to target keys for unifying
40
+ datasets with different naming conventions. Keys are renamed before the
41
+ intersection logic runs, allowing datasets with different key names to be
42
+ combined. Example: {"action.pos": "action", "trajectory": "action"}
43
+
44
+ Note: When a key is renamed, any corresponding "_is_pad" key (added by
45
+ LeRobot when using delta_timestamps) is automatically renamed as well.
46
+ E.g., "action.pos" -> "action" also renames "action.pos_is_pad" -> "action_is_pad".
38
47
  """
39
48
  super().__init__()
40
49
 
@@ -64,6 +73,10 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
64
73
  self._cumulative_lengths.append(self._cumulative_lengths[-1] + length)
65
74
  self._total_length = self._cumulative_lengths[-1]
66
75
 
76
+ # Key rename mapping: unify differently-named keys across datasets
77
+ self.key_rename_map = key_rename_map or {}
78
+ self._dataset_renames = self._compute_dataset_renames()
79
+
67
80
  # Plugin management: one plugin class, many instances (one per dataset)
68
81
  self._plugins: list[DatasetPlugin] = plugins or []
69
82
  self._plugin_instances: list[list[PluginInstance]] = []
@@ -91,41 +104,51 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
91
104
  # Disable any data keys that are not common across all of the datasets. Note: we may relax this
92
105
  # restriction in future iterations of this class. For now, this is necessary at least for being able
93
106
  # to use PyTorch's default DataLoader collate function.
107
+ #
108
+ # Key rename mapping is applied first (conceptually), so intersection is computed on
109
+ # "effective" features (post-rename). This allows datasets with different key names to be
110
+ # unified before the intersection check.
94
111
  self.disabled_features = set()
95
- intersection_features = set(self._datasets[0].features)
96
- for ds in self._datasets:
97
- intersection_features.intersection_update(ds.features)
112
+ intersection_features = self._get_effective_features(0)
113
+ for i in range(len(self._datasets)):
114
+ intersection_features.intersection_update(self._get_effective_features(i))
98
115
  if len(intersection_features) == 0:
99
116
  raise RuntimeError(
100
117
  "Multiple datasets were provided but they had no keys common to all of them. "
101
118
  "The multi-dataset functionality currently only keeps common keys."
102
119
  )
103
- for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
104
- extra_keys = set(ds.features).difference(intersection_features)
120
+ for i, repo_id in enumerate(self.repo_ids):
121
+ effective_keys = self._get_effective_features(i)
122
+ extra_keys = effective_keys.difference(intersection_features)
105
123
  if extra_keys:
106
124
  logging.warning(
107
- f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
108
- "other datasets."
125
+ f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
126
+ "other datasets."
109
127
  )
110
128
  self.disabled_features.update(extra_keys)
111
129
 
112
130
  # Validate that common features have compatible shapes
131
+ # Note: We need to look up the original key name for renamed keys
113
132
  for key in intersection_features:
114
133
  shapes = []
115
- for ds in self._datasets:
116
- if key in ds.meta.features:
117
- feature_shape = ds.meta.features[key].get('shape', [])
134
+ shape_details = []
135
+ for i, ds in enumerate(self._datasets):
136
+ # Find the original key (may be renamed)
137
+ renames = self._dataset_renames[i]
138
+ reverse_renames = {v: k for k, v in renames.items()}
139
+ original_key = reverse_renames.get(key, key)
140
+
141
+ if original_key in ds.meta.features:
142
+ feature_shape = ds.meta.features[original_key].get('shape', [])
118
143
  shapes.append(tuple(feature_shape))
144
+ if original_key != key:
145
+ shape_details.append(f"{ds.repo_id}: {feature_shape} (from '{original_key}')")
146
+ else:
147
+ shape_details.append(f"{ds.repo_id}: {feature_shape}")
119
148
 
120
149
  # Check if all shapes are the same
121
150
  unique_shapes = set(shapes)
122
151
  if len(unique_shapes) > 1:
123
- shape_details = []
124
- for ds in self._datasets:
125
- if key in ds.meta.features:
126
- shape = ds.meta.features[key].get('shape', [])
127
- shape_details.append(f"{ds.repo_id}: {shape}")
128
-
129
152
  raise ValueError(
130
153
  f"Incompatible shapes for feature '{key}' across datasets:\n" +
131
154
  "\n".join(f" - {detail}" for detail in shape_details) +
@@ -296,6 +319,77 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
296
319
  # Also update the cached stats property
297
320
  self.stats = self._meta.stats
298
321
 
322
+ def _compute_dataset_renames(self) -> list[dict[str, str]]:
323
+ """
324
+ Pre-compute which key renames apply to each dataset.
325
+
326
+ For each dataset, determines which source keys from key_rename_map exist
327
+ and can be renamed (i.e., target key doesn't already exist).
328
+
329
+ Also automatically handles derived _is_pad keys that LeRobot adds when
330
+ delta_timestamps are used. For example, if renaming "action.pos" -> "action",
331
+ this will also rename "action.pos_is_pad" -> "action_is_pad".
332
+
333
+ Returns:
334
+ List of dicts mapping source_key -> target_key for each dataset
335
+ """
336
+ dataset_renames = []
337
+ for dataset in self._datasets:
338
+ ds_renames = {}
339
+ ds_keys = set(dataset.features)
340
+
341
+ for source, target in self.key_rename_map.items():
342
+ if source in ds_keys:
343
+ if target in ds_keys:
344
+ # Target already exists in this dataset - skip rename to avoid conflict
345
+ logging.warning(
346
+ f"Skipping rename '{source}' -> '{target}' for {dataset.repo_id}: "
347
+ f"target key already exists in dataset"
348
+ )
349
+ else:
350
+ ds_renames[source] = target
351
+
352
+ # Also handle the _is_pad suffix that LeRobot adds for delta_timestamps
353
+ # These keys are dynamically added during __getitem__ and may not be in
354
+ # dataset.features, but we still want to rename them consistently
355
+ is_pad_source = f"{source}_is_pad"
356
+ is_pad_target = f"{target}_is_pad"
357
+
358
+ # Check for conflicts on the _is_pad key as well
359
+ if is_pad_target in ds_keys:
360
+ logging.warning(
361
+ f"Skipping derived rename '{is_pad_source}' -> '{is_pad_target}' "
362
+ f"for {dataset.repo_id}: target key already exists in dataset"
363
+ )
364
+ else:
365
+ ds_renames[is_pad_source] = is_pad_target
366
+
367
+ dataset_renames.append(ds_renames)
368
+
369
+ return dataset_renames
370
+
371
+ def _get_effective_features(self, dataset_idx: int) -> set[str]:
372
+ """
373
+ Get the effective feature keys for a dataset after applying renames.
374
+
375
+ Args:
376
+ dataset_idx: Index of the dataset
377
+
378
+ Returns:
379
+ Set of feature keys that would exist after renaming
380
+ """
381
+ ds = self._datasets[dataset_idx]
382
+ renames = self._dataset_renames[dataset_idx]
383
+
384
+ effective = set()
385
+ for key in ds.features:
386
+ if key in renames:
387
+ effective.add(renames[key])
388
+ else:
389
+ effective.add(key)
390
+
391
+ return effective
392
+
299
393
  def _validate_plugin_keys(self):
300
394
  """
301
395
  Check for key conflicts between plugins.
@@ -483,7 +577,14 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
483
577
  # Add dataset index
484
578
  item["dataset_index"] = torch.tensor(dataset_idx)
485
579
 
486
- # Remove disabled features
580
+ # Apply key renaming for this dataset (before filtering disabled features)
581
+ # This unifies differently-named keys across datasets
582
+ renames = self._dataset_renames[dataset_idx]
583
+ for source, target in renames.items():
584
+ if source in item:
585
+ item[target] = item.pop(source)
586
+
587
+ # Remove disabled features (now operates on effective/renamed key names)
487
588
  for data_key in self.disabled_features:
488
589
  if data_key in item:
489
590
  del item[data_key]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: robocandywrapper
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: Sweet wrappers for extending and remixing LeRobot Datasets
5
5
  Author: RoboCandyWrapper Contributors
6
6
  License: MIT License
@@ -9,7 +9,7 @@ long_description = readme_file.read_text(encoding="utf-8") if readme_file.exists
9
9
 
10
10
  setup(
11
11
  name="robocandywrapper",
12
- version="0.2.1",
12
+ version="0.2.3",
13
13
  description="Sweet wrappers for extending and remixing LeRobot Datasets",
14
14
  long_description=long_description,
15
15
  long_description_content_type="text/markdown",