mldataforge 0.1.7__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {mldataforge-0.1.7 → mldataforge-0.2.1}/PKG-INFO +1 -1
  2. {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/commands/convert/jsonl.py +6 -2
  3. {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/commands/convert/mds.py +6 -2
  4. {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/commands/convert/parquet.py +6 -2
  5. {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/commands/join.py +8 -3
  6. {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/commands/split.py +15 -3
  7. {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/options.py +12 -0
  8. mldataforge-0.2.1/mldataforge/trafos.py +165 -0
  9. {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/utils.py +10 -6
  10. {mldataforge-0.1.7 → mldataforge-0.2.1}/pyproject.toml +1 -1
  11. {mldataforge-0.1.7 → mldataforge-0.2.1}/.gitignore +0 -0
  12. {mldataforge-0.1.7 → mldataforge-0.2.1}/LICENSE +0 -0
  13. {mldataforge-0.1.7 → mldataforge-0.2.1}/README.md +0 -0
  14. {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/__main__.py +0 -0
  15. {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/brotli.py +0 -0
  16. {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/commands/__init__.py +0 -0
  17. {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/commands/convert/__init__.py +0 -0
  18. {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/compression.py +0 -0
  19. {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/mds.py +0 -0
  20. {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/pigz.py +0 -0
  21. {mldataforge-0.1.7 → mldataforge-0.2.1}/mldataforge/snappy.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mldataforge
3
- Version: 0.1.7
3
+ Version: 0.2.1
4
4
  Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
5
5
  Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
6
6
  Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
@@ -21,9 +21,10 @@ def jsonl():
21
21
  @buf_size_option()
22
22
  @shard_size_option()
23
23
  @no_pigz_option()
24
+ @trafo_option()
24
25
  def mds(**kwargs):
25
26
  jsonl_to_mds(**kwargs)
26
- def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz):
27
+ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz, trafo):
27
28
  check_arguments(output_dir, overwrite, yes, jsonl_files)
28
29
  save_mds(
29
30
  load_jsonl_files(jsonl_files),
@@ -33,6 +34,7 @@ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes
33
34
  buf_size=buf_size,
34
35
  pigz=use_pigz(compression, no_pigz),
35
36
  shard_size=shard_size,
37
+ trafo=trafo,
36
38
  )
37
39
 
38
40
  @jsonl.command()
@@ -42,13 +44,15 @@ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes
42
44
  @overwrite_option()
43
45
  @yes_option()
44
46
  @batch_size_option()
47
+ @trafo_option()
45
48
  def parquet(**kwargs):
46
49
  jsonl_to_parquet(**kwargs)
47
- def jsonl_to_parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size):
50
+ def jsonl_to_parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size, trafo):
48
51
  check_arguments(output_file, overwrite, yes, jsonl_files)
49
52
  save_parquet(
50
53
  load_jsonl_files(jsonl_files),
51
54
  output_file,
52
55
  compression=compression,
53
56
  batch_size=batch_size,
57
+ trafo=trafo,
54
58
  )
@@ -19,15 +19,17 @@ def mds():
19
19
  @yes_option()
20
20
  @batch_size_option()
21
21
  @no_bulk_option()
22
+ @trafo_option()
22
23
  def jsonl(**kwargs):
23
24
  mds_to_jsonl(**kwargs)
24
- def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk):
25
+ def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk, trafo):
25
26
  check_arguments(output_file, overwrite, yes, mds_directories)
26
27
  save_jsonl(
27
28
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
28
29
  output_file,
29
30
  compression=compression,
30
31
  processes=processes,
32
+ trafo=trafo,
31
33
  )
32
34
 
33
35
  @mds.command()
@@ -38,13 +40,15 @@ def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite
38
40
  @yes_option()
39
41
  @batch_size_option()
40
42
  @no_bulk_option()
43
+ @trafo_option()
41
44
  def parquet(**kwargs):
42
45
  mds_to_parquet(**kwargs)
43
- def mds_to_parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk):
46
+ def mds_to_parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk, trafo):
44
47
  check_arguments(output_file, overwrite, yes, mds_directories)
45
48
  save_parquet(
46
49
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
47
50
  output_file,
48
51
  compression=compression,
49
52
  batch_size=batch_size,
53
+ trafo=trafo,
50
54
  )
@@ -18,15 +18,17 @@ def parquet():
18
18
  @processes_option()
19
19
  @overwrite_option()
20
20
  @yes_option()
21
+ @trafo_option()
21
22
  def jsonl(**kwargs):
22
23
  parquet_to_jsonl(**kwargs)
23
- def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwrite, yes):
24
+ def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwrite, yes, trafo):
24
25
  check_arguments(output_file, overwrite, yes, parquet_files)
25
26
  save_jsonl(
26
27
  load_dataset("parquet", data_files=parquet_files, split="train"),
27
28
  output_file,
28
29
  compression=compression,
29
30
  processes=processes,
31
+ trafo=trafo,
30
32
  )
31
33
 
32
34
  @parquet.command()
@@ -39,9 +41,10 @@ def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwri
39
41
  @buf_size_option()
40
42
  @shard_size_option()
41
43
  @no_pigz_option()
44
+ @trafo_option()
42
45
  def mds(**kwargs):
43
46
  parquet_to_mds(**kwargs)
44
- def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz):
47
+ def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz, trafo):
45
48
  check_arguments(output_dir, overwrite, yes, parquet_files)
46
49
  save_mds(
47
50
  load_dataset("parquet", data_files=parquet_files, split="train"),
@@ -51,4 +54,5 @@ def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite,
51
54
  buf_size=buf_size,
52
55
  pigz=use_pigz(compression, no_pigz=no_pigz),
53
56
  shard_size=shard_size,
57
+ trafo=trafo,
54
58
  )
@@ -18,9 +18,10 @@ def join():
18
18
  @processes_option()
19
19
  @overwrite_option()
20
20
  @yes_option()
21
+ @trafo_option()
21
22
  def jsonl(**kwargs):
22
23
  join_jsonl(**kwargs)
23
- def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes):
24
+ def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes, trafo):
24
25
  check_arguments(output_file, overwrite, yes, jsonl_files)
