mldataforge 0.1.6__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {mldataforge-0.1.6 → mldataforge-0.2.0}/PKG-INFO +1 -1
  2. {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/commands/convert/jsonl.py +6 -2
  3. {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/commands/convert/mds.py +6 -2
  4. {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/commands/convert/parquet.py +6 -2
  5. {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/commands/join.py +8 -3
  6. {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/commands/split.py +15 -3
  7. {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/mds.py +135 -4
  8. {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/options.py +12 -0
  9. mldataforge-0.2.0/mldataforge/trafos.py +111 -0
  10. {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/utils.py +20 -8
  11. {mldataforge-0.1.6 → mldataforge-0.2.0}/pyproject.toml +1 -1
  12. {mldataforge-0.1.6 → mldataforge-0.2.0}/.gitignore +0 -0
  13. {mldataforge-0.1.6 → mldataforge-0.2.0}/LICENSE +0 -0
  14. {mldataforge-0.1.6 → mldataforge-0.2.0}/README.md +0 -0
  15. {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/__main__.py +0 -0
  16. {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/brotli.py +0 -0
  17. {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/commands/__init__.py +0 -0
  18. {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/commands/convert/__init__.py +0 -0
  19. {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/compression.py +0 -0
  20. {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/pigz.py +0 -0
  21. {mldataforge-0.1.6 → mldataforge-0.2.0}/mldataforge/snappy.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mldataforge
3
- Version: 0.1.6
3
+ Version: 0.2.0
4
4
  Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
5
5
  Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
6
6
  Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
@@ -21,9 +21,10 @@ def jsonl():
21
21
  @buf_size_option()
22
22
  @shard_size_option()
23
23
  @no_pigz_option()
24
+ @trafo_option()
24
25
  def mds(**kwargs):
25
26
  jsonl_to_mds(**kwargs)
26
- def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz):
27
+ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz, trafo):
27
28
  check_arguments(output_dir, overwrite, yes, jsonl_files)
28
29
  save_mds(
29
30
  load_jsonl_files(jsonl_files),
@@ -33,6 +34,7 @@ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes
33
34
  buf_size=buf_size,
34
35
  pigz=use_pigz(compression, no_pigz),
35
36
  shard_size=shard_size,
37
+ trafo=trafo,
36
38
  )
37
39
 
38
40
  @jsonl.command()
@@ -42,13 +44,15 @@ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes
42
44
  @overwrite_option()
43
45
  @yes_option()
44
46
  @batch_size_option()
47
+ @trafo_option()
45
48
  def parquet(**kwargs):
46
49
  jsonl_to_parquet(**kwargs)
47
- def jsonl_to_parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size):
50
+ def jsonl_to_parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size, trafo):
48
51
  check_arguments(output_file, overwrite, yes, jsonl_files)
49
52
  save_parquet(
50
53
  load_jsonl_files(jsonl_files),
51
54
  output_file,
52
55
  compression=compression,
53
56
  batch_size=batch_size,
57
+ trafo=trafo,
54
58
  )
@@ -19,15 +19,17 @@ def mds():
19
19
  @yes_option()
20
20
  @batch_size_option()
21
21
  @no_bulk_option()
22
+ @trafo_option()
22
23
  def jsonl(**kwargs):
23
24
  mds_to_jsonl(**kwargs)
24
- def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk):
25
+ def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk, trafo):
25
26
  check_arguments(output_file, overwrite, yes, mds_directories)
26
27
  save_jsonl(
27
28
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
28
29
  output_file,
29
30
  compression=compression,
30
31
  processes=processes,
32
+ trafo=trafo,
31
33
  )
32
34
 
33
35
  @mds.command()
@@ -38,13 +40,15 @@ def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite
38
40
  @yes_option()
39
41
  @batch_size_option()
40
42
  @no_bulk_option()
43
+ @trafo_option()
41
44
  def parquet(**kwargs):
42
45
  mds_to_parquet(**kwargs)
43
- def mds_to_parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk):
46
+ def mds_to_parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk, trafo):
44
47
  check_arguments(output_file, overwrite, yes, mds_directories)
45
48
  save_parquet(
46
49
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
47
50
  output_file,
48
51
  compression=compression,
49
52
  batch_size=batch_size,
53
+ trafo=trafo,
50
54
  )
@@ -18,15 +18,17 @@ def parquet():
18
18
  @processes_option()
19
19
  @overwrite_option()
20
20
  @yes_option()
21
+ @trafo_option()
21
22
  def jsonl(**kwargs):
22
23
  parquet_to_jsonl(**kwargs)
23
- def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwrite, yes):
24
+ def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwrite, yes, trafo):
24
25
  check_arguments(output_file, overwrite, yes, parquet_files)
25
26
  save_jsonl(
26
27
  load_dataset("parquet", data_files=parquet_files, split="train"),
27
28
  output_file,
28
29
  compression=compression,
29
30
  processes=processes,
31
+ trafo=trafo,
30
32
  )
31
33
 
32
34
  @parquet.command()
@@ -39,9 +41,10 @@ def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwri
39
41
  @buf_size_option()
40
42
  @shard_size_option()
41
43
  @no_pigz_option()
44
+ @trafo_option()
42
45
  def mds(**kwargs):
43
46
  parquet_to_mds(**kwargs)
44
- def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz):
47
+ def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz, trafo):
45
48
  check_arguments(output_dir, overwrite, yes, parquet_files)
46
49
  save_mds(
47
50
  load_dataset("parquet", data_files=parquet_files, split="train"),
@@ -51,4 +54,5 @@ def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite,
51
54
  buf_size=buf_size,
52
55
  pigz=use_pigz(compression, no_pigz=no_pigz),
53
56
  shard_size=shard_size,
57
+ trafo=trafo,
54
58
  )
@@ -18,9 +18,10 @@ def join():
18
18
  @processes_option()
19
19
  @overwrite_option()
20
20
  @yes_option()
21
+ @trafo_option()
21
22
  def jsonl(**kwargs):
22
23
  join_jsonl(**kwargs)
23
- def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes):
24
+ def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes, trafo):
24
25
  check_arguments(output_file, overwrite, yes, jsonl_files)
