mldataforge 0.1.5__tar.gz → 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mldataforge-0.1.5 → mldataforge-0.1.6}/PKG-INFO +1 -1
- {mldataforge-0.1.5 → mldataforge-0.1.6}/mldataforge/commands/join.py +1 -1
- {mldataforge-0.1.5 → mldataforge-0.1.6}/mldataforge/compression.py +5 -1
- mldataforge-0.1.6/mldataforge/mds.py +317 -0
- {mldataforge-0.1.5 → mldataforge-0.1.6}/mldataforge/snappy.py +0 -54
- {mldataforge-0.1.5 → mldataforge-0.1.6}/mldataforge/utils.py +2 -2
- {mldataforge-0.1.5 → mldataforge-0.1.6}/pyproject.toml +1 -1
- mldataforge-0.1.5/mldataforge/mds.py +0 -95
- {mldataforge-0.1.5 → mldataforge-0.1.6}/.gitignore +0 -0
- {mldataforge-0.1.5 → mldataforge-0.1.6}/LICENSE +0 -0
- {mldataforge-0.1.5 → mldataforge-0.1.6}/README.md +0 -0
- {mldataforge-0.1.5 → mldataforge-0.1.6}/mldataforge/__main__.py +0 -0
- {mldataforge-0.1.5 → mldataforge-0.1.6}/mldataforge/brotli.py +0 -0
- {mldataforge-0.1.5 → mldataforge-0.1.6}/mldataforge/commands/__init__.py +0 -0
- {mldataforge-0.1.5 → mldataforge-0.1.6}/mldataforge/commands/convert/__init__.py +0 -0
- {mldataforge-0.1.5 → mldataforge-0.1.6}/mldataforge/commands/convert/jsonl.py +0 -0
- {mldataforge-0.1.5 → mldataforge-0.1.6}/mldataforge/commands/convert/mds.py +0 -0
- {mldataforge-0.1.5 → mldataforge-0.1.6}/mldataforge/commands/convert/parquet.py +0 -0
- {mldataforge-0.1.5 → mldataforge-0.1.6}/mldataforge/commands/split.py +0 -0
- {mldataforge-0.1.5 → mldataforge-0.1.6}/mldataforge/options.py +0 -0
- {mldataforge-0.1.5 → mldataforge-0.1.6}/mldataforge/pigz.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mldataforge
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.6
|
4
4
|
Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
|
5
5
|
Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
|
6
6
|
Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
|
@@ -53,7 +53,7 @@ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes
|
|
53
53
|
compression=compression,
|
54
54
|
buf_size=buf_size,
|
55
55
|
shard_size=shard_size,
|
56
|
-
pigz=use_pigz(compression, no_pigz)
|
56
|
+
pigz=use_pigz(compression, no_pigz),
|
57
57
|
)
|
58
58
|
|
59
59
|
@join.command()
|
@@ -30,7 +30,7 @@ JSONL_COMPRESSIONS = dict(
|
|
30
30
|
)
|
31
31
|
MDS_COMPRESSIONS = dict(
|
32
32
|
default=None,
|
33
|
-
choices=["none", "brotli", "bz2", "gzip", "pigz", "snappy", "zstd"],
|
33
|
+
choices=["none", "brotli", "bz2", "gzip", "pigz", "snappy", "zstd", "sample::brotli", "sample::bz2", "sample::gzip", "sample::snappy", "sample::zstd"],
|
34
34
|
)
|
35
35
|
PARQUET_COMPRESSIONS = dict(
|
36
36
|
default="snappy",
|
@@ -55,6 +55,10 @@ def determine_compression(fmt, file_path, compression="infer", no_pigz=False):
|
|
55
55
|
return "gz"
|
56
56
|
if compression == "brotli":
|
57
57
|
return "br"
|
58
|
+
if compression == "sample::gzip":
|
59
|
+
return "gz"
|
60
|
+
if compression == "sample::brotli":
|
61
|
+
return "br"
|
58
62
|
return compression
|
59
63
|
if fmt == "parquet":
|
60
64
|
return compression
|
@@ -0,0 +1,317 @@
|
|
1
|
+
import json
|
2
|
+
import numpy as np
|
3
|
+
import os
|
4
|
+
import shutil
|
5
|
+
from streaming.base.compression import compress, decompress, get_compression_extension, is_compression
|
6
|
+
from streaming.base.format.index import get_index_basename
|
7
|
+
from streaming.base.format.mds.encodings import mds_decode, mds_encode, is_mds_encoding, get_mds_encodings, get_mds_encoded_size
|
8
|
+
from streaming.base.hashing import get_hash, is_hash
|
9
|
+
from streaming.base.util import bytes_to_int
|
10
|
+
from typing import Any, Optional, Generator, Self, Union
|
11
|
+
|
12
|
+
from .utils import open_compression
|
13
|
+
|
14
|
+
class MDSBulkReader:
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
dirnames: list[str],
|
18
|
+
split: Optional[str],
|
19
|
+
) -> None:
|
20
|
+
self.shards = []
|
21
|
+
self.samples = 0
|
22
|
+
for dirname in dirnames:
|
23
|
+
if split is not None:
|
24
|
+
dirname = os.path.join(dirname, split)
|
25
|
+
index = json.load(open(os.path.join(dirname, "index.json"), 'rt'))
|
26
|
+
for shard in index["shards"]:
|
27
|
+
basename = shard['raw_data']['basename'] if shard['zip_data'] is None else shard['zip_data']['basename']
|
28
|
+
filename = os.path.join(dirname, basename)
|
29
|
+
self.shards.append({
|
30
|
+
"filename": filename,
|
31
|
+
"compression": shard['compression'],
|
32
|
+
})
|
33
|
+
self.samples += shard['samples']
|
34
|
+
|
35
|
+
def __len__(self) -> int:
|
36
|
+
return self.samples
|
37
|
+
|
38
|
+
def __iter__(self) -> Generator[dict[str, Any], None, None]:
|
39
|
+
for shard in self.shards:
|
40
|
+
with MDSShardReader(**shard) as reader:
|
41
|
+
for sample in reader:
|
42
|
+
yield sample
|
43
|
+
|
44
|
+
class MDSShardReader:
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
filename: str,
|
48
|
+
compression: Optional[str],
|
49
|
+
) -> None:
|
50
|
+
self.sample_compression = None
|
51
|
+
if compression is not None and compression.startswith("sample::"):
|
52
|
+
compression, self.sample_compression = None, compression.removeprefix("sample::")
|
53
|
+
self.fp = open_compression(filename, "rb", compression=compression)
|
54
|
+
self.samples = np.frombuffer(self.fp.read(4), np.uint32)[0]
|
55
|
+
self.index = np.frombuffer(self.fp.read((1+self.samples)*4), np.uint32)
|
56
|
+
info = json.loads(self.fp.read(self.index[0]-self.fp.tell()))
|
57
|
+
self.column_encodings = info["column_encodings"]
|
58
|
+
self.column_names = info["column_names"]
|
59
|
+
self.column_sizes = info["column_sizes"]
|
60
|
+
assert self.fp.tell() == self.index[0]
|
61
|
+
|
62
|
+
def decode_sample(self, data: bytes) -> dict[str, Any]:
|
63
|
+
sizes = []
|
64
|
+
idx = 0
|
65
|
+
for key, size in zip(self.column_names, self.column_sizes):
|
66
|
+
if size:
|
67
|
+
sizes.append(size)
|
68
|
+
else:
|
69
|
+
size, = np.frombuffer(data[idx:idx + 4], np.uint32)
|
70
|
+
sizes.append(size)
|
71
|
+
idx += 4
|
72
|
+
sample = {}
|
73
|
+
for key, encoding, size in zip(self.column_names, self.column_encodings, sizes):
|
74
|
+
value = data[idx:idx + size]
|
75
|
+
sample[key] = mds_decode(encoding, value)
|
76
|
+
idx += size
|
77
|
+
return sample
|
78
|
+
|
79
|
+
def get_sample_data(self, idx: int) -> bytes:
|
80
|
+
begin, end = self.index[idx:idx+2]
|
81
|
+
assert self.fp.tell() == begin
|
82
|
+
data = self.fp.read(end - begin)
|
83
|
+
assert self.fp.tell() == end
|
84
|
+
assert data
|
85
|
+
return data
|
86
|
+
|
87
|
+
def get_item(self, idx: int) -> dict[str, Any]:
|
88
|
+
data = self.get_sample_data(idx)
|
89
|
+
if self.sample_compression is not None:
|
90
|
+
data = decompress(self.sample_compression, data)
|
91
|
+
return self.decode_sample(data)
|
92
|
+
|
93
|
+
def __iter__(self) -> Generator[dict[str, Any], None, None]:
|
94
|
+
for i in range(self.samples):
|
95
|
+
yield self.get_item(i)
|
96
|
+
|
97
|
+
def __enter__(self) -> "MDSShardReader":
|
98
|
+
return self
|
99
|
+
|
100
|
+
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
101
|
+
self.fp.close()
|
102
|
+
|
103
|
+
|
104
|
+
class MDSWriter:
|
105
|
+
|
106
|
+
format = 'mds'
|
107
|
+
extra_bytes_per_sample = 4
|
108
|
+
|
109
|
+
def __init__(self,
|
110
|
+
*,
|
111
|
+
columns: dict[str, str],
|
112
|
+
out: Union[str, tuple[str, str]],
|
113
|
+
compression: Optional[str] = None,
|
114
|
+
hashes: Optional[list[str]] = None,
|
115
|
+
size_limit: Optional[Union[int, str]] = 1 << 26,
|
116
|
+
**kwargs: Any) -> None:
|
117
|
+
compression = compression or None
|
118
|
+
sample_compression = None
|
119
|
+
if compression is not None and compression.startswith("sample::"):
|
120
|
+
compression, sample_compression = None, compression.removeprefix("sample::")
|
121
|
+
if compression:
|
122
|
+
if not is_compression(compression):
|
123
|
+
raise ValueError(f'Invalid compression: {compression}.')
|
124
|
+
if sample_compression:
|
125
|
+
if not is_compression(sample_compression):
|
126
|
+
raise ValueError(f'Invalid sample compression: {sample_compression}.')
|
127
|
+
hashes = hashes or []
|
128
|
+
if list(hashes) != sorted(hashes):
|
129
|
+
raise ValueError('Hashes must be unique and in sorted order.')
|
130
|
+
for algo in hashes:
|
131
|
+
if not is_hash(algo):
|
132
|
+
raise ValueError(f'Invalid hash: {algo}.')
|
133
|
+
|
134
|
+
size_limit_value = None
|
135
|
+
if size_limit:
|
136
|
+
size_limit_value = bytes_to_int(size_limit)
|
137
|
+
if size_limit_value < 0:
|
138
|
+
raise ValueError(f'`size_limit` must be greater than zero, instead, ' +
|
139
|
+
f'found as {size_limit_value}.')
|
140
|
+
if size_limit_value >= 2**32:
|
141
|
+
raise ValueError(f'`size_limit` must be less than 2**32, instead, ' +
|
142
|
+
f'found as {size_limit_value}. This is because sample ' +
|
143
|
+
f'byte offsets are stored with uint32.')
|
144
|
+
|
145
|
+
# Validate keyword arguments
|
146
|
+
invalid_kwargs = [
|
147
|
+
arg for arg in kwargs.keys()
|
148
|
+
if arg not in ('progress_bar', 'exist_ok')
|
149
|
+
]
|
150
|
+
if invalid_kwargs:
|
151
|
+
raise ValueError(f'Invalid Writer argument(s): {invalid_kwargs} ')
|
152
|
+
|
153
|
+
self.compression = compression
|
154
|
+
self.sample_compression = sample_compression
|
155
|
+
self.hashes = hashes
|
156
|
+
self.size_limit = size_limit_value
|
157
|
+
self.new_samples: list[bytes]
|
158
|
+
self.new_shard_size: int
|
159
|
+
|
160
|
+
self.shards = []
|
161
|
+
|
162
|
+
# Remove local directory if requested prior to creating writer
|
163
|
+
self.local = os.path.expanduser(out)
|
164
|
+
if os.path.exists(self.local) and len(os.listdir(self.local)) != 0:
|
165
|
+
if kwargs.get('exist_ok', False):
|
166
|
+
raise FileExistsError(f'Directory is not empty: {self.local}')
|
167
|
+
shutil.rmtree(self.local)
|
168
|
+
os.makedirs(self.local, exist_ok=True)
|
169
|
+
|
170
|
+
self.columns = columns
|
171
|
+
self.column_names = []
|
172
|
+
self.column_encodings = []
|
173
|
+
self.column_sizes = []
|
174
|
+
for name in sorted(columns):
|
175
|
+
encoding = columns[name]
|
176
|
+
if not is_mds_encoding(encoding):
|
177
|
+
raise TypeError(f'MDSWriter passed column `{name}` with encoding `{encoding}` ' +
|
178
|
+
f'is unsupported. Supported encodings are {get_mds_encodings()}')
|
179
|
+
size = get_mds_encoded_size(encoding)
|
180
|
+
self.column_names.append(name)
|
181
|
+
self.column_encodings.append(encoding)
|
182
|
+
self.column_sizes.append(size)
|
183
|
+
|
184
|
+
obj = self.get_config()
|
185
|
+
text = json.dumps(obj, sort_keys=True)
|
186
|
+
self.config_data = text.encode('utf-8')
|
187
|
+
self.extra_bytes_per_shard = 4 + 4 + len(self.config_data)
|
188
|
+
self._reset_cache()
|
189
|
+
|
190
|
+
def encode_sample(self, sample: dict[str, Any]) -> bytes:
|
191
|
+
sizes = []
|
192
|
+
data = []
|
193
|
+
for key, encoding, size in zip(self.column_names, self.column_encodings,
|
194
|
+
self.column_sizes):
|
195
|
+
value = sample[key]
|
196
|
+
datum = mds_encode(encoding, value)
|
197
|
+
if size is None:
|
198
|
+
size = len(datum)
|
199
|
+
sizes.append(size)
|
200
|
+
else:
|
201
|
+
if size != len(datum):
|
202
|
+
raise KeyError(f'Unexpected data size; was this data typed with the correct ' +
|
203
|
+
f'encoding ({encoding})?')
|
204
|
+
data.append(datum)
|
205
|
+
head = np.array(sizes, np.uint32).tobytes()
|
206
|
+
body = b''.join(data)
|
207
|
+
sample_data = head + body
|
208
|
+
if self.sample_compression:
|
209
|
+
sample_data = compress(self.sample_compression, sample_data)
|
210
|
+
return sample_data
|
211
|
+
|
212
|
+
def encode_joint_shard(self) -> bytes:
|
213
|
+
num_samples = np.uint32(len(self.new_samples))
|
214
|
+
sizes = list(map(len, self.new_samples))
|
215
|
+
offsets = np.array([0] + sizes).cumsum().astype(np.uint32)
|
216
|
+
offsets += len(num_samples.tobytes()) + len(offsets.tobytes()) + len(self.config_data)
|
217
|
+
sample_data = b''.join(self.new_samples)
|
218
|
+
return num_samples.tobytes() + offsets.tobytes() + self.config_data + sample_data
|
219
|
+
|
220
|
+
def flush_shard(self) -> None:
|
221
|
+
raw_data_basename, zip_data_basename = self._name_next_shard()
|
222
|
+
raw_data = self.encode_joint_shard()
|
223
|
+
raw_data_info, zip_data_info = self._process_file(raw_data, raw_data_basename,
|
224
|
+
zip_data_basename)
|
225
|
+
obj = {
|
226
|
+
'samples': len(self.new_samples),
|
227
|
+
'raw_data': raw_data_info,
|
228
|
+
'zip_data': zip_data_info
|
229
|
+
}
|
230
|
+
obj.update(self.get_config())
|
231
|
+
self.shards.append(obj)
|
232
|
+
|
233
|
+
def _reset_cache(self) -> None:
|
234
|
+
self.new_samples = []
|
235
|
+
self.new_shard_size = self.extra_bytes_per_shard
|
236
|
+
|
237
|
+
def _name_next_shard(self, extension: Optional[str] = None) -> tuple[str, Optional[str]]:
|
238
|
+
shard = len(self.shards)
|
239
|
+
parts = ['shard', f'{shard:05}', self.format]
|
240
|
+
if extension:
|
241
|
+
parts.append(extension)
|
242
|
+
raw_basename = '.'.join(parts)
|
243
|
+
if self.compression:
|
244
|
+
ext = get_compression_extension(self.compression)
|
245
|
+
parts.append(ext)
|
246
|
+
zip_basename = '.'.join(parts)
|
247
|
+
else:
|
248
|
+
zip_basename = None
|
249
|
+
return raw_basename, zip_basename
|
250
|
+
|
251
|
+
def _hash(self, data: bytes, basename: str) -> dict[str, Any]:
|
252
|
+
hashes = {}
|
253
|
+
for algo in self.hashes:
|
254
|
+
hashes[algo] = get_hash(algo, data)
|
255
|
+
return {'basename': basename, 'bytes': len(data), 'hashes': hashes}
|
256
|
+
|
257
|
+
def _process_file(self, raw_data: bytes, raw_basename: str,
|
258
|
+
zip_basename: Optional[str]) -> tuple[dict, Optional[dict]]:
|
259
|
+
raw_info = self._hash(raw_data, raw_basename)
|
260
|
+
if zip_basename:
|
261
|
+
zip_data = compress(self.compression, raw_data)
|
262
|
+
zip_info = self._hash(zip_data, zip_basename)
|
263
|
+
data = zip_data
|
264
|
+
basename = zip_basename
|
265
|
+
else:
|
266
|
+
zip_info = None
|
267
|
+
data = raw_data
|
268
|
+
basename = raw_basename
|
269
|
+
filename = os.path.join(self.local, basename)
|
270
|
+
with open(filename, 'wb') as out:
|
271
|
+
out.write(data)
|
272
|
+
return raw_info, zip_info
|
273
|
+
|
274
|
+
def get_config(self) -> dict[str, Any]:
|
275
|
+
return {
|
276
|
+
'version': 2,
|
277
|
+
'format': self.format,
|
278
|
+
'compression': self.compression if self.sample_compression is None else f"sample::{self.sample_compression}",
|
279
|
+
'hashes': self.hashes,
|
280
|
+
'size_limit': self.size_limit,
|
281
|
+
'column_names': self.column_names,
|
282
|
+
'column_encodings': self.column_encodings,
|
283
|
+
'column_sizes': self.column_sizes,
|
284
|
+
}
|
285
|
+
|
286
|
+
def write(self, sample: dict[str, Any]) -> None:
|
287
|
+
new_sample = self.encode_sample(sample)
|
288
|
+
new_sample_size = len(new_sample) + self.extra_bytes_per_sample
|
289
|
+
if self.size_limit and self.size_limit < self.new_shard_size + new_sample_size:
|
290
|
+
self.flush_shard()
|
291
|
+
self._reset_cache()
|
292
|
+
self.new_samples.append(new_sample)
|
293
|
+
self.new_shard_size += new_sample_size
|
294
|
+
|
295
|
+
def _write_index(self) -> None:
|
296
|
+
if self.new_samples:
|
297
|
+
raise RuntimeError('Internal error: not all samples have been written.')
|
298
|
+
basename = get_index_basename()
|
299
|
+
filename = os.path.join(self.local, basename)
|
300
|
+
obj = {
|
301
|
+
'version': 2,
|
302
|
+
'shards': self.shards,
|
303
|
+
}
|
304
|
+
with open(filename, 'w') as out:
|
305
|
+
json.dump(obj, out, sort_keys=True)
|
306
|
+
|
307
|
+
def finish(self) -> None:
|
308
|
+
if self.new_samples:
|
309
|
+
self.flush_shard()
|
310
|
+
self._reset_cache()
|
311
|
+
self._write_index()
|
312
|
+
|
313
|
+
def __enter__(self) -> Self:
|
314
|
+
return self
|
315
|
+
|
316
|
+
def __exit__(self, exc_type, exc, traceback):
|
317
|
+
self.finish()
|
@@ -36,60 +36,6 @@ class _SnappyWriteWrapper(io.RawIOBase):
|
|
36
36
|
def writable(self):
|
37
37
|
return True
|
38
38
|
|
39
|
-
|
40
|
-
# class _SnappyReadWrapper(io.RawIOBase):
|
41
|
-
# def __init__(self, fileobj):
|
42
|
-
# self.fileobj = fileobj
|
43
|
-
# self.buffer = io.BytesIO()
|
44
|
-
# self.eof = False
|
45
|
-
|
46
|
-
# def _fill_buffer_if_needed(self, min_bytes):
|
47
|
-
# self.buffer.seek(0, io.SEEK_END)
|
48
|
-
# while not self.eof and self.buffer.tell() < min_bytes:
|
49
|
-
# length_bytes = self.fileobj.read(4)
|
50
|
-
# if not length_bytes:
|
51
|
-
# self.eof = True
|
52
|
-
# break
|
53
|
-
# if len(length_bytes) < 4:
|
54
|
-
# self.eof = True # mark as EOF even if last chunk is malformed
|
55
|
-
# break
|
56
|
-
|
57
|
-
# try:
|
58
|
-
# length = struct.unpack(">I", length_bytes)[0]
|
59
|
-
# compressed = self.fileobj.read(length)
|
60
|
-
# if len(compressed) < length:
|
61
|
-
# self.eof = True
|
62
|
-
# break
|
63
|
-
|
64
|
-
# decompressed = snappy.decompress(compressed)
|
65
|
-
# self.buffer.write(decompressed)
|
66
|
-
# except Exception:
|
67
|
-
# self.eof = True
|
68
|
-
# break
|
69
|
-
|
70
|
-
# self.buffer.seek(0)
|
71
|
-
|
72
|
-
# def read(self, size=-1):
|
73
|
-
# if size == -1:
|
74
|
-
# while not self.eof:
|
75
|
-
# self._fill_buffer_if_needed(_CHUNK_SIZE)
|
76
|
-
# result = self.buffer.read()
|
77
|
-
# self.buffer = io.BytesIO()
|
78
|
-
# return result
|
79
|
-
|
80
|
-
# self._fill_buffer_if_needed(size)
|
81
|
-
# data = self.buffer.read(size)
|
82
|
-
# rest = self.buffer.read()
|
83
|
-
# self.buffer = io.BytesIO()
|
84
|
-
# self.buffer.write(rest)
|
85
|
-
# return data
|
86
|
-
|
87
|
-
# def readable(self):
|
88
|
-
# return True
|
89
|
-
|
90
|
-
# def close(self):
|
91
|
-
# self.fileobj.close()
|
92
|
-
|
93
39
|
class _SnappyReadWrapper(io.RawIOBase):
|
94
40
|
def __init__(self, fileobj):
|
95
41
|
self.fileobj = fileobj
|
@@ -6,11 +6,11 @@ import pyarrow as pa
|
|
6
6
|
import pyarrow.parquet as pq
|
7
7
|
import os
|
8
8
|
import shutil
|
9
|
-
from streaming import
|
9
|
+
from streaming import StreamingDataset
|
10
10
|
from tqdm import tqdm
|
11
11
|
|
12
12
|
from .compression import determine_compression, open_compression, pigz_compress
|
13
|
-
from .mds import MDSBulkReader
|
13
|
+
from .mds import MDSBulkReader, MDSWriter
|
14
14
|
from .pigz import pigz_open
|
15
15
|
|
16
16
|
__all__ = [
|
@@ -1,95 +0,0 @@
|
|
1
|
-
import gzip
|
2
|
-
import json
|
3
|
-
from mltiming import timing
|
4
|
-
import numpy as np
|
5
|
-
import os
|
6
|
-
import snappy
|
7
|
-
from streaming.base.format.mds.encodings import mds_decode
|
8
|
-
from typing import Any, Optional, Generator
|
9
|
-
|
10
|
-
from .options import MDS_COMPRESSIONS
|
11
|
-
from .utils import open_compression
|
12
|
-
|
13
|
-
class MDSBulkReader:
|
14
|
-
def __init__(
|
15
|
-
self,
|
16
|
-
dirnames: list[str],
|
17
|
-
split: Optional[str],
|
18
|
-
) -> None:
|
19
|
-
self.shards = []
|
20
|
-
self.samples = 0
|
21
|
-
for dirname in dirnames:
|
22
|
-
if split is not None:
|
23
|
-
dirname = os.path.join(dirname, split)
|
24
|
-
index = json.load(open(os.path.join(dirname, "index.json"), 'rt'))
|
25
|
-
for shard in index["shards"]:
|
26
|
-
basename = shard['raw_data']['basename'] if shard['zip_data'] is None else shard['zip_data']['basename']
|
27
|
-
filename = os.path.join(dirname, basename)
|
28
|
-
self.shards.append({
|
29
|
-
"filename": filename,
|
30
|
-
"compression": shard['compression'],
|
31
|
-
})
|
32
|
-
self.samples += shard['samples']
|
33
|
-
|
34
|
-
def __len__(self) -> int:
|
35
|
-
return self.samples
|
36
|
-
|
37
|
-
def __iter__(self) -> Generator[dict[str, Any], None, None]:
|
38
|
-
for shard in self.shards:
|
39
|
-
with MDSShardReader(**shard) as reader:
|
40
|
-
for sample in reader:
|
41
|
-
yield sample
|
42
|
-
|
43
|
-
class MDSShardReader:
|
44
|
-
def __init__(
|
45
|
-
self,
|
46
|
-
filename: str,
|
47
|
-
compression: Optional[str],
|
48
|
-
) -> None:
|
49
|
-
self.fp = open_compression(filename, "rb", compression=compression)
|
50
|
-
self.samples = np.frombuffer(self.fp.read(4), np.uint32)[0]
|
51
|
-
self.index = np.frombuffer(self.fp.read((1+self.samples)*4), np.uint32)
|
52
|
-
info = json.loads(self.fp.read(self.index[0]-self.fp.tell()))
|
53
|
-
self.column_encodings = info["column_encodings"]
|
54
|
-
self.column_names = info["column_names"]
|
55
|
-
self.column_sizes = info["column_sizes"]
|
56
|
-
assert self.fp.tell() == self.index[0]
|
57
|
-
|
58
|
-
def decode_sample(self, data: bytes) -> dict[str, Any]:
|
59
|
-
sizes = []
|
60
|
-
idx = 0
|
61
|
-
for key, size in zip(self.column_names, self.column_sizes):
|
62
|
-
if size:
|
63
|
-
sizes.append(size)
|
64
|
-
else:
|
65
|
-
size, = np.frombuffer(data[idx:idx + 4], np.uint32)
|
66
|
-
sizes.append(size)
|
67
|
-
idx += 4
|
68
|
-
sample = {}
|
69
|
-
for key, encoding, size in zip(self.column_names, self.column_encodings, sizes):
|
70
|
-
value = data[idx:idx + size]
|
71
|
-
sample[key] = mds_decode(encoding, value)
|
72
|
-
idx += size
|
73
|
-
return sample
|
74
|
-
|
75
|
-
def get_sample_data(self, idx: int) -> bytes:
|
76
|
-
begin, end = self.index[idx:idx+2]
|
77
|
-
assert self.fp.tell() == begin
|
78
|
-
data = self.fp.read(end - begin)
|
79
|
-
assert self.fp.tell() == end
|
80
|
-
assert data
|
81
|
-
return data
|
82
|
-
|
83
|
-
def get_item(self, idx: int) -> dict[str, Any]:
|
84
|
-
data = self.get_sample_data(idx)
|
85
|
-
return self.decode_sample(data)
|
86
|
-
|
87
|
-
def __iter__(self) -> Generator[dict[str, Any], None, None]:
|
88
|
-
for i in range(self.samples):
|
89
|
-
yield self.get_item(i)
|
90
|
-
|
91
|
-
def __enter__(self) -> "MDSShardReader":
|
92
|
-
return self
|
93
|
-
|
94
|
-
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
95
|
-
self.fp.close()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|