25
26
  save_jsonl(
26
27
  load_jsonl_files(jsonl_files),
@@ -41,10 +42,11 @@ def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes)
41
42
  @no_bulk_option()
42
43
  @shard_size_option()
43
44
  @no_pigz_option()
45
+ @trafo_option()
44
46
  def mds(**kwargs):
45
47
  print(kwargs)
46
48
  join_mds(**kwargs)
47
- def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk, shard_size, no_pigz):
49
+ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk, shard_size, no_pigz, trafo):
48
50
  check_arguments(output_dir, overwrite, yes, mds_directories)
49
51
  save_mds(
50
52
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
@@ -54,6 +56,7 @@ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes
54
56
  buf_size=buf_size,
55
57
  shard_size=shard_size,
56
58
  pigz=use_pigz(compression, no_pigz),
59
+ trafo=trafo,
57
60
  )
58
61
 
59
62
  @join.command()
@@ -63,13 +66,15 @@ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes
63
66
  @overwrite_option()
64
67
  @yes_option()
65
68
  @batch_size_option()
69
+ @trafo_option()
66
70
  def parquet(**kwargs):
67
71
  join_parquet(**kwargs)
68
- def join_parquet(output_file, parquet_files, compression, overwrite, yes, batch_size):
72
+ def join_parquet(output_file, parquet_files, compression, overwrite, yes, batch_size, trafo):
69
73
  check_arguments(output_file, overwrite, yes, parquet_files)
70
74
  save_parquet(
71
75
  load_dataset("parquet", data_files=parquet_files, split="train"),
72
76
  output_file,
73
77
  compression=compression,
74
78
  batch_size=batch_size,
79
+ trafo=trafo,
75
80
  )
@@ -20,7 +20,10 @@ def split():
20
20
  @processes_option()
21
21
  @overwrite_option()
22
22
  @yes_option()
23
- def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, overwrite, yes):
23
+ @trafo_option()
24
+ def jsonl(*args, **kwargs):
25
+ split_jsonl(*args, **kwargs)
26
+ def split_jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, overwrite, yes, trafo):
24
27
  save_jsonl(
25
28
  load_jsonl_files(jsonl_files),
26
29
  output_file=f"{output_dir}/{prefix}{{part:04d}}.jsonl{extension_compression(compression, jsonl_files[0])}",
@@ -29,6 +32,7 @@ def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, ov
29
32
  size_hint=size_hint,
30
33
  overwrite=overwrite,
31
34
  yes=yes,
35
+ trafo=trafo,
32
36
  )