25
26
  save_jsonl(
26
27
  load_jsonl_files(jsonl_files),
@@ -41,10 +42,11 @@ def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes)
41
42
  @no_bulk_option()
42
43
  @shard_size_option()
43
44
  @no_pigz_option()
45
+ @trafo_option()
44
46
  def mds(**kwargs):
45
47
  print(kwargs)
46
48
  join_mds(**kwargs)
47
- def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk, shard_size, no_pigz):
49
+ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk, shard_size, no_pigz, trafo):
48
50
  check_arguments(output_dir, overwrite, yes, mds_directories)
49
51
  save_mds(
50
52
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
@@ -54,6 +56,7 @@ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes
54
56
  buf_size=buf_size,
55
57
  shard_size=shard_size,
56
58
  pigz=use_pigz(compression, no_pigz),
59
+ trafo=trafo,
57
60
  )
58
61
 
59
62
  @join.command()
@@ -63,13 +66,15 @@ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes
63
66
  @overwrite_option()
64
67
  @yes_option()
65
68
  @batch_size_option()
69
+ @trafo_option()
66
70
  def parquet(**kwargs):
67
71
  join_parquet(**kwargs)
68
- def join_parquet(output_file, parquet_files, compression, overwrite, yes, batch_size):
72
+ def join_parquet(output_file, parquet_files, compression, overwrite, yes, batch_size, trafo):
69
73
  check_arguments(output_file, overwrite, yes, parquet_files)
70
74
  save_parquet(
71
75
  load_dataset("parquet", data_files=parquet_files, split="train"),
72
76
  output_file,
73
77
  compression=compression,
74
78
  batch_size=batch_size,
79
+ trafo=trafo,
75
80
  )
@@ -20,7 +20,10 @@ def split():
20
20
  @processes_option()
