mcap-data-loader 0.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcap_data_loader-0.0.0/LICENSE +21 -0
- mcap_data_loader-0.0.0/PKG-INFO +19 -0
- mcap_data_loader-0.0.0/mcap_data_loader/__init__.py +0 -0
- mcap_data_loader-0.0.0/mcap_data_loader/datasets/dataset.py +267 -0
- mcap_data_loader-0.0.0/mcap_data_loader/datasets/mcap_dataset.py +315 -0
- mcap_data_loader-0.0.0/mcap_data_loader/schemas/airbot_fbs/FloatArray.py +92 -0
- mcap_data_loader-0.0.0/mcap_data_loader/schemas/airbot_fbs/__init__.py +0 -0
- mcap_data_loader-0.0.0/mcap_data_loader/schemas/airbot_fbs/bfbs/__init__.py +5 -0
- mcap_data_loader-0.0.0/mcap_data_loader/utils/av_coder.py +410 -0
- mcap_data_loader-0.0.0/mcap_data_loader/utils/basic.py +179 -0
- mcap_data_loader-0.0.0/mcap_data_loader/utils/mcap_utils.py +600 -0
- mcap_data_loader-0.0.0/mcap_data_loader.egg-info/PKG-INFO +19 -0
- mcap_data_loader-0.0.0/mcap_data_loader.egg-info/SOURCES.txt +16 -0
- mcap_data_loader-0.0.0/mcap_data_loader.egg-info/dependency_links.txt +1 -0
- mcap_data_loader-0.0.0/mcap_data_loader.egg-info/requires.txt +10 -0
- mcap_data_loader-0.0.0/mcap_data_loader.egg-info/top_level.txt +1 -0
- mcap_data_loader-0.0.0/pyproject.toml +28 -0
- mcap_data_loader-0.0.0/setup.cfg +4 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Ge Haizhou
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: mcap-data-loader
|
|
3
|
+
Version: 0.0.0
|
|
4
|
+
Summary: MCAP Data Loader
|
|
5
|
+
Author-email: OpenGHz <your.email@example.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Requires-Python: >=3.9
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Requires-Dist: pydantic
|
|
11
|
+
Requires-Dist: numpy
|
|
12
|
+
Requires-Dist: more-itertools
|
|
13
|
+
Requires-Dist: typing-extensions
|
|
14
|
+
Requires-Dist: flatbuffers
|
|
15
|
+
Requires-Dist: foxglove-schemas-flatbuffer
|
|
16
|
+
Requires-Dist: mcap
|
|
17
|
+
Requires-Dist: flatten-dict
|
|
18
|
+
Requires-Dist: av
|
|
19
|
+
Requires-Dist: PyTurboJPEG
|
|
File without changes
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import (
|
|
3
|
+
Any,
|
|
4
|
+
Callable,
|
|
5
|
+
Generator,
|
|
6
|
+
Iterable,
|
|
7
|
+
Iterator,
|
|
8
|
+
List,
|
|
9
|
+
Optional,
|
|
10
|
+
Dict,
|
|
11
|
+
Union,
|
|
12
|
+
)
|
|
13
|
+
from pydantic import BaseModel, NonNegativeInt, computed_field
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
|
+
from functools import cached_property
|
|
16
|
+
from logging import getLogger
|
|
17
|
+
from mcap_data_loader.utils.basic import StrEnum, SlicesType, multi_slices_to_indexes
|
|
18
|
+
from enum import auto
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from torch.utils.data import IterableDataset, get_worker_info
|
|
22
|
+
except ImportError as e:
|
|
23
|
+
|
|
24
|
+
class IterableDataset:
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
# Dummy function if torch is not available
|
|
28
|
+
get_worker_info = lambda: None # noqa: E731
|
|
29
|
+
getLogger(__name__).warning(
|
|
30
|
+
"torch.utils.data is not available, some features may not work. "
|
|
31
|
+
"Please install PyTorch to use these features."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
DictableSlicesType = Union[Dict[str, SlicesType], SlicesType]
|
|
35
|
+
DictableIndexesType = Union[Dict[str, List[int]], List[int]]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class RearrangeType(StrEnum):
|
|
39
|
+
NONE = auto()
|
|
40
|
+
SORT = auto()
|
|
41
|
+
SHUFFLE = auto()
|
|
42
|
+
REVERSE = auto()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class DataSlicesConfig(BaseModel):
|
|
46
|
+
"""Configuration for slicing data.
|
|
47
|
+
This class defines how to slice samples, episodes, and datasets.
|
|
48
|
+
Args:
|
|
49
|
+
sample: Consider a flattened dict sample {'key1': [1, 2, 3], 'key2': [4, 5, 6]},
|
|
50
|
+
given the dict slices: {'key1': (0, 2), 'key2': (1, 3)}, the result will be:
|
|
51
|
+
{'key1': [1, 2], 'key2': [5, 6]}.
|
|
52
|
+
episode: Consider a flattened dataset: {'/path1/episode0': [point1, point2, point3],
|
|
53
|
+
'/path2/episode1': [point1, point2, point3]}, given the dict slices: {'/path1/episode0': (0, 2),
|
|
54
|
+
'/path2/episode1': (1, 3)}, the result will be {'/path1/episode0': [point1, point2],
|
|
55
|
+
'/path2/episode1': [point2, point3]}
|
|
56
|
+
dataset: Consider a flattened dataset with multiple sub-datasets:
|
|
57
|
+
{'dataset1': ['episode1', 'episode2', 'episode3'], 'dataset2': ['episode1', 'episode2', 'episode3']},
|
|
58
|
+
given the dict slices: {'dataset1': (0, 2), 'dataset2': (1, 3)}, the result will be:
|
|
59
|
+
{'dataset1': ['episode1', 'episode2'], 'dataset2': ['episode2', 'episode3']}
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
sample: DictableSlicesType = {}
|
|
63
|
+
episode: DictableSlicesType = {}
|
|
64
|
+
dataset: DictableSlicesType = {}
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def _slices_to_indexes(slices: DictableSlicesType) -> DictableIndexesType:
|
|
68
|
+
"""
|
|
69
|
+
Convert slices to indexes.
|
|
70
|
+
If slices is a dict, convert each key's slices to indexes.
|
|
71
|
+
If slices is a list, convert the list of slices to indexes.
|
|
72
|
+
"""
|
|
73
|
+
if isinstance(slices, dict):
|
|
74
|
+
return {k: multi_slices_to_indexes(v) for k, v in slices.items()}
|
|
75
|
+
elif isinstance(slices, list):
|
|
76
|
+
return multi_slices_to_indexes(slices)
|
|
77
|
+
|
|
78
|
+
@computed_field
|
|
79
|
+
@cached_property
|
|
80
|
+
def sample_indexes(self) -> DictableIndexesType:
|
|
81
|
+
return self._slices_to_indexes(self.sample)
|
|
82
|
+
|
|
83
|
+
@computed_field
|
|
84
|
+
@cached_property
|
|
85
|
+
def episode_indexes(self) -> DictableIndexesType:
|
|
86
|
+
return self._slices_to_indexes(self.episode)
|
|
87
|
+
|
|
88
|
+
@computed_field
|
|
89
|
+
@cached_property
|
|
90
|
+
def dataset_indexes(self) -> DictableIndexesType:
|
|
91
|
+
return self._slices_to_indexes(self.dataset)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class DataRearrangeConfig(BaseModel):
|
|
95
|
+
"""Configuration for rearranging data.
|
|
96
|
+
This class defines how to rearrange samples, episodes, and datasets.
|
|
97
|
+
Args:
|
|
98
|
+
sample: Rearrangement strategy for each sample (rarely used).
|
|
99
|
+
episode: Rearrangement strategy for each episode (e.g. reverse a trajectory).
|
|
100
|
+
dataset: Rearrangement strategy for the dataset.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
sample: RearrangeType = RearrangeType.NONE
|
|
104
|
+
episode: RearrangeType = RearrangeType.NONE
|
|
105
|
+
dataset: RearrangeType = RearrangeType.NONE
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def rearrange(
|
|
109
|
+
data: List[Any],
|
|
110
|
+
strategy: RearrangeType,
|
|
111
|
+
random_generator: Optional[random.Random] = None,
|
|
112
|
+
) -> None:
|
|
113
|
+
"""
|
|
114
|
+
Rearrange the data based on the specified strategy and random generator.
|
|
115
|
+
Args:
|
|
116
|
+
data (List[Any]): The data to rearrange.
|
|
117
|
+
strategy (RearrangeType): The rearrangement strategy to apply.
|
|
118
|
+
random_generator (Optional[random.Random]): Optional random generator for shuffling.
|
|
119
|
+
Raises:
|
|
120
|
+
ValueError: If an unsupported rearrangement strategy is provided.
|
|
121
|
+
Description:
|
|
122
|
+
- "sort": Sort the data in ascending order.
|
|
123
|
+
- "shuffle": Shuffle the data randomly using the provided random generator.
|
|
124
|
+
- "none": No rearrangement is applied.
|
|
125
|
+
"""
|
|
126
|
+
if strategy == RearrangeType.SORT:
|
|
127
|
+
data.sort()
|
|
128
|
+
elif strategy == RearrangeType.SHUFFLE:
|
|
129
|
+
if random_generator is None:
|
|
130
|
+
random.shuffle(data)
|
|
131
|
+
else:
|
|
132
|
+
random_generator.shuffle(data)
|
|
133
|
+
elif strategy != RearrangeType.NONE:
|
|
134
|
+
raise ValueError(f"Unsupported rearrangement strategy: {strategy}")
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class IterableDatasetConfig(BaseModel):
|
|
138
|
+
"""Generic iterable Dataset configuration.
|
|
139
|
+
Contains data root directory, random seed, multi-process configuration, etc.
|
|
140
|
+
Subclasses can extend this configuration class to add specific parameters.
|
|
141
|
+
Args:
|
|
142
|
+
data_root (str, List[str]): Raw data root directory/file paths
|
|
143
|
+
shuffle_buffer_size (NonNegativeInt): Buffer size for streaming shuffle
|
|
144
|
+
seed (Optional[int]): Random seed; None means not fixed
|
|
145
|
+
world_size (int): Total number of processes (for distributed training)
|
|
146
|
+
rank (int): Current process rank
|
|
147
|
+
resume_from_sample (int): Resume from the Nth sample
|
|
148
|
+
transform (Optional[Callable[[Any], Any]]): Sample-level transform function
|
|
149
|
+
filter_fn (Optional[Callable[[Any], bool]]): Filter function
|
|
150
|
+
slices (DataSlicesConfig): Slicing configuration for samples, episodes, and datasets
|
|
151
|
+
rearrange (Literal["none", "sort", "shuffle"]): Rearrangement strategy for episodes.
|
|
152
|
+
Each dataset is processed separately.
|
|
153
|
+
Description:
|
|
154
|
+
- `data_root` can be file path, URL or other data source prefix
|
|
155
|
+
- `shuffle_buffer_size` of 0 means no shuffle
|
|
156
|
+
- `seed` controls randomness, None means different each run
|
|
157
|
+
- `world_size` and `rank` for distributed training, ensuring each sample is processed only once
|
|
158
|
+
- `resume_from_sample` for checkpoint resumption, starting from specified sample
|
|
159
|
+
- `transform` and `filter_fn` for sample-level transformation and filtering
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
data_root: Union[str, List[str]]
|
|
163
|
+
shuffle_buffer_size: NonNegativeInt = 0
|
|
164
|
+
seed: Optional[int] = None
|
|
165
|
+
world_size: NonNegativeInt = 1
|
|
166
|
+
rank: NonNegativeInt = 0
|
|
167
|
+
resume_from_sample: NonNegativeInt = 0
|
|
168
|
+
transform: Optional[Callable[[Any], Any]] = None
|
|
169
|
+
filter_fn: Optional[Callable[[Any], bool]] = None
|
|
170
|
+
slices: DataSlicesConfig = DataSlicesConfig()
|
|
171
|
+
rearrange: DataRearrangeConfig = DataRearrangeConfig()
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class IterableDatasetABC(IterableDataset, ABC):
|
|
175
|
+
"""
|
|
176
|
+
Generic iterable dataset template.
|
|
177
|
+
Subclasses only need to implement `_read_stream()` to generate samples.
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
def __init__(self, config: IterableDatasetConfig) -> None:
|
|
181
|
+
super().__init__()
|
|
182
|
+
self.cfg = config
|
|
183
|
+
self._rng = random.Random(self.cfg.seed)
|
|
184
|
+
|
|
185
|
+
def load(self):
|
|
186
|
+
"""
|
|
187
|
+
Load the dataset into memory or prepare it for streaming.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
@abstractmethod
|
|
191
|
+
def _read_stream(self) -> Iterable[Any]:
|
|
192
|
+
"""
|
|
193
|
+
Returns an **iterable object**, each element is a sample.
|
|
194
|
+
Subclasses read files, databases, network streams, etc. based on data_root.
|
|
195
|
+
"""
|
|
196
|
+
raise NotImplementedError
|
|
197
|
+
|
|
198
|
+
def __iter__(self) -> Iterator[Any]:
|
|
199
|
+
# -> Generator[Any, None, None] only for >py39
|
|
200
|
+
# TODO: really consider how to handle multi-process/multi-node sharding
|
|
201
|
+
# 1. Get the original stream
|
|
202
|
+
stream = self._read_stream()
|
|
203
|
+
|
|
204
|
+
# 2. Multi-process/multi-node sharding
|
|
205
|
+
stream = self._shard_stream(stream)
|
|
206
|
+
|
|
207
|
+
# 3. Skip resumed samples
|
|
208
|
+
stream = self._skip_samples(stream)
|
|
209
|
+
|
|
210
|
+
# 4. Filter
|
|
211
|
+
if self.cfg.filter_fn is not None:
|
|
212
|
+
stream = filter(self.cfg.filter_fn, stream)
|
|
213
|
+
|
|
214
|
+
# 5. Transform
|
|
215
|
+
if self.cfg.transform is not None:
|
|
216
|
+
stream = map(self.cfg.transform, stream)
|
|
217
|
+
|
|
218
|
+
# 6. Shuffle (streaming)
|
|
219
|
+
if self.cfg.shuffle_buffer_size > 0:
|
|
220
|
+
stream = self._shuffle_stream(stream)
|
|
221
|
+
|
|
222
|
+
yield from stream
|
|
223
|
+
|
|
224
|
+
def _shard_stream(self, stream: Iterable[Any]) -> Generator[Any, None, None]:
|
|
225
|
+
"""
|
|
226
|
+
Shard the data stream based on worker and distributed rank, ensuring each sample is processed only once.
|
|
227
|
+
"""
|
|
228
|
+
worker_info = get_worker_info()
|
|
229
|
+
# Total parallelism = number of nodes * processes per node * workers per process
|
|
230
|
+
total_parts = self.cfg.world_size
|
|
231
|
+
part_id = self.cfg.rank
|
|
232
|
+
|
|
233
|
+
if worker_info is not None:
|
|
234
|
+
total_parts *= worker_info.num_workers
|
|
235
|
+
part_id = part_id * worker_info.num_workers + worker_info.id
|
|
236
|
+
|
|
237
|
+
for idx, sample in enumerate(stream):
|
|
238
|
+
if idx % total_parts == part_id:
|
|
239
|
+
yield sample
|
|
240
|
+
|
|
241
|
+
def _skip_samples(self, stream: Iterable[Any]) -> Generator[Any, None, None]:
|
|
242
|
+
"""
|
|
243
|
+
Skip samples before resume_from_sample.
|
|
244
|
+
"""
|
|
245
|
+
if self.cfg.resume_from_sample <= 0:
|
|
246
|
+
yield from stream
|
|
247
|
+
return
|
|
248
|
+
for idx, sample in enumerate(stream, start=1):
|
|
249
|
+
if idx > self.cfg.resume_from_sample:
|
|
250
|
+
yield sample
|
|
251
|
+
|
|
252
|
+
def _shuffle_stream(self, stream: Iterable[Any]) -> Generator[Any, None, None]:
|
|
253
|
+
"""
|
|
254
|
+
Use fixed-size buffer for streaming shuffle.
|
|
255
|
+
"""
|
|
256
|
+
buf: List[Any] = []
|
|
257
|
+
for sample in stream:
|
|
258
|
+
buf.append(sample)
|
|
259
|
+
if len(buf) >= self.cfg.shuffle_buffer_size:
|
|
260
|
+
idx = self._rng.randrange(len(buf))
|
|
261
|
+
yield buf.pop(idx)
|
|
262
|
+
# Randomly output remaining samples
|
|
263
|
+
self._rng.shuffle(buf)
|
|
264
|
+
yield from buf
|
|
265
|
+
|
|
266
|
+
def get_logger(self):
|
|
267
|
+
return getLogger(self.__class__.__name__)
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any, Generator, Iterable, Iterator, List, Optional, Dict
|
|
3
|
+
from pydantic import field_validator
|
|
4
|
+
from functools import cache
|
|
5
|
+
import numpy as np
|
|
6
|
+
from more_itertools import peekable, nth
|
|
7
|
+
from mcap_data_loader.utils.mcap_utils import McapFlatbufferReader
|
|
8
|
+
from mcap_data_loader.utils.basic import (
|
|
9
|
+
get_items_by_ext,
|
|
10
|
+
zip,
|
|
11
|
+
# DictableSlicesType,
|
|
12
|
+
# DictableIndexesType,
|
|
13
|
+
)
|
|
14
|
+
from mcap_data_loader.datasets.dataset import (
|
|
15
|
+
IterableDatasetABC,
|
|
16
|
+
IterableDatasetConfig,
|
|
17
|
+
DataRearrangeConfig,
|
|
18
|
+
RearrangeType,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class McapDatasetConfig(IterableDatasetConfig):
|
|
23
|
+
"""
|
|
24
|
+
MCAP dataset configuration.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
keys: List[str] = []
|
|
28
|
+
topics: Optional[List[str]] = []
|
|
29
|
+
attachments: Optional[List[str]] = []
|
|
30
|
+
cache_items: bool = True
|
|
31
|
+
cache_iters: bool = False
|
|
32
|
+
|
|
33
|
+
@field_validator("data_root")
|
|
34
|
+
def validate_data_root(cls, v) -> str:
|
|
35
|
+
if not isinstance(v, str):
|
|
36
|
+
if len(v) == 1:
|
|
37
|
+
v = v[0]
|
|
38
|
+
else:
|
|
39
|
+
raise ValueError(f"data_root {v} must be a string path to a MCAP file")
|
|
40
|
+
if not v.endswith(".mcap"):
|
|
41
|
+
raise ValueError(f"data_root {v} must be a `.mcap` file")
|
|
42
|
+
return v
|
|
43
|
+
|
|
44
|
+
def model_post_init(self, context):
|
|
45
|
+
assert not self.slices.sample, "not implemented yet"
|
|
46
|
+
assert not self.slices.episode, "not implemented yet"
|
|
47
|
+
assert isinstance(self.slices.dataset, dict), "dataset slices must be a dict"
|
|
48
|
+
assert not self.cache_iters, "iters now are not cached"
|
|
49
|
+
assert self.rearrange.sample == RearrangeType.NONE, (
|
|
50
|
+
"sample rearrangement is not supported"
|
|
51
|
+
)
|
|
52
|
+
assert self.rearrange.episode in {RearrangeType.NONE, RearrangeType.REVERSE}, (
|
|
53
|
+
"episode rearrangement must be NONE or REVERSE"
|
|
54
|
+
)
|
|
55
|
+
assert self.rearrange.dataset == RearrangeType.NONE, (
|
|
56
|
+
"dataset rearrangement is not supported"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class McapFlatbufferSampleDataset(IterableDatasetABC):
|
|
61
|
+
"""
|
|
62
|
+
Iterable dataset for reading a MCAP file.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
cfg: McapDatasetConfig
|
|
66
|
+
|
|
67
|
+
def load(self):
|
|
68
|
+
self._init_reader()
|
|
69
|
+
if self.cfg.cache_items:
|
|
70
|
+
self._indexed_stream = peekable(self._flatten_iter())
|
|
71
|
+
|
|
72
|
+
def _flatten_iter(self):
|
|
73
|
+
"""Flatten"""
|
|
74
|
+
return self
|
|
75
|
+
|
|
76
|
+
def _init_reader(self):
|
|
77
|
+
"""
|
|
78
|
+
Initialize the MCAP reader.
|
|
79
|
+
This is called in the constructor to set up the reader.
|
|
80
|
+
"""
|
|
81
|
+
self.reader = McapFlatbufferReader(open(self.cfg.data_root, "rb"))
|
|
82
|
+
|
|
83
|
+
def _read_stream(self) -> Generator[Dict[str, Any], None, None]:
|
|
84
|
+
"""
|
|
85
|
+
Read MCAP file and return message stream.
|
|
86
|
+
"""
|
|
87
|
+
return self._iter_a_file_samples(self.reader)
|
|
88
|
+
|
|
89
|
+
def _iter_a_file_samples(
|
|
90
|
+
self, reader: McapFlatbufferReader
|
|
91
|
+
) -> Generator[Dict[str, Any], None, None]:
|
|
92
|
+
yield from reader.iter_samples(
|
|
93
|
+
keys=self.cfg.keys,
|
|
94
|
+
topics=self.cfg.topics,
|
|
95
|
+
attachments=self.cfg.attachments,
|
|
96
|
+
reverse=self.cfg.rearrange.episode == RearrangeType.REVERSE,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def __del__(self):
|
|
100
|
+
if hasattr(self, "reader"):
|
|
101
|
+
if self.reader:
|
|
102
|
+
self.reader.file_io.close()
|
|
103
|
+
|
|
104
|
+
def __len__(self) -> int:
|
|
105
|
+
"""Get the total number of messages in the MCAP file."""
|
|
106
|
+
return len(self.reader)
|
|
107
|
+
|
|
108
|
+
def __getitem__(self, index: int):
|
|
109
|
+
"""
|
|
110
|
+
Get a specific sample by index.
|
|
111
|
+
This is not efficient for large datasets, use with caution.
|
|
112
|
+
"""
|
|
113
|
+
# TODO: should support 2-dim indexing, e.g.
|
|
114
|
+
# dataset[episode_index][sample_index] or
|
|
115
|
+
# dataset[episode_index, sample_index]?
|
|
116
|
+
# This may be configurable in the future.
|
|
117
|
+
if index < 0:
|
|
118
|
+
index += len(self)
|
|
119
|
+
if self.cfg.cache_items:
|
|
120
|
+
return self._indexed_stream[index]
|
|
121
|
+
else:
|
|
122
|
+
return nth(self._flatten_iter(), index)
|
|
123
|
+
|
|
124
|
+
def __iter__(self) -> Iterator[Dict[str, np.ndarray]]:
|
|
125
|
+
return super().__iter__()
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class McapFlatbufferEpisodeDatasetConfig(McapDatasetConfig):
|
|
129
|
+
"""
|
|
130
|
+
Episodic dataset configuration for reading MCAP files.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
@field_validator("data_root")
|
|
134
|
+
def validate_data_root(cls, v) -> List[str]:
|
|
135
|
+
if isinstance(v, str):
|
|
136
|
+
v = [v]
|
|
137
|
+
for directory in v:
|
|
138
|
+
if not os.path.isdir(directory):
|
|
139
|
+
raise ValueError(
|
|
140
|
+
f"data_root {os.path.abspath(directory)} must be a directory containing MCAP files"
|
|
141
|
+
)
|
|
142
|
+
return v
|
|
143
|
+
|
|
144
|
+
def model_post_init(self, context):
|
|
145
|
+
assert not self.slices.sample, "not implemented yet"
|
|
146
|
+
assert not self.slices.episode, "not implemented yet"
|
|
147
|
+
assert isinstance(self.slices.dataset, dict), "dataset slices must be a dict"
|
|
148
|
+
assert not self.cache_iters, "iters now are not cached"
|
|
149
|
+
assert self.rearrange.sample == RearrangeType.NONE, (
|
|
150
|
+
"sample rearrangement is not supported"
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class McapFlatbufferEpisodeDataset(McapFlatbufferSampleDataset):
|
|
155
|
+
"""
|
|
156
|
+
Episodic dataset for reading MCAP files.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
cfg: McapFlatbufferEpisodeDatasetConfig
|
|
160
|
+
|
|
161
|
+
def __init__(self, config):
|
|
162
|
+
super().__init__(config)
|
|
163
|
+
self.reader: Dict[str, McapFlatbufferReader] = {}
|
|
164
|
+
dataset_files = {}
|
|
165
|
+
DataRearrangeConfig.rearrange(
|
|
166
|
+
self.cfg.data_root, self.cfg.rearrange.dataset, self._rng
|
|
167
|
+
)
|
|
168
|
+
for root in self.cfg.data_root:
|
|
169
|
+
files = get_items_by_ext(root, ".mcap")
|
|
170
|
+
DataRearrangeConfig.rearrange(files, self.cfg.rearrange.episode, self._rng)
|
|
171
|
+
indexes = self.cfg.slices.dataset_indexes.get(root, None)
|
|
172
|
+
if indexes:
|
|
173
|
+
# slice the files by indexes
|
|
174
|
+
files = np.array(files)[indexes].tolist()
|
|
175
|
+
dataset_files[root] = files
|
|
176
|
+
self._dataset_files = dataset_files
|
|
177
|
+
|
|
178
|
+
def _flatten_iter(self):
|
|
179
|
+
for episode in self:
|
|
180
|
+
for sample in episode:
|
|
181
|
+
yield sample
|
|
182
|
+
|
|
183
|
+
def _init_reader(self):
|
|
184
|
+
for dataset, file_paths in self._dataset_files.items():
|
|
185
|
+
for file_path in file_paths:
|
|
186
|
+
full_path = os.path.join(dataset, file_path)
|
|
187
|
+
assert full_path not in self.reader, f"Duplicate file path: {full_path}"
|
|
188
|
+
self.reader[full_path] = McapFlatbufferReader(open(full_path, "rb"))
|
|
189
|
+
|
|
190
|
+
def _read_stream(self) -> Generator[Iterable[dict[str, Any]], None, None]:
|
|
191
|
+
"""
|
|
192
|
+
Read MCAP files and return episodic message stream.
|
|
193
|
+
Each episode corresponds to one MCAP file.
|
|
194
|
+
"""
|
|
195
|
+
for file_path, reader in self.reader.items():
|
|
196
|
+
self._current_file = file_path
|
|
197
|
+
yield self._iter_a_file_samples(reader)
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def current_file(self) -> str:
|
|
201
|
+
return self._current_file
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def all_files(self) -> Dict[str, List[str]]:
|
|
205
|
+
return self._dataset_files
|
|
206
|
+
|
|
207
|
+
def __del__(self):
|
|
208
|
+
for reader in self.reader.values():
|
|
209
|
+
reader.file_io.close()
|
|
210
|
+
|
|
211
|
+
@cache
|
|
212
|
+
def __len__(self) -> int:
|
|
213
|
+
"""Get the total number of messages in all MCAP files."""
|
|
214
|
+
total_count = 0
|
|
215
|
+
for reader in self.reader.values():
|
|
216
|
+
total_count += len(reader)
|
|
217
|
+
return total_count
|
|
218
|
+
|
|
219
|
+
def __iter__(self) -> Iterator[Iterator[Dict[str, np.ndarray]]]:
|
|
220
|
+
return super().__iter__()
|
|
221
|
+
|
|
222
|
+
def __getitem__(self, index) -> Dict[str, np.ndarray]:
|
|
223
|
+
return super().__getitem__(index)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
if __name__ == "__main__":
|
|
227
|
+
from pprint import pprint
|
|
228
|
+
import time
|
|
229
|
+
from more_itertools import batched
|
|
230
|
+
import logging
|
|
231
|
+
from mcap_data_loader.datasets.dataset import DataSlicesConfig
|
|
232
|
+
|
|
233
|
+
logging.basicConfig(level=logging.INFO)
|
|
234
|
+
|
|
235
|
+
root_dir = "data/arm1-001"
|
|
236
|
+
# data_root = "0.mcap"
|
|
237
|
+
data_root = root_dir
|
|
238
|
+
keys = [
|
|
239
|
+
"/left/follow/arm/joint_state/position",
|
|
240
|
+
"/left/follow/eef/joint_state/position",
|
|
241
|
+
"/left/lead/arm/joint_state/position",
|
|
242
|
+
"/left/lead/eef/joint_state/position",
|
|
243
|
+
"/env_camera/env/color/image_raw",
|
|
244
|
+
]
|
|
245
|
+
# keys = (
|
|
246
|
+
# [
|
|
247
|
+
# # "/follow/arm/joint_state/position",
|
|
248
|
+
# # "/follow/eef/joint_state/position",
|
|
249
|
+
# ]
|
|
250
|
+
# + [
|
|
251
|
+
# "/env_camera/color/image_raw",
|
|
252
|
+
# # "/follow_camera/color/image_raw",
|
|
253
|
+
# # discoverse camera keys
|
|
254
|
+
# # "/cam_0/color/image_raw",
|
|
255
|
+
# # "/cam_1/color/image_raw",
|
|
256
|
+
# "log_stamps",
|
|
257
|
+
# ]
|
|
258
|
+
# )
|
|
259
|
+
|
|
260
|
+
# dataset = McapFlatbufferDataset(
|
|
261
|
+
# McapFlatbufferDatasetConfig(
|
|
262
|
+
# data_root=data_root,
|
|
263
|
+
# keys=keys,
|
|
264
|
+
# )
|
|
265
|
+
# )
|
|
266
|
+
# start = time.perf_counter()
|
|
267
|
+
# for sample in dataset:
|
|
268
|
+
# print(time.perf_counter() - start)
|
|
269
|
+
# # pprint(sample)
|
|
270
|
+
# start = time.perf_counter()
|
|
271
|
+
# # break # Only print the first sample
|
|
272
|
+
|
|
273
|
+
dataset = McapFlatbufferEpisodeDataset(
|
|
274
|
+
McapFlatbufferEpisodeDatasetConfig(
|
|
275
|
+
data_root=data_root,
|
|
276
|
+
keys=keys,
|
|
277
|
+
slices=DataSlicesConfig(dataset={root_dir: (0, 1)}),
|
|
278
|
+
rearrange=DataRearrangeConfig(
|
|
279
|
+
episode="sort",
|
|
280
|
+
),
|
|
281
|
+
cache_items=True,
|
|
282
|
+
)
|
|
283
|
+
)
|
|
284
|
+
dataset.load()
|
|
285
|
+
print(dataset.all_files)
|
|
286
|
+
print(f"Dataset length: {len(dataset)}")
|
|
287
|
+
pprint(dataset[0].keys())
|
|
288
|
+
for v1, v2 in zip(dataset[0].values(), dataset[0].values()):
|
|
289
|
+
assert np.array_equal(v1, v2), f"{v1=} != {v2=}"
|
|
290
|
+
for v1, v2 in zip(dataset[0].values(), dataset[1].values()):
|
|
291
|
+
if not np.array_equal(v1, v2):
|
|
292
|
+
print("OK: Samples are not equal")
|
|
293
|
+
break
|
|
294
|
+
else:
|
|
295
|
+
raise ValueError("Samples are equal")
|
|
296
|
+
|
|
297
|
+
for file_path, reader in dataset.reader.items():
|
|
298
|
+
print(f"File: {file_path}, Messages: {len(reader)}")
|
|
299
|
+
start = time.perf_counter()
|
|
300
|
+
batch_size = 10
|
|
301
|
+
steps = 1
|
|
302
|
+
for episode in dataset:
|
|
303
|
+
next(episode) # Skip the first sample
|
|
304
|
+
start = time.perf_counter()
|
|
305
|
+
for step, batch in enumerate(batched(episode, batch_size, strict=True)):
|
|
306
|
+
print(f"{step=}", batch[0].keys())
|
|
307
|
+
if step + 1 >= steps:
|
|
308
|
+
break
|
|
309
|
+
else:
|
|
310
|
+
print(f"Processed {len(episode)} samples in episode {dataset.current_file}")
|
|
311
|
+
total_time = time.perf_counter() - start
|
|
312
|
+
avg_time = total_time / batch_size
|
|
313
|
+
print(f"Average time per sample: {avg_time:.5f} seconds")
|
|
314
|
+
print(f"Total time taken for {batch_size=}: {total_time:.5f} seconds")
|
|
315
|
+
break # Only process the first episode
|