robocandywrapper 0.2.8__tar.gz → 0.2.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {robocandywrapper-0.2.8/robocandywrapper.egg-info → robocandywrapper-0.2.10}/PKG-INFO +1 -1
  2. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/pyproject.toml +1 -1
  3. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/__init__.py +1 -1
  4. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/factory.py +4 -1
  5. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/metadata_view.py +19 -5
  6. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/wrapper.py +126 -37
  7. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10/robocandywrapper.egg-info}/PKG-INFO +1 -1
  8. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/setup.py +1 -1
  9. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/LICENSE +0 -0
  10. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/MANIFEST.in +0 -0
  11. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/README.md +0 -0
  12. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/constants.py +0 -0
  13. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/dataformats/__init__.py +0 -0
  14. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/dataformats/lerobot_21/__init__.py +0 -0
  15. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/dataformats/lerobot_21/convert_v20_to_v21.py +0 -0
  16. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/dataformats/lerobot_21/dataset.py +0 -0
  17. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/dataformats/lerobot_21/utils.py +0 -0
  18. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/plugin.py +0 -0
  19. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/plugins/__init__.py +0 -0
  20. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/plugins/affordance.py +0 -0
  21. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/plugins/control_mode.py +0 -0
  22. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/plugins/episode_outcome.py +0 -0
  23. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/plugins/subtask.py +0 -0
  24. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/samplers/__init__.py +0 -0
  25. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/samplers/config.py +0 -0
  26. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/samplers/factory.py +0 -0
  27. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/samplers/weighted.py +0 -0
  28. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/utils.py +0 -0
  29. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper.egg-info/SOURCES.txt +0 -0
  30. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper.egg-info/dependency_links.txt +0 -0
  31. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper.egg-info/requires.txt +0 -0
  32. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper.egg-info/top_level.txt +0 -0
  33. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/setup.cfg +0 -0
  34. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/tests/test_dataset_weights_integration.py +0 -0
  35. {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/tests/test_key_rename_stats.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: robocandywrapper
3
- Version: 0.2.8
3
+ Version: 0.2.10
4
4
  Summary: Sweet wrappers for extending and remixing LeRobot Datasets
5
5
  Author: RoboCandyWrapper Contributors
6
6
  License: MIT License
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "robocandywrapper"
7
- version = "0.2.8"
7
+ version = "0.2.10"
8
8
  description = "Sweet wrappers for extending and remixing LeRobot Datasets"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -17,7 +17,7 @@ from robocandywrapper.constants import (
17
17
  EPISODE_OUTCOME_PLUGIN_NAME,
18
18
  )
19
19
 
20
- __version__ = "0.2.8"
20
+ __version__ = "0.2.10"
21
21
 
22
22
  __all__ = [
23
23
  "DatasetPlugin",
@@ -257,6 +257,8 @@ def make_dataset_without_config(
257
257
  plugins: Optional[list[DatasetPlugin]] = None,
258
258
  key_rename_map: Optional[dict[str, str]] = None,
259
259
  load_videos: bool = True,
260
+ pad_to_max_dim: bool = False,
261
+ fill_missing_images: str = "disable",
260
262
  ) -> WrappedRobotDataset:
261
263
  """Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
262
264
 
@@ -304,11 +306,12 @@ def make_dataset_without_config(
304
306
  load_videos=load_videos,
305
307
  )
306
308
 
307
- # Wrap in WrappedRobotDataset with plugins
308
309
  wrapped_dataset = WrappedRobotDataset(
309
310
  datasets=datasets,
310
311
  plugins=plugins,
311
312
  key_rename_map=key_rename_map,
313
+ pad_to_max_dim=pad_to_max_dim,
314
+ fill_missing_images=fill_missing_images,
312
315
  )
313
316
 
314
317
  return wrapped_dataset
@@ -79,11 +79,25 @@ def aggregate_stats_weighted(
79
79
  if not stats_with_key:
80
80
  continue
81
81
 
82
- # Extract arrays
83
- means = np.stack([np.array(s["mean"]) for s in stats_with_key])
84
- stds = np.stack([np.array(s["std"]) for s in stats_with_key])
85
- mins = np.stack([np.array(s["min"]) for s in stats_with_key])
86
- maxs = np.stack([np.array(s["max"]) for s in stats_with_key])
82
+ # Extract arrays, padding to max dim if shapes differ
83
+ def _to_arrays(stat_key, pad_value=0.0):
84
+ arrs = [np.array(s[stat_key]) for s in stats_with_key]
85
+ if any(a.shape != arrs[0].shape for a in arrs):
86
+ max_dim = max(a.shape[-1] for a in arrs if a.ndim > 0)
87
+ padded = []
88
+ for a in arrs:
89
+ if a.ndim > 0 and a.shape[-1] < max_dim:
90
+ pad_width = [(0, 0)] * (a.ndim - 1) + [(0, max_dim - a.shape[-1])]
91
+ padded.append(np.pad(a, pad_width, constant_values=pad_value))
92
+ else:
93
+ padded.append(a)
94
+ return np.stack(padded)
95
+ return np.stack(arrs)
96
+
97
+ means = _to_arrays("mean", pad_value=0.0)
98
+ stds = _to_arrays("std", pad_value=1.0)
99
+ mins = _to_arrays("min", pad_value=0.0)
100
+ maxs = _to_arrays("max", pad_value=0.0)
87
101
 
88
102
  # Get counts and apply weight multipliers
89
103
  # Extract scalar count value (handle both scalar and array counts)
@@ -24,6 +24,8 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
24
24
  error_on_key_conflicts: bool = True,
25
25
  dataset_weights: Optional[dict[str, float]] = None,
26
26
  key_rename_map: Optional[dict[str, str]] = None,
27
+ pad_to_max_dim: bool = False,
28
+ fill_missing_images: str = "disable",
27
29
  **kwargs
28
30
  ):
29
31
  """
@@ -44,6 +46,17 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
44
46
  Note: When a key is renamed, any corresponding "_is_pad" key (added by
45
47
  LeRobot when using delta_timestamps) is automatically renamed as well.
46
48
  E.g., "action.pos" -> "action" also renames "action.pos_is_pad" -> "action_is_pad".
49
+ pad_to_max_dim: If True, features with different shapes across datasets
50
+ (e.g. 7-dim vs 14-dim actions) are zero-padded to the max dim instead
51
+ of raising an error. Adds ``action_dim_mask`` (bool tensor, True for
52
+ real dims) to each item so downstream loss functions can ignore the
53
+ padded dimensions.
54
+ fill_missing_images: How to handle image keys not present in all datasets.
55
+ - "disable" (default): remove the key entirely (original behaviour)
56
+ - "zeros": fill with a zero tensor of the same shape as other datasets
57
+ - "noise": fill with random noise (uniform [0, 255] uint8)
58
+ When set to "zeros" or "noise", an ``image_mask`` dict entry is also
59
+ added (True = real image, False = filled placeholder).
47
60
  """
48
61
  super().__init__()
49
62
 
@@ -52,6 +65,8 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
52
65
  self.image_transforms = image_transforms
53
66
  self.warn_on_key_conflicts = warn_on_key_conflicts
54
67
  self.error_on_key_conflicts = error_on_key_conflicts
68
+ self.pad_to_max_dim = pad_to_max_dim
69
+ self.fill_missing_images = fill_missing_images
55
70
 
56
71
  # Calculate dataset boundaries for flat index space
57
72
  self._dataset_lengths = []
@@ -102,60 +117,93 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
102
117
 
103
118
  # ** MATCHING LeRobot MULTI-DATASET API DESIGN **
104
119
 
105
- # Disable any data keys that are not common across all of the datasets. Note: we may relax this
106
- # restriction in future iterations of this class. For now, this is necessary at least for being able
107
- # to use PyTorch's default DataLoader collate function.
108
- #
109
- # Key rename mapping is applied first (conceptually), so intersection is computed on
110
- # "effective" features (post-rename). This allows datasets with different key names to be
111
- # unified before the intersection check.
112
- self.disabled_features = set()
113
- intersection_features = self._get_effective_features(0)
120
+ # Compute feature intersection across datasets (post-rename).
121
+ # Image keys that are missing from some datasets can optionally be
122
+ # filled with zeros/noise instead of being disabled.
123
+ union_features = set()
124
+ all_ds_features = []
114
125
  for i in range(len(self._datasets)):
115
- intersection_features.intersection_update(self._get_effective_features(i))
126
+ ef = self._get_effective_features(i)
127
+ all_ds_features.append(ef)
128
+ union_features.update(ef)
129
+
130
+ intersection_features = set.intersection(*all_ds_features) if all_ds_features else set()
131
+
116
132
  if len(intersection_features) == 0:
117
133
  raise RuntimeError(
118
134
  "Multiple datasets were provided but they had no keys common to all of them. "
119
135
  "The multi-dataset functionality currently only keeps common keys."
120
136
  )
137
+
138
+ # Determine which non-common features to disable vs fill
139
+ self.disabled_features = set()
140
+ self._filled_image_keys: set[str] = set() # image keys that need filling per-item
141
+
121
142
  for i, repo_id in enumerate(self.repo_ids):
122
- effective_keys = self._get_effective_features(i)
123
- extra_keys = effective_keys.difference(intersection_features)
124
- if extra_keys:
125
- logging.warning(
126
- f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
127
- "other datasets."
128
- )
129
- self.disabled_features.update(extra_keys)
143
+ extra_keys = all_ds_features[i].difference(intersection_features)
144
+ for key in extra_keys:
145
+ is_image = "image" in key.lower() or "cam" in key.lower()
146
+ if is_image and self.fill_missing_images != "disable":
147
+ self._filled_image_keys.add(key)
148
+ else:
149
+ if key not in self._filled_image_keys:
150
+ self.disabled_features.add(key)
151
+
152
+ # Promote filled image keys into the effective feature set
153
+ for key in self._filled_image_keys:
154
+ if key in self.disabled_features:
155
+ self.disabled_features.discard(key)
156
+
157
+ if self.disabled_features:
158
+ logging.warning(
159
+ f"Non-common features disabled: {self.disabled_features} "
160
+ f"(not in all datasets and not eligible for filling)"
161
+ )
130
162
 
131
- # Validate that common features have compatible shapes
132
- # Note: We need to look up the original key name for renamed keys
133
- for key in intersection_features:
163
+ active_features = intersection_features | self._filled_image_keys
164
+
165
+ # Validate shapes and compute padding info for active features
166
+ self._feature_max_dims: dict[str, int] = {} # key -> max last-dim across datasets
167
+ self._per_dataset_dims: dict[str, dict[int, int]] = {} # key -> {ds_idx: dim}
168
+
169
+ for key in active_features:
134
170
  shapes = []
135
- shape_details = []
171
+ per_ds = {}
136
172
  for i, ds in enumerate(self._datasets):
137
- # Find the original key (may be renamed)
138
173
  renames = self._dataset_renames[i]
139
174
  reverse_renames = {v: k for k, v in renames.items()}
140
175
  original_key = reverse_renames.get(key, key)
141
-
142
176
  if original_key in ds.meta.features:
143
177
  feature_shape = ds.meta.features[original_key].get('shape', [])
144
178
  shapes.append(tuple(feature_shape))
145
- if original_key != key:
146
- shape_details.append(f"{ds.repo_id}: {feature_shape} (from '{original_key}')")
147
- else:
148
- shape_details.append(f"{ds.repo_id}: {feature_shape}")
149
-
150
- # Check if all shapes are the same
179
+ if feature_shape:
180
+ per_ds[i] = feature_shape[-1]
181
+
151
182
  unique_shapes = set(shapes)
152
183
  if len(unique_shapes) > 1:
153
- raise ValueError(
154
- f"Incompatible shapes for feature '{key}' across datasets:\n" +
155
- "\n".join(f" - {detail}" for detail in shape_details) +
156
- f"\n\nCannot mix datasets with different {key} dimensions. "
157
- f"This typically happens when mixing datasets from different robot configurations."
158
- )
184
+ if self.pad_to_max_dim:
185
+ max_dim = max(s[-1] for s in unique_shapes if s)
186
+ self._feature_max_dims[key] = max_dim
187
+ self._per_dataset_dims[key] = per_ds
188
+ logging.info(
189
+ f"Feature '{key}' has mixed dims {unique_shapes}; "
190
+ f"padding to {max_dim} (pad_to_max_dim=True)"
191
+ )
192
+ else:
193
+ shape_details = []
194
+ for i, ds in enumerate(self._datasets):
195
+ renames = self._dataset_renames[i]
196
+ reverse_renames = {v: k for k, v in renames.items()}
197
+ original_key = reverse_renames.get(key, key)
198
+ if original_key in ds.meta.features:
199
+ feature_shape = ds.meta.features[original_key].get('shape', [])
200
+ shape_details.append(f"{ds.repo_id}: {feature_shape}")
201
+ raise ValueError(
202
+ f"Incompatible shapes for feature '{key}' across datasets:\n" +
203
+ "\n".join(f" - {detail}" for detail in shape_details) +
204
+ f"\n\nCannot mix datasets with different {key} dimensions. "
205
+ f"Use pad_to_max_dim=True to zero-pad smaller dims."
206
+ )
159
207
 
160
208
  # Keep backward compatible stats property
161
209
  self.stats = self._meta.stats
@@ -582,7 +630,6 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
582
630
  item["dataset_index"] = torch.tensor(dataset_idx)
583
631
 
584
632
  # Apply key renaming for this dataset (before filtering disabled features)
585
- # This unifies differently-named keys across datasets
586
633
  renames = self._dataset_renames[dataset_idx]
587
634
  for source, target in renames.items():
588
635
  if source in item:
@@ -592,6 +639,48 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
592
639
  for data_key in self.disabled_features:
593
640
  if data_key in item:
594
641
  del item[data_key]
642
+
643
+ # Fill missing image keys with zeros/noise
644
+ if self._filled_image_keys:
645
+ # Find a reference image shape from an existing image key in this item
646
+ ref_shape = None
647
+ for k, v in item.items():
648
+ if hasattr(v, 'shape') and len(getattr(v, 'shape', ())) >= 3:
649
+ ref_shape = v.shape
650
+ ref_dtype = v.dtype if hasattr(v, 'dtype') else torch.uint8
651
+ break
652
+ for key in self._filled_image_keys:
653
+ if key not in item and ref_shape is not None:
654
+ if self.fill_missing_images == "noise":
655
+ item[key] = torch.randint(0, 256, ref_shape, dtype=torch.uint8)
656
+ else:
657
+ item[key] = torch.zeros(ref_shape, dtype=ref_dtype)
658
+
659
+ # Pad features with mismatched dims to the max dim across datasets
660
+ if self._feature_max_dims:
661
+ import numpy as np
662
+ for key, max_dim in self._feature_max_dims.items():
663
+ if key not in item:
664
+ continue
665
+ val = item[key]
666
+ if not hasattr(val, 'shape'):
667
+ continue
668
+ current_dim = val.shape[-1]
669
+ if current_dim < max_dim:
670
+ pad_size = max_dim - current_dim
671
+ if isinstance(val, torch.Tensor):
672
+ pad = torch.zeros(*val.shape[:-1], pad_size, dtype=val.dtype)
673
+ item[key] = torch.cat([val, pad], dim=-1)
674
+ else:
675
+ pad_widths = [(0, 0)] * (len(val.shape) - 1) + [(0, pad_size)]
676
+ item[key] = np.pad(val, pad_widths, constant_values=0)
677
+
678
+ # Add a dimension mask for action/state features
679
+ if "action" in key.lower() or "state" in key.lower():
680
+ mask_key = f"{key}_dim_mask"
681
+ mask = torch.zeros(max_dim, dtype=torch.bool)
682
+ mask[:current_dim] = True
683
+ item[mask_key] = mask
595
684
 
596
685
  # Execute plugins sequentially, passing accumulated data
597
686
  for plugin_instance in self._plugin_instances[dataset_idx]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: robocandywrapper
3
- Version: 0.2.8
3
+ Version: 0.2.10
4
4
  Summary: Sweet wrappers for extending and remixing LeRobot Datasets
5
5
  Author: RoboCandyWrapper Contributors
6
6
  License: MIT License
@@ -9,7 +9,7 @@ long_description = readme_file.read_text(encoding="utf-8") if readme_file.exists
9
9
 
10
10
  setup(
11
11
  name="robocandywrapper",
12
- version="0.2.8",
12
+ version="0.2.10",
13
13
  description="Sweet wrappers for extending and remixing LeRobot Datasets",
14
14
  long_description=long_description,
15
15
  long_description_content_type="text/markdown",