21
21
  @overwrite_option()
22
22
  @yes_option()
23
- def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, overwrite, yes):
23
+ @trafo_option()
24
+ def jsonl(*args, **kwargs):
25
+ split_jsonl(*args, **kwargs)
26
+ def split_jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, overwrite, yes, trafo):
24
27
  save_jsonl(
25
28
  load_jsonl_files(jsonl_files),
26
29
  output_file=f"{output_dir}/{prefix}{{part:04d}}.jsonl{extension_compression(compression, jsonl_files[0])}",
@@ -29,6 +32,7 @@ def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, ov
29
32
  size_hint=size_hint,
30
33
  overwrite=overwrite,
31
34
  yes=yes,
35
+ trafo=trafo,
32
36
  )
33
37
 
34
38
  @split.command()
@@ -45,7 +49,10 @@ def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, ov
45
49
  @no_bulk_option()
46
50
  @shard_size_option()
47
51
  @no_pigz_option()
48
- def mds(mds_directories, prefix, output_dir, size_hint, compression, processes, overwrite, yes, buf_size, batch_size, no_bulk, shard_size, no_pigz):
52
+ @trafo_option()
53
+ def mds(*args, **kwargs):
54
+ split_mds(*args, **kwargs)
55
+ def split_mds(mds_directories, prefix, output_dir, size_hint, compression, processes, overwrite, yes, buf_size, batch_size, no_bulk, shard_size, no_pigz, trafo):
49
56
  save_mds(
50
57
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
51
58
  output_dir=f"{output_dir}/{prefix}{{part:04d}}",
@@ -57,6 +64,7 @@ def mds(mds_directories, prefix, output_dir, size_hint, compression, processes,
57
64
  size_hint=size_hint,
58
65
  overwrite=overwrite,
59
66
  yes=yes,
67
+ trafo=trafo,
60
68
  )
61
69
 
62
70
  @split.command()
@@ -68,7 +76,10 @@ def mds(mds_directories, prefix, output_dir, size_hint, compression, processes,
68
76
  @overwrite_option()
69
77
  @yes_option()
70
78
  @batch_size_option()
71
- def parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite, yes, batch_size):
79
+ @trafo_option()
80
+ def parquet(*args, **kwargs):
81
+ split_parquet(*args, **kwargs)
82
+ def split_parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite, yes, batch_size, trafo):
72
83
  save_parquet(
73
84
  load_dataset("parquet", data_files=parquet_files, split="train"),
74
85
  output_file=f"{output_dir}/{prefix}{{part:04d}}.parquet",
@@ -77,4 +88,5 @@ def parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite
77
88
  size_hint=size_hint,
78
89
  overwrite=overwrite,
79
90
  yes=yes,
91
+ trafo=trafo,
80
92
  )
@@ -1,16 +1,26 @@
1
+ from copy import deepcopy
1
2
  import json
2
3
  import numpy as np
3
4
  import os
4
5
  import shutil
5
6
  from streaming.base.compression import compress, decompress, get_compression_extension, is_compression
7
+ from streaming.base.format import _readers
8
+ from streaming.base.format.base.reader import FileInfo, JointReader
6
9
  from streaming.base.format.index import get_index_basename
7
- from streaming.base.format.mds.encodings import mds_decode, mds_encode, is_mds_encoding, get_mds_encodings, get_mds_encoded_size
10
+ from streaming.base.format.mds.encodings import mds_decode, mds_encode, is_mds_encoding, is_mds_encoding_safe, get_mds_encodings, get_mds_encoded_size
8
11
  from streaming.base.hashing import get_hash, is_hash
9
12
  from streaming.base.util import bytes_to_int
10
13
  from typing import Any, Optional, Generator, Self, Union
11
14
 
12
15
  from .utils import open_compression
13
16
 
17
+ __all__ = [
18
+ "MDSBulkReader",
19
+ "MDSBulkShardReader",
20
+ "MDSReader",
21
+ "MDSWriter",
22
+ ]
23
+
14
24
  class MDSBulkReader:
