mldataforge 0.1.7__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mldataforge-0.1.7 → mldataforge-0.2.0}/PKG-INFO +1 -1
- {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/commands/convert/jsonl.py +6 -2
- {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/commands/convert/mds.py +6 -2
- {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/commands/convert/parquet.py +6 -2
- {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/commands/join.py +8 -3
- {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/commands/split.py +15 -3
- {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/options.py +12 -0
- mldataforge-0.2.0/mldataforge/trafos.py +111 -0
- {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/utils.py +10 -3
- {mldataforge-0.1.7 → mldataforge-0.2.0}/pyproject.toml +1 -1
- {mldataforge-0.1.7 → mldataforge-0.2.0}/.gitignore +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.0}/LICENSE +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.0}/README.md +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/__main__.py +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/brotli.py +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/commands/__init__.py +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/commands/convert/__init__.py +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/compression.py +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/mds.py +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/pigz.py +0 -0
- {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/snappy.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mldataforge
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
|
5
5
|
Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
|
6
6
|
Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
|
@@ -21,9 +21,10 @@ def jsonl():
|
|
21
21
|
@buf_size_option()
|
22
22
|
@shard_size_option()
|
23
23
|
@no_pigz_option()
|
24
|
+
@trafo_option()
|
24
25
|
def mds(**kwargs):
|
25
26
|
jsonl_to_mds(**kwargs)
|
26
|
-
def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz):
|
27
|
+
def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz, trafo):
|
27
28
|
check_arguments(output_dir, overwrite, yes, jsonl_files)
|
28
29
|
save_mds(
|
29
30
|
load_jsonl_files(jsonl_files),
|
@@ -33,6 +34,7 @@ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes
|
|
33
34
|
buf_size=buf_size,
|
34
35
|
pigz=use_pigz(compression, no_pigz),
|
35
36
|
shard_size=shard_size,
|
37
|
+
trafo=trafo,
|
36
38
|
)
|
37
39
|
|
38
40
|
@jsonl.command()
|
@@ -42,13 +44,15 @@ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes
|
|
42
44
|
@overwrite_option()
|
43
45
|
@yes_option()
|
44
46
|
@batch_size_option()
|
47
|
+
@trafo_option()
|
45
48
|
def parquet(**kwargs):
|
46
49
|
jsonl_to_parquet(**kwargs)
|
47
|
-
def jsonl_to_parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size):
|
50
|
+
def jsonl_to_parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size, trafo):
|
48
51
|
check_arguments(output_file, overwrite, yes, jsonl_files)
|
49
52
|
save_parquet(
|
50
53
|
load_jsonl_files(jsonl_files),
|
51
54
|
output_file,
|
52
55
|
compression=compression,
|
53
56
|
batch_size=batch_size,
|
57
|
+
trafo=trafo,
|
54
58
|
)
|
@@ -19,15 +19,17 @@ def mds():
|
|
19
19
|
@yes_option()
|
20
20
|
@batch_size_option()
|
21
21
|
@no_bulk_option()
|
22
|
+
@trafo_option()
|
22
23
|
def jsonl(**kwargs):
|
23
24
|
mds_to_jsonl(**kwargs)
|
24
|
-
def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk):
|
25
|
+
def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk, trafo):
|
25
26
|
check_arguments(output_file, overwrite, yes, mds_directories)
|
26
27
|
save_jsonl(
|
27
28
|
load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
|
28
29
|
output_file,
|
29
30
|
compression=compression,
|
30
31
|
processes=processes,
|
32
|
+
trafo=trafo,
|
31
33
|
)
|
32
34
|
|
33
35
|
@mds.command()
|
@@ -38,13 +40,15 @@ def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite
|
|
38
40
|
@yes_option()
|
39
41
|
@batch_size_option()
|
40
42
|
@no_bulk_option()
|
43
|
+
@trafo_option()
|
41
44
|
def parquet(**kwargs):
|
42
45
|
mds_to_parquet(**kwargs)
|
43
|
-
def mds_to_parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk):
|
46
|
+
def mds_to_parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk, trafo):
|
44
47
|
check_arguments(output_file, overwrite, yes, mds_directories)
|
45
48
|
save_parquet(
|
46
49
|
load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
|
47
50
|
output_file,
|
48
51
|
compression=compression,
|
49
52
|
batch_size=batch_size,
|
53
|
+
trafo=trafo,
|
50
54
|
)
|
@@ -18,15 +18,17 @@ def parquet():
|
|
18
18
|
@processes_option()
|
19
19
|
@overwrite_option()
|
20
20
|
@yes_option()
|
21
|
+
@trafo_option()
|
21
22
|
def jsonl(**kwargs):
|
22
23
|
parquet_to_jsonl(**kwargs)
|
23
|
-
def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwrite, yes):
|
24
|
+
def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwrite, yes, trafo):
|
24
25
|
check_arguments(output_file, overwrite, yes, parquet_files)
|
25
26
|
save_jsonl(
|
26
27
|
load_dataset("parquet", data_files=parquet_files, split="train"),
|
27
28
|
output_file,
|
28
29
|
compression=compression,
|
29
30
|
processes=processes,
|
31
|
+
trafo=trafo,
|
30
32
|
)
|
31
33
|
|
32
34
|
@parquet.command()
|
@@ -39,9 +41,10 @@ def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwri
|
|
39
41
|
@buf_size_option()
|
40
42
|
@shard_size_option()
|
41
43
|
@no_pigz_option()
|
44
|
+
@trafo_option()
|
42
45
|
def mds(**kwargs):
|
43
46
|
parquet_to_mds(**kwargs)
|
44
|
-
def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz):
|
47
|
+
def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz, trafo):
|
45
48
|
check_arguments(output_dir, overwrite, yes, parquet_files)
|
46
49
|
save_mds(
|
47
50
|
load_dataset("parquet", data_files=parquet_files, split="train"),
|
@@ -51,4 +54,5 @@ def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite,
|
|
51
54
|
buf_size=buf_size,
|
52
55
|
pigz=use_pigz(compression, no_pigz=no_pigz),
|
53
56
|
shard_size=shard_size,
|
57
|
+
trafo=trafo,
|
54
58
|
)
|
@@ -18,9 +18,10 @@ def join():
|
|
18
18
|
@processes_option()
|
19
19
|
@overwrite_option()
|
20
20
|
@yes_option()
|
21
|
+
@trafo_option()
|
21
22
|
def jsonl(**kwargs):
|
22
23
|
join_jsonl(**kwargs)
|
23
|
-
def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes):
|
24
|
+
def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes, trafo):
|
24
25
|
check_arguments(output_file, overwrite, yes, jsonl_files)
|
25
26
|
save_jsonl(
|
26
27
|
load_jsonl_files(jsonl_files),
|
@@ -41,10 +42,11 @@ def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes)
|
|
41
42
|
@no_bulk_option()
|
42
43
|
@shard_size_option()
|
43
44
|
@no_pigz_option()
|
45
|
+
@trafo_option()
|
44
46
|
def mds(**kwargs):
|
45
47
|
print(kwargs)
|
46
48
|
join_mds(**kwargs)
|
47
|
-
def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk, shard_size, no_pigz):
|
49
|
+
def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk, shard_size, no_pigz, trafo):
|
48
50
|
check_arguments(output_dir, overwrite, yes, mds_directories)
|
49
51
|
save_mds(
|
50
52
|
load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
|
@@ -54,6 +56,7 @@ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes
|
|
54
56
|
buf_size=buf_size,
|
55
57
|
shard_size=shard_size,
|
56
58
|
pigz=use_pigz(compression, no_pigz),
|
59
|
+
trafo=trafo,
|
57
60
|
)
|
58
61
|
|
59
62
|
@join.command()
|
@@ -63,13 +66,15 @@ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes
|
|
63
66
|
@overwrite_option()
|
64
67
|
@yes_option()
|
65
68
|
@batch_size_option()
|
69
|
+
@trafo_option()
|
66
70
|
def parquet(**kwargs):
|
67
71
|
join_parquet(**kwargs)
|
68
|
-
def join_parquet(output_file, parquet_files, compression, overwrite, yes, batch_size):
|
72
|
+
def join_parquet(output_file, parquet_files, compression, overwrite, yes, batch_size, trafo):
|
69
73
|
check_arguments(output_file, overwrite, yes, parquet_files)
|
70
74
|
save_parquet(
|
71
75
|
load_dataset("parquet", data_files=parquet_files, split="train"),
|
72
76
|
output_file,
|
73
77
|
compression=compression,
|
74
78
|
batch_size=batch_size,
|
79
|
+
trafo=trafo,
|
75
80
|
)
|
@@ -20,7 +20,10 @@ def split():
|
|
20
20
|
@processes_option()
|
21
21
|
@overwrite_option()
|
22
22
|
@yes_option()
|
23
|
-
|
23
|
+
@trafo_option()
|
24
|
+
def jsonl(*args, **kwargs):
|
25
|
+
split_jsonl(*args, **kwargs)
|
26
|
+
def split_jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, overwrite, yes, trafo):
|
24
27
|
save_jsonl(
|
25
28
|
load_jsonl_files(jsonl_files),
|
26
29
|
output_file=f"{output_dir}/{prefix}{{part:04d}}.jsonl{extension_compression(compression, jsonl_files[0])}",
|
@@ -29,6 +32,7 @@ def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, ov
|
|
29
32
|
size_hint=size_hint,
|
30
33
|
overwrite=overwrite,
|
31
34
|
yes=yes,
|
35
|
+
trafo=trafo,
|
32
36
|
)
|
33
37
|
|
34
38
|
@split.command()
|
@@ -45,7 +49,10 @@ def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, ov
|
|
45
49
|
@no_bulk_option()
|
46
50
|
@shard_size_option()
|
47
51
|
@no_pigz_option()
|
48
|
-
|
52
|
+
@trafo_option()
|
53
|
+
def mds(*args, **kwargs):
|
54
|
+
split_mds(*args, **kwargs)
|
55
|
+
def split_mds(mds_directories, prefix, output_dir, size_hint, compression, processes, overwrite, yes, buf_size, batch_size, no_bulk, shard_size, no_pigz, trafo):
|
49
56
|
save_mds(
|
50
57
|
load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
|
51
58
|
output_dir=f"{output_dir}/{prefix}{{part:04d}}",
|
@@ -57,6 +64,7 @@ def mds(mds_directories, prefix, output_dir, size_hint, compression, processes,
|
|
57
64
|
size_hint=size_hint,
|
58
65
|
overwrite=overwrite,
|
59
66
|
yes=yes,
|
67
|
+
trafo=trafo,
|
60
68
|
)
|
61
69
|
|
62
70
|
@split.command()
|
@@ -68,7 +76,10 @@ def mds(mds_directories, prefix, output_dir, size_hint, compression, processes,
|
|
68
76
|
@overwrite_option()
|
69
77
|
@yes_option()
|
70
78
|
@batch_size_option()
|
71
|
-
|
79
|
+
@trafo_option()
|
80
|
+
def parquet(*args, **kwargs):
|
81
|
+
split_parquet(*args, **kwargs)
|
82
|
+
def split_parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite, yes, batch_size, trafo):
|
72
83
|
save_parquet(
|
73
84
|
load_dataset("parquet", data_files=parquet_files, split="train"),
|
74
85
|
output_file=f"{output_dir}/{prefix}{{part:04d}}.parquet",
|
@@ -77,4 +88,5 @@ def parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite
|
|
77
88
|
size_hint=size_hint,
|
78
89
|
overwrite=overwrite,
|
79
90
|
yes=yes,
|
91
|
+
trafo=trafo,
|
80
92
|
)
|
@@ -14,6 +14,7 @@ __all__ = [
|
|
14
14
|
"prefix_option",
|
15
15
|
"shard_size_option",
|
16
16
|
"size_hint_option",
|
17
|
+
"trafo_option",
|
17
18
|
"yes_option",
|
18
19
|
]
|
19
20
|
|
@@ -129,6 +130,17 @@ def size_hint_option(default=2**26):
|
|
129
130
|
help=f"Size hint for the dataset (default: {default}).",
|
130
131
|
)
|
131
132
|
|
133
|
+
def trafo_option():
|
134
|
+
"""
|
135
|
+
Option for specifying the transformation function.
|
136
|
+
"""
|
137
|
+
return click.option(
|
138
|
+
"--trafo",
|
139
|
+
default=None,
|
140
|
+
type=str,
|
141
|
+
help="Transformation function to apply to the dataset.",
|
142
|
+
)
|
143
|
+
|
132
144
|
def yes_option():
|
133
145
|
"""
|
134
146
|
Option for specifying whether to assume yes to all prompts.
|
@@ -0,0 +1,111 @@
|
|
1
|
+
import re
|
2
|
+
from typing import Callable
|
3
|
+
|
4
|
+
__all__ = ['Trafo', 'flatten_json', 'unflatten_json']
|
5
|
+
|
6
|
+
class Trafo:
|
7
|
+
"""
|
8
|
+
Base class for transformations.
|
9
|
+
"""
|
10
|
+
|
11
|
+
def __init__(self, trafo: Callable | str | None):
|
12
|
+
self.trafo = trafo
|
13
|
+
if isinstance(trafo, str):
|
14
|
+
self.trafo = eval(trafo)
|
15
|
+
|
16
|
+
def __call__(self, obj):
|
17
|
+
return self.trafo(obj) if self.trafo else obj
|
18
|
+
|
19
|
+
def __repr__(self):
|
20
|
+
return f"{self.__class__.__name__}({self.trafo})"
|
21
|
+
|
22
|
+
|
23
|
+
def flatten_json(obj, parent_key='', sep='.', escape_char='\\'):
|
24
|
+
items = []
|
25
|
+
|
26
|
+
def escape(key):
|
27
|
+
return key.replace(escape_char, escape_char * 2)\
|
28
|
+
.replace(sep, escape_char + sep)\
|
29
|
+
.replace('[', escape_char + '[')\
|
30
|
+
.replace(']', escape_char + ']')
|
31
|
+
|
32
|
+
if isinstance(obj, dict):
|
33
|
+
if not obj:
|
34
|
+
# explicitly handle empty dict
|
35
|
+
items.append((parent_key, {}))
|
36
|
+
else:
|
37
|
+
for k, v in obj.items():
|
38
|
+
new_key = f"{parent_key}{sep}{escape(k)}" if parent_key else escape(k)
|
39
|
+
items.extend(flatten_json(v, new_key, sep, escape_char).items())
|
40
|
+
elif isinstance(obj, list):
|
41
|
+
if not obj:
|
42
|
+
# explicitly handle empty list
|
43
|
+
items.append((parent_key, []))
|
44
|
+
else:
|
45
|
+
for idx, v in enumerate(obj):
|
46
|
+
new_key = f"{parent_key}[{idx}]"
|
47
|
+
items.extend(flatten_json(v, new_key, sep, escape_char).items())
|
48
|
+
else:
|
49
|
+
items.append((parent_key, obj))
|
50
|
+
return dict(items)
|
51
|
+
|
52
|
+
|
53
|
+
def unflatten_json(flat_dict, sep='.', escape_char='\\'):
|
54
|
+
|
55
|
+
def check_flat_json(obj):
|
56
|
+
assert isinstance(obj, dict), "Input must be a dictionary"
|
57
|
+
for k, v in obj.items():
|
58
|
+
assert isinstance(k, str), f"Key {k} is not a string"
|
59
|
+
assert isinstance(v, (str, int, float, bool)), f"Value {v} is not a valid JSON type"
|
60
|
+
|
61
|
+
def parse_key(key):
|
62
|
+
tokens = re.findall(r'(?:[^.\[\]\\]|\\.)+|\[\d+\]', key)
|
63
|
+
parsed = []
|
64
|
+
for token in tokens:
|
65
|
+
if token.startswith('['):
|
66
|
+
parsed.append(int(token[1:-1]))
|
67
|
+
else:
|
68
|
+
parsed.append(token.replace(escape_char + sep, sep)
|
69
|
+
.replace(escape_char + '[', '[')
|
70
|
+
.replace(escape_char + ']', ']')
|
71
|
+
.replace(escape_char*2, escape_char))
|
72
|
+
return parsed
|
73
|
+
|
74
|
+
check_flat_json(flat_dict)
|
75
|
+
|
76
|
+
result = {}
|
77
|
+
|
78
|
+
for compound_key, value in flat_dict.items():
|
79
|
+
keys = parse_key(compound_key)
|
80
|
+
current = result
|
81
|
+
for idx, key in enumerate(keys):
|
82
|
+
if idx == len(keys) - 1:
|
83
|
+
if isinstance(key, int):
|
84
|
+
if not isinstance(current, list):
|
85
|
+
current_parent[last_key] = []
|
86
|
+
current = current_parent[last_key]
|
87
|
+
while len(current) <= key:
|
88
|
+
current.append(None)
|
89
|
+
current[key] = value
|
90
|
+
else:
|
91
|
+
current[key] = value
|
92
|
+
else:
|
93
|
+
next_key = keys[idx + 1]
|
94
|
+
if isinstance(key, int):
|
95
|
+
if not isinstance(current, list):
|
96
|
+
current_parent[last_key] = []
|
97
|
+
current = current_parent[last_key]
|
98
|
+
while len(current) <= key:
|
99
|
+
current.append(None)
|
100
|
+
if current[key] is None:
|
101
|
+
current[key] = [] if isinstance(next_key, int) else {}
|
102
|
+
current_parent = current
|
103
|
+
current = current[key]
|
104
|
+
else:
|
105
|
+
if key not in current:
|
106
|
+
current[key] = [] if isinstance(next_key, int) else {}
|
107
|
+
current_parent = current
|
108
|
+
current = current[key]
|
109
|
+
last_key = key
|
110
|
+
|
111
|
+
return result
|
@@ -12,6 +12,7 @@ from tqdm import tqdm
|
|
12
12
|
from .compression import determine_compression, open_compression, pigz_compress
|
13
13
|
from .mds import MDSBulkReader, MDSWriter
|
14
14
|
from .pigz import pigz_open
|
15
|
+
from .trafos import Trafo
|
15
16
|
|
16
17
|
__all__ = [
|
17
18
|
"check_arguments",
|
@@ -111,10 +112,12 @@ def load_mds_directories(mds_directories, split='.', batch_size=2**16, bulk=True
|
|
111
112
|
ds = concatenate_datasets(dsets=dss)
|
112
113
|
return ds
|
113
114
|
|
114
|
-
def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True):
|
115
|
+
def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True, trafo=None):
|
115
116
|
f = None
|
116
117
|
part = 0
|
118
|
+
trafo = Trafo(trafo)
|
117
119
|
for item in tqdm(iterable, desc="Writing to JSONL", unit="sample", disable=_NO_PROGESS):
|
120
|
+
item = trafo(item)
|
118
121
|
if f is None:
|
119
122
|
part_file = output_file.format(part=part)
|
120
123
|
check_arguments(part_file, overwrite, yes)
|
@@ -127,12 +130,14 @@ def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=
|
|
127
130
|
if f is not None:
|
128
131
|
f.close()
|
129
132
|
|
130
|
-
def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pigz=True, shard_size=None, size_hint=None, overwrite=True, yes=True):
|
133
|
+
def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pigz=True, shard_size=None, size_hint=None, overwrite=True, yes=True, trafo=None):
|
131
134
|
compression = determine_compression("mds", output_dir, compression, no_pigz=not pigz)
|
132
135
|
writer = None
|
133
136
|
part = 0
|
134
137
|
files = []
|
138
|
+
trafo = Trafo(trafo)
|
135
139
|
for sample in tqdm(it, desc="Writing to MDS", unit="sample", disable=_NO_PROGESS):
|
140
|
+
sample = trafo(sample)
|
136
141
|
if writer is None:
|
137
142
|
part_dir = output_dir.format(part=part)
|
138
143
|
check_arguments(part_dir, overwrite, yes)
|
@@ -170,12 +175,14 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
|
|
170
175
|
json.dump(index, open(index_path, "wt"))
|
171
176
|
print(f"Compressed {output_dir} with pigz")
|
172
177
|
|
173
|
-
def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=None, overwrite=True, yes=True):
|
178
|
+
def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=None, overwrite=True, yes=True, trafo=None):
|
174
179
|
compression = determine_compression("parquet", output_file, compression)
|
175
180
|
writer = None
|
176
181
|
part = 0
|
182
|
+
trafo = Trafo(trafo)
|
177
183
|
it = tqdm(it, desc="Writing to Parquet", unit="sample", disable=_NO_PROGESS)
|
178
184
|
for batch in _batch_iterable(it, batch_size):
|
185
|
+
batch = [trafo(sample) for sample in batch]
|
179
186
|
table = pa.Table.from_pylist(batch)
|
180
187
|
if writer is None:
|
181
188
|
part_file = output_file.format(part=part)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|