mldataforge 0.1.7__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {mldataforge-0.1.7 → mldataforge-0.2.0}/PKG-INFO +1 -1
  2. {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/commands/convert/jsonl.py +6 -2
  3. {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/commands/convert/mds.py +6 -2
  4. {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/commands/convert/parquet.py +6 -2
  5. {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/commands/join.py +8 -3
  6. {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/commands/split.py +15 -3
  7. {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/options.py +12 -0
  8. mldataforge-0.2.0/mldataforge/trafos.py +111 -0
  9. {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/utils.py +10 -3
  10. {mldataforge-0.1.7 → mldataforge-0.2.0}/pyproject.toml +1 -1
  11. {mldataforge-0.1.7 → mldataforge-0.2.0}/.gitignore +0 -0
  12. {mldataforge-0.1.7 → mldataforge-0.2.0}/LICENSE +0 -0
  13. {mldataforge-0.1.7 → mldataforge-0.2.0}/README.md +0 -0
  14. {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/__main__.py +0 -0
  15. {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/brotli.py +0 -0
  16. {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/commands/__init__.py +0 -0
  17. {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/commands/convert/__init__.py +0 -0
  18. {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/compression.py +0 -0
  19. {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/mds.py +0 -0
  20. {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/pigz.py +0 -0
  21. {mldataforge-0.1.7 → mldataforge-0.2.0}/mldataforge/snappy.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mldataforge
3
- Version: 0.1.7
3
+ Version: 0.2.0
4
4
  Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
5
5
  Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
6
6
  Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
@@ -21,9 +21,10 @@ def jsonl():
21
21
  @buf_size_option()
22
22
  @shard_size_option()
23
23
  @no_pigz_option()
24
+ @trafo_option()
24
25
  def mds(**kwargs):
25
26
  jsonl_to_mds(**kwargs)
26
- def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz):
27
+ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz, trafo):
27
28
  check_arguments(output_dir, overwrite, yes, jsonl_files)
28
29
  save_mds(
29
30
  load_jsonl_files(jsonl_files),
@@ -33,6 +34,7 @@ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes
33
34
  buf_size=buf_size,
34
35
  pigz=use_pigz(compression, no_pigz),
35
36
  shard_size=shard_size,
37
+ trafo=trafo,
36
38
  )
37
39
 
38
40
  @jsonl.command()
@@ -42,13 +44,15 @@ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes
42
44
  @overwrite_option()
43
45
  @yes_option()
44
46
  @batch_size_option()
47
+ @trafo_option()
45
48
  def parquet(**kwargs):
46
49
  jsonl_to_parquet(**kwargs)
47
- def jsonl_to_parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size):
50
+ def jsonl_to_parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size, trafo):
48
51
  check_arguments(output_file, overwrite, yes, jsonl_files)
49
52
  save_parquet(
50
53
  load_jsonl_files(jsonl_files),
51
54
  output_file,
52
55
  compression=compression,
53
56
  batch_size=batch_size,
57
+ trafo=trafo,
54
58
  )
@@ -19,15 +19,17 @@ def mds():
19
19
  @yes_option()
20
20
  @batch_size_option()
21
21
  @no_bulk_option()
22
+ @trafo_option()
22
23
  def jsonl(**kwargs):
23
24
  mds_to_jsonl(**kwargs)
24
- def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk):
25
+ def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk, trafo):
25
26
  check_arguments(output_file, overwrite, yes, mds_directories)
26
27
  save_jsonl(
27
28
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
28
29
  output_file,
29
30
  compression=compression,
30
31
  processes=processes,
32
+ trafo=trafo,
31
33
  )
32
34
 
33
35
  @mds.command()
@@ -38,13 +40,15 @@ def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite
38
40
  @yes_option()
39
41
  @batch_size_option()
40
42
  @no_bulk_option()
43
+ @trafo_option()
41
44
  def parquet(**kwargs):
42
45
  mds_to_parquet(**kwargs)
43
- def mds_to_parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk):
46
+ def mds_to_parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk, trafo):
44
47
  check_arguments(output_file, overwrite, yes, mds_directories)
45
48
  save_parquet(
46
49
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
47
50
  output_file,
48
51
  compression=compression,
49
52
  batch_size=batch_size,
53
+ trafo=trafo,
50
54
  )
@@ -18,15 +18,17 @@ def parquet():
18
18
  @processes_option()
19
19
  @overwrite_option()
20
20
  @yes_option()
21
+ @trafo_option()
21
22
  def jsonl(**kwargs):
22
23
  parquet_to_jsonl(**kwargs)
23
- def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwrite, yes):
24
+ def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwrite, yes, trafo):
24
25
  check_arguments(output_file, overwrite, yes, parquet_files)
25
26
  save_jsonl(
26
27
  load_dataset("parquet", data_files=parquet_files, split="train"),
27
28
  output_file,
28
29
  compression=compression,
29
30
  processes=processes,
31
+ trafo=trafo,
30
32
  )
31
33
 
32
34
  @parquet.command()
@@ -39,9 +41,10 @@ def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwri
39
41
  @buf_size_option()
40
42
  @shard_size_option()
41
43
  @no_pigz_option()
44
+ @trafo_option()
42
45
  def mds(**kwargs):
43
46
  parquet_to_mds(**kwargs)
44
- def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz):
47
+ def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz, trafo):
45
48
  check_arguments(output_dir, overwrite, yes, parquet_files)
46
49
  save_mds(
47
50
  load_dataset("parquet", data_files=parquet_files, split="train"),
@@ -51,4 +54,5 @@ def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite,
51
54
  buf_size=buf_size,
52
55
  pigz=use_pigz(compression, no_pigz=no_pigz),
53
56
  shard_size=shard_size,
57
+ trafo=trafo,
54
58
  )
@@ -18,9 +18,10 @@ def join():
18
18
  @processes_option()
19
19
  @overwrite_option()
20
20
  @yes_option()
21
+ @trafo_option()
21
22
  def jsonl(**kwargs):
22
23
  join_jsonl(**kwargs)
23
- def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes):
24
+ def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes, trafo):
24
25
  check_arguments(output_file, overwrite, yes, jsonl_files)
