mldataforge 0.1.6__tar.gz → 0.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mldataforge-0.1.6 → mldataforge-0.1.7}/PKG-INFO +1 -1
- {mldataforge-0.1.6 → mldataforge-0.1.7}/mldataforge/mds.py +135 -4
- {mldataforge-0.1.6 → mldataforge-0.1.7}/mldataforge/utils.py +10 -5
- {mldataforge-0.1.6 → mldataforge-0.1.7}/pyproject.toml +1 -1
- {mldataforge-0.1.6 → mldataforge-0.1.7}/.gitignore +0 -0
- {mldataforge-0.1.6 → mldataforge-0.1.7}/LICENSE +0 -0
- {mldataforge-0.1.6 → mldataforge-0.1.7}/README.md +0 -0
- {mldataforge-0.1.6 → mldataforge-0.1.7}/mldataforge/__main__.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.1.7}/mldataforge/brotli.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.1.7}/mldataforge/commands/__init__.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.1.7}/mldataforge/commands/convert/__init__.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.1.7}/mldataforge/commands/convert/jsonl.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.1.7}/mldataforge/commands/convert/mds.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.1.7}/mldataforge/commands/convert/parquet.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.1.7}/mldataforge/commands/join.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.1.7}/mldataforge/commands/split.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.1.7}/mldataforge/compression.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.1.7}/mldataforge/options.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.1.7}/mldataforge/pigz.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.1.7}/mldataforge/snappy.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mldataforge
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.7
|
4
4
|
Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
|
5
5
|
Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
|
6
6
|
Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
|
@@ -1,16 +1,26 @@
|
|
1
|
+
from copy import deepcopy
|
1
2
|
import json
|
2
3
|
import numpy as np
|
3
4
|
import os
|
4
5
|
import shutil
|
5
6
|
from streaming.base.compression import compress, decompress, get_compression_extension, is_compression
|
7
|
+
from streaming.base.format import _readers
|
8
|
+
from streaming.base.format.base.reader import FileInfo, JointReader
|
6
9
|
from streaming.base.format.index import get_index_basename
|
7
|
-
from streaming.base.format.mds.encodings import mds_decode, mds_encode, is_mds_encoding, get_mds_encodings, get_mds_encoded_size
|
10
|
+
from streaming.base.format.mds.encodings import mds_decode, mds_encode, is_mds_encoding, is_mds_encoding_safe, get_mds_encodings, get_mds_encoded_size
|
8
11
|
from streaming.base.hashing import get_hash, is_hash
|
9
12
|
from streaming.base.util import bytes_to_int
|
10
13
|
from typing import Any, Optional, Generator, Self, Union
|
11
14
|
|
12
15
|
from .utils import open_compression
|
13
16
|
|
17
|
+
__all__ = [
|
18
|
+
"MDSBulkReader",
|
19
|
+
"MDSBulkShardReader",
|
20
|
+
"MDSReader",
|
21
|
+
"MDSWriter",
|
22
|
+
]
|
23
|
+
|
14
24
|
class MDSBulkReader:
|
15
25
|
def __init__(
|
16
26
|
self,
|
@@ -37,11 +47,11 @@ class MDSBulkReader:
|
|
37
47
|
|
38
48
|
def __iter__(self) -> Generator[dict[str, Any], None, None]:
|
39
49
|
for shard in self.shards:
|
40
|
-
with
|
50
|
+
with MDSBulkShardReader(**shard) as reader:
|
41
51
|
for sample in reader:
|
42
52
|
yield sample
|
43
53
|
|
44
|
-
class
|
54
|
+
class MDSBulkShardReader:
|
45
55
|
def __init__(
|
46
56
|
self,
|
47
57
|
filename: str,
|
@@ -94,7 +104,7 @@ class MDSShardReader:
|
|
94
104
|
for i in range(self.samples):
|
95
105
|
yield self.get_item(i)
|
96
106
|
|
97
|
-
def __enter__(self) -> "
|
107
|
+
def __enter__(self) -> "MDSBulkShardReader":
|
98
108
|
return self
|
99
109
|
|
100
110
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
@@ -315,3 +325,124 @@ class MDSWriter:
|
|
315
325
|
|
316
326
|
def __exit__(self, exc_type, exc, traceback):
|
317
327
|
self.finish()
|
328
|
+
|
329
|
+
class MDSReader(JointReader):
|
330
|
+
|
331
|
+
def __init__(
|
332
|
+
self,
|
333
|
+
dirname: str,
|
334
|
+
split: Optional[str],
|
335
|
+
column_encodings: list[str],
|
336
|
+
column_names: list[str],
|
337
|
+
column_sizes: list[Optional[int]],
|
338
|
+
compression: Optional[str],
|
339
|
+
hashes: list[str],
|
340
|
+
raw_data: FileInfo,
|
341
|
+
samples: int,
|
342
|
+
size_limit: Optional[Union[int, str]],
|
343
|
+
zip_data: Optional[FileInfo],
|
344
|
+
) -> None:
|
345
|
+
self.sample_compression = None
|
346
|
+
if compression and compression.startswith("sample::"):
|
347
|
+
compression, self.sample_compression = None, compression.removeprefix("sample::")
|
348
|
+
super().__init__(dirname, split, compression, hashes, raw_data, samples, size_limit,
|
349
|
+
zip_data)
|
350
|
+
self.column_encodings = column_encodings
|
351
|
+
self.column_names = column_names
|
352
|
+
self.column_sizes = column_sizes
|
353
|
+
|
354
|
+
@classmethod
|
355
|
+
def from_json(cls, dirname: str, split: Optional[str], obj: dict[str, Any]) -> Self:
|
356
|
+
"""Initialize from JSON object.
|
357
|
+
|
358
|
+
Args:
|
359
|
+
dirname (str): Local directory containing shards.
|
360
|
+
split (str, optional): Which dataset split to use, if any.
|
361
|
+
obj (Dict[str, Any]): JSON object to load.
|
362
|
+
|
363
|
+
Returns:
|
364
|
+
Self: Loaded MDSReader.
|
365
|
+
"""
|
366
|
+
args = deepcopy(obj)
|
367
|
+
args_version = args['version']
|
368
|
+
if args_version != 2:
|
369
|
+
raise ValueError(
|
370
|
+
f'Unsupported streaming data version: {args_version}. Expected version 2.')
|
371
|
+
del args['version']
|
372
|
+
args_format = args['format']
|
373
|
+
if args_format != 'mds':
|
374
|
+
raise ValueError(f'Unsupported data format: {args_format}. Expected to be `mds`.')
|
375
|
+
del args['format']
|
376
|
+
args['dirname'] = dirname
|
377
|
+
args['split'] = split
|
378
|
+
for key in ['raw_data', 'zip_data']:
|
379
|
+
arg = args[key]
|
380
|
+
args[key] = FileInfo(**arg) if arg else None
|
381
|
+
return cls(**args)
|
382
|
+
|
383
|
+
def validate(self, allow_unsafe_types: bool) -> None:
|
384
|
+
"""Check whether this shard is acceptable to be part of some Stream.
|
385
|
+
|
386
|
+
Args:
|
387
|
+
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
|
388
|
+
execution during deserialization, whether to keep going if ``True`` or raise an
|
389
|
+
error if ``False``.
|
390
|
+
"""
|
391
|
+
if not allow_unsafe_types:
|
392
|
+
for column_id, encoding in enumerate(self.column_encodings):
|
393
|
+
if not is_mds_encoding_safe(encoding):
|
394
|
+
name = self.column_names[column_id]
|
395
|
+
raise ValueError(f'Column {name} contains an unsafe type: {encoding}. To ' +
|
396
|
+
f'proceed anyway, set ``allow_unsafe_types=True``.')
|
397
|
+
|
398
|
+
def decode_sample(self, data: bytes) -> dict[str, Any]:
|
399
|
+
"""Decode a sample dict from bytes.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
data (bytes): The sample encoded as bytes.
|
403
|
+
|
404
|
+
Returns:
|
405
|
+
Dict[str, Any]: Sample dict.
|
406
|
+
"""
|
407
|
+
sizes = []
|
408
|
+
idx = 0
|
409
|
+
for key, size in zip(self.column_names, self.column_sizes):
|
410
|
+
if size:
|
411
|
+
sizes.append(size)
|
412
|
+
else:
|
413
|
+
size, = np.frombuffer(data[idx:idx + 4], np.uint32)
|
414
|
+
sizes.append(size)
|
415
|
+
idx += 4
|
416
|
+
sample = {}
|
417
|
+
for key, encoding, size in zip(self.column_names, self.column_encodings, sizes):
|
418
|
+
value = data[idx:idx + size]
|
419
|
+
sample[key] = mds_decode(encoding, value)
|
420
|
+
idx += size
|
421
|
+
return sample
|
422
|
+
|
423
|
+
def get_sample_data(self, idx: int) -> bytes:
|
424
|
+
"""Get the raw sample data at the index.
|
425
|
+
|
426
|
+
Args:
|
427
|
+
idx (int): Sample index.
|
428
|
+
|
429
|
+
Returns:
|
430
|
+
bytes: Sample data.
|
431
|
+
"""
|
432
|
+
filename = os.path.join(self.dirname, self.split, self.raw_data.basename)
|
433
|
+
offset = (1 + idx) * 4
|
434
|
+
with open(filename, 'rb', 0) as fp:
|
435
|
+
fp.seek(offset)
|
436
|
+
pair = fp.read(8)
|
437
|
+
begin, end = np.frombuffer(pair, np.uint32)
|
438
|
+
fp.seek(begin)
|
439
|
+
data = fp.read(end - begin)
|
440
|
+
if not data:
|
441
|
+
raise IndexError(
|
442
|
+
f'Relative sample index {idx} is not present in the {self.raw_data.basename} file.'
|
443
|
+
)
|
444
|
+
if self.sample_compression:
|
445
|
+
data = decompress(self.sample_compression, data)
|
446
|
+
return data
|
447
|
+
|
448
|
+
_readers["mds"] = MDSReader
|
@@ -23,6 +23,11 @@ __all__ = [
|
|
23
23
|
"save_parquet",
|
24
24
|
]
|
25
25
|
|
26
|
+
_NO_PROGESS = False
|
27
|
+
def set_progress(value):
|
28
|
+
global _NO_PROGESS
|
29
|
+
_NO_PROGESS = value
|
30
|
+
|
26
31
|
def _batch_iterable(iterable, batch_size):
|
27
32
|
batch = []
|
28
33
|
for item in iterable:
|
@@ -73,7 +78,7 @@ def _infer_mds_encoding(value):
|
|
73
78
|
return 'pkl'
|
74
79
|
|
75
80
|
def _streaming_jsonl(jsonl_files, compressions):
|
76
|
-
for jsonl_file, compression in tqdm(zip(jsonl_files, compressions), desc="Loading JSONL files", unit="file"):
|
81
|
+
for jsonl_file, compression in tqdm(zip(jsonl_files, compressions), desc="Loading JSONL files", unit="file", disable=_NO_PROGESS):
|
77
82
|
for line in open_compression(jsonl_file, mode="rt", compression=compression):
|
78
83
|
yield json.loads(line)
|
79
84
|
|
@@ -109,7 +114,7 @@ def load_mds_directories(mds_directories, split='.', batch_size=2**16, bulk=True
|
|
109
114
|
def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True):
|
110
115
|
f = None
|
111
116
|
part = 0
|
112
|
-
for item in tqdm(iterable, desc="Writing to JSONL", unit="sample"):
|
117
|
+
for item in tqdm(iterable, desc="Writing to JSONL", unit="sample", disable=_NO_PROGESS):
|
113
118
|
if f is None:
|
114
119
|
part_file = output_file.format(part=part)
|
115
120
|
check_arguments(part_file, overwrite, yes)
|
@@ -127,7 +132,7 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
|
|
127
132
|
writer = None
|
128
133
|
part = 0
|
129
134
|
files = []
|
130
|
-
for sample in tqdm(it, desc="Writing to MDS", unit="sample"):
|
135
|
+
for sample in tqdm(it, desc="Writing to MDS", unit="sample", disable=_NO_PROGESS):
|
131
136
|
if writer is None:
|
132
137
|
part_dir = output_dir.format(part=part)
|
133
138
|
check_arguments(part_dir, overwrite, yes)
|
@@ -151,7 +156,7 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
|
|
151
156
|
name2info = {shard["raw_data"]["basename"]: shard for shard in index["shards"]}
|
152
157
|
file_names = [file for file in os.listdir(output_dir) if file.endswith(".mds")]
|
153
158
|
assert set(file_names) == set(name2info.keys())
|
154
|
-
for file_name in tqdm(file_names, desc="Compressing with pigz", unit="file"):
|
159
|
+
for file_name in tqdm(file_names, desc="Compressing with pigz", unit="file", disable=_NO_PROGESS):
|
155
160
|
compressed_file_name = file_name + ".gz"
|
156
161
|
file_path = os.path.join(output_dir, file_name)
|
157
162
|
compressed_file_path = os.path.join(output_dir, compressed_file_name)
|
@@ -169,7 +174,7 @@ def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=
|
|
169
174
|
compression = determine_compression("parquet", output_file, compression)
|
170
175
|
writer = None
|
171
176
|
part = 0
|
172
|
-
it = tqdm(it, desc="Writing to Parquet", unit="sample")
|
177
|
+
it = tqdm(it, desc="Writing to Parquet", unit="sample", disable=_NO_PROGESS)
|
173
178
|
for batch in _batch_iterable(it, batch_size):
|
174
179
|
table = pa.Table.from_pylist(batch)
|
175
180
|
if writer is None:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|