mldataforge 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {mldataforge-0.2.0 → mldataforge-0.2.2}/PKG-INFO +1 -1
  2. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/commands/convert/mds.py +6 -4
  3. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/commands/join.py +3 -2
  4. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/commands/split.py +3 -2
  5. mldataforge-0.2.2/mldataforge/indexing.py +25 -0
  6. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/options.py +12 -0
  7. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/trafos.py +63 -28
  8. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/utils.py +17 -11
  9. {mldataforge-0.2.0 → mldataforge-0.2.2}/pyproject.toml +1 -1
  10. {mldataforge-0.2.0 → mldataforge-0.2.2}/.gitignore +0 -0
  11. {mldataforge-0.2.0 → mldataforge-0.2.2}/LICENSE +0 -0
  12. {mldataforge-0.2.0 → mldataforge-0.2.2}/README.md +0 -0
  13. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/__main__.py +0 -0
  14. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/brotli.py +0 -0
  15. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/commands/__init__.py +0 -0
  16. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/commands/convert/__init__.py +0 -0
  17. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/commands/convert/jsonl.py +0 -0
  18. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/commands/convert/parquet.py +0 -0
  19. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/compression.py +0 -0
  20. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/mds.py +0 -0
  21. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/pigz.py +0 -0
  22. {mldataforge-0.2.0 → mldataforge-0.2.2}/mldataforge/snappy.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mldataforge
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
5
5
  Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
6
6
  Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
@@ -20,12 +20,13 @@ def mds():
20
20
  @batch_size_option()
21
21
  @no_bulk_option()
22
22
  @trafo_option()
23
+ @shuffle_option()
23
24
  def jsonl(**kwargs):
24
25
  mds_to_jsonl(**kwargs)
25
- def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk, trafo):
26
+ def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk, trafo, shuffle):
26
27
  check_arguments(output_file, overwrite, yes, mds_directories)
27
28
  save_jsonl(
28
- load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
29
+ load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk, shuffle=shuffle),
29
30
  output_file,
30
31
  compression=compression,
31
32
  processes=processes,
@@ -41,12 +42,13 @@ def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite
41
42
  @batch_size_option()
42
43
  @no_bulk_option()
43
44
  @trafo_option()
45
+ @shuffle_option()
44
46
  def parquet(**kwargs):
45
47
  mds_to_parquet(**kwargs)
46
- def mds_to_parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk, trafo):
48
+ def mds_to_parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk, trafo, shuffle):
47
49
  check_arguments(output_file, overwrite, yes, mds_directories)
48
50
  save_parquet(
49
- load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
51
+ load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk, shuffle=shuffle),
50
52
  output_file,
51
53
  compression=compression,
52
54
  batch_size=batch_size,
@@ -43,13 +43,14 @@ def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes,
43
43
  @shard_size_option()
44
44
  @no_pigz_option()
45
45
  @trafo_option()
46
+ @shuffle_option()
46
47
  def mds(**kwargs):
47
48
  print(kwargs)
48
49
  join_mds(**kwargs)
49
- def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk, shard_size, no_pigz, trafo):
50
+ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk, shard_size, no_pigz, trafo, shuffle):
50
51
  check_arguments(output_dir, overwrite, yes, mds_directories)
51
52
  save_mds(
52
- load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
53
+ load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk, shuffle=shuffle),
53
54
  output_dir,
54
55
  processes=processes,
55
56
  compression=compression,