25
26
  save_jsonl(
26
27
  load_jsonl_files(jsonl_files),
@@ -41,10 +42,11 @@ def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes)
41
42
  @no_bulk_option()
42
43
  @shard_size_option()
43
44
  @no_pigz_option()
45
+ @trafo_option()
44
46
  def mds(**kwargs):
45
47
  print(kwargs)
46
48
  join_mds(**kwargs)
47
- def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk, shard_size, no_pigz):
49
+ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk, shard_size, no_pigz, trafo):
48
50
  check_arguments(output_dir, overwrite, yes, mds_directories)
49
51
  save_mds(
50
52
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
@@ -54,6 +56,7 @@ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes
54
56
  buf_size=buf_size,
55
57
  shard_size=shard_size,
56
58
  pigz=use_pigz(compression, no_pigz),
59
+ trafo=trafo,
57
60
  )
58
61
 
59
62
  @join.command()
@@ -63,13 +66,15 @@ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes
63
66
  @overwrite_option()
64
67
  @yes_option()
65
68
  @batch_size_option()
69
+ @trafo_option()
66
70
  def parquet(**kwargs):
67
71
  join_parquet(**kwargs)
68
- def join_parquet(output_file, parquet_files, compression, overwrite, yes, batch_size):
72
+ def join_parquet(output_file, parquet_files, compression, overwrite, yes, batch_size, trafo):
69
73
  check_arguments(output_file, overwrite, yes, parquet_files)
70
74
  save_parquet(
71
75
  load_dataset("parquet", data_files=parquet_files, split="train"),
72
76
  output_file,
73
77
  compression=compression,
74
78
  batch_size=batch_size,
79
+ trafo=trafo,
75
80
  )
@@ -20,7 +20,10 @@ def split():
20
20
  @processes_option()
21
21
  @overwrite_option()
22
22
  @yes_option()
23
- def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, overwrite, yes):
23
+ @trafo_option()
24
+ def jsonl(*args, **kwargs):
25
+ split_jsonl(*args, **kwargs)
26
+ def split_jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, overwrite, yes, trafo):
24
27
  save_jsonl(
25
28
  load_jsonl_files(jsonl_files),
26
29
  output_file=f"{output_dir}/{prefix}{{part:04d}}.jsonl{extension_compression(compression, jsonl_files[0])}",
@@ -29,6 +32,7 @@ def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, ov
29
32
  size_hint=size_hint,
30
33
  overwrite=overwrite,
31
34
  yes=yes,
35
+ trafo=trafo,
32
36
  )
33
37
 
34
38
  @split.command()
@@ -45,7 +49,10 @@ def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, ov
45
49
  @no_bulk_option()
46
50
  @shard_size_option()
47
51
  @no_pigz_option()
