mldataforge 0.1.7__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mldataforge-0.1.7 → mldataforge-0.2.1}/PKG-INFO +1 -1
- {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/commands/convert/jsonl.py +6 -2
- {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/commands/convert/mds.py +6 -2
- {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/commands/convert/parquet.py +6 -2
- {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/commands/join.py +8 -3
- {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/commands/split.py +15 -3
- {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/options.py +12 -0
- mldataforge-0.2.1/mldataforge/trafos.py +165 -0
- {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/utils.py +10 -6
- {mldataforge-0.1.7 → mldataforge-0.2.1}/pyproject.toml +1 -1
- {mldataforge-0.1.7 → mldataforge-0.2.1}/.gitignore +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.1}/LICENSE +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.1}/README.md +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/__main__.py +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/brotli.py +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/commands/__init__.py +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/commands/convert/__init__.py +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/compression.py +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/mds.py +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/pigz.py +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/snappy.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mldataforge
|
3
|
-
Version: 0.1
|
3
|
+
Version: 0.2.1
|
4
4
|
Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
|
5
5
|
Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
|
6
6
|
Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
|
@@ -21,9 +21,10 @@ def jsonl():
|
|
21
21
|
@buf_size_option()
|
22
22
|
@shard_size_option()
|
23
23
|
@no_pigz_option()
|
24
|
+
@trafo_option()
|
24
25
|
def mds(**kwargs):
|
25
26
|
jsonl_to_mds(**kwargs)
|
26
|
-
def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz):
|
27
|
+
def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz, trafo):
|
27
28
|
check_arguments(output_dir, overwrite, yes, jsonl_files)
|
28
29
|
save_mds(
|
29
30
|
load_jsonl_files(jsonl_files),
|
@@ -33,6 +34,7 @@ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes
|
|
33
34
|
buf_size=buf_size,
|
34
35
|
pigz=use_pigz(compression, no_pigz),
|
35
36
|
shard_size=shard_size,
|
37
|
+
trafo=trafo,
|
36
38
|
)
|
37
39
|
|
38
40
|
@jsonl.command()
|
@@ -42,13 +44,15 @@ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes
|
|
42
44
|
@overwrite_option()
|
43
45
|
@yes_option()
|
44
46
|
@batch_size_option()
|
47
|
+
@trafo_option()
|
45
48
|
def parquet(**kwargs):
|
46
49
|
jsonl_to_parquet(**kwargs)
|
47
|
-
def jsonl_to_parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size):
|
50
|
+
def jsonl_to_parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size, trafo):
|
48
51
|
check_arguments(output_file, overwrite, yes, jsonl_files)
|
49
52
|
save_parquet(
|
50
53
|
load_jsonl_files(jsonl_files),
|
51
54
|
output_file,
|
52
55
|
compression=compression,
|
53
56
|
batch_size=batch_size,
|
57
|
+
trafo=trafo,
|
54
58
|
)
|
@@ -19,15 +19,17 @@ def mds():
|
|
19
19
|
@yes_option()
|
20
20
|
@batch_size_option()
|
21
21
|
@no_bulk_option()
|
22
|
+
@trafo_option()
|
22
23
|
def jsonl(**kwargs):
|
23
24
|
mds_to_jsonl(**kwargs)
|
24
|
-
def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk):
|
25
|
+
def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk, trafo):
|
25
26
|
check_arguments(output_file, overwrite, yes, mds_directories)
|
26
27
|
save_jsonl(
|
27
28
|
load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
|
28
29
|
output_file,
|
29
30
|
compression=compression,
|
30
31
|
processes=processes,
|
32
|
+
trafo=trafo,
|
31
33
|
)
|
32
34
|
|
33
35
|
@mds.command()
|
@@ -38,13 +40,15 @@ def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite
|
|
38
40
|
@yes_option()
|
39
41
|
@batch_size_option()
|
40
42
|
@no_bulk_option()
|
43
|
+
@trafo_option()
|
41
44
|
def parquet(**kwargs):
|
42
45
|
mds_to_parquet(**kwargs)
|
43
|
-
def mds_to_parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk):
|
46
|
+
def mds_to_parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk, trafo):
|
44
47
|
check_arguments(output_file, overwrite, yes, mds_directories)
|
45
48
|
save_parquet(
|
46
49
|
load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
|
47
50
|
output_file,
|
48
51
|
compression=compression,
|
49
52
|
batch_size=batch_size,
|
53
|
+
trafo=trafo,
|
50
54
|
)
|
@@ -18,15 +18,17 @@ def parquet():
|
|
18
18
|
@processes_option()
|
19
19
|
@overwrite_option()
|
20
20
|
@yes_option()
|
21
|
+
@trafo_option()
|
21
22
|
def jsonl(**kwargs):
|
22
23
|
parquet_to_jsonl(**kwargs)
|
23
|
-
def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwrite, yes):
|
24
|
+
def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwrite, yes, trafo):
|
24
25
|
check_arguments(output_file, overwrite, yes, parquet_files)
|
25
26
|
save_jsonl(
|
26
27
|
load_dataset("parquet", data_files=parquet_files, split="train"),
|
27
28
|
output_file,
|
28
29
|
compression=compression,
|
29
30
|
processes=processes,
|
31
|
+
trafo=trafo,
|
30
32
|
)
|
31
33
|
|
32
34
|
@parquet.command()
|
@@ -39,9 +41,10 @@ def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwri
|
|
39
41
|
@buf_size_option()
|
40
42
|
@shard_size_option()
|
41
43
|
@no_pigz_option()
|
44
|
+
@trafo_option()
|
42
45
|
def mds(**kwargs):
|
43
46
|
parquet_to_mds(**kwargs)
|
44
|
-
def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz):
|
47
|
+
def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz, trafo):
|
45
48
|
check_arguments(output_dir, overwrite, yes, parquet_files)
|
46
49
|
save_mds(
|
47
50
|
load_dataset("parquet", data_files=parquet_files, split="train"),
|
@@ -51,4 +54,5 @@ def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite,
|
|
51
54
|
buf_size=buf_size,
|
52
55
|
pigz=use_pigz(compression, no_pigz=no_pigz),
|
53
56
|
shard_size=shard_size,
|
57
|
+
trafo=trafo,
|
54
58
|
)
|
@@ -18,9 +18,10 @@ def join():
|
|
18
18
|
@processes_option()
|
19
19
|
@overwrite_option()
|
20
20
|
@yes_option()
|
21
|
+
@trafo_option()
|
21
22
|
def jsonl(**kwargs):
|
22
23
|
join_jsonl(**kwargs)
|
23
|
-
def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes):
|
24
|
+
def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes, trafo):
|
24
25
|
check_arguments(output_file, overwrite, yes, jsonl_files)
|
25
26
|
save_jsonl(
|
26
27
|
load_jsonl_files(jsonl_files),
|
@@ -41,10 +42,11 @@ def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes)
|
|
41
42
|
@no_bulk_option()
|
42
43
|
@shard_size_option()
|
43
44
|
@no_pigz_option()
|
45
|
+
@trafo_option()
|
44
46
|
def mds(**kwargs):
|
45
47
|
print(kwargs)
|
46
48
|
join_mds(**kwargs)
|
47
|
-
def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk, shard_size, no_pigz):
|
49
|
+
def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk, shard_size, no_pigz, trafo):
|
48
50
|
check_arguments(output_dir, overwrite, yes, mds_directories)
|
49
51
|
save_mds(
|
50
52
|
load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
|
@@ -54,6 +56,7 @@ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes
|
|
54
56
|
buf_size=buf_size,
|
55
57
|
shard_size=shard_size,
|
56
58
|
pigz=use_pigz(compression, no_pigz),
|
59
|
+
trafo=trafo,
|
57
60
|
)
|
58
61
|
|
59
62
|
@join.command()
|
@@ -63,13 +66,15 @@ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes
|
|
63
66
|
@overwrite_option()
|
64
67
|
@yes_option()
|
65
68
|
@batch_size_option()
|
69
|
+
@trafo_option()
|
66
70
|
def parquet(**kwargs):
|
67
71
|
join_parquet(**kwargs)
|
68
|
-
def join_parquet(output_file, parquet_files, compression, overwrite, yes, batch_size):
|
72
|
+
def join_parquet(output_file, parquet_files, compression, overwrite, yes, batch_size, trafo):
|
69
73
|
check_arguments(output_file, overwrite, yes, parquet_files)
|
70
74
|
save_parquet(
|
71
75
|
load_dataset("parquet", data_files=parquet_files, split="train"),
|
72
76
|
output_file,
|
73
77
|
compression=compression,
|
74
78
|
batch_size=batch_size,
|
79
|
+
trafo=trafo,
|
75
80
|
)
|
@@ -20,7 +20,10 @@ def split():
|
|
20
20
|
@processes_option()
|
21
21
|
@overwrite_option()
|
22
22
|
@yes_option()
|
23
|
-
|
23
|
+
@trafo_option()
|
24
|
+
def jsonl(*args, **kwargs):
|
25
|
+
split_jsonl(*args, **kwargs)
|
26
|
+
def split_jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, overwrite, yes, trafo):
|
24
27
|
save_jsonl(
|
25
28
|
load_jsonl_files(jsonl_files),
|
26
29
|
output_file=f"{output_dir}/{prefix}{{part:04d}}.jsonl{extension_compression(compression, jsonl_files[0])}",
|
@@ -29,6 +32,7 @@ def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, ov
|
|
29
32
|
size_hint=size_hint,
|
30
33
|
overwrite=overwrite,
|
31
34
|
yes=yes,
|
35
|
+
trafo=trafo,
|
32
36
|
)
|
33
37
|
|
34
38
|
@split.command()
|
@@ -45,7 +49,10 @@ def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, ov
|
|
45
49
|
@no_bulk_option()
|
46
50
|
@shard_size_option()
|
47
51
|
@no_pigz_option()
|
48
|
-
|
52
|
+
@trafo_option()
|
53
|
+
def mds(*args, **kwargs):
|
54
|
+
split_mds(*args, **kwargs)
|
55
|
+
def split_mds(mds_directories, prefix, output_dir, size_hint, compression, processes, overwrite, yes, buf_size, batch_size, no_bulk, shard_size, no_pigz, trafo):
|
49
56
|
save_mds(
|
50
57
|
load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
|
51
58
|
output_dir=f"{output_dir}/{prefix}{{part:04d}}",
|
@@ -57,6 +64,7 @@ def mds(mds_directories, prefix, output_dir, size_hint, compression, processes,
|
|
57
64
|
size_hint=size_hint,
|
58
65
|
overwrite=overwrite,
|
59
66
|
yes=yes,
|
67
|
+
trafo=trafo,
|
60
68
|
)
|
61
69
|
|
62
70
|
@split.command()
|
@@ -68,7 +76,10 @@ def mds(mds_directories, prefix, output_dir, size_hint, compression, processes,
|
|
68
76
|
@overwrite_option()
|
69
77
|
@yes_option()
|
70
78
|
@batch_size_option()
|
71
|
-
|
79
|
+
@trafo_option()
|
80
|
+
def parquet(*args, **kwargs):
|
81
|
+
split_parquet(*args, **kwargs)
|
82
|
+
def split_parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite, yes, batch_size, trafo):
|
72
83
|
save_parquet(
|
73
84
|
load_dataset("parquet", data_files=parquet_files, split="train"),
|
74
85
|
output_file=f"{output_dir}/{prefix}{{part:04d}}.parquet",
|
@@ -77,4 +88,5 @@ def parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite
|
|
77
88
|
size_hint=size_hint,
|
78
89
|
overwrite=overwrite,
|
79
90
|
yes=yes,
|
91
|
+
trafo=trafo,
|
80
92
|
)
|
@@ -14,6 +14,7 @@ __all__ = [
|
|
14
14
|
"prefix_option",
|
15
15
|
"shard_size_option",
|
16
16
|
"size_hint_option",
|
17
|
+
"trafo_option",
|
17
18
|
"yes_option",
|
18
19
|
]
|
19
20
|
|
@@ -129,6 +130,17 @@ def size_hint_option(default=2**26):
|
|
129
130
|
help=f"Size hint for the dataset (default: {default}).",
|
130
131
|
)
|
131
132
|
|
133
|
+
def trafo_option():
|
134
|
+
"""
|
135
|
+
Option for specifying the transformation function.
|
136
|
+
"""
|
137
|
+
return click.option(
|
138
|
+
"--trafo",
|
139
|
+
default=None,
|
140
|
+
type=str,
|
141
|
+
help="Transformation function to apply to the dataset.",
|
142
|
+
)
|
143
|
+
|
132
144
|
def yes_option():
|
133
145
|
"""
|
134
146
|
Option for specifying whether to assume yes to all prompts.
|
@@ -0,0 +1,165 @@
|
|
1
|
+
import re
|
2
|
+
|
3
|
+
__all__ = ['IndexedDatasetView', 'Transformation', 'Transformations', 'flatten_json', 'unflatten_json']
|
4
|
+
|
5
|
+
class IndexedDatasetView:
|
6
|
+
def __init__(self, dataset, indices):
|
7
|
+
self.dataset = dataset
|
8
|
+
self.indices = list(indices) # ensure repeatable access
|
9
|
+
|
10
|
+
def __iter__(self):
|
11
|
+
for idx in self.indices:
|
12
|
+
yield self.dataset[idx]
|
13
|
+
|
14
|
+
def __len__(self):
|
15
|
+
return len(self.indices)
|
16
|
+
|
17
|
+
class Transformation:
|
18
|
+
def __init__(self, code: str):
|
19
|
+
self.code = code
|
20
|
+
self._init_context()
|
21
|
+
|
22
|
+
def _init_context(self):
|
23
|
+
self.global_context = {}
|
24
|
+
exec(self.code, self.global_context)
|
25
|
+
if 'process' not in self.global_context or not callable(self.global_context['process']):
|
26
|
+
raise ValueError("code must define a callable named 'process'")
|
27
|
+
self.process = self.global_context['process']
|
28
|
+
self._flushable = hasattr(self.process, 'flushable') and self.process.flushable
|
29
|
+
|
30
|
+
def _normalize_outputs(self, result):
|
31
|
+
if result is None:
|
32
|
+
return []
|
33
|
+
if isinstance(result, (list, tuple, set)):
|
34
|
+
return list(result)
|
35
|
+
return [result]
|
36
|
+
|
37
|
+
def _flush(self):
|
38
|
+
if self._flushable:
|
39
|
+
while True:
|
40
|
+
flushed = self._normalize_outputs(self.process(None))
|
41
|
+
if not flushed:
|
42
|
+
return
|
43
|
+
yield from flushed
|
44
|
+
|
45
|
+
def __call__(self, iterable):
|
46
|
+
for sample in iterable:
|
47
|
+
results = self._normalize_outputs(self.process(sample))
|
48
|
+
yield from results
|
49
|
+
if not results:
|
50
|
+
yield from self._flush()
|
51
|
+
if self._flushable:
|
52
|
+
yield from self._flush()
|
53
|
+
|
54
|
+
def __len__(self):
|
55
|
+
if self._last_input_len is not None:
|
56
|
+
return self._last_input_len
|
57
|
+
raise TypeError("Length is not available for this transformation.")
|
58
|
+
|
59
|
+
|
60
|
+
class Transformations:
|
61
|
+
def __init__(self, codes: list[str], indices=None):
|
62
|
+
self.pipeline = [Transformation(code) for code in codes]
|
63
|
+
self.indices = indices # Optional index iterable
|
64
|
+
|
65
|
+
def __call__(self, dataset):
|
66
|
+
# Wrap dataset with IndexedDatasetView if indices are provided
|
67
|
+
if self.indices is not None:
|
68
|
+
dataset = IndexedDatasetView(dataset, self.indices)
|
69
|
+
|
70
|
+
result = dataset
|
71
|
+
for transform in self.pipeline:
|
72
|
+
result = transform(result)
|
73
|
+
return result
|
74
|
+
|
75
|
+
def __len__(self):
|
76
|
+
# Return the input length to the pipeline
|
77
|
+
if self.indices is not None:
|
78
|
+
return len(self.indices)
|
79
|
+
elif hasattr(self.pipeline[0], '_last_input_len') and self.pipeline[0]._last_input_len is not None:
|
80
|
+
return self.pipeline[0]._last_input_len
|
81
|
+
raise TypeError("Transformations length is not available until __call__ is used on a sized input.")
|
82
|
+
|
83
|
+
def identity(obj):
|
84
|
+
return obj
|
85
|
+
|
86
|
+
def flatten_json(obj, parent_key='', sep='.', escape_char='\\'):
|
87
|
+
def escape(key):
|
88
|
+
return key.replace(escape_char, escape_char * 2)\
|
89
|
+
.replace(sep, escape_char + sep)\
|
90
|
+
.replace('[', escape_char + '[')\
|
91
|
+
.replace(']', escape_char + ']')
|
92
|
+
items = []
|
93
|
+
if isinstance(obj, dict):
|
94
|
+
if not obj:
|
95
|
+
# explicitly handle empty dict
|
96
|
+
items.append((parent_key, {}))
|
97
|
+
else:
|
98
|
+
for k, v in obj.items():
|
99
|
+
new_key = f"{parent_key}{sep}{escape(k)}" if parent_key else escape(k)
|
100
|
+
items.extend(flatten_json(v, new_key, sep, escape_char).items())
|
101
|
+
elif isinstance(obj, list):
|
102
|
+
if not obj:
|
103
|
+
# explicitly handle empty list
|
104
|
+
items.append((parent_key, []))
|
105
|
+
else:
|
106
|
+
for idx, v in enumerate(obj):
|
107
|
+
new_key = f"{parent_key}[{idx}]"
|
108
|
+
items.extend(flatten_json(v, new_key, sep, escape_char).items())
|
109
|
+
else:
|
110
|
+
items.append((parent_key, obj))
|
111
|
+
return dict(items)
|
112
|
+
|
113
|
+
def unflatten_json(flat_dict, sep='.', escape_char='\\'):
|
114
|
+
def check_flat_json(obj):
|
115
|
+
assert isinstance(obj, dict), "Input must be a dictionary"
|
116
|
+
for k, v in obj.items():
|
117
|
+
assert isinstance(k, str), f"Key {k} is not a string"
|
118
|
+
assert isinstance(v, (str, int, float, bool)), f"Value {v} is not a valid JSON type"
|
119
|
+
def parse_key(key):
|
120
|
+
tokens = re.findall(r'(?:[^.\[\]\\]|\\.)+|\[\d+\]', key)
|
121
|
+
parsed = []
|
122
|
+
for token in tokens:
|
123
|
+
if token.startswith('['):
|
124
|
+
parsed.append(int(token[1:-1]))
|
125
|
+
else:
|
126
|
+
parsed.append(token.replace(escape_char + sep, sep)
|
127
|
+
.replace(escape_char + '[', '[')
|
128
|
+
.replace(escape_char + ']', ']')
|
129
|
+
.replace(escape_char*2, escape_char))
|
130
|
+
return parsed
|
131
|
+
check_flat_json(flat_dict)
|
132
|
+
result = {}
|
133
|
+
for compound_key, value in flat_dict.items():
|
134
|
+
keys = parse_key(compound_key)
|
135
|
+
current = result
|
136
|
+
for idx, key in enumerate(keys):
|
137
|
+
if idx == len(keys) - 1:
|
138
|
+
if isinstance(key, int):
|
139
|
+
if not isinstance(current, list):
|
140
|
+
current_parent[last_key] = []
|
141
|
+
current = current_parent[last_key]
|
142
|
+
while len(current) <= key:
|
143
|
+
current.append(None)
|
144
|
+
current[key] = value
|
145
|
+
else:
|
146
|
+
current[key] = value
|
147
|
+
else:
|
148
|
+
next_key = keys[idx + 1]
|
149
|
+
if isinstance(key, int):
|
150
|
+
if not isinstance(current, list):
|
151
|
+
current_parent[last_key] = []
|
152
|
+
current = current_parent[last_key]
|
153
|
+
while len(current) <= key:
|
154
|
+
current.append(None)
|
155
|
+
if current[key] is None:
|
156
|
+
current[key] = [] if isinstance(next_key, int) else {}
|
157
|
+
current_parent = current
|
158
|
+
current = current[key]
|
159
|
+
else:
|
160
|
+
if key not in current:
|
161
|
+
current[key] = [] if isinstance(next_key, int) else {}
|
162
|
+
current_parent = current
|
163
|
+
current = current[key]
|
164
|
+
last_key = key
|
165
|
+
return result
|
@@ -12,6 +12,7 @@ from tqdm import tqdm
|
|
12
12
|
from .compression import determine_compression, open_compression, pigz_compress
|
13
13
|
from .mds import MDSBulkReader, MDSWriter
|
14
14
|
from .pigz import pigz_open
|
15
|
+
from .trafos import Transformations
|
15
16
|
|
16
17
|
__all__ = [
|
17
18
|
"check_arguments",
|
@@ -111,10 +112,11 @@ def load_mds_directories(mds_directories, split='.', batch_size=2**16, bulk=True
|
|
111
112
|
ds = concatenate_datasets(dsets=dss)
|
112
113
|
return ds
|
113
114
|
|
114
|
-
def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True):
|
115
|
+
def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True, trafo=None):
|
115
116
|
f = None
|
116
117
|
part = 0
|
117
|
-
|
118
|
+
trafo = Transformations([] if trafo is None else [trafo])
|
119
|
+
for item in tqdm(trafo(iterable), desc="Writing to JSONL", unit="sample", disable=_NO_PROGESS):
|
118
120
|
if f is None:
|
119
121
|
part_file = output_file.format(part=part)
|
120
122
|
check_arguments(part_file, overwrite, yes)
|
@@ -127,12 +129,13 @@ def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=
|
|
127
129
|
if f is not None:
|
128
130
|
f.close()
|
129
131
|
|
130
|
-
def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pigz=True, shard_size=None, size_hint=None, overwrite=True, yes=True):
|
132
|
+
def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pigz=True, shard_size=None, size_hint=None, overwrite=True, yes=True, trafo=None):
|
131
133
|
compression = determine_compression("mds", output_dir, compression, no_pigz=not pigz)
|
132
134
|
writer = None
|
133
135
|
part = 0
|
134
136
|
files = []
|
135
|
-
|
137
|
+
trafo = Transformations([] if trafo is None else [trafo])
|
138
|
+
for sample in tqdm(trafo(it), desc="Writing to MDS", unit="sample", disable=_NO_PROGESS):
|
136
139
|
if writer is None:
|
137
140
|
part_dir = output_dir.format(part=part)
|
138
141
|
check_arguments(part_dir, overwrite, yes)
|
@@ -170,12 +173,13 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
|
|
170
173
|
json.dump(index, open(index_path, "wt"))
|
171
174
|
print(f"Compressed {output_dir} with pigz")
|
172
175
|
|
173
|
-
def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=None, overwrite=True, yes=True):
|
176
|
+
def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=None, overwrite=True, yes=True, trafo=None):
|
174
177
|
compression = determine_compression("parquet", output_file, compression)
|
175
178
|
writer = None
|
176
179
|
part = 0
|
180
|
+
trafo = Transformations([] if trafo is None else [trafo])
|
177
181
|
it = tqdm(it, desc="Writing to Parquet", unit="sample", disable=_NO_PROGESS)
|
178
|
-
for batch in _batch_iterable(it, batch_size):
|
182
|
+
for batch in _batch_iterable(trafo(it), batch_size):
|
179
183
|
table = pa.Table.from_pylist(batch)
|
180
184
|
if writer is None:
|
181
185
|
part_file = output_file.format(part=part)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|