33
37
 
34
38
  @split.command()
@@ -45,7 +49,10 @@ def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, ov
45
49
  @no_bulk_option()
46
50
  @shard_size_option()
47
51
  @no_pigz_option()
48
- def mds(mds_directories, prefix, output_dir, size_hint, compression, processes, overwrite, yes, buf_size, batch_size, no_bulk, shard_size, no_pigz):
52
+ @trafo_option()
53
+ def mds(*args, **kwargs):
54
+ split_mds(*args, **kwargs)
55
+ def split_mds(mds_directories, prefix, output_dir, size_hint, compression, processes, overwrite, yes, buf_size, batch_size, no_bulk, shard_size, no_pigz, trafo):
49
56
  save_mds(
50
57
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
51
58
  output_dir=f"{output_dir}/{prefix}{{part:04d}}",
@@ -57,6 +64,7 @@ def mds(mds_directories, prefix, output_dir, size_hint, compression, processes,
57
64
  size_hint=size_hint,
58
65
  overwrite=overwrite,
59
66
  yes=yes,
67
+ trafo=trafo,
60
68
  )
61
69
 
62
70
  @split.command()
@@ -68,7 +76,10 @@ def mds(mds_directories, prefix, output_dir, size_hint, compression, processes,
68
76
  @overwrite_option()
69
77
  @yes_option()
70
78
  @batch_size_option()
71
- def parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite, yes, batch_size):
79
+ @trafo_option()
80
+ def parquet(*args, **kwargs):
81
+ split_parquet(*args, **kwargs)
82
+ def split_parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite, yes, batch_size, trafo):
72
83
  save_parquet(
73
84
  load_dataset("parquet", data_files=parquet_files, split="train"),
74
85
  output_file=f"{output_dir}/{prefix}{{part:04d}}.parquet",
@@ -77,4 +88,5 @@ def parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite
77
88
  size_hint=size_hint,
78
89
  overwrite=overwrite,
79
90
  yes=yes,
91
+ trafo=trafo,
80
92
  )
@@ -14,6 +14,7 @@ __all__ = [
14
14
  "prefix_option",
15
15
  "shard_size_option",
16
16
  "size_hint_option",
17
+ "trafo_option",
17
18
  "yes_option",
18
19
  ]
19
20
 
@@ -129,6 +130,17 @@ def size_hint_option(default=2**26):
129
130
  help=f"Size hint for the dataset (default: {default}).",
130
131
  )
131
132
 
133
+ def trafo_option():
134
+ """
135
+ Option for specifying the transformation function.
136
+ """
137
+ return click.option(
138
+ "--trafo",
139
+ default=None,
140
+ type=str,
141
+ help="Transformation function to apply to the dataset.",
142
+ )
143
+
132
144
  def yes_option():