15
25
  def __init__(
16
26
  self,
@@ -37,11 +47,11 @@ class MDSBulkReader:
37
47
 
38
48
  def __iter__(self) -> Generator[dict[str, Any], None, None]:
39
49
  for shard in self.shards:
40
- with MDSShardReader(**shard) as reader:
50
+ with MDSBulkShardReader(**shard) as reader:
41
51
  for sample in reader:
42
52
  yield sample
43
53
 
44
- class MDSShardReader:
54
+ class MDSBulkShardReader:
45
55
  def __init__(
46
56
  self,
47
57
  filename: str,
@@ -94,7 +104,7 @@ class MDSShardReader:
94
104
  for i in range(self.samples):
95
105
  yield self.get_item(i)
96
106
 
97
- def __enter__(self) -> "MDSShardReader":
107
+ def __enter__(self) -> "MDSBulkShardReader":
98
108
  return self
99
109
 
100
110
  def __exit__(self, exc_type, exc_value, traceback) -> None:
@@ -315,3 +325,124 @@ class MDSWriter:
315
325
 
316
326
  def __exit__(self, exc_type, exc, traceback):
317
327
  self.finish()
328
+
329
+ class MDSReader(JointReader):
330
+
331
+ def __init__(
332
+ self,
333
+ dirname: str,
334
+ split: Optional[str],
335
+ column_encodings: list[str],
336
+ column_names: list[str],
337
+ column_sizes: list[Optional[int]],
338
+ compression: Optional[str],
339
+ hashes: list[str],
340
+ raw_data: FileInfo,
341
+ samples: int,
342
+ size_limit: Optional[Union[int, str]],
343
+ zip_data: Optional[FileInfo],
344
+ ) -> None:
345
+ self.sample_compression = None
346
+ if compression and compression.startswith("sample::"):
347
+ compression, self.sample_compression = None, compression.removeprefix("sample::")
348
+ super().__init__(dirname, split, compression, hashes, raw_data, samples, size_limit,
349
+ zip_data)
350
+ self.column_encodings = column_encodings
351
+ self.column_names = column_names
352
+ self.column_sizes = column_sizes
353
+
354
+ @classmethod
355
+ def from_json(cls, dirname: str, split: Optional[str], obj: dict[str, Any]) -> Self:
356
+ """Initialize from JSON object.
357
+
358
+ Args:
359
+ dirname (str): Local directory containing shards.
360
+ split (str, optional): Which dataset split to use, if any.
361
+ obj (Dict[str, Any]): JSON object to load.
362
+
363
+ Returns:
364
+ Self: Loaded MDSReader.
365
+ """
366
+ args = deepcopy(obj)
367
+ args_version = args['version']
368
+ if args_version != 2:
369
+ raise ValueError(
370
+ f'Unsupported streaming data version: {args_version}. Expected version 2.')
371
+ del args['version']
372
+ args_format = args['format']
373
+ if args_format != 'mds':
374
+ raise ValueError(f'Unsupported data format: {args_format}. Expected to be `mds`.')
375
+ del args['format']
376
+ args['dirname'] = dirname
377
+ args['split'] = split
378
+ for key in ['raw_data', 'zip_data']:
379
+ arg = args[key]
380
+ args[key] = FileInfo(**arg) if arg else None
381
+ return cls(**args)
382
+
383
+ def validate(self, allow_unsafe_types: bool) -> None:
384
+ """Check whether this shard is acceptable to be part of some Stream.
385
+
386
+ Args:
387
+ allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
388
+ execution during deserialization, whether to keep going if ``True`` or raise an
389
+ error if ``False``.
390
+ """
391
+ if not allow_unsafe_types:
392
+ for column_id, encoding in enumerate(self.column_encodings):
393
+ if not is_mds_encoding_safe(encoding):
394
+ name = self.column_names[column_id]
395
+ raise ValueError(f'Column {name} contains an unsafe type: {encoding}. To ' +
396
+ f'proceed anyway, set ``allow_unsafe_types=True``.')
397
+
398
+ def decode_sample(self, data: bytes) -> dict[str, Any]:
399
+ """Decode a sample dict from bytes.
400
+
401
+ Args:
402
+ data (bytes): The sample encoded as bytes.
403
+
404
+ Returns:
405
+ Dict[str, Any]: Sample dict.
406
+ """
407
+ sizes = []
408
+ idx = 0
409
+ for key, size in zip(self.column_names, self.column_sizes):
410
+ if size:
411
+ sizes.append(size)
412
+ else:
413
+ size, = np.frombuffer(data[idx:idx + 4], np.uint32)
414
+ sizes.append(size)
415
+ idx += 4
416
+ sample = {}
417
+ for key, encoding, size in zip(self.column_names, self.column_encodings, sizes):
418
+ value = data[idx:idx + size]
419
+ sample[key] = mds_decode(encoding, value)
420
+ idx += size
421
+ return sample
422
+
423
+ def get_sample_data(self, idx: int) -> bytes:
424
+ """Get the raw sample data at the index.
425
+
426
+ Args:
427
+ idx (int): Sample index.
428
+
429
+ Returns:
430
+ bytes: Sample data.
431
+ """
432
+ filename = os.path.join(self.dirname, self.split, self.raw_data.basename)
433
+ offset = (1 + idx) * 4
434
+ with open(filename, 'rb', 0) as fp:
435
+ fp.seek(offset)
436
+ pair = fp.read(8)
437
+ begin, end = np.frombuffer(pair, np.uint32)
438
+ fp.seek(begin)
439
+ data = fp.read(end - begin)
440
+ if not data:
441
+ raise IndexError(
442
+ f'Relative sample index {idx} is not present in the {self.raw_data.basename} file.'
443
+ )
444
+ if self.sample_compression:
445
+ data = decompress(self.sample_compression, data)
446
+ return data
447
+
448
+ _readers["mds"] = MDSReader
@@ -14,6 +14,7 @@ __all__ = [
14
14
  "prefix_option",
15
15
  "shard_size_option",
16
16
  "size_hint_option",
17
+ "trafo_option",
17
18
  "yes_option",
18
19
  ]
19
20
 
@@ -129,6 +130,17 @@ def size_hint_option(default=2**26):
129
130
  help=f"Size hint for the dataset (default: {default}).",
130
131
  )
131
132
 
133
+ def trafo_option():
134
+ """
135
+ Option for specifying the transformation function.
136
+ """
137
+ return click.option(
138
+ "--trafo",
139
+ default=None,
140
+ type=str,
141
+ help="Transformation function to apply to the dataset.",
142
+ )
143
+
132
144
  def yes_option():
133
145
  """
134
146
  Option for specifying whether to assume yes to all prompts.
@@ -0,0 +1,111 @@
1
+ import re
2
+ from typing import Callable
3
+
4
+ __all__ = ['Trafo', 'flatten_json', 'unflatten_json']
5
+
6
+ class Trafo:
7
+ """
8
+ Base class for transformations.
9
+ """
10
+
11
+ def __init__(self, trafo: Callable | str | None):
12
+ self.trafo = trafo
13
+ if isinstance(trafo, str):
14
+ self.trafo = eval(trafo)
15
+
16
+ def __call__(self, obj):
17
+ return self.trafo(obj) if self.trafo else obj
18
+
19
+ def __repr__(self):
20
+ return f"{self.__class__.__name__}({self.trafo})"
21
+
22
+
23
+ def flatten_json(obj, parent_key='', sep='.', escape_char='\\'):
24
+ items = []
25
+
26
+ def escape(key):
27
+ return key.replace(escape_char, escape_char * 2)\
28
+ .replace(sep, escape_char + sep)\
29
+ .replace('[', escape_char + '[')\
30
+ .replace(']', escape_char + ']')
31
+
32
+ if isinstance(obj, dict):
33
+ if not obj:
34
+ # explicitly handle empty dict
35
+ items.append((parent_key, {}))
36
+ else:
37
+ for k, v in obj.items():
38
+ new_key = f"{parent_key}{sep}{escape(k)}" if parent_key else escape(k)
39
+ items.extend(flatten_json(v, new_key, sep, escape_char).items())
40
+ elif isinstance(obj, list):
41
+ if not obj:
42
+ # explicitly handle empty list
43
+ items.append((parent_key, []))
44
+ else:
45
+ for idx, v in enumerate(obj):
46
+ new_key = f"{parent_key}[{idx}]"
47
+ items.extend(flatten_json(v, new_key, sep, escape_char).items())
48
+ else:
49
+ items.append((parent_key, obj))
50
+ return dict(items)
51
+
52
+
53
+ def unflatten_json(flat_dict, sep='.', escape_char='\\'):
54
+
55
+ def check_flat_json(obj):
56
+ assert isinstance(obj, dict), "Input must be a dictionary"
57
+ for k, v in obj.items():
58
+ assert isinstance(k, str), f"Key {k} is not a string"
59
+ assert isinstance(v, (str, int, float, bool)), f"Value {v} is not a valid JSON type"
60
+
61
+ def parse_key(key):
62
+ tokens = re.findall(r'(?:[^.\[\]\\]|\\.)+|\[\d+\]', key)
63
+ parsed = []
64
+ for token in tokens:
65
+ if token.startswith('['):
66
+ parsed.append(int(token[1:-1]))
67
+ else:
68
+ parsed.append(token.replace(escape_char + sep, sep)
69
+ .replace(escape_char + '[', '[')
70
+ .replace(escape_char + ']', ']')
71
+ .replace(escape_char*2, escape_char))
72
+ return parsed
73
+
74
+ check_flat_json(flat_dict)
75
+
76
+ result = {}
77
+
78
+ for compound_key, value in flat_dict.items():
79
+ keys = parse_key(compound_key)
80
+ current = result
81
+ for idx, key in enumerate(keys):
82
+ if idx == len(keys) - 1:
83
+ if isinstance(key, int):
84
+ if not isinstance(current, list):
85
+ current_parent[last_key] = []
86
+ current = current_parent[last_key]
87
+ while len(current) <= key:
88
+ current.append(None)
89
+ current[key] = value
90
+ else:
91
+ current[key] = value
92
+ else:
93
+ next_key = keys[idx + 1]
94
+ if isinstance(key, int):
95
+ if not isinstance(current, list):
96
+ current_parent[last_key] = []
97
+ current = current_parent[last_key]
98
+ while len(current) <= key:
99
+ current.append(None)
100
+ if current[key] is None:
101
+ current[key] = [] if isinstance(next_key, int) else {}
102
+ current_parent = current
103
+ current = current[key]
104
+ else:
105
+ if key not in current:
106
+ current[key] = [] if isinstance(next_key, int) else {}
107
+ current_parent = current
108
+ current = current[key]
109
+ last_key = key
110
+
111
+ return result
@@ -12,6 +12,7 @@ from tqdm import tqdm
12
12
  from .compression import determine_compression, open_compression, pigz_compress
13
13
  from .mds import MDSBulkReader, MDSWriter
14
14
  from .pigz import pigz_open
15
+ from .trafos import Trafo
15
16
 
16
17
  __all__ = [
17
18
  "check_arguments",
@@ -23,6 +24,11 @@ __all__ = [
23
24
  "save_parquet",
24
25
  ]
25
26
 
27
+ _NO_PROGESS = False
28
+ def set_progress(value):
29
+ global _NO_PROGESS
30
+ _NO_PROGESS = value
31
+
26
32
  def _batch_iterable(iterable, batch_size):
27
33
  batch = []
28
34
  for item in iterable:
@@ -73,7 +79,7 @@ def _infer_mds_encoding(value):
73
79
  return 'pkl'
74
80
 
75
81
  def _streaming_jsonl(jsonl_files, compressions):
76
- for jsonl_file, compression in tqdm(zip(jsonl_files, compressions), desc="Loading JSONL files", unit="file"):
82
+ for jsonl_file, compression in tqdm(zip(jsonl_files, compressions), desc="Loading JSONL files", unit="file", disable=_NO_PROGESS):
77
83
  for line in open_compression(jsonl_file, mode="rt", compression=compression):
78
84
  yield json.loads(line)
79
85
 
@@ -106,10 +112,12 @@ def load_mds_directories(mds_directories, split='.', batch_size=2**16, bulk=True
106
112
  ds = concatenate_datasets(dsets=dss)
107
113
  return ds
108
114
 
109
- def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True):
115
+ def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True, trafo=None):
110
116
  f = None
111
117
  part = 0
112
- for item in tqdm(iterable, desc="Writing to JSONL", unit="sample"):
118
+ trafo = Trafo(trafo)
119
+ for item in tqdm(iterable, desc="Writing to JSONL", unit="sample", disable=_NO_PROGESS):
120
+ item = trafo(item)
113
121
  if f is None:
114
122
  part_file = output_file.format(part=part)
115
123
  check_arguments(part_file, overwrite, yes)
@@ -122,12 +130,14 @@ def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=
122
130
  if f is not None:
123
131
  f.close()
124
132
 
125
- def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pigz=True, shard_size=None, size_hint=None, overwrite=True, yes=True):
133
+ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pigz=True, shard_size=None, size_hint=None, overwrite=True, yes=True, trafo=None):
126
134
  compression = determine_compression("mds", output_dir, compression, no_pigz=not pigz)
127
135
  writer = None
128
136
  part = 0
129
137
  files = []
130
- for sample in tqdm(it, desc="Writing to MDS", unit="sample"):
138
+ trafo = Trafo(trafo)
139
+ for sample in tqdm(it, desc="Writing to MDS", unit="sample", disable=_NO_PROGESS):
140
+ sample = trafo(sample)
131
141
  if writer is None:
132
142
  part_dir = output_dir.format(part=part)
133
143
  check_arguments(part_dir, overwrite, yes)
@@ -151,7 +161,7 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
151
161
  name2info = {shard["raw_data"]["basename"]: shard for shard in index["shards"]}
152
162
  file_names = [file for file in os.listdir(output_dir) if file.endswith(".mds")]
153
163
  assert set(file_names) == set(name2info.keys())
154
- for file_name in tqdm(file_names, desc="Compressing with pigz", unit="file"):
164
+ for file_name in tqdm(file_names, desc="Compressing with pigz", unit="file", disable=_NO_PROGESS):
155
165
  compressed_file_name = file_name + ".gz"
156
166
  file_path = os.path.join(output_dir, file_name)
157
167
  compressed_file_path = os.path.join(output_dir, compressed_file_name)
@@ -165,12 +175,14 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
165
175
  json.dump(index, open(index_path, "wt"))
166
176
  print(f"Compressed {output_dir} with pigz")
167
177
 
168
- def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=None, overwrite=True, yes=True):
178
+ def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=None, overwrite=True, yes=True, trafo=None):
169
179
  compression = determine_compression("parquet", output_file, compression)
170
180
  writer = None
171
181
  part = 0
172
- it = tqdm(it, desc="Writing to Parquet", unit="sample")
182
+ trafo = Trafo(trafo)
183
+ it = tqdm(it, desc="Writing to Parquet", unit="sample", disable=_NO_PROGESS)
173
184
  for batch in _batch_iterable(it, batch_size):
185
+ batch = [trafo(sample) for sample in batch]
174
186
  table = pa.Table.from_pylist(batch)
175
187
  if writer is None:
176
188
  part_file = output_file.format(part=part)
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "mldataforge"
7
- version = "0.1.6"
7
+ version = "0.2.0"
8
8
  authors = [
9
9
  { name = "Peter Schneider-Kamp" }
10
10
  ]
File without changes
File without changes
File without changes