mldataforge 0.1.6__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mldataforge-0.1.6 → mldataforge-0.2.0}/PKG-INFO +1 -1
- {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/commands/convert/jsonl.py +6 -2
- {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/commands/convert/mds.py +6 -2
- {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/commands/convert/parquet.py +6 -2
- {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/commands/join.py +8 -3
- {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/commands/split.py +15 -3
- {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/mds.py +135 -4
- {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/options.py +12 -0
- mldataforge-0.2.0/mldataforge/trafos.py +111 -0
- {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/utils.py +20 -8
- {mldataforge-0.1.6 → mldataforge-0.2.0}/pyproject.toml +1 -1
- {mldataforge-0.1.6 → mldataforge-0.2.0}/.gitignore +0 -0
- {mldataforge-0.1.6 → mldataforge-0.2.0}/LICENSE +0 -0
- {mldataforge-0.1.6 → mldataforge-0.2.0}/README.md +0 -0
- {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/__main__.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/brotli.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/commands/__init__.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/commands/convert/__init__.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/compression.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/pigz.py +0 -0
- {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/snappy.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mldataforge
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
|
5
5
|
Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
|
6
6
|
Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
|
@@ -21,9 +21,10 @@ def jsonl():
|
|
21
21
|
@buf_size_option()
|
22
22
|
@shard_size_option()
|
23
23
|
@no_pigz_option()
|
24
|
+
@trafo_option()
|
24
25
|
def mds(**kwargs):
|
25
26
|
jsonl_to_mds(**kwargs)
|
26
|
-
def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz):
|
27
|
+
def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz, trafo):
|
27
28
|
check_arguments(output_dir, overwrite, yes, jsonl_files)
|
28
29
|
save_mds(
|
29
30
|
load_jsonl_files(jsonl_files),
|
@@ -33,6 +34,7 @@ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes
|
|
33
34
|
buf_size=buf_size,
|
34
35
|
pigz=use_pigz(compression, no_pigz),
|
35
36
|
shard_size=shard_size,
|
37
|
+
trafo=trafo,
|
36
38
|
)
|
37
39
|
|
38
40
|
@jsonl.command()
|
@@ -42,13 +44,15 @@ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes
|
|
42
44
|
@overwrite_option()
|
43
45
|
@yes_option()
|
44
46
|
@batch_size_option()
|
47
|
+
@trafo_option()
|
45
48
|
def parquet(**kwargs):
|
46
49
|
jsonl_to_parquet(**kwargs)
|
47
|
-
def jsonl_to_parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size):
|
50
|
+
def jsonl_to_parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size, trafo):
|
48
51
|
check_arguments(output_file, overwrite, yes, jsonl_files)
|
49
52
|
save_parquet(
|
50
53
|
load_jsonl_files(jsonl_files),
|
51
54
|
output_file,
|
52
55
|
compression=compression,
|
53
56
|
batch_size=batch_size,
|
57
|
+
trafo=trafo,
|
54
58
|
)
|
@@ -19,15 +19,17 @@ def mds():
|
|
19
19
|
@yes_option()
|
20
20
|
@batch_size_option()
|
21
21
|
@no_bulk_option()
|
22
|
+
@trafo_option()
|
22
23
|
def jsonl(**kwargs):
|
23
24
|
mds_to_jsonl(**kwargs)
|
24
|
-
def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk):
|
25
|
+
def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk, trafo):
|
25
26
|
check_arguments(output_file, overwrite, yes, mds_directories)
|
26
27
|
save_jsonl(
|
27
28
|
load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
|
28
29
|
output_file,
|
29
30
|
compression=compression,
|
30
31
|
processes=processes,
|
32
|
+
trafo=trafo,
|
31
33
|
)
|
32
34
|
|
33
35
|
@mds.command()
|
@@ -38,13 +40,15 @@ def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite
|
|
38
40
|
@yes_option()
|
39
41
|
@batch_size_option()
|
40
42
|
@no_bulk_option()
|
43
|
+
@trafo_option()
|
41
44
|
def parquet(**kwargs):
|
42
45
|
mds_to_parquet(**kwargs)
|
43
|
-
def mds_to_parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk):
|
46
|
+
def mds_to_parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk, trafo):
|
44
47
|
check_arguments(output_file, overwrite, yes, mds_directories)
|
45
48
|
save_parquet(
|
46
49
|
load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
|
47
50
|
output_file,
|
48
51
|
compression=compression,
|
49
52
|
batch_size=batch_size,
|
53
|
+
trafo=trafo,
|
50
54
|
)
|
@@ -18,15 +18,17 @@ def parquet():
|
|
18
18
|
@processes_option()
|
19
19
|
@overwrite_option()
|
20
20
|
@yes_option()
|
21
|
+
@trafo_option()
|
21
22
|
def jsonl(**kwargs):
|
22
23
|
parquet_to_jsonl(**kwargs)
|
23
|
-
def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwrite, yes):
|
24
|
+
def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwrite, yes, trafo):
|
24
25
|
check_arguments(output_file, overwrite, yes, parquet_files)
|
25
26
|
save_jsonl(
|
26
27
|
load_dataset("parquet", data_files=parquet_files, split="train"),
|
27
28
|
output_file,
|
28
29
|
compression=compression,
|
29
30
|
processes=processes,
|
31
|
+
trafo=trafo,
|
30
32
|
)
|
31
33
|
|
32
34
|
@parquet.command()
|
@@ -39,9 +41,10 @@ def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwri
|
|
39
41
|
@buf_size_option()
|
40
42
|
@shard_size_option()
|
41
43
|
@no_pigz_option()
|
44
|
+
@trafo_option()
|
42
45
|
def mds(**kwargs):
|
43
46
|
parquet_to_mds(**kwargs)
|
44
|
-
def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz):
|
47
|
+
def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz, trafo):
|
45
48
|
check_arguments(output_dir, overwrite, yes, parquet_files)
|
46
49
|
save_mds(
|
47
50
|
load_dataset("parquet", data_files=parquet_files, split="train"),
|
@@ -51,4 +54,5 @@ def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite,
|
|
51
54
|
buf_size=buf_size,
|
52
55
|
pigz=use_pigz(compression, no_pigz=no_pigz),
|
53
56
|
shard_size=shard_size,
|
57
|
+
trafo=trafo,
|
54
58
|
)
|
@@ -18,9 +18,10 @@ def join():
|
|
18
18
|
@processes_option()
|
19
19
|
@overwrite_option()
|
20
20
|
@yes_option()
|
21
|
+
@trafo_option()
|
21
22
|
def jsonl(**kwargs):
|
22
23
|
join_jsonl(**kwargs)
|
23
|
-
def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes):
|
24
|
+
def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes, trafo):
|
24
25
|
check_arguments(output_file, overwrite, yes, jsonl_files)
|
25
26
|
save_jsonl(
|
26
27
|
load_jsonl_files(jsonl_files),
|
@@ -41,10 +42,11 @@ def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes)
|
|
41
42
|
@no_bulk_option()
|
42
43
|
@shard_size_option()
|
43
44
|
@no_pigz_option()
|
45
|
+
@trafo_option()
|
44
46
|
def mds(**kwargs):
|
45
47
|
print(kwargs)
|
46
48
|
join_mds(**kwargs)
|
47
|
-
def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk, shard_size, no_pigz):
|
49
|
+
def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk, shard_size, no_pigz, trafo):
|
48
50
|
check_arguments(output_dir, overwrite, yes, mds_directories)
|
49
51
|
save_mds(
|
50
52
|
load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
|
@@ -54,6 +56,7 @@ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes
|
|
54
56
|
buf_size=buf_size,
|
55
57
|
shard_size=shard_size,
|
56
58
|
pigz=use_pigz(compression, no_pigz),
|
59
|
+
trafo=trafo,
|
57
60
|
)
|
58
61
|
|
59
62
|
@join.command()
|
@@ -63,13 +66,15 @@ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes
|
|
63
66
|
@overwrite_option()
|
64
67
|
@yes_option()
|
65
68
|
@batch_size_option()
|
69
|
+
@trafo_option()
|
66
70
|
def parquet(**kwargs):
|
67
71
|
join_parquet(**kwargs)
|
68
|
-
def join_parquet(output_file, parquet_files, compression, overwrite, yes, batch_size):
|
72
|
+
def join_parquet(output_file, parquet_files, compression, overwrite, yes, batch_size, trafo):
|
69
73
|
check_arguments(output_file, overwrite, yes, parquet_files)
|
70
74
|
save_parquet(
|
71
75
|
load_dataset("parquet", data_files=parquet_files, split="train"),
|
72
76
|
output_file,
|
73
77
|
compression=compression,
|
74
78
|
batch_size=batch_size,
|
79
|
+
trafo=trafo,
|
75
80
|
)
|
@@ -20,7 +20,10 @@ def split():
|
|
20
20
|
@processes_option()
|
21
21
|
@overwrite_option()
|
22
22
|
@yes_option()
|
23
|
-
|
23
|
+
@trafo_option()
|
24
|
+
def jsonl(*args, **kwargs):
|
25
|
+
split_jsonl(*args, **kwargs)
|
26
|
+
def split_jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, overwrite, yes, trafo):
|
24
27
|
save_jsonl(
|
25
28
|
load_jsonl_files(jsonl_files),
|
26
29
|
output_file=f"{output_dir}/{prefix}{{part:04d}}.jsonl{extension_compression(compression, jsonl_files[0])}",
|
@@ -29,6 +32,7 @@ def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, ov
|
|
29
32
|
size_hint=size_hint,
|
30
33
|
overwrite=overwrite,
|
31
34
|
yes=yes,
|
35
|
+
trafo=trafo,
|
32
36
|
)
|
33
37
|
|
34
38
|
@split.command()
|
@@ -45,7 +49,10 @@ def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, ov
|
|
45
49
|
@no_bulk_option()
|
46
50
|
@shard_size_option()
|
47
51
|
@no_pigz_option()
|
48
|
-
|
52
|
+
@trafo_option()
|
53
|
+
def mds(*args, **kwargs):
|
54
|
+
split_mds(*args, **kwargs)
|
55
|
+
def split_mds(mds_directories, prefix, output_dir, size_hint, compression, processes, overwrite, yes, buf_size, batch_size, no_bulk, shard_size, no_pigz, trafo):
|
49
56
|
save_mds(
|
50
57
|
load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
|
51
58
|
output_dir=f"{output_dir}/{prefix}{{part:04d}}",
|
@@ -57,6 +64,7 @@ def mds(mds_directories, prefix, output_dir, size_hint, compression, processes,
|
|
57
64
|
size_hint=size_hint,
|
58
65
|
overwrite=overwrite,
|
59
66
|
yes=yes,
|
67
|
+
trafo=trafo,
|
60
68
|
)
|
61
69
|
|
62
70
|
@split.command()
|
@@ -68,7 +76,10 @@ def mds(mds_directories, prefix, output_dir, size_hint, compression, processes,
|
|
68
76
|
@overwrite_option()
|
69
77
|
@yes_option()
|
70
78
|
@batch_size_option()
|
71
|
-
|
79
|
+
@trafo_option()
|
80
|
+
def parquet(*args, **kwargs):
|
81
|
+
split_parquet(*args, **kwargs)
|
82
|
+
def split_parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite, yes, batch_size, trafo):
|
72
83
|
save_parquet(
|
73
84
|
load_dataset("parquet", data_files=parquet_files, split="train"),
|
74
85
|
output_file=f"{output_dir}/{prefix}{{part:04d}}.parquet",
|
@@ -77,4 +88,5 @@ def parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite
|
|
77
88
|
size_hint=size_hint,
|
78
89
|
overwrite=overwrite,
|
79
90
|
yes=yes,
|
91
|
+
trafo=trafo,
|
80
92
|
)
|
@@ -1,16 +1,26 @@
|
|
1
|
+
from copy import deepcopy
|
1
2
|
import json
|
2
3
|
import numpy as np
|
3
4
|
import os
|
4
5
|
import shutil
|
5
6
|
from streaming.base.compression import compress, decompress, get_compression_extension, is_compression
|
7
|
+
from streaming.base.format import _readers
|
8
|
+
from streaming.base.format.base.reader import FileInfo, JointReader
|
6
9
|
from streaming.base.format.index import get_index_basename
|
7
|
-
from streaming.base.format.mds.encodings import mds_decode, mds_encode, is_mds_encoding, get_mds_encodings, get_mds_encoded_size
|
10
|
+
from streaming.base.format.mds.encodings import mds_decode, mds_encode, is_mds_encoding, is_mds_encoding_safe, get_mds_encodings, get_mds_encoded_size
|
8
11
|
from streaming.base.hashing import get_hash, is_hash
|
9
12
|
from streaming.base.util import bytes_to_int
|
10
13
|
from typing import Any, Optional, Generator, Self, Union
|
11
14
|
|
12
15
|
from .utils import open_compression
|
13
16
|
|
17
|
+
__all__ = [
|
18
|
+
"MDSBulkReader",
|
19
|
+
"MDSBulkShardReader",
|
20
|
+
"MDSReader",
|
21
|
+
"MDSWriter",
|
22
|
+
]
|
23
|
+
|
14
24
|
class MDSBulkReader:
|
15
25
|
def __init__(
|
16
26
|
self,
|
@@ -37,11 +47,11 @@ class MDSBulkReader:
|
|
37
47
|
|
38
48
|
def __iter__(self) -> Generator[dict[str, Any], None, None]:
|
39
49
|
for shard in self.shards:
|
40
|
-
with
|
50
|
+
with MDSBulkShardReader(**shard) as reader:
|
41
51
|
for sample in reader:
|
42
52
|
yield sample
|
43
53
|
|
44
|
-
class
|
54
|
+
class MDSBulkShardReader:
|
45
55
|
def __init__(
|
46
56
|
self,
|
47
57
|
filename: str,
|
@@ -94,7 +104,7 @@ class MDSShardReader:
|
|
94
104
|
for i in range(self.samples):
|
95
105
|
yield self.get_item(i)
|
96
106
|
|
97
|
-
def __enter__(self) -> "
|
107
|
+
def __enter__(self) -> "MDSBulkShardReader":
|
98
108
|
return self
|
99
109
|
|
100
110
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
@@ -315,3 +325,124 @@ class MDSWriter:
|
|
315
325
|
|
316
326
|
def __exit__(self, exc_type, exc, traceback):
|
317
327
|
self.finish()
|
328
|
+
|
329
|
+
class MDSReader(JointReader):
|
330
|
+
|
331
|
+
def __init__(
|
332
|
+
self,
|
333
|
+
dirname: str,
|
334
|
+
split: Optional[str],
|
335
|
+
column_encodings: list[str],
|
336
|
+
column_names: list[str],
|
337
|
+
column_sizes: list[Optional[int]],
|
338
|
+
compression: Optional[str],
|
339
|
+
hashes: list[str],
|
340
|
+
raw_data: FileInfo,
|
341
|
+
samples: int,
|
342
|
+
size_limit: Optional[Union[int, str]],
|
343
|
+
zip_data: Optional[FileInfo],
|
344
|
+
) -> None:
|
345
|
+
self.sample_compression = None
|
346
|
+
if compression and compression.startswith("sample::"):
|
347
|
+
compression, self.sample_compression = None, compression.removeprefix("sample::")
|
348
|
+
super().__init__(dirname, split, compression, hashes, raw_data, samples, size_limit,
|
349
|
+
zip_data)
|
350
|
+
self.column_encodings = column_encodings
|
351
|
+
self.column_names = column_names
|
352
|
+
self.column_sizes = column_sizes
|
353
|
+
|
354
|
+
@classmethod
|
355
|
+
def from_json(cls, dirname: str, split: Optional[str], obj: dict[str, Any]) -> Self:
|
356
|
+
"""Initialize from JSON object.
|
357
|
+
|
358
|
+
Args:
|
359
|
+
dirname (str): Local directory containing shards.
|
360
|
+
split (str, optional): Which dataset split to use, if any.
|
361
|
+
obj (Dict[str, Any]): JSON object to load.
|
362
|
+
|
363
|
+
Returns:
|
364
|
+
Self: Loaded MDSReader.
|
365
|
+
"""
|
366
|
+
args = deepcopy(obj)
|
367
|
+
args_version = args['version']
|
368
|
+
if args_version != 2:
|
369
|
+
raise ValueError(
|
370
|
+
f'Unsupported streaming data version: {args_version}. Expected version 2.')
|
371
|
+
del args['version']
|
372
|
+
args_format = args['format']
|
373
|
+
if args_format != 'mds':
|
374
|
+
raise ValueError(f'Unsupported data format: {args_format}. Expected to be `mds`.')
|
375
|
+
del args['format']
|
376
|
+
args['dirname'] = dirname
|
377
|
+
args['split'] = split
|
378
|
+
for key in ['raw_data', 'zip_data']:
|
379
|
+
arg = args[key]
|
380
|
+
args[key] = FileInfo(**arg) if arg else None
|
381
|
+
return cls(**args)
|
382
|
+
|
383
|
+
def validate(self, allow_unsafe_types: bool) -> None:
|
384
|
+
"""Check whether this shard is acceptable to be part of some Stream.
|
385
|
+
|
386
|
+
Args:
|
387
|
+
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
|
388
|
+
execution during deserialization, whether to keep going if ``True`` or raise an
|
389
|
+
error if ``False``.
|
390
|
+
"""
|
391
|
+
if not allow_unsafe_types:
|
392
|
+
for column_id, encoding in enumerate(self.column_encodings):
|
393
|
+
if not is_mds_encoding_safe(encoding):
|
394
|
+
name = self.column_names[column_id]
|
395
|
+
raise ValueError(f'Column {name} contains an unsafe type: {encoding}. To ' +
|
396
|
+
f'proceed anyway, set ``allow_unsafe_types=True``.')
|
397
|
+
|
398
|
+
def decode_sample(self, data: bytes) -> dict[str, Any]:
|
399
|
+
"""Decode a sample dict from bytes.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
data (bytes): The sample encoded as bytes.
|
403
|
+
|
404
|
+
Returns:
|
405
|
+
Dict[str, Any]: Sample dict.
|
406
|
+
"""
|
407
|
+
sizes = []
|
408
|
+
idx = 0
|
409
|
+
for key, size in zip(self.column_names, self.column_sizes):
|
410
|
+
if size:
|
411
|
+
sizes.append(size)
|
412
|
+
else:
|
413
|
+
size, = np.frombuffer(data[idx:idx + 4], np.uint32)
|
414
|
+
sizes.append(size)
|
415
|
+
idx += 4
|
416
|
+
sample = {}
|
417
|
+
for key, encoding, size in zip(self.column_names, self.column_encodings, sizes):
|
418
|
+
value = data[idx:idx + size]
|
419
|
+
sample[key] = mds_decode(encoding, value)
|
420
|
+
idx += size
|
421
|
+
return sample
|
422
|
+
|
423
|
+
def get_sample_data(self, idx: int) -> bytes:
|
424
|
+
"""Get the raw sample data at the index.
|
425
|
+
|
426
|
+
Args:
|
427
|
+
idx (int): Sample index.
|
428
|
+
|
429
|
+
Returns:
|
430
|
+
bytes: Sample data.
|
431
|
+
"""
|
432
|
+
filename = os.path.join(self.dirname, self.split, self.raw_data.basename)
|
433
|
+
offset = (1 + idx) * 4
|
434
|
+
with open(filename, 'rb', 0) as fp:
|
435
|
+
fp.seek(offset)
|
436
|
+
pair = fp.read(8)
|
437
|
+
begin, end = np.frombuffer(pair, np.uint32)
|
438
|
+
fp.seek(begin)
|
439
|
+
data = fp.read(end - begin)
|
440
|
+
if not data:
|
441
|
+
raise IndexError(
|
442
|
+
f'Relative sample index {idx} is not present in the {self.raw_data.basename} file.'
|
443
|
+
)
|
444
|
+
if self.sample_compression:
|
445
|
+
data = decompress(self.sample_compression, data)
|
446
|
+
return data
|
447
|
+
|
448
|
+
_readers["mds"] = MDSReader
|
@@ -14,6 +14,7 @@ __all__ = [
|
|
14
14
|
"prefix_option",
|
15
15
|
"shard_size_option",
|
16
16
|
"size_hint_option",
|
17
|
+
"trafo_option",
|
17
18
|
"yes_option",
|
18
19
|
]
|
19
20
|
|
@@ -129,6 +130,17 @@ def size_hint_option(default=2**26):
|
|
129
130
|
help=f"Size hint for the dataset (default: {default}).",
|
130
131
|
)
|
131
132
|
|
133
|
+
def trafo_option():
|
134
|
+
"""
|
135
|
+
Option for specifying the transformation function.
|
136
|
+
"""
|
137
|
+
return click.option(
|
138
|
+
"--trafo",
|
139
|
+
default=None,
|
140
|
+
type=str,
|
141
|
+
help="Transformation function to apply to the dataset.",
|
142
|
+
)
|
143
|
+
|
132
144
|
def yes_option():
|
133
145
|
"""
|
134
146
|
Option for specifying whether to assume yes to all prompts.
|
@@ -0,0 +1,111 @@
|
|
1
|
+
import re
|
2
|
+
from typing import Callable
|
3
|
+
|
4
|
+
__all__ = ['Trafo', 'flatten_json', 'unflatten_json']
|
5
|
+
|
6
|
+
class Trafo:
|
7
|
+
"""
|
8
|
+
Base class for transformations.
|
9
|
+
"""
|
10
|
+
|
11
|
+
def __init__(self, trafo: Callable | str | None):
|
12
|
+
self.trafo = trafo
|
13
|
+
if isinstance(trafo, str):
|
14
|
+
self.trafo = eval(trafo)
|
15
|
+
|
16
|
+
def __call__(self, obj):
|
17
|
+
return self.trafo(obj) if self.trafo else obj
|
18
|
+
|
19
|
+
def __repr__(self):
|
20
|
+
return f"{self.__class__.__name__}({self.trafo})"
|
21
|
+
|
22
|
+
|
23
|
+
def flatten_json(obj, parent_key='', sep='.', escape_char='\\'):
|
24
|
+
items = []
|
25
|
+
|
26
|
+
def escape(key):
|
27
|
+
return key.replace(escape_char, escape_char * 2)\
|
28
|
+
.replace(sep, escape_char + sep)\
|
29
|
+
.replace('[', escape_char + '[')\
|
30
|
+
.replace(']', escape_char + ']')
|
31
|
+
|
32
|
+
if isinstance(obj, dict):
|
33
|
+
if not obj:
|
34
|
+
# explicitly handle empty dict
|
35
|
+
items.append((parent_key, {}))
|
36
|
+
else:
|
37
|
+
for k, v in obj.items():
|
38
|
+
new_key = f"{parent_key}{sep}{escape(k)}" if parent_key else escape(k)
|
39
|
+
items.extend(flatten_json(v, new_key, sep, escape_char).items())
|
40
|
+
elif isinstance(obj, list):
|
41
|
+
if not obj:
|
42
|
+
# explicitly handle empty list
|
43
|
+
items.append((parent_key, []))
|
44
|
+
else:
|
45
|
+
for idx, v in enumerate(obj):
|
46
|
+
new_key = f"{parent_key}[{idx}]"
|
47
|
+
items.extend(flatten_json(v, new_key, sep, escape_char).items())
|
48
|
+
else:
|
49
|
+
items.append((parent_key, obj))
|
50
|
+
return dict(items)
|
51
|
+
|
52
|
+
|
53
|
+
def unflatten_json(flat_dict, sep='.', escape_char='\\'):
|
54
|
+
|
55
|
+
def check_flat_json(obj):
|
56
|
+
assert isinstance(obj, dict), "Input must be a dictionary"
|
57
|
+
for k, v in obj.items():
|
58
|
+
assert isinstance(k, str), f"Key {k} is not a string"
|
59
|
+
assert isinstance(v, (str, int, float, bool)), f"Value {v} is not a valid JSON type"
|
60
|
+
|
61
|
+
def parse_key(key):
|
62
|
+
tokens = re.findall(r'(?:[^.\[\]\\]|\\.)+|\[\d+\]', key)
|
63
|
+
parsed = []
|
64
|
+
for token in tokens:
|
65
|
+
if token.startswith('['):
|
66
|
+
parsed.append(int(token[1:-1]))
|
67
|
+
else:
|
68
|
+
parsed.append(token.replace(escape_char + sep, sep)
|
69
|
+
.replace(escape_char + '[', '[')
|
70
|
+
.replace(escape_char + ']', ']')
|
71
|
+
.replace(escape_char*2, escape_char))
|
72
|
+
return parsed
|
73
|
+
|
74
|
+
check_flat_json(flat_dict)
|
75
|
+
|
76
|
+
result = {}
|
77
|
+
|
78
|
+
for compound_key, value in flat_dict.items():
|
79
|
+
keys = parse_key(compound_key)
|
80
|
+
current = result
|
81
|
+
for idx, key in enumerate(keys):
|
82
|
+
if idx == len(keys) - 1:
|
83
|
+
if isinstance(key, int):
|
84
|
+
if not isinstance(current, list):
|
85
|
+
current_parent[last_key] = []
|
86
|
+
current = current_parent[last_key]
|
87
|
+
while len(current) <= key:
|
88
|
+
current.append(None)
|
89
|
+
current[key] = value
|
90
|
+
else:
|
91
|
+
current[key] = value
|
92
|
+
else:
|
93
|
+
next_key = keys[idx + 1]
|
94
|
+
if isinstance(key, int):
|
95
|
+
if not isinstance(current, list):
|
96
|
+
current_parent[last_key] = []
|
97
|
+
current = current_parent[last_key]
|
98
|
+
while len(current) <= key:
|
99
|
+
current.append(None)
|
100
|
+
if current[key] is None:
|
101
|
+
current[key] = [] if isinstance(next_key, int) else {}
|
102
|
+
current_parent = current
|
103
|
+
current = current[key]
|
104
|
+
else:
|
105
|
+
if key not in current:
|
106
|
+
current[key] = [] if isinstance(next_key, int) else {}
|
107
|
+
current_parent = current
|
108
|
+
current = current[key]
|
109
|
+
last_key = key
|
110
|
+
|
111
|
+
return result
|
@@ -12,6 +12,7 @@ from tqdm import tqdm
|
|
12
12
|
from .compression import determine_compression, open_compression, pigz_compress
|
13
13
|
from .mds import MDSBulkReader, MDSWriter
|
14
14
|
from .pigz import pigz_open
|
15
|
+
from .trafos import Trafo
|
15
16
|
|
16
17
|
__all__ = [
|
17
18
|
"check_arguments",
|
@@ -23,6 +24,11 @@ __all__ = [
|
|
23
24
|
"save_parquet",
|
24
25
|
]
|
25
26
|
|
27
|
+
_NO_PROGESS = False
|
28
|
+
def set_progress(value):
|
29
|
+
global _NO_PROGESS
|
30
|
+
_NO_PROGESS = value
|
31
|
+
|
26
32
|
def _batch_iterable(iterable, batch_size):
|
27
33
|
batch = []
|
28
34
|
for item in iterable:
|
@@ -73,7 +79,7 @@ def _infer_mds_encoding(value):
|
|
73
79
|
return 'pkl'
|
74
80
|
|
75
81
|
def _streaming_jsonl(jsonl_files, compressions):
|
76
|
-
for jsonl_file, compression in tqdm(zip(jsonl_files, compressions), desc="Loading JSONL files", unit="file"):
|
82
|
+
for jsonl_file, compression in tqdm(zip(jsonl_files, compressions), desc="Loading JSONL files", unit="file", disable=_NO_PROGESS):
|
77
83
|
for line in open_compression(jsonl_file, mode="rt", compression=compression):
|
78
84
|
yield json.loads(line)
|
79
85
|
|
@@ -106,10 +112,12 @@ def load_mds_directories(mds_directories, split='.', batch_size=2**16, bulk=True
|
|
106
112
|
ds = concatenate_datasets(dsets=dss)
|
107
113
|
return ds
|
108
114
|
|
109
|
-
def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True):
|
115
|
+
def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True, trafo=None):
|
110
116
|
f = None
|
111
117
|
part = 0
|
112
|
-
|
118
|
+
trafo = Trafo(trafo)
|
119
|
+
for item in tqdm(iterable, desc="Writing to JSONL", unit="sample", disable=_NO_PROGESS):
|
120
|
+
item = trafo(item)
|
113
121
|
if f is None:
|
114
122
|
part_file = output_file.format(part=part)
|
115
123
|
check_arguments(part_file, overwrite, yes)
|
@@ -122,12 +130,14 @@ def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=
|
|
122
130
|
if f is not None:
|
123
131
|
f.close()
|
124
132
|
|
125
|
-
def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pigz=True, shard_size=None, size_hint=None, overwrite=True, yes=True):
|
133
|
+
def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pigz=True, shard_size=None, size_hint=None, overwrite=True, yes=True, trafo=None):
|
126
134
|
compression = determine_compression("mds", output_dir, compression, no_pigz=not pigz)
|
127
135
|
writer = None
|
128
136
|
part = 0
|
129
137
|
files = []
|
130
|
-
|
138
|
+
trafo = Trafo(trafo)
|
139
|
+
for sample in tqdm(it, desc="Writing to MDS", unit="sample", disable=_NO_PROGESS):
|
140
|
+
sample = trafo(sample)
|
131
141
|
if writer is None:
|
132
142
|
part_dir = output_dir.format(part=part)
|
133
143
|
check_arguments(part_dir, overwrite, yes)
|
@@ -151,7 +161,7 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
|
|
151
161
|
name2info = {shard["raw_data"]["basename"]: shard for shard in index["shards"]}
|
152
162
|
file_names = [file for file in os.listdir(output_dir) if file.endswith(".mds")]
|
153
163
|
assert set(file_names) == set(name2info.keys())
|
154
|
-
for file_name in tqdm(file_names, desc="Compressing with pigz", unit="file"):
|
164
|
+
for file_name in tqdm(file_names, desc="Compressing with pigz", unit="file", disable=_NO_PROGESS):
|
155
165
|
compressed_file_name = file_name + ".gz"
|
156
166
|
file_path = os.path.join(output_dir, file_name)
|
157
167
|
compressed_file_path = os.path.join(output_dir, compressed_file_name)
|
@@ -165,12 +175,14 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
|
|
165
175
|
json.dump(index, open(index_path, "wt"))
|
166
176
|
print(f"Compressed {output_dir} with pigz")
|
167
177
|
|
168
|
-
def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=None, overwrite=True, yes=True):
|
178
|
+
def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=None, overwrite=True, yes=True, trafo=None):
|
169
179
|
compression = determine_compression("parquet", output_file, compression)
|
170
180
|
writer = None
|
171
181
|
part = 0
|
172
|
-
|
182
|
+
trafo = Trafo(trafo)
|
183
|
+
it = tqdm(it, desc="Writing to Parquet", unit="sample", disable=_NO_PROGESS)
|
173
184
|
for batch in _batch_iterable(it, batch_size):
|
185
|
+
batch = [trafo(sample) for sample in batch]
|
174
186
|
table = pa.Table.from_pylist(batch)
|
175
187
|
if writer is None:
|
176
188
|
part_file = output_file.format(part=part)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|