48
- def mds(mds_directories, prefix, output_dir, size_hint, compression, processes, overwrite, yes, buf_size, batch_size, no_bulk, shard_size, no_pigz):
52
+ @trafo_option()
53
+ def mds(*args, **kwargs):
54
+ split_mds(*args, **kwargs)
55
+ def split_mds(mds_directories, prefix, output_dir, size_hint, compression, processes, overwrite, yes, buf_size, batch_size, no_bulk, shard_size, no_pigz, trafo):
49
56
  save_mds(
50
57
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
51
58
  output_dir=f"{output_dir}/{prefix}{{part:04d}}",
@@ -57,6 +64,7 @@ def mds(mds_directories, prefix, output_dir, size_hint, compression, processes,
57
64
  size_hint=size_hint,
58
65
  overwrite=overwrite,
59
66
  yes=yes,
67
+ trafo=trafo,
60
68
  )
61
69
 
62
70
  @split.command()
@@ -68,7 +76,10 @@ def mds(mds_directories, prefix, output_dir, size_hint, compression, processes,
68
76
  @overwrite_option()
69
77
  @yes_option()
70
78
  @batch_size_option()
71
- def parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite, yes, batch_size):
79
+ @trafo_option()
80
+ def parquet(*args, **kwargs):
81
+ split_parquet(*args, **kwargs)
82
+ def split_parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite, yes, batch_size, trafo):
72
83
  save_parquet(
73
84
  load_dataset("parquet", data_files=parquet_files, split="train"),
74
85
  output_file=f"{output_dir}/{prefix}{{part:04d}}.parquet",
@@ -77,4 +88,5 @@ def parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite
77
88
  size_hint=size_hint,
78
89
  overwrite=overwrite,
79
90
  yes=yes,
91
+ trafo=trafo,
80
92
  )
@@ -14,6 +14,7 @@ __all__ = [
14
14
  "prefix_option",
15
15
  "shard_size_option",
16
16
  "size_hint_option",
17
+ "trafo_option",
17
18
  "yes_option",
18
19
  ]
19
20
 
@@ -129,6 +130,17 @@ def size_hint_option(default=2**26):
129
130
  help=f"Size hint for the dataset (default: {default}).",
130
131
  )
131
132
 
133
+ def trafo_option():
134
+ """
135
+ Option for specifying the transformation function.
136
+ """
137
+ return click.option(
138
+ "--trafo",
139
+ default=None,
140
+ type=str,
141
+ help="Transformation function to apply to the dataset.",
142
+ )
143
+
132
144
  def yes_option():