@@ -50,11 +50,12 @@ def split_jsonl(jsonl_files, prefix, output_dir, size_hint, compression, process
50
50
  @shard_size_option()
51
51
  @no_pigz_option()
52
52
  @trafo_option()
53
+ @shuffle_option()
53
54
  def mds(*args, **kwargs):
54
55
  split_mds(*args, **kwargs)
55
- def split_mds(mds_directories, prefix, output_dir, size_hint, compression, processes, overwrite, yes, buf_size, batch_size, no_bulk, shard_size, no_pigz, trafo):
56
+ def split_mds(mds_directories, prefix, output_dir, size_hint, compression, processes, overwrite, yes, buf_size, batch_size, no_bulk, shard_size, no_pigz, trafo, shuffle):
56
57
  save_mds(
57
- load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
58
+ load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk, shuffle=shuffle),
58
59
  output_dir=f"{output_dir}/{prefix}{{part:04d}}",
59
60
  processes=processes,
60
61
  compression=compression,
@@ -0,0 +1,25 @@
1
+ import numpy as np
2
+
3
+ __all__ = ['IndexedDatasetView', 'shuffle_permutation']
4
+
5
+ class IndexedDatasetView:
6
+ def __init__(self, dataset, indices):
7
+ self.dataset = dataset
8
+ self.indices = list(indices) # ensure repeatable accessx
9
+
10
+ def __iter__(self):
11
+ for idx in self.indices:
12
+ yield self.dataset[idx]
13
+
14
+ def __len__(self):
15
+ return len(self.indices)
16
+
17
+ def shuffle_permutation(n, seed=int):
18
+ rng = np.random.default_rng(seed)
19
+ return rng.permutation(n)
20
+
21
+ def reverse_permutation(indices):
22
+ n = len(indices)
23
+ reverse_indices = np.empty(n, dtype=int)
24
+ reverse_indices[indices] = np.arange(n)
25
+ return reverse_indices
@@ -13,6 +13,7 @@ __all__ = [
13
13
  "processes_option",
14
14
  "prefix_option",
15
15
  "shard_size_option",
16
+ "shuffle_option",
16
17
  "size_hint_option",
17
18
  "trafo_option",
18
19
  "yes_option",
@@ -120,6 +121,17 @@ def shard_size_option(default=2**26):
120
121
  help=f"Shard size for the dataset (default: {default}).",
121
122
  )
122
123
 
124
+ def shuffle_option():
125
+ """
126
+ Option for specifying whether to shuffle the dataset by providing a random seed.
127
+ """
128
+ return click.option(
129
+ "--shuffle",
130
+ default=None,
131
+ type=int,
132
+ help="Shuffle the dataset by providing a random seed.",
133
+ )
134
+
123
135
  def size_hint_option(default=2**26):
124
136
  """
125
137
  Option for specifying the size hint.
@@ -1,34 +1,73 @@
1
1
  import re
2
- from typing import Callable
3
-
4
- __all__ = ['Trafo', 'flatten_json', 'unflatten_json']
5
-
6
- class Trafo:
7
- """
8
- Base class for transformations.
9
- """
10
-
11
- def __init__(self, trafo: Callable | str | None):
12
- self.trafo = trafo
13
- if isinstance(trafo, str):
14
- self.trafo = eval(trafo)
15
-
16
- def __call__(self, obj):
17
- return self.trafo(obj) if self.trafo else obj
18
-
19
- def __repr__(self):
20
- return f"{self.__class__.__name__}({self.trafo})"
21
2
 
3
+ __all__ = ['Transformation', 'Transformations', 'flatten_json', 'identity', 'unflatten_json']
4
+
5
+ class Transformation:
6
+ def __init__(self, code: str):
7
+ self.code = code
8
+ self._init_context()
9
+
10
+ def _init_context(self):
11
+ self.global_context = {}
12
+ exec(self.code, self.global_context)
13
+ if 'process' not in self.global_context or not callable(self.global_context['process']):
14
+ raise ValueError("code must define a callable named 'process'")
15
+ self.process = self.global_context['process']
16
+ self._flushable = hasattr(self.process, 'flushable') and self.process.flushable
17
+
18
+ def _normalize_outputs(self, result):
19
+ if result is None:
20
+ return []
21
+ if isinstance(result, (list, tuple, set)):
22
+ return list(result)
23
+ return [result]
24
+
25
+ def _flush(self):
26
+ if self._flushable:
27
+ while True:
28
+ flushed = self._normalize_outputs(self.process(None))
29
+ if not flushed:
30
+ return
31
+ yield from flushed
32
+
33
+ def __call__(self, iterable):
34
+ for sample in iterable:
35
+ results = self._normalize_outputs(self.process(sample))
36
+ yield from results
37
+ if not results:
38
+ yield from self._flush()
39
+ if self._flushable:
40
+ yield from self._flush()
41
+
42
+ def __len__(self):
43
+ if self._last_input_len is not None:
44
+ return self._last_input_len
45
+ raise TypeError("Length is not available for this transformation.")
46
+
47
+ class Transformations:
48
+ def __init__(self, codes: list[str], indices=None):
49
+ self.pipeline = [Transformation(code) for code in codes]
50
+
51
+ def __call__(self, dataset):
52
+ result = dataset
53
+ for transform in self.pipeline:
54
+ result = transform(result)
55
+ return result
56
+
57
+ def __len__(self):
58
+ if self.indices is not None:
59
+ return len(self.indices)
60
+ elif hasattr(self.pipeline[0], '_last_input_len') and self.pipeline[0]._last_input_len is not None:
61
+ return self.pipeline[0]._last_input_len
62
+ raise TypeError("Transformations length is not available until __call__ is used on a sized input.")
22
63
 
23
64
  def flatten_json(obj, parent_key='', sep='.', escape_char='\\'):
24
- items = []
25
-
26
65
  def escape(key):
27
66
  return key.replace(escape_char, escape_char * 2)\
28
67
  .replace(sep, escape_char + sep)\
29
68
  .replace('[', escape_char + '[')\
30
69
  .replace(']', escape_char + ']')
31
-
70
+ items = []
32
71
  if isinstance(obj, dict):
33
72
  if not obj:
34
73
  # explicitly handle empty dict
@@ -49,15 +88,15 @@ def flatten_json(obj, parent_key='', sep='.', escape_char='\\'):
49
88
  items.append((parent_key, obj))
50
89
  return dict(items)
51
90
 
91
+ def identity(obj):
92
+ return obj
52
93
 
53
94
  def unflatten_json(flat_dict, sep='.', escape_char='\\'):
54
-
55
95
  def check_flat_json(obj):
56
96
  assert isinstance(obj, dict), "Input must be a dictionary"
57
97
  for k, v in obj.items():
58
98
  assert isinstance(k, str), f"Key {k} is not a string"
59
99
  assert isinstance(v, (str, int, float, bool)), f"Value {v} is not a valid JSON type"
60
-
61
100
  def parse_key(key):
62
101
  tokens = re.findall(r'(?:[^.\[\]\\]|\\.)+|\[\d+\]', key)
63
102
  parsed = []
@@ -70,11 +109,8 @@ def unflatten_json(flat_dict, sep='.', escape_char='\\'):
70
109
  .replace(escape_char + ']', ']')
71
110
  .replace(escape_char*2, escape_char))
72
111
  return parsed
73
-
74
112
  check_flat_json(flat_dict)
75
-
76
113
  result = {}
77
-
78
114
  for compound_key, value in flat_dict.items():
79
115
  keys = parse_key(compound_key)
80
116
  current = result
@@ -107,5 +143,4 @@ def unflatten_json(flat_dict, sep='.', escape_char='\\'):
107
143
  current_parent = current
108
144
  current = current[key]
109
145
  last_key = key
110
-
111
146
  return result
@@ -10,9 +10,10 @@ from streaming import StreamingDataset
10
10
  from tqdm import tqdm
11
11
 
12
12
  from .compression import determine_compression, open_compression, pigz_compress
13
+ from .indexing import IndexedDatasetView, reverse_permutation, shuffle_permutation
13
14
  from .mds import MDSBulkReader, MDSWriter
14
15
  from .pigz import pigz_open
15
- from .trafos import Trafo
16
+ from .trafos import Transformations
16
17
 
17
18
  __all__ = [
18
19
  "check_arguments",
@@ -89,7 +90,9 @@ def load_jsonl_files(jsonl_files):
89
90
  return _streaming_jsonl(jsonl_files, compressions)
90
91
  return load_dataset("json", data_files=jsonl_files, split="train")
91
92
 
92
- def load_mds_directories(mds_directories, split='.', batch_size=2**16, bulk=True):
93
+ def load_mds_directories(mds_directories, split='.', batch_size=2**16, bulk=True, shuffle=None):
94
+ if bulk and shuffle is not None:
95
+ raise ValueError("Bulk reader does not support shuffling by design.")
93
96
  if bulk:
94
97
  return MDSBulkReader(mds_directories, split=split)
95
98
  dss = []
@@ -110,14 +113,19 @@ def load_mds_directories(mds_directories, split='.', batch_size=2**16, bulk=True
110
113
  else:
111
114
  with timing(message=f"Concatenating {len(dss)} datasets"):
112
115
  ds = concatenate_datasets(dsets=dss)
116
+ if shuffle is not None:
117
+ with timing(message="Creating shuffle indices"):
118
+ indices = shuffle_permutation(len(ds), seed=abs(shuffle))
119
+ if shuffle < 0:
120
+ indices = reverse_permutation(indices)
121
+ ds = IndexedDatasetView(ds, indices)
113
122
  return ds
114
123
 
115
124
  def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True, trafo=None):
116
125
  f = None
117
126
  part = 0
118
- trafo = Trafo(trafo)
119
- for item in tqdm(iterable, desc="Writing to JSONL", unit="sample", disable=_NO_PROGESS):
120
- item = trafo(item)
127
+ trafo = Transformations([] if trafo is None else [trafo])
128
+ for item in tqdm(trafo(iterable), desc="Writing to JSONL", unit="sample", disable=_NO_PROGESS):
121
129
  if f is None:
122
130
  part_file = output_file.format(part=part)
123
131
  check_arguments(part_file, overwrite, yes)
@@ -135,9 +143,8 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
135
143
  writer = None
136
144
  part = 0
137
145
  files = []
138
- trafo = Trafo(trafo)
139
- for sample in tqdm(it, desc="Writing to MDS", unit="sample", disable=_NO_PROGESS):
140
- sample = trafo(sample)
146
+ trafo = Transformations([] if trafo is None else [trafo])
147
+ for sample in tqdm(trafo(it), desc="Writing to MDS", unit="sample", disable=_NO_PROGESS):
141
148
  if writer is None:
142
149
  part_dir = output_dir.format(part=part)
143
150
  check_arguments(part_dir, overwrite, yes)
@@ -179,10 +186,9 @@ def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=
179
186
  compression = determine_compression("parquet", output_file, compression)
180
187
  writer = None
181
188
  part = 0
182
- trafo = Trafo(trafo)
189
+ trafo = Transformations([] if trafo is None else [trafo])
183
190
  it = tqdm(it, desc="Writing to Parquet", unit="sample", disable=_NO_PROGESS)
184
- for batch in _batch_iterable(it, batch_size):
185
- batch = [trafo(sample) for sample in batch]
191
+ for batch in _batch_iterable(trafo(it), batch_size):
186
192
  table = pa.Table.from_pylist(batch)
187
193
  if writer is None:
188
194
  part_file = output_file.format(part=part)
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "mldataforge"
7
- version = "0.2.0"
7
+ version = "0.2.2"
8
8
  authors = [
9
9
  { name = "Peter Schneider-Kamp" }
10
10
  ]
File without changes
File without changes
File without changes