133
145
  """
134
146
  Option for specifying whether to assume yes to all prompts.
@@ -0,0 +1,165 @@
1
+ import re
2
+
3
+ __all__ = ['IndexedDatasetView', 'Transformation', 'Transformations', 'flatten_json', 'unflatten_json']
4
+
5
+ class IndexedDatasetView:
6
+ def __init__(self, dataset, indices):
7
+ self.dataset = dataset
8
+ self.indices = list(indices) # ensure repeatable access
9
+
10
+ def __iter__(self):
11
+ for idx in self.indices:
12
+ yield self.dataset[idx]
13
+
14
+ def __len__(self):
15
+ return len(self.indices)
16
+
17
+ class Transformation:
18
+ def __init__(self, code: str):
19
+ self.code = code
20
+ self._init_context()
21
+
22
+ def _init_context(self):
23
+ self.global_context = {}
24
+ exec(self.code, self.global_context)
25
+ if 'process' not in self.global_context or not callable(self.global_context['process']):
26
+ raise ValueError("code must define a callable named 'process'")
27
+ self.process = self.global_context['process']
28
+ self._flushable = hasattr(self.process, 'flushable') and self.process.flushable
29
+
30
+ def _normalize_outputs(self, result):
31
+ if result is None:
32
+ return []
33
+ if isinstance(result, (list, tuple, set)):
34
+ return list(result)
35
+ return [result]
36
+
37
+ def _flush(self):
38
+ if self._flushable:
39
+ while True:
40
+ flushed = self._normalize_outputs(self.process(None))
41
+ if not flushed:
42
+ return
43
+ yield from flushed
44
+
45
+ def __call__(self, iterable):
46
+ for sample in iterable:
47
+ results = self._normalize_outputs(self.process(sample))
48
+ yield from results
49
+ if not results:
50
+ yield from self._flush()
51
+ if self._flushable:
52
+ yield from self._flush()
53
+
54
+ def __len__(self):
55
+ if self._last_input_len is not None:
56
+ return self._last_input_len
57
+ raise TypeError("Length is not available for this transformation.")
58
+
59
+
60
+ class Transformations:
61
+ def __init__(self, codes: list[str], indices=None):
62
+ self.pipeline = [Transformation(code) for code in codes]
63
+ self.indices = indices # Optional index iterable
64
+
65
+ def __call__(self, dataset):
66
+ # Wrap dataset with IndexedDatasetView if indices are provided
67
+ if self.indices is not None:
68
+ dataset = IndexedDatasetView(dataset, self.indices)
69
+
70
+ result = dataset
71
+ for transform in self.pipeline:
72
+ result = transform(result)
73
+ return result
74
+
75
+ def __len__(self):
76
+ # Return the input length to the pipeline
77
+ if self.indices is not None:
78
+ return len(self.indices)
79
+ elif hasattr(self.pipeline[0], '_last_input_len') and self.pipeline[0]._last_input_len is not None:
80
+ return self.pipeline[0]._last_input_len
81
+ raise TypeError("Transformations length is not available until __call__ is used on a sized input.")
82
+
83
+ def identity(obj):
84
+ return obj
85
+
86
+ def flatten_json(obj, parent_key='', sep='.', escape_char='\\'):
87
+ def escape(key):
88
+ return key.replace(escape_char, escape_char * 2)\
89
+ .replace(sep, escape_char + sep)\
90
+ .replace('[', escape_char + '[')\
91
+ .replace(']', escape_char + ']')
92
+ items = []
93
+ if isinstance(obj, dict):
94
+ if not obj:
95
+ # explicitly handle empty dict
96
+ items.append((parent_key, {}))
97
+ else:
98
+ for k, v in obj.items():
99
+ new_key = f"{parent_key}{sep}{escape(k)}" if parent_key else escape(k)
100
+ items.extend(flatten_json(v, new_key, sep, escape_char).items())
101
+ elif isinstance(obj, list):
102
+ if not obj:
103
+ # explicitly handle empty list
104
+ items.append((parent_key, []))
105
+ else:
106
+ for idx, v in enumerate(obj):
107
+ new_key = f"{parent_key}[{idx}]"
108
+ items.extend(flatten_json(v, new_key, sep, escape_char).items())
109
+ else:
110
+ items.append((parent_key, obj))
111
+ return dict(items)
112
+
113
+ def unflatten_json(flat_dict, sep='.', escape_char='\\'):
114
+ def check_flat_json(obj):
115
+ assert isinstance(obj, dict), "Input must be a dictionary"
116
+ for k, v in obj.items():
117
+ assert isinstance(k, str), f"Key {k} is not a string"
118
+ assert isinstance(v, (str, int, float, bool)), f"Value {v} is not a valid JSON type"
119
+ def parse_key(key):
120
+ tokens = re.findall(r'(?:[^.\[\]\\]|\\.)+|\[\d+\]', key)
121
+ parsed = []
122
+ for token in tokens:
123
+ if token.startswith('['):
124
+ parsed.append(int(token[1:-1]))
125
+ else:
126
+ parsed.append(token.replace(escape_char + sep, sep)
127
+ .replace(escape_char + '[', '[')
128
+ .replace(escape_char + ']', ']')
129
+ .replace(escape_char*2, escape_char))
130
+ return parsed
131
+ check_flat_json(flat_dict)
132
+ result = {}
133
+ for compound_key, value in flat_dict.items():
134
+ keys = parse_key(compound_key)
135
+ current = result
136
+ for idx, key in enumerate(keys):
137
+ if idx == len(keys) - 1:
138
+ if isinstance(key, int):
139
+ if not isinstance(current, list):
140
+ current_parent[last_key] = []
141
+ current = current_parent[last_key]
142
+ while len(current) <= key:
143
+ current.append(None)
144
+ current[key] = value
145
+ else:
146
+ current[key] = value
147
+ else:
148
+ next_key = keys[idx + 1]
149
+ if isinstance(key, int):
150
+ if not isinstance(current, list):
151
+ current_parent[last_key] = []
152
+ current = current_parent[last_key]
153
+ while len(current) <= key:
154
+ current.append(None)
155
+ if current[key] is None:
156
+ current[key] = [] if isinstance(next_key, int) else {}
157
+ current_parent = current
158
+ current = current[key]
159
+ else:
160
+ if key not in current:
161
+ current[key] = [] if isinstance(next_key, int) else {}
162
+ current_parent = current
163
+ current = current[key]
164
+ last_key = key
165
+ return result
@@ -12,6 +12,7 @@ from tqdm import tqdm
12
12
  from .compression import determine_compression, open_compression, pigz_compress
13
13
  from .mds import MDSBulkReader, MDSWriter
14
14
  from .pigz import pigz_open
15
+ from .trafos import Transformations
15
16
 
16
17
  __all__ = [
17
18
  "check_arguments",
@@ -111,10 +112,11 @@ def load_mds_directories(mds_directories, split='.', batch_size=2**16, bulk=True
111
112
  ds = concatenate_datasets(dsets=dss)
112
113
  return ds
113
114
 
114
- def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True):
115
+ def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True, trafo=None):
115
116
  f = None
116
117
  part = 0
117
- for item in tqdm(iterable, desc="Writing to JSONL", unit="sample", disable=_NO_PROGESS):
118
+ trafo = Transformations([] if trafo is None else [trafo])
119
+ for item in tqdm(trafo(iterable), desc="Writing to JSONL", unit="sample", disable=_NO_PROGESS):
118
120
  if f is None:
119
121
  part_file = output_file.format(part=part)
120
122
  check_arguments(part_file, overwrite, yes)
@@ -127,12 +129,13 @@ def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=
127
129
  if f is not None:
128
130
  f.close()
129
131
 
130
- def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pigz=True, shard_size=None, size_hint=None, overwrite=True, yes=True):
132
+ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pigz=True, shard_size=None, size_hint=None, overwrite=True, yes=True, trafo=None):
131
133
  compression = determine_compression("mds", output_dir, compression, no_pigz=not pigz)
132
134
  writer = None
133
135
  part = 0
134
136
  files = []
135
- for sample in tqdm(it, desc="Writing to MDS", unit="sample", disable=_NO_PROGESS):
137
+ trafo = Transformations([] if trafo is None else [trafo])
138
+ for sample in tqdm(trafo(it), desc="Writing to MDS", unit="sample", disable=_NO_PROGESS):
136
139
  if writer is None:
137
140
  part_dir = output_dir.format(part=part)
138
141
  check_arguments(part_dir, overwrite, yes)
@@ -170,12 +173,13 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
170
173
  json.dump(index, open(index_path, "wt"))
171
174
  print(f"Compressed {output_dir} with pigz")
172
175
 
173
- def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=None, overwrite=True, yes=True):
176
+ def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=None, overwrite=True, yes=True, trafo=None):
174
177
  compression = determine_compression("parquet", output_file, compression)
175
178
  writer = None
176
179
  part = 0
180
+ trafo = Transformations([] if trafo is None else [trafo])
177
181
  it = tqdm(it, desc="Writing to Parquet", unit="sample", disable=_NO_PROGESS)
178
- for batch in _batch_iterable(it, batch_size):
182
+ for batch in _batch_iterable(trafo(it), batch_size):
179
183
  table = pa.Table.from_pylist(batch)
180
184
  if writer is None:
181
185
  part_file = output_file.format(part=part)
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "mldataforge"
7
- version = "0.1.7"
7
+ version = "0.2.1"
8
8
  authors = [
9
9
  { name = "Peter Schneider-Kamp" }
10
10
  ]
File without changes
File without changes
File without changes