mldataforge 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {mldataforge-0.2.0 → mldataforge-0.2.1}/PKG-INFO +1 -1
  2. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/trafos.py +83 -29
  3. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/utils.py +7 -10
  4. {mldataforge-0.2.0 → mldataforge-0.2.1}/pyproject.toml +1 -1
  5. {mldataforge-0.2.0 → mldataforge-0.2.1}/.gitignore +0 -0
  6. {mldataforge-0.2.0 → mldataforge-0.2.1}/LICENSE +0 -0
  7. {mldataforge-0.2.0 → mldataforge-0.2.1}/README.md +0 -0
  8. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/__main__.py +0 -0
  9. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/brotli.py +0 -0
  10. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/commands/__init__.py +0 -0
  11. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/commands/convert/__init__.py +0 -0
  12. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/commands/convert/jsonl.py +0 -0
  13. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/commands/convert/mds.py +0 -0
  14. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/commands/convert/parquet.py +0 -0
  15. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/commands/join.py +0 -0
  16. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/commands/split.py +0 -0
  17. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/compression.py +0 -0
  18. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/mds.py +0 -0
  19. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/options.py +0 -0
  20. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/pigz.py +0 -0
  21. {mldataforge-0.2.0 → mldataforge-0.2.1}/mldataforge/snappy.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mldataforge
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
5
5
  Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
6
6
  Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
@@ -1,34 +1,95 @@
1
1
  import re
2
- from typing import Callable
3
-
4
- __all__ = ['Trafo', 'flatten_json', 'unflatten_json']
5
-
6
- class Trafo:
7
- """
8
- Base class for transformations.
9
- """
10
-
11
- def __init__(self, trafo: Callable | str | None):
12
- self.trafo = trafo
13
- if isinstance(trafo, str):
14
- self.trafo = eval(trafo)
15
-
16
- def __call__(self, obj):
17
- return self.trafo(obj) if self.trafo else obj
18
-
19
- def __repr__(self):
20
- return f"{self.__class__.__name__}({self.trafo})"
21
2
 
3
+ __all__ = ['IndexedDatasetView', 'Transformation', 'Transformations', 'flatten_json', 'unflatten_json']
4
+
5
+ class IndexedDatasetView:
6
+ def __init__(self, dataset, indices):
7
+ self.dataset = dataset
8
+ self.indices = list(indices) # ensure repeatable access
9
+
10
+ def __iter__(self):
11
+ for idx in self.indices:
12
+ yield self.dataset[idx]
13
+
14
+ def __len__(self):
15
+ return len(self.indices)
16
+
17
+ class Transformation:
18
+ def __init__(self, code: str):
19
+ self.code = code
20
+ self._init_context()
21
+
22
+ def _init_context(self):
23
+ self.global_context = {}
24
+ exec(self.code, self.global_context)
25
+ if 'process' not in self.global_context or not callable(self.global_context['process']):
26
+ raise ValueError("code must define a callable named 'process'")
27
+ self.process = self.global_context['process']
28
+ self._flushable = hasattr(self.process, 'flushable') and self.process.flushable
29
+
30
+ def _normalize_outputs(self, result):
31
+ if result is None:
32
+ return []
33
+ if isinstance(result, (list, tuple, set)):
34
+ return list(result)
35
+ return [result]
36
+
37
+ def _flush(self):
38
+ if self._flushable:
39
+ while True:
40
+ flushed = self._normalize_outputs(self.process(None))
41
+ if not flushed:
42
+ return
43
+ yield from flushed
44
+
45
+ def __call__(self, iterable):
46
+ for sample in iterable:
47
+ results = self._normalize_outputs(self.process(sample))
48
+ yield from results
49
+ if not results:
50
+ yield from self._flush()
51
+ if self._flushable:
52
+ yield from self._flush()
53
+
54
+ def __len__(self):
55
+ if self._last_input_len is not None:
56
+ return self._last_input_len
57
+ raise TypeError("Length is not available for this transformation.")
58
+
59
+
60
+ class Transformations:
61
+ def __init__(self, codes: list[str], indices=None):
62
+ self.pipeline = [Transformation(code) for code in codes]
63
+ self.indices = indices # Optional index iterable
64
+
65
+ def __call__(self, dataset):
66
+ # Wrap dataset with IndexedDatasetView if indices are provided
67
+ if self.indices is not None:
68
+ dataset = IndexedDatasetView(dataset, self.indices)
69
+
70
+ result = dataset
71
+ for transform in self.pipeline:
72
+ result = transform(result)
73
+ return result
74
+
75
+ def __len__(self):
76
+ # Return the input length to the pipeline
77
+ if self.indices is not None:
78
+ return len(self.indices)
79
+ elif hasattr(self.pipeline[0], '_last_input_len') and self.pipeline[0]._last_input_len is not None:
80
+ return self.pipeline[0]._last_input_len
81
+ raise TypeError("Transformations length is not available until __call__ is used on a sized input.")
82
+
83
+ def identity(obj):
84
+ return obj
22
85
 