133
145
  """
134
146
  Option for specifying whether to assume yes to all prompts.
@@ -0,0 +1,111 @@
1
+ import re
2
+ from typing import Callable
3
+
4
+ __all__ = ['Trafo', 'flatten_json', 'unflatten_json']
5
+
6
+ class Trafo:
7
+ """
8
+ Base class for transformations.
9
+ """
10
+
11
+ def __init__(self, trafo: Callable | str | None):
12
+ self.trafo = trafo
13
+ if isinstance(trafo, str):
14
+ self.trafo = eval(trafo)
15
+
16
+ def __call__(self, obj):
17
+ return self.trafo(obj) if self.trafo else obj
18
+
19
+ def __repr__(self):
20
+ return f"{self.__class__.__name__}({self.trafo})"
21
+
22
+
23
+ def flatten_json(obj, parent_key='', sep='.', escape_char='\\'):
24
+ items = []
25
+
26
+ def escape(key):
27
+ return key.replace(escape_char, escape_char * 2)\
28
+ .replace(sep, escape_char + sep)\
29
+ .replace('[', escape_char + '[')\
30
+ .replace(']', escape_char + ']')
31
+
32
+ if isinstance(obj, dict):
33
+ if not obj:
34
+ # explicitly handle empty dict
35
+ items.append((parent_key, {}))
36
+ else:
37
+ for k, v in obj.items():
38
+ new_key = f"{parent_key}{sep}{escape(k)}" if parent_key else escape(k)
39
+ items.extend(flatten_json(v, new_key, sep, escape_char).items())
40
+ elif isinstance(obj, list):
41
+ if not obj:
42
+ # explicitly handle empty list
43
+ items.append((parent_key, []))
44
+ else:
45
+ for idx, v in enumerate(obj):
46
+ new_key = f"{parent_key}[{idx}]"
47
+ items.extend(flatten_json(v, new_key, sep, escape_char).items())
48
+ else:
49
+ items.append((parent_key, obj))
50
+ return dict(items)
51
+
52
+
53
+ def unflatten_json(flat_dict, sep='.', escape_char='\\'):
54
+
55
+ def check_flat_json(obj):
56
+ assert isinstance(obj, dict), "Input must be a dictionary"
57
+ for k, v in obj.items():
58
+ assert isinstance(k, str), f"Key {k} is not a string"
59
+ assert isinstance(v, (str, int, float, bool)), f"Value {v} is not a valid JSON type"
60
+
61
+ def parse_key(key):
62
+ tokens = re.findall(r'(?:[^.\[\]\\]|\\.)+|\[\d+\]', key)
63
+ parsed = []
64
+ for token in tokens:
65
+ if token.startswith('['):
66
+ parsed.append(int(token[1:-1]))
67
+ else:
68
+ parsed.append(token.replace(escape_char + sep, sep)
69
+ .replace(escape_char + '[', '[')
70
+ .replace(escape_char + ']', ']')
71
+ .replace(escape_char*2, escape_char))
72
+ return parsed
73
+
74
+ check_flat_json(flat_dict)
75
+
76
+ result = {}
77
+
78
+ for compound_key, value in flat_dict.items():
79
+ keys = parse_key(compound_key)
80
+ current = result
81
+ for idx, key in enumerate(keys):
82
+ if idx == len(keys) - 1:
83
+ if isinstance(key, int):
84
+ if not isinstance(current, list):
85
+ current_parent[last_key] = []
86
+ current = current_parent[last_key]
87
+ while len(current) <= key:
88
+ current.append(None)
89
+ current[key] = value
90
+ else:
91
+ current[key] = value
92
+ else:
93
+ next_key = keys[idx + 1]
94
+ if isinstance(key, int):
95
+ if not isinstance(current, list):
96
+ current_parent[last_key] = []
97
+ current = current_parent[last_key]
98
+ while len(current) <= key:
99
+ current.append(None)
100
+ if current[key] is None:
101
+ current[key] = [] if isinstance(next_key, int) else {}
102
+ current_parent = current
103
+ current = current[key]
104
+ else:
105
+ if key not in current:
106
+ current[key] = [] if isinstance(next_key, int) else {}
107
+ current_parent = current
108
+ current = current[key]
109
+ last_key = key
110
+
111
+ return result
@@ -12,6 +12,7 @@ from tqdm import tqdm
12
12
  from .compression import determine_compression, open_compression, pigz_compress
13
13
  from .mds import MDSBulkReader, MDSWriter
14
14
  from .pigz import pigz_open
15
+ from .trafos import Trafo
15
16
 
16
17
  __all__ = [
17
18
  "check_arguments",
@@ -111,10 +112,12 @@ def load_mds_directories(mds_directories, split='.', batch_size=2**16, bulk=True
111
112
  ds = concatenate_datasets(dsets=dss)
112
113
  return ds
113
114
 
114
- def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True):
115
+ def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True, trafo=None):
115
116
  f = None
116
117
  part = 0
118
+ trafo = Trafo(trafo)
117
119
  for item in tqdm(iterable, desc="Writing to JSONL", unit="sample", disable=_NO_PROGESS):
120
+ item = trafo(item)
118
121
  if f is None:
119
122
  part_file = output_file.format(part=part)
120
123
  check_arguments(part_file, overwrite, yes)
@@ -127,12 +130,14 @@ def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=
127
130
  if f is not None:
128
131
  f.close()
129
132
 
130
- def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pigz=True, shard_size=None, size_hint=None, overwrite=True, yes=True):
133
+ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pigz=True, shard_size=None, size_hint=None, overwrite=True, yes=True, trafo=None):
131
134
  compression = determine_compression("mds", output_dir, compression, no_pigz=not pigz)
132
135
  writer = None
133
136
  part = 0
134
137
  files = []
138
+ trafo = Trafo(trafo)
135
139
  for sample in tqdm(it, desc="Writing to MDS", unit="sample", disable=_NO_PROGESS):
140
+ sample = trafo(sample)
136
141
  if writer is None:
137
142
  part_dir = output_dir.format(part=part)
138
143
  check_arguments(part_dir, overwrite, yes)
@@ -170,12 +175,14 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
170
175
  json.dump(index, open(index_path, "wt"))
171
176
  print(f"Compressed {output_dir} with pigz")
172
177
 
173
- def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=None, overwrite=True, yes=True):
178
+ def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=None, overwrite=True, yes=True, trafo=None):
174
179
  compression = determine_compression("parquet", output_file, compression)
175
180
  writer = None
176
181
  part = 0
182
+ trafo = Trafo(trafo)
177
183
  it = tqdm(it, desc="Writing to Parquet", unit="sample", disable=_NO_PROGESS)
178
184
  for batch in _batch_iterable(it, batch_size):
185
+ batch = [trafo(sample) for sample in batch]
179
186
  table = pa.Table.from_pylist(batch)
180
187
  if writer is None:
181
188
  part_file = output_file.format(part=part)
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "mldataforge"
7
- version = "0.1.7"
7
+ version = "0.2.0"
8
8
  authors = [
9
9
  { name = "Peter Schneider-Kamp" }
10
10
  ]
File without changes
File without changes
File without changes