robocandywrapper 0.2.8__tar.gz → 0.2.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {robocandywrapper-0.2.8/robocandywrapper.egg-info → robocandywrapper-0.2.10}/PKG-INFO +1 -1
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/pyproject.toml +1 -1
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/__init__.py +1 -1
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/factory.py +4 -1
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/metadata_view.py +19 -5
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/wrapper.py +126 -37
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10/robocandywrapper.egg-info}/PKG-INFO +1 -1
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/setup.py +1 -1
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/LICENSE +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/MANIFEST.in +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/README.md +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/constants.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/dataformats/__init__.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/dataformats/lerobot_21/__init__.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/dataformats/lerobot_21/convert_v20_to_v21.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/dataformats/lerobot_21/dataset.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/dataformats/lerobot_21/utils.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/plugin.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/plugins/__init__.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/plugins/affordance.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/plugins/control_mode.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/plugins/episode_outcome.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/plugins/subtask.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/samplers/__init__.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/samplers/config.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/samplers/factory.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/samplers/weighted.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/utils.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper.egg-info/SOURCES.txt +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper.egg-info/dependency_links.txt +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper.egg-info/requires.txt +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper.egg-info/top_level.txt +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/setup.cfg +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/tests/test_dataset_weights_integration.py +0 -0
- {robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/tests/test_key_rename_stats.py +0 -0
|
@@ -257,6 +257,8 @@ def make_dataset_without_config(
|
|
|
257
257
|
plugins: Optional[list[DatasetPlugin]] = None,
|
|
258
258
|
key_rename_map: Optional[dict[str, str]] = None,
|
|
259
259
|
load_videos: bool = True,
|
|
260
|
+
pad_to_max_dim: bool = False,
|
|
261
|
+
fill_missing_images: str = "disable",
|
|
260
262
|
) -> WrappedRobotDataset:
|
|
261
263
|
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
|
|
262
264
|
|
|
@@ -304,11 +306,12 @@ def make_dataset_without_config(
|
|
|
304
306
|
load_videos=load_videos,
|
|
305
307
|
)
|
|
306
308
|
|
|
307
|
-
# Wrap in WrappedRobotDataset with plugins
|
|
308
309
|
wrapped_dataset = WrappedRobotDataset(
|
|
309
310
|
datasets=datasets,
|
|
310
311
|
plugins=plugins,
|
|
311
312
|
key_rename_map=key_rename_map,
|
|
313
|
+
pad_to_max_dim=pad_to_max_dim,
|
|
314
|
+
fill_missing_images=fill_missing_images,
|
|
312
315
|
)
|
|
313
316
|
|
|
314
317
|
return wrapped_dataset
|
|
@@ -79,11 +79,25 @@ def aggregate_stats_weighted(
|
|
|
79
79
|
if not stats_with_key:
|
|
80
80
|
continue
|
|
81
81
|
|
|
82
|
-
# Extract arrays
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
82
|
+
# Extract arrays, padding to max dim if shapes differ
|
|
83
|
+
def _to_arrays(stat_key, pad_value=0.0):
|
|
84
|
+
arrs = [np.array(s[stat_key]) for s in stats_with_key]
|
|
85
|
+
if any(a.shape != arrs[0].shape for a in arrs):
|
|
86
|
+
max_dim = max(a.shape[-1] for a in arrs if a.ndim > 0)
|
|
87
|
+
padded = []
|
|
88
|
+
for a in arrs:
|
|
89
|
+
if a.ndim > 0 and a.shape[-1] < max_dim:
|
|
90
|
+
pad_width = [(0, 0)] * (a.ndim - 1) + [(0, max_dim - a.shape[-1])]
|
|
91
|
+
padded.append(np.pad(a, pad_width, constant_values=pad_value))
|
|
92
|
+
else:
|
|
93
|
+
padded.append(a)
|
|
94
|
+
return np.stack(padded)
|
|
95
|
+
return np.stack(arrs)
|
|
96
|
+
|
|
97
|
+
means = _to_arrays("mean", pad_value=0.0)
|
|
98
|
+
stds = _to_arrays("std", pad_value=1.0)
|
|
99
|
+
mins = _to_arrays("min", pad_value=0.0)
|
|
100
|
+
maxs = _to_arrays("max", pad_value=0.0)
|
|
87
101
|
|
|
88
102
|
# Get counts and apply weight multipliers
|
|
89
103
|
# Extract scalar count value (handle both scalar and array counts)
|
|
@@ -24,6 +24,8 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
|
|
|
24
24
|
error_on_key_conflicts: bool = True,
|
|
25
25
|
dataset_weights: Optional[dict[str, float]] = None,
|
|
26
26
|
key_rename_map: Optional[dict[str, str]] = None,
|
|
27
|
+
pad_to_max_dim: bool = False,
|
|
28
|
+
fill_missing_images: str = "disable",
|
|
27
29
|
**kwargs
|
|
28
30
|
):
|
|
29
31
|
"""
|
|
@@ -44,6 +46,17 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
|
|
|
44
46
|
Note: When a key is renamed, any corresponding "_is_pad" key (added by
|
|
45
47
|
LeRobot when using delta_timestamps) is automatically renamed as well.
|
|
46
48
|
E.g., "action.pos" -> "action" also renames "action.pos_is_pad" -> "action_is_pad".
|
|
49
|
+
pad_to_max_dim: If True, features with different shapes across datasets
|
|
50
|
+
(e.g. 7-dim vs 14-dim actions) are zero-padded to the max dim instead
|
|
51
|
+
of raising an error. Adds ``action_dim_mask`` (bool tensor, True for
|
|
52
|
+
real dims) to each item so downstream loss functions can ignore the
|
|
53
|
+
padded dimensions.
|
|
54
|
+
fill_missing_images: How to handle image keys not present in all datasets.
|
|
55
|
+
- "disable" (default): remove the key entirely (original behaviour)
|
|
56
|
+
- "zeros": fill with a zero tensor of the same shape as other datasets
|
|
57
|
+
- "noise": fill with random noise (uniform [0, 255] uint8)
|
|
58
|
+
When set to "zeros" or "noise", an ``image_mask`` dict entry is also
|
|
59
|
+
added (True = real image, False = filled placeholder).
|
|
47
60
|
"""
|
|
48
61
|
super().__init__()
|
|
49
62
|
|
|
@@ -52,6 +65,8 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
|
|
|
52
65
|
self.image_transforms = image_transforms
|
|
53
66
|
self.warn_on_key_conflicts = warn_on_key_conflicts
|
|
54
67
|
self.error_on_key_conflicts = error_on_key_conflicts
|
|
68
|
+
self.pad_to_max_dim = pad_to_max_dim
|
|
69
|
+
self.fill_missing_images = fill_missing_images
|
|
55
70
|
|
|
56
71
|
# Calculate dataset boundaries for flat index space
|
|
57
72
|
self._dataset_lengths = []
|
|
@@ -102,60 +117,93 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
|
|
|
102
117
|
|
|
103
118
|
# ** MATCHING LeRobot MULTI-DATASET API DESIGN **
|
|
104
119
|
|
|
105
|
-
#
|
|
106
|
-
#
|
|
107
|
-
#
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
# "effective" features (post-rename). This allows datasets with different key names to be
|
|
111
|
-
# unified before the intersection check.
|
|
112
|
-
self.disabled_features = set()
|
|
113
|
-
intersection_features = self._get_effective_features(0)
|
|
120
|
+
# Compute feature intersection across datasets (post-rename).
|
|
121
|
+
# Image keys that are missing from some datasets can optionally be
|
|
122
|
+
# filled with zeros/noise instead of being disabled.
|
|
123
|
+
union_features = set()
|
|
124
|
+
all_ds_features = []
|
|
114
125
|
for i in range(len(self._datasets)):
|
|
115
|
-
|
|
126
|
+
ef = self._get_effective_features(i)
|
|
127
|
+
all_ds_features.append(ef)
|
|
128
|
+
union_features.update(ef)
|
|
129
|
+
|
|
130
|
+
intersection_features = set.intersection(*all_ds_features) if all_ds_features else set()
|
|
131
|
+
|
|
116
132
|
if len(intersection_features) == 0:
|
|
117
133
|
raise RuntimeError(
|
|
118
134
|
"Multiple datasets were provided but they had no keys common to all of them. "
|
|
119
135
|
"The multi-dataset functionality currently only keeps common keys."
|
|
120
136
|
)
|
|
137
|
+
|
|
138
|
+
# Determine which non-common features to disable vs fill
|
|
139
|
+
self.disabled_features = set()
|
|
140
|
+
self._filled_image_keys: set[str] = set() # image keys that need filling per-item
|
|
141
|
+
|
|
121
142
|
for i, repo_id in enumerate(self.repo_ids):
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
143
|
+
extra_keys = all_ds_features[i].difference(intersection_features)
|
|
144
|
+
for key in extra_keys:
|
|
145
|
+
is_image = "image" in key.lower() or "cam" in key.lower()
|
|
146
|
+
if is_image and self.fill_missing_images != "disable":
|
|
147
|
+
self._filled_image_keys.add(key)
|
|
148
|
+
else:
|
|
149
|
+
if key not in self._filled_image_keys:
|
|
150
|
+
self.disabled_features.add(key)
|
|
151
|
+
|
|
152
|
+
# Promote filled image keys into the effective feature set
|
|
153
|
+
for key in self._filled_image_keys:
|
|
154
|
+
if key in self.disabled_features:
|
|
155
|
+
self.disabled_features.discard(key)
|
|
156
|
+
|
|
157
|
+
if self.disabled_features:
|
|
158
|
+
logging.warning(
|
|
159
|
+
f"Non-common features disabled: {self.disabled_features} "
|
|
160
|
+
f"(not in all datasets and not eligible for filling)"
|
|
161
|
+
)
|
|
130
162
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
for
|
|
163
|
+
active_features = intersection_features | self._filled_image_keys
|
|
164
|
+
|
|
165
|
+
# Validate shapes and compute padding info for active features
|
|
166
|
+
self._feature_max_dims: dict[str, int] = {} # key -> max last-dim across datasets
|
|
167
|
+
self._per_dataset_dims: dict[str, dict[int, int]] = {} # key -> {ds_idx: dim}
|
|
168
|
+
|
|
169
|
+
for key in active_features:
|
|
134
170
|
shapes = []
|
|
135
|
-
|
|
171
|
+
per_ds = {}
|
|
136
172
|
for i, ds in enumerate(self._datasets):
|
|
137
|
-
# Find the original key (may be renamed)
|
|
138
173
|
renames = self._dataset_renames[i]
|
|
139
174
|
reverse_renames = {v: k for k, v in renames.items()}
|
|
140
175
|
original_key = reverse_renames.get(key, key)
|
|
141
|
-
|
|
142
176
|
if original_key in ds.meta.features:
|
|
143
177
|
feature_shape = ds.meta.features[original_key].get('shape', [])
|
|
144
178
|
shapes.append(tuple(feature_shape))
|
|
145
|
-
if
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
shape_details.append(f"{ds.repo_id}: {feature_shape}")
|
|
149
|
-
|
|
150
|
-
# Check if all shapes are the same
|
|
179
|
+
if feature_shape:
|
|
180
|
+
per_ds[i] = feature_shape[-1]
|
|
181
|
+
|
|
151
182
|
unique_shapes = set(shapes)
|
|
152
183
|
if len(unique_shapes) > 1:
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
184
|
+
if self.pad_to_max_dim:
|
|
185
|
+
max_dim = max(s[-1] for s in unique_shapes if s)
|
|
186
|
+
self._feature_max_dims[key] = max_dim
|
|
187
|
+
self._per_dataset_dims[key] = per_ds
|
|
188
|
+
logging.info(
|
|
189
|
+
f"Feature '{key}' has mixed dims {unique_shapes}; "
|
|
190
|
+
f"padding to {max_dim} (pad_to_max_dim=True)"
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
shape_details = []
|
|
194
|
+
for i, ds in enumerate(self._datasets):
|
|
195
|
+
renames = self._dataset_renames[i]
|
|
196
|
+
reverse_renames = {v: k for k, v in renames.items()}
|
|
197
|
+
original_key = reverse_renames.get(key, key)
|
|
198
|
+
if original_key in ds.meta.features:
|
|
199
|
+
feature_shape = ds.meta.features[original_key].get('shape', [])
|
|
200
|
+
shape_details.append(f"{ds.repo_id}: {feature_shape}")
|
|
201
|
+
raise ValueError(
|
|
202
|
+
f"Incompatible shapes for feature '{key}' across datasets:\n" +
|
|
203
|
+
"\n".join(f" - {detail}" for detail in shape_details) +
|
|
204
|
+
f"\n\nCannot mix datasets with different {key} dimensions. "
|
|
205
|
+
f"Use pad_to_max_dim=True to zero-pad smaller dims."
|
|
206
|
+
)
|
|
159
207
|
|
|
160
208
|
# Keep backward compatible stats property
|
|
161
209
|
self.stats = self._meta.stats
|
|
@@ -582,7 +630,6 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
|
|
|
582
630
|
item["dataset_index"] = torch.tensor(dataset_idx)
|
|
583
631
|
|
|
584
632
|
# Apply key renaming for this dataset (before filtering disabled features)
|
|
585
|
-
# This unifies differently-named keys across datasets
|
|
586
633
|
renames = self._dataset_renames[dataset_idx]
|
|
587
634
|
for source, target in renames.items():
|
|
588
635
|
if source in item:
|
|
@@ -592,6 +639,48 @@ class WrappedRobotDataset(torch.utils.data.Dataset):
|
|
|
592
639
|
for data_key in self.disabled_features:
|
|
593
640
|
if data_key in item:
|
|
594
641
|
del item[data_key]
|
|
642
|
+
|
|
643
|
+
# Fill missing image keys with zeros/noise
|
|
644
|
+
if self._filled_image_keys:
|
|
645
|
+
# Find a reference image shape from an existing image key in this item
|
|
646
|
+
ref_shape = None
|
|
647
|
+
for k, v in item.items():
|
|
648
|
+
if hasattr(v, 'shape') and len(getattr(v, 'shape', ())) >= 3:
|
|
649
|
+
ref_shape = v.shape
|
|
650
|
+
ref_dtype = v.dtype if hasattr(v, 'dtype') else torch.uint8
|
|
651
|
+
break
|
|
652
|
+
for key in self._filled_image_keys:
|
|
653
|
+
if key not in item and ref_shape is not None:
|
|
654
|
+
if self.fill_missing_images == "noise":
|
|
655
|
+
item[key] = torch.randint(0, 256, ref_shape, dtype=torch.uint8)
|
|
656
|
+
else:
|
|
657
|
+
item[key] = torch.zeros(ref_shape, dtype=ref_dtype)
|
|
658
|
+
|
|
659
|
+
# Pad features with mismatched dims to the max dim across datasets
|
|
660
|
+
if self._feature_max_dims:
|
|
661
|
+
import numpy as np
|
|
662
|
+
for key, max_dim in self._feature_max_dims.items():
|
|
663
|
+
if key not in item:
|
|
664
|
+
continue
|
|
665
|
+
val = item[key]
|
|
666
|
+
if not hasattr(val, 'shape'):
|
|
667
|
+
continue
|
|
668
|
+
current_dim = val.shape[-1]
|
|
669
|
+
if current_dim < max_dim:
|
|
670
|
+
pad_size = max_dim - current_dim
|
|
671
|
+
if isinstance(val, torch.Tensor):
|
|
672
|
+
pad = torch.zeros(*val.shape[:-1], pad_size, dtype=val.dtype)
|
|
673
|
+
item[key] = torch.cat([val, pad], dim=-1)
|
|
674
|
+
else:
|
|
675
|
+
pad_widths = [(0, 0)] * (len(val.shape) - 1) + [(0, pad_size)]
|
|
676
|
+
item[key] = np.pad(val, pad_widths, constant_values=0)
|
|
677
|
+
|
|
678
|
+
# Add a dimension mask for action/state features
|
|
679
|
+
if "action" in key.lower() or "state" in key.lower():
|
|
680
|
+
mask_key = f"{key}_dim_mask"
|
|
681
|
+
mask = torch.zeros(max_dim, dtype=torch.bool)
|
|
682
|
+
mask[:current_dim] = True
|
|
683
|
+
item[mask_key] = mask
|
|
595
684
|
|
|
596
685
|
# Execute plugins sequentially, passing accumulated data
|
|
597
686
|
for plugin_instance in self._plugin_instances[dataset_idx]:
|
|
@@ -9,7 +9,7 @@ long_description = readme_file.read_text(encoding="utf-8") if readme_file.exists
|
|
|
9
9
|
|
|
10
10
|
setup(
|
|
11
11
|
name="robocandywrapper",
|
|
12
|
-
version="0.2.
|
|
12
|
+
version="0.2.10",
|
|
13
13
|
description="Sweet wrappers for extending and remixing LeRobot Datasets",
|
|
14
14
|
long_description=long_description,
|
|
15
15
|
long_description_content_type="text/markdown",
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/dataformats/lerobot_21/utils.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper/plugins/episode_outcome.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/robocandywrapper.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{robocandywrapper-0.2.8 → robocandywrapper-0.2.10}/tests/test_dataset_weights_integration.py
RENAMED
|
File without changes
|
|
File without changes
|