23
86
  def flatten_json(obj, parent_key='', sep='.', escape_char='\\'):
24
- items = []
25
-
26
87
  def escape(key):
27
88
  return key.replace(escape_char, escape_char * 2)\
28
89
  .replace(sep, escape_char + sep)\
29
90
  .replace('[', escape_char + '[')\
30
91
  .replace(']', escape_char + ']')
31
-
92
+ items = []
32
93
  if isinstance(obj, dict):
33
94
  if not obj:
34
95
  # explicitly handle empty dict
@@ -49,15 +110,12 @@ def flatten_json(obj, parent_key='', sep='.', escape_char='\\'):
49
110
  items.append((parent_key, obj))
50
111
  return dict(items)
51
112
 
52
-
53
113
  def unflatten_json(flat_dict, sep='.', escape_char='\\'):
54
-
55
114
  def check_flat_json(obj):
56
115
  assert isinstance(obj, dict), "Input must be a dictionary"
57
116
  for k, v in obj.items():
58
117
  assert isinstance(k, str), f"Key {k} is not a string"
59
118
  assert isinstance(v, (str, int, float, bool)), f"Value {v} is not a valid JSON type"
60
-
61
119
  def parse_key(key):
62
120
  tokens = re.findall(r'(?:[^.\[\]\\]|\\.)+|\[\d+\]', key)
63
121
  parsed = []
@@ -70,11 +128,8 @@ def unflatten_json(flat_dict, sep='.', escape_char='\\'):
70
128
  .replace(escape_char + ']', ']')
71
129
  .replace(escape_char*2, escape_char))
72
130
  return parsed
73
-
74
131
  check_flat_json(flat_dict)
75
-
76
132
  result = {}
77
-
78
133
  for compound_key, value in flat_dict.items():
79
134
  keys = parse_key(compound_key)
80
135
  current = result
@@ -107,5 +162,4 @@ def unflatten_json(flat_dict, sep='.', escape_char='\\'):
107
162
  current_parent = current
108
163
  current = current[key]
109
164
  last_key = key
110
-
111
165
  return result
@@ -12,7 +12,7 @@ from tqdm import tqdm
12
12
  from .compression import determine_compression, open_compression, pigz_compress
13
13
  from .mds import MDSBulkReader, MDSWriter
14
14
  from .pigz import pigz_open
15
- from .trafos import Trafo
15
+ from .trafos import Transformations
16
16
 
17
17
  __all__ = [
18
18
  "check_arguments",
@@ -115,9 +115,8 @@ def load_mds_directories(mds_directories, split='.', batch_size=2**16, bulk=True
115
115
  def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True, trafo=None):
116
116
  f = None
117
117
  part = 0
118
- trafo = Trafo(trafo)
119
- for item in tqdm(iterable, desc="Writing to JSONL", unit="sample", disable=_NO_PROGESS):
120
- item = trafo(item)
118
+ trafo = Transformations([] if trafo is None else [trafo])
119
+ for item in tqdm(trafo(iterable), desc="Writing to JSONL", unit="sample", disable=_NO_PROGESS):
121
120
  if f is None:
122
121
  part_file = output_file.format(part=part)
123
122
  check_arguments(part_file, overwrite, yes)
@@ -135,9 +134,8 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
135
134
  writer = None
136
135
  part = 0
137
136
  files = []
138
- trafo = Trafo(trafo)
139
- for sample in tqdm(it, desc="Writing to MDS", unit="sample", disable=_NO_PROGESS):
140
- sample = trafo(sample)
137
+ trafo = Transformations([] if trafo is None else [trafo])
138
+ for sample in tqdm(trafo(it), desc="Writing to MDS", unit="sample", disable=_NO_PROGESS):
141
139
  if writer is None:
142
140
  part_dir = output_dir.format(part=part)
143
141
  check_arguments(part_dir, overwrite, yes)
@@ -179,10 +177,9 @@ def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=
179
177
  compression = determine_compression("parquet", output_file, compression)
180
178
  writer = None
181
179
  part = 0
182
- trafo = Trafo(trafo)
180
+ trafo = Transformations([] if trafo is None else [trafo])
183
181
  it = tqdm(it, desc="Writing to Parquet", unit="sample", disable=_NO_PROGESS)
184
- for batch in _batch_iterable(it, batch_size):
185
- batch = [trafo(sample) for sample in batch]
182
+ for batch in _batch_iterable(trafo(it), batch_size):
186
183
  table = pa.Table.from_pylist(batch)
187
184
  if writer is None:
188
185
  part_file = output_file.format(part=part)
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "mldataforge"
7
- version = "0.2.0"
7
+ version = "0.2.1"
8
8
  authors = [
9
9
  { name = "Peter Schneider-Kamp" }
10
10
  ]
File without changes
File without changes
File without changes