mldataforge 0.1.3__tar.gz → 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mldataforge
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
5
5
  Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
6
6
  Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
@@ -10,10 +10,15 @@ Classifier: License :: OSI Approved :: MIT License
10
10
  Classifier: Operating System :: OS Independent
11
11
  Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.12
13
+ Requires-Dist: brotlicffi
13
14
  Requires-Dist: click
14
15
  Requires-Dist: datasets
16
+ Requires-Dist: isal
17
+ Requires-Dist: lz4
15
18
  Requires-Dist: mltiming
16
19
  Requires-Dist: mosaicml-streaming
20
+ Requires-Dist: python-snappy
21
+ Requires-Dist: zstandard
17
22
  Provides-Extra: all
18
23
  Requires-Dist: build; extra == 'all'
19
24
  Requires-Dist: pytest; extra == 'all'
@@ -0,0 +1,82 @@
1
+ import brotlicffi as brotli
2
+ import io
3
+
4
+ __all__ = ["brotli_open"]
5
+
6
+ def brotli_open(filename, mode='rb', encoding='utf-8', compress_level=11):
7
+ return BrotliFile(filename, mode=mode, encoding=encoding, compress_level=11)
8
+
9
+ import brotlicffi as brotli
10
+ import io
11
+
12
+ __all__ = ["brotli_open"]
13
+
14
+ class BrotliFile:
15
+ def __init__(self, filename, mode='rb', encoding='utf-8', compress_level=11):
16
+ self.filename = filename
17
+ self.mode = mode
18
+ self.encoding = encoding
19
+ self.compress_level = compress_level
20
+
21
+ self.binary = 'b' in mode
22
+ file_mode = mode.replace('t', 'b')
23
+ self.file = open(filename, file_mode)
24
+
25
+ if 'r' in mode:
26
+ self._decompressor = brotli.Decompressor()
27
+ self._stream = self._wrap_reader()
28
+ elif 'w' in mode:
29
+ self._compressor = brotli.Compressor(quality=compress_level)
30
+ self._stream = self._wrap_writer()
31
+ else:
32
+ raise ValueError("Unsupported mode (use 'rb', 'wb', 'rt', or 'wt')")
33
+
34
+ def _wrap_reader(self):
35
+ buffer = io.BytesIO()
36
+ while True:
37
+ chunk = self.file.read(8192)
38
+ if not chunk:
39
+ break
40
+ buffer.write(self._decompressor.process(chunk))
41
+ buffer.seek(0)
42
+ return buffer if self.binary else io.TextIOWrapper(buffer, encoding=self.encoding)
43
+
44
+ def _wrap_writer(self):
45
+ return self if self.binary else io.TextIOWrapper(self, encoding=self.encoding)
46
+
47
+ def write(self, data):
48
+ if isinstance(data, str):
49
+ data = data.encode(self.encoding)
50
+ compressed = self._compressor.process(data)
51
+ self.file.write(compressed)
52
+ return len(data)
53
+
54
+ def flush(self):
55
+ if hasattr(self, '_compressor'):
56
+ self.file.write(self._compressor.finish())
57
+ self.file.flush()
58
+
59
+ def read(self, *args, **kwargs):
60
+ return self._stream.read(*args, **kwargs)
61
+
62
+ def readline(self, *args, **kwargs):
63
+ return self._stream.readline(*args, **kwargs)
64
+
65
+ def __iter__(self):
66
+ return iter(self._stream)
67
+
68
+ def close(self):
69
+ try:
70
+ if hasattr(self._stream, 'flush'):
71
+ self._stream.flush()
72
+ finally:
73
+ self.file.close()
74
+
75
+ def __enter__(self):
76
+ return self
77
+
78
+ def __exit__(self, exc_type, exc_val, exc_tb):
79
+ self.close()
80
+
81
+ def tell(self):
82
+ return self._stream.tell()
@@ -1,10 +1,11 @@
1
1
  import click
2
2
  from datasets import load_dataset
3
3
 
4
+ from ...compression import *
4
5
  from ...options import *
5
6
  from ...utils import *
6
7
 
7
- __all__ = ["jsonl"]
8
+ __all__ = ["jsonl_to_mds", "jsonl_to_parquet"]
8
9
 
9
10
  @click.group()
10
11
  def jsonl():
@@ -13,35 +14,40 @@ def jsonl():
13
14
  @jsonl.command()
14
15
  @click.argument('output_dir', type=click.Path(exists=False))
15
16
  @click.argument('jsonl_files', nargs=-1, type=click.Path(exists=True))
16
- @compression_option(None, ['none', 'br', 'bz2', 'gzip', 'pigz', 'snappy', 'zstd'])
17
+ @compression_option(MDS_COMPRESSIONS)
17
18
  @overwrite_option()
18
19
  @yes_option()
19
20
  @processes_option()
20
21
  @buf_size_option()
21
22
  @shard_size_option()
22
- def mds(output_dir, jsonl_files, compression, processes, overwrite, yes, buf_size, shard_size):
23
+ @no_pigz_option()
24
+ def mds(**kwargs):
25
+ jsonl_to_mds(**kwargs)
26
+ def jsonl_to_mds(output_dir, jsonl_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz):
23
27
  check_arguments(output_dir, overwrite, yes, jsonl_files)
24
28
  save_mds(
25
- load_dataset("json", data_files=jsonl_files, split="train"),
29
+ load_jsonl_files(jsonl_files),
26
30
  output_dir,
27
31
  processes=processes,
28
32
  compression=compression,
29
33
  buf_size=buf_size,
30
- pigz=use_pigz(compression),
34
+ pigz=use_pigz(compression, no_pigz),
31
35
  shard_size=shard_size,
32
36
  )
33
37
 
34
38
  @jsonl.command()
35
39
  @click.argument('output_file', type=click.Path(exists=False))
36
40
  @click.argument('jsonl_files', nargs=-1, type=click.Path(exists=True))
37
- @compression_option("snappy", ["snappy", "gzip", "zstd"])
41
+ @compression_option(PARQUET_COMPRESSIONS)
38
42
  @overwrite_option()
39
43
  @yes_option()
40
44
  @batch_size_option()
41
- def parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size):
45
+ def parquet(**kwargs):
46
+ jsonl_to_parquet(**kwargs)
47
+ def jsonl_to_parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size):
42
48
  check_arguments(output_file, overwrite, yes, jsonl_files)
43
49
  save_parquet(
44
- load_dataset("json", data_files=jsonl_files, split="train"),
50
+ load_jsonl_files(jsonl_files),
45
51
  output_file,
46
52
  compression=compression,
47
53
  batch_size=batch_size,
@@ -1,9 +1,10 @@
1
1
  import click
2
2
 
3
+ from ...compression import *
3
4
  from ...options import *
4
5
  from ...utils import *
5
6
 
6
- __all__ = ["mds"]
7
+ __all__ = ["mds_to_jsonl", "mds_to_parquet"]
7
8
 
8
9
  @click.group()
9
10
  def mds():
@@ -12,13 +13,15 @@ def mds():
12
13
  @mds.command()
13
14
  @click.argument("output_file", type=click.Path(exists=False), required=True)
14
15
  @click.argument("mds_directories", type=click.Path(exists=True), required=True, nargs=-1)
15
- @compression_option("infer", ["none", "infer", "pigz", "gzip", "bz2", "xz"])
16
+ @compression_option(JSONL_COMPRESSIONS)
16
17
  @processes_option()
17
18
  @overwrite_option()
18
19
  @yes_option()
19
20
  @batch_size_option()
20
21
  @no_bulk_option()
21
- def jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk):
22
+ def jsonl(**kwargs):
23
+ mds_to_jsonl(**kwargs)
24
+ def mds_to_jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk):
22
25
  check_arguments(output_file, overwrite, yes, mds_directories)
23
26
  save_jsonl(
24
27
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
@@ -30,12 +33,14 @@ def jsonl(output_file, mds_directories, compression, processes, overwrite, yes,
30
33
  @mds.command()
31
34
  @click.argument("output_file", type=click.Path(exists=False), required=True)
32
35
  @click.argument("mds_directories", type=click.Path(exists=True), required=True, nargs=-1)
33
- @compression_option("snappy", ["snappy", "gzip", "zstd"])
36
+ @compression_option(PARQUET_COMPRESSIONS)
34
37
  @overwrite_option()
35
38
  @yes_option()
36
39
  @batch_size_option()
37
40
  @no_bulk_option()
38
- def parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk):
41
+ def parquet(**kwargs):
42
+ mds_to_parquet(**kwargs)
43
+ def mds_to_parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk):
39
44
  check_arguments(output_file, overwrite, yes, mds_directories)
40
45
  save_parquet(
41
46
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
@@ -1,10 +1,11 @@
1
1
  import click
2
2
  from datasets import load_dataset
3
3
 
4
+ from ...compression import *
4
5
  from ...options import *
5
6
  from ...utils import *
6
7
 
7
- __all__ = ["parquet"]
8
+ __all__ = ["parquet_to_jsonl", "parquet_to_mds"]
8
9
 
9
10
  @click.group()
10
11
  def parquet():
@@ -13,11 +14,13 @@ def parquet():
13
14
  @parquet.command()
14
15
  @click.argument("output_file", type=click.Path(exists=False), required=True)
15
16
  @click.argument("parquet_files", type=click.Path(exists=True), required=True, nargs=-1)
16
- @compression_option("infer", ["none", "infer", "pigz", "gzip", "bz2", "xz"])
17
+ @compression_option(JSONL_COMPRESSIONS)
17
18
  @processes_option()
18
19
  @overwrite_option()
19
20
  @yes_option()
20
- def jsonl(output_file, parquet_files, compression, processes, overwrite, yes):
21
+ def jsonl(**kwargs):
22
+ parquet_to_jsonl(**kwargs)
23
+ def parquet_to_jsonl(output_file, parquet_files, compression, processes, overwrite, yes):
21
24
  check_arguments(output_file, overwrite, yes, parquet_files)
22
25
  save_jsonl(
23
26
  load_dataset("parquet", data_files=parquet_files, split="train"),
@@ -29,13 +32,16 @@ def jsonl(output_file, parquet_files, compression, processes, overwrite, yes):
29
32
  @parquet.command()
30
33
  @click.argument('output_dir', type=click.Path(exists=False))
31
34
  @click.argument('parquet_files', nargs=-1, type=click.Path(exists=True))
32
- @compression_option(None, ['none', 'br', 'bz2', 'gzip', 'pigz', 'snappy', 'zstd'])
35
+ @compression_option(MDS_COMPRESSIONS)
33
36
  @processes_option()
34
37
  @overwrite_option()
35
38
  @yes_option()
36
39
  @buf_size_option()
37
40
  @shard_size_option()
38
- def mds(output_dir, parquet_files, compression, processes, overwrite, yes, buf_size, shard_size):
41
+ @no_pigz_option()
42
+ def mds(**kwargs):
43
+ parquet_to_mds(**kwargs)
44
+ def parquet_to_mds(output_dir, parquet_files, compression, processes, overwrite, yes, buf_size, shard_size, no_pigz):
39
45
  check_arguments(output_dir, overwrite, yes, parquet_files)
40
46
  save_mds(
41
47
  load_dataset("parquet", data_files=parquet_files, split="train"),
@@ -43,6 +49,6 @@ def mds(output_dir, parquet_files, compression, processes, overwrite, yes, buf_s
43
49
  processes=processes,
44
50
  compression=compression,
45
51
  buf_size=buf_size,
46
- pigz=use_pigz(compression),
52
+ pigz=use_pigz(compression, no_pigz=no_pigz),
47
53
  shard_size=shard_size,
48
54
  )
@@ -1,10 +1,11 @@
1
1
  import click
2
2
  from datasets import load_dataset
3
3
 
4
+ from ..compression import *
4
5
  from ..options import *
5
6
  from ..utils import *
6
7
 
7
- __all__ = ["join"]
8
+ __all__ = ["join_jsonl", "join_mds", "join_parquet"]
8
9
 
9
10
  @click.group()
10
11
  def join():
@@ -13,14 +14,16 @@ def join():
13
14
  @join.command()
14
15
  @click.argument("output_file", type=click.Path(exists=False), required=True)
15
16
  @click.argument("jsonl_files", type=click.Path(exists=True), required=True, nargs=-1)
16
- @compression_option("infer", ["none", "infer", "pigz", "gzip", "bz2", "xz"])
17
+ @compression_option(JSONL_COMPRESSIONS)
17
18
  @processes_option()
18
19
  @overwrite_option()
19
20
  @yes_option()
20
- def jsonl(output_file, jsonl_files, compression, processes, overwrite, yes):
21
+ def jsonl(**kwargs):
22
+ join_jsonl(**kwargs)
23
+ def join_jsonl(output_file, jsonl_files, compression, processes, overwrite, yes):
21
24
  check_arguments(output_file, overwrite, yes, jsonl_files)
22
25
  save_jsonl(
23
- load_dataset("json", data_files=jsonl_files, split="train"),
26
+ load_jsonl_files(jsonl_files),
24
27
  output_file,
25
28
  compression=compression,
26
29
  processes=processes,
@@ -29,14 +32,19 @@ def jsonl(output_file, jsonl_files, compression, processes, overwrite, yes):
29
32
  @join.command()
30
33
  @click.argument("output_dir", type=click.Path(exists=False), required=True)
31
34
  @click.argument("mds_directories", type=click.Path(exists=True), required=True, nargs=-1)
32
- @compression_option(None, ['none', 'br', 'bz2', 'gzip', 'pigz', 'snappy', 'zstd'])
35
+ @compression_option(MDS_COMPRESSIONS)
33
36
  @processes_option()
34
37
  @overwrite_option()
35
38
  @yes_option()
36
39
  @batch_size_option()
37
40
  @buf_size_option()
38
41
  @no_bulk_option()
39
- def mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk):
42
+ @shard_size_option()
43
+ @no_pigz_option()
44
+ def mds(**kwargs):
45
+ print(kwargs)
46
+ join_mds(**kwargs)
47
+ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes, batch_size, buf_size, no_bulk, shard_size, no_pigz):
40
48
  check_arguments(output_dir, overwrite, yes, mds_directories)
41
49
  save_mds(
42
50
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
@@ -44,17 +52,20 @@ def mds(output_dir, mds_directories, compression, processes, overwrite, yes, bat
44
52
  processes=processes,
45
53
  compression=compression,
46
54
  buf_size=buf_size,
47
- pigz=use_pigz(compression),
55
+ shard_size=shard_size,
56
+ pigz=use_pigz(compression, no_pigz)
48
57
  )
49
58
 
50
59
  @join.command()
51
60
  @click.argument("output_file", type=click.Path(exists=False), required=True)
52
61
  @click.argument("parquet_files", type=click.Path(exists=True), required=True, nargs=-1)
53
- @compression_option("snappy", ["snappy", "gzip", "zstd"])
62
+ @compression_option(PARQUET_COMPRESSIONS)
54
63
  @overwrite_option()
55
64
  @yes_option()
56
65
  @batch_size_option()
57
- def parquet(output_file, parquet_files, compression, overwrite, yes, batch_size):
66
+ def parquet(**kwargs):
67
+ join_parquet(**kwargs)
68
+ def join_parquet(output_file, parquet_files, compression, overwrite, yes, batch_size):
58
69
  check_arguments(output_file, overwrite, yes, parquet_files)
59
70
  save_parquet(
60
71
  load_dataset("parquet", data_files=parquet_files, split="train"),
@@ -1,10 +1,11 @@
1
1
  import click
2
2
  from datasets import load_dataset
3
3
 
4
+ from ..compression import *
4
5
  from ..options import *
5
6
  from ..utils import *
6
7
 
7
- __all__ = ["split"]
8
+ __all__ = ["split_jsonl", "split_mds", "split_parquet"]
8
9
 
9
10
  @click.group()
10
11
  def split():
@@ -15,14 +16,14 @@ def split():
15
16
  @prefix_option()
16
17
  @output_dir_option()
17
18
  @size_hint_option()
18
- @compression_option("infer", ["none", "infer", "pigz", "gzip", "bz2", "xz"])
19
+ @compression_option(JSONL_COMPRESSIONS)
19
20
  @processes_option()
20
21
  @overwrite_option()
21
22
  @yes_option()
22
23
  def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, overwrite, yes):
23
24
  save_jsonl(
24
- load_dataset("json", data_files=jsonl_files, split="train"),
25
- output_file=f"{output_dir}/{prefix}{{part:04d}}.jsonl{extension(compression, jsonl_files[0])}",
25
+ load_jsonl_files(jsonl_files),
26
+ output_file=f"{output_dir}/{prefix}{{part:04d}}.jsonl{extension_compression(compression, jsonl_files[0])}",
26
27
  compression=compression,
27
28
  processes=processes,
28
29
  size_hint=size_hint,
@@ -35,7 +36,7 @@ def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, ov
35
36
  @prefix_option()
36
37
  @output_dir_option()
37
38
  @size_hint_option()
38
- @compression_option(None, ['none', 'br', 'bz2', 'gzip', 'pigz', 'snappy', 'zstd'])
39
+ @compression_option(MDS_COMPRESSIONS)
39
40
  @processes_option()
40
41
  @overwrite_option()
41
42
  @yes_option()
@@ -43,16 +44,37 @@ def jsonl(jsonl_files, prefix, output_dir, size_hint, compression, processes, ov
43
44
  @batch_size_option()
44
45
  @no_bulk_option()
45
46
  @shard_size_option()
46
- def mds(mds_directories, prefix, output_dir, size_hint, compression, processes, overwrite, yes, buf_size, batch_size, no_bulk, shard_size):
47
+ @no_pigz_option()
48
+ def mds(mds_directories, prefix, output_dir, size_hint, compression, processes, overwrite, yes, buf_size, batch_size, no_bulk, shard_size, no_pigz):
47
49
  save_mds(
48
50
  load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
49
51
  output_dir=f"{output_dir}/{prefix}{{part:04d}}",
50
52
  processes=processes,
51
53
  compression=compression,
52
54
  buf_size=buf_size,
53
- pigz=use_pigz(compression),
55
+ pigz=use_pigz(compression, no_pigz),
54
56
  shard_size=shard_size,
55
57
  size_hint=size_hint,
56
58
  overwrite=overwrite,
57
59
  yes=yes,
58
60
  )
61
+
62
+ @split.command()
63
+ @click.argument("parquet_files", type=click.Path(exists=True), required=True, nargs=-1)
64
+ @prefix_option()
65
+ @output_dir_option()
66
+ @size_hint_option()
67
+ @compression_option(PARQUET_COMPRESSIONS)
68
+ @overwrite_option()
69
+ @yes_option()
70
+ @batch_size_option()
71
+ def parquet(parquet_files, prefix, output_dir, size_hint, compression, overwrite, yes, batch_size):
72
+ save_parquet(
73
+ load_dataset("parquet", data_files=parquet_files, split="train"),
74
+ output_file=f"{output_dir}/{prefix}{{part:04d}}.parquet",
75
+ compression=compression,
76
+ batch_size=batch_size,
77
+ size_hint=size_hint,
78
+ overwrite=overwrite,
79
+ yes=yes,
80
+ )
@@ -0,0 +1,158 @@
1
+ import bz2
2
+ from isal import igzip as gzip
3
+ import lz4
4
+ import lzma
5
+ import os
6
+ import shutil
7
+ from tqdm import tqdm
8
+ import zstandard
9
+
10
+ from .brotli import brotli_open
11
+ from .pigz import pigz_open
12
+ from .snappy import snappy_open
13
+
14
+ __all__ = [
15
+ "JSONL_COMPRESSIONS",
16
+ "MDS_COMPRESSIONS",
17
+ "PARQUET_COMPRESSIONS",
18
+ "determine_compression",
19
+ "extension_compression",
20
+ "infer_compression",
21
+ "open_compression",
22
+ "pigz_available",
23
+ "pigz_compress",
24
+ "use_pigz",
25
+ ]
26
+
27
+ JSONL_COMPRESSIONS = dict(
28
+ default="infer",
29
+ choices=["infer", "none", "bz2", "gzip", "lz4", "lzma", "pigz", "snappy", "xz", "zstd"],
30
+ )
31
+ MDS_COMPRESSIONS = dict(
32
+ default=None,
33
+ choices=["none", "brotli", "bz2", "gzip", "pigz", "snappy", "zstd"],
34
+ )
35
+ PARQUET_COMPRESSIONS = dict(
36
+ default="snappy",
37
+ choices=["snappy", "brotli", "gzip", "lz4", "zstd"],
38
+ )
39
+
40
+ def determine_compression(fmt, file_path, compression="infer", no_pigz=False):
41
+ if compression == "none":
42
+ return None
43
+ if fmt == "jsonl":
44
+ if compression == "infer":
45
+ compression = infer_compression(file_path)
46
+ if compression == "brotli":
47
+ return "br"
48
+ return compression
49
+ if fmt == "mds":
50
+ if compression == "infer":
51
+ raise ValueError()
52
+ if compression == "pigz" or (not no_pigz and compression == "gzip" and pigz_available()):
53
+ return None
54
+ if compression == "gzip":
55
+ return "gz"
56
+ if compression == "brotli":
57
+ return "br"
58
+ return compression
59
+ if fmt == "parquet":
60
+ return compression
61
+ raise ValueError(f"Unsupported format: {format}")
62
+
63
+ def extension_compression(compression, file_path):
64
+ """Get the file extension for the given compression type."""
65
+ if compression == "infer":
66
+ compression = infer_compression(file_path)
67
+ if compression == "brotli":
68
+ return ".br"
69
+ if compression == "bz2":
70
+ return ".bz2"
71
+ if compression in ("gzip", "pigz"):
72
+ return ".gz"
73
+ if compression == "lz4":
74
+ return ".lz4"
75
+ if compression == "lzma":
76
+ return ".lzma"
77
+ if compression == "snappy":
78
+ return ".snappy"
79
+ if compression == "xz":
80
+ return ".xz"
81
+ if compression == "zstd":
82
+ return ".zst"
83
+ if compression is None or compression == "none":
84
+ return ""
85
+ raise ValueError(f"Unsupported compression type: {compression}")
86
+
87
+ def infer_compression(file_path, pigz=True):
88
+ """Infer the compression type from the file extension."""
89
+ extension = os.path.splitext(file_path)[1]
90
+ if extension.endswith('.br'):
91
+ return 'brotli'
92
+ if extension.endswith('.bz2'):
93
+ return 'bz2'
94
+ if extension.endswith('.gz'):
95
+ if pigz and pigz_available():
96
+ return 'pigz'
97
+ return 'gzip'
98
+ if extension.endswith('.lz4'):
99
+ return 'lz4'
100
+ if extension.endswith('.lzma'):
101
+ return 'lzma'
102
+ if extension.endswith('.snappy'):
103
+ return 'snappy'
104
+ if extension.endswith('.xz'):
105
+ return 'xz'
106
+ if extension.endswith('.zip'):
107
+ return 'zip'
108
+ if extension.endswith('.zst'):
109
+ return 'zstd'
110
+ return None
111
+
112
+ def open_compression(file_path, mode="rt", compression="infer", processes=64):
113
+ """Open a file, handling compression if necessary."""
114
+ if compression == "infer":
115
+ compression = infer_compression(file_path)
116
+ if compression in ("brotli", "br"):
117
+ return brotli_open(file_path, mode)
118
+ if compression in ("gzip", "gz"):
119
+ return gzip.open(file_path, mode)
120
+ if compression == "pigz":
121
+ return pigz_open(file_path, mode, processes=processes) if mode[0] == "w" else gzip.open(file_path, mode)
122
+ if compression == "bz2":
123
+ return bz2.open(file_path, mode)
124
+ if compression == "lz4":
125
+ return lz4.frame.open(file_path, mode)
126
+ if compression in ("lzma", "xz"):
127
+ return lzma.open(file_path, mode)
128
+ if compression == "snappy":
129
+ return snappy_open(file_path, mode)
130
+ if compression == "zstd":
131
+ return zstandard.open(file_path, mode)
132
+ if compression is None or compression == "none":
133
+ return open(file_path, mode)
134
+ raise ValueError(f"Unsupported compression type: {compression}")
135
+
136
+ def pigz_available():
137
+ """Check if pigz is available on the system."""
138
+ return shutil.which("pigz") is not None
139
+
140
+ def pigz_compress(input_file, output_file, processes=64, buf_size=2**24, keep=False, quiet=False):
141
+ """Compress a file using pigz."""
142
+ size = os.stat(input_file).st_size
143
+ num_blocks = (size+buf_size-1) // buf_size
144
+ with open(input_file, "rb") as f_in, pigz_open(output_file, "wb", processes=processes) as f_out:
145
+ for _ in tqdm(range(num_blocks), desc="Compressing with pigz", unit="block", disable=quiet):
146
+ buf = f_in.read(buf_size)
147
+ assert buf
148
+ f_out.write(buf)
149
+ buf = f_in.read()
150
+ assert not buf
151
+ if not keep:
152
+ os.remove(input_file)
153
+ if not quiet:
154
+ print(f"Removed {input_file}")
155
+
156
+ def use_pigz(compression, no_pigz=False):
157
+ """Determine if pigz should be used based on the compression type."""
158
+ return compression == "pigz" or (not no_pigz and compression == "gzip" and pigz_available())
@@ -3,9 +3,13 @@ import json
3
3
  from mltiming import timing
4
4
  import numpy as np
5
5
  import os
6
+ import snappy
6
7
  from streaming.base.format.mds.encodings import mds_decode
7
8
  from typing import Any, Optional, Generator
8
9
 
10
+ from .options import MDS_COMPRESSIONS
11
+ from .utils import open_compression
12
+
9
13
  class MDSBulkReader:
10
14
  def __init__(
11
15
  self,
@@ -42,13 +46,7 @@ class MDSShardReader:
42
46
  filename: str,
43
47
  compression: Optional[str],
44
48
  ) -> None:
45
- if compression is None:
46
- _open = open
47
- elif compression == 'gz':
48
- _open = gzip.open
49
- else:
50
- raise ValueError(f'Unsupported compression type: {compression}. Supported types: None, gzip.')
51
- self.fp = _open(filename, "rb")
49
+ self.fp = open_compression(filename, "rb", compression=compression)
52
50
  self.samples = np.frombuffer(self.fp.read(4), np.uint32)[0]
53
51
  self.index = np.frombuffer(self.fp.read((1+self.samples)*4), np.uint32)
54
52
  info = json.loads(self.fp.read(self.index[0]-self.fp.tell()))
@@ -1,11 +1,19 @@
1
1
  import click
2
2
 
3
- __alll__ = [
3
+ from .compression import JSONL_COMPRESSIONS, MDS_COMPRESSIONS, PARQUET_COMPRESSIONS
4
+
5
+ __all__ = [
4
6
  "batch_size_option",
5
7
  "buf_size_option",
6
8
  "compression_option",
9
+ "no_bulk_option",
10
+ "no_pigz_option",
11
+ "output_dir_option",
7
12
  "overwrite_option",
8
13
  "processes_option",
14
+ "prefix_option",
15
+ "shard_size_option",
16
+ "size_hint_option",
9
17
  "yes_option",
10
18
  ]
11
19
 
@@ -39,15 +47,25 @@ def no_bulk_option():
39
47
  help="Use a custom space and time-efficient bulk reader (only gzip and no compression).",
40
48
  )
41
49
 
42
- def compression_option(default, choices):
50
+ def no_pigz_option():
51
+ """
52
+ Option for specifying whether to use pigz compression.
53
+ """
54
+ return click.option(
55
+ "--no-pigz",
56
+ is_flag=True,
57
+ help="Do not use pigz compression.",
58
+ )
59
+
60
+ def compression_option(args):
43
61
  """
44
62
  Option for specifying the compression type.
45
63
  """
46
64
  return click.option(
47
65
  "--compression",
48
- default=default,
49
- type=click.Choice(choices, case_sensitive=False),
50
- help=f"Compress the output file (default: {default}).",
66
+ default=args["default"],
67
+ type=click.Choice(args["choices"], case_sensitive=False),
68
+ help=f'Compress the output file (default: {args["default"]}).',
51
69
  )
52
70
 
53
71
  def output_dir_option(default="."):
@@ -0,0 +1,226 @@
1
+ import snappy
2
+ import struct
3
+ import io
4
+
5
+ __all__ = ["snappy_open"]
6
+
7
+ _CHUNK_SIZE = 8192 # default read block size
8
+
9
+ def snappy_open(filename, mode='rb', encoding='utf-8'):
10
+ return SnappyFile(filename, mode=mode, encoding=encoding)
11
+
12
+ class _SnappyWriteWrapper(io.RawIOBase):
13
+ def __init__(self, fileobj):
14
+ self.fileobj = fileobj
15
+ self.buffer = io.BytesIO()
16
+
17
+ def write(self, b):
18
+ if not isinstance(b, (bytes, bytearray)):
19
+ raise TypeError("Expected bytes")
20
+ self.buffer.write(b)
21
+ return len(b)
22
+
23
+ def flush(self):
24
+ data = self.buffer.getvalue()
25
+ if data:
26
+ compressed = snappy.compress(data)
27
+ length = struct.pack(">I", len(compressed))
28
+ self.fileobj.write(length + compressed)
29
+ self.buffer = io.BytesIO()
30
+ self.fileobj.flush()
31
+
32
+ def close(self):
33
+ self.flush()
34
+ self.fileobj.close()
35
+
36
+ def writable(self):
37
+ return True
38
+
39
+
40
+ # class _SnappyReadWrapper(io.RawIOBase):
41
+ # def __init__(self, fileobj):
42
+ # self.fileobj = fileobj
43
+ # self.buffer = io.BytesIO()
44
+ # self.eof = False
45
+
46
+ # def _fill_buffer_if_needed(self, min_bytes):
47
+ # self.buffer.seek(0, io.SEEK_END)
48
+ # while not self.eof and self.buffer.tell() < min_bytes:
49
+ # length_bytes = self.fileobj.read(4)
50
+ # if not length_bytes:
51
+ # self.eof = True
52
+ # break
53
+ # if len(length_bytes) < 4:
54
+ # self.eof = True # mark as EOF even if last chunk is malformed
55
+ # break
56
+
57
+ # try:
58
+ # length = struct.unpack(">I", length_bytes)[0]
59
+ # compressed = self.fileobj.read(length)
60
+ # if len(compressed) < length:
61
+ # self.eof = True
62
+ # break
63
+
64
+ # decompressed = snappy.decompress(compressed)
65
+ # self.buffer.write(decompressed)
66
+ # except Exception:
67
+ # self.eof = True
68
+ # break
69
+
70
+ # self.buffer.seek(0)
71
+
72
+ # def read(self, size=-1):
73
+ # if size == -1:
74
+ # while not self.eof:
75
+ # self._fill_buffer_if_needed(_CHUNK_SIZE)
76
+ # result = self.buffer.read()
77
+ # self.buffer = io.BytesIO()
78
+ # return result
79
+
80
+ # self._fill_buffer_if_needed(size)
81
+ # data = self.buffer.read(size)
82
+ # rest = self.buffer.read()
83
+ # self.buffer = io.BytesIO()
84
+ # self.buffer.write(rest)
85
+ # return data
86
+
87
+ # def readable(self):
88
+ # return True
89
+
90
+ # def close(self):
91
+ # self.fileobj.close()
92
+
93
+ class _SnappyReadWrapper(io.RawIOBase):
94
+ def __init__(self, fileobj):
95
+ self.fileobj = fileobj
96
+ self.buffer = io.BytesIO()
97
+ self.eof = False
98
+ self._autodetect_format()
99
+
100
+ def _autodetect_format(self):
101
+ self.fileobj.seek(0)
102
+ preview = self.fileobj.read()
103
+ try:
104
+ self._raw_decompressed = snappy.decompress(preview)
105
+ self._mode = "raw"
106
+ self.buffer = io.BytesIO(self._raw_decompressed)
107
+ except Exception:
108
+ self.fileobj.seek(0)
109
+ self._mode = "framed"
110
+
111
+ def _fill_buffer_if_needed(self, min_bytes):
112
+ self.buffer.seek(0, io.SEEK_END)
113
+ while not self.eof and self.buffer.tell() < min_bytes:
114
+ length_bytes = self.fileobj.read(4)
115
+ if not length_bytes:
116
+ self.eof = True
117
+ break
118
+ if len(length_bytes) < 4:
119
+ self.eof = True
120
+ break
121
+ try:
122
+ length = struct.unpack(">I", length_bytes)[0]
123
+ compressed = self.fileobj.read(length)
124
+ if len(compressed) < length:
125
+ self.eof = True
126
+ break
127
+ decompressed = snappy.decompress(compressed)
128
+ self.buffer.write(decompressed)
129
+ except Exception:
130
+ self.eof = True
131
+ break
132
+ self.buffer.seek(0)
133
+
134
+ def read(self, size=-1):
135
+ if self._mode == "raw":
136
+ return self.buffer.read(size)
137
+ else:
138
+ if size == -1:
139
+ while not self.eof:
140
+ self._fill_buffer_if_needed(_CHUNK_SIZE)
141
+ result = self.buffer.read()
142
+ self.buffer = io.BytesIO()
143
+ return result
144
+ else:
145
+ self._fill_buffer_if_needed(size)
146
+ data = self.buffer.read(size)
147
+ rest = self.buffer.read()
148
+ self.buffer = io.BytesIO()
149
+ self.buffer.write(rest)
150
+ return data
151
+
152
+ def readable(self):
153
+ return True
154
+
155
+ def close(self):
156
+ self.fileobj.close()
157
+
158
+ def tell(self):
159
+ return self.buffer.tell()
160
+
161
+ def seek(self, offset, whence=io.SEEK_SET):
162
+ return self.buffer.seek(offset, whence)
163
+
164
+ class SnappyFile:
165
+ def __init__(self, filename, mode='rb', encoding='utf-8'):
166
+ self.filename = filename
167
+ self.mode = mode
168
+ self.encoding = encoding
169
+ self.binary = 'b' in mode
170
+ raw_mode = mode.replace('t', 'b')
171
+ self.fileobj = open(filename, raw_mode)
172
+
173
+ if 'r' in mode:
174
+ self._stream = self._reader() if self.binary else io.TextIOWrapper(self._reader(), encoding=encoding)
175
+ elif 'w' in mode:
176
+ self._stream = self._writer() if self.binary else io.TextIOWrapper(self._writer(), encoding=encoding)
177
+ else:
178
+ raise ValueError("Unsupported mode: use 'rb', 'wb', 'rt', or 'wt'")
179
+
180
+ def _reader(self):
181
+ return _SnappyReadWrapper(self.fileobj)
182
+
183
+ def _writer(self):
184
+ return _SnappyWriteWrapper(self.fileobj)
185
+
186
+ def __enter__(self):
187
+ return self
188
+
189
+ def __exit__(self, exc_type, exc_val, exc_tb):
190
+ self.close()
191
+
192
+ def close(self):
193
+ if hasattr(self._stream, 'flush'):
194
+ self._stream.flush()
195
+ self._stream.close()
196
+
197
+ def flush(self):
198
+ if hasattr(self._stream, 'flush'):
199
+ self._stream.flush()
200
+
201
+ def read(self, *args, **kwargs):
202
+ return self._stream.read(*args, **kwargs)
203
+
204
+ def write(self, *args, **kwargs):
205
+ return self._stream.write(*args, **kwargs)
206
+
207
+ def readline(self, *args, **kwargs):
208
+ return self._stream.readline(*args, **kwargs)
209
+
210
+ def tell(self):
211
+ return self._stream.tell()
212
+
213
+ def seek(self, offset, whence=io.SEEK_SET):
214
+ return self._stream.seek(offset, whence)
215
+
216
+ def readable(self):
217
+ return hasattr(self._stream, "read")
218
+
219
+ def writable(self):
220
+ return hasattr(self._stream, "write")
221
+
222
+ def seekable(self):
223
+ return hasattr(self._stream, "seek")
224
+
225
+ def __iter__(self):
226
+ return iter(self._stream)
@@ -1,9 +1,6 @@
1
- import bz2
2
1
  import click
3
- from datasets import concatenate_datasets
4
- import gzip
2
+ from datasets import concatenate_datasets, load_dataset
5
3
  import json
6
- import lzma
7
4
  from mltiming import timing
8
5
  import pyarrow as pa
9
6
  import pyarrow.parquet as pq
@@ -12,22 +9,21 @@ import shutil
12
9
  from streaming import MDSWriter, StreamingDataset
13
10
  from tqdm import tqdm
14
11
 
12
+ from .compression import determine_compression, open_compression, pigz_compress
15
13
  from .mds import MDSBulkReader
16
14
  from .pigz import pigz_open
17
15
 
18
16
  __all__ = [
19
- "batch_iterable",
20
17
  "check_arguments",
21
18
  "confirm_overwrite",
22
- "extension",
19
+ "load_jsonl_files",
23
20
  "load_mds_directories",
24
21
  "save_jsonl",
25
22
  "save_mds",
26
23
  "save_parquet",
27
- "use_pigz",
28
24
  ]
29
25
 
30
- def batch_iterable(iterable, batch_size):
26
+ def _batch_iterable(iterable, batch_size):
31
27
  batch = []
32
28
  for item in iterable:
33
29
  batch.append(item)
@@ -64,31 +60,6 @@ def confirm_overwrite(message):
64
60
  if response.lower() != 'yes':
65
61
  raise click.Abort()
66
62
 
67
- def _determine_compression(file_path, compression="infer"):
68
- if compression == "infer":
69
- compression = _infer_compression(file_path)
70
- if compression == "none":
71
- compression = None
72
- return compression
73
-
74
- def extension(compression, file_path):
75
- """Get the file extension for the given compression type."""
76
- if compression == "infer":
77
- compression = _infer_compression(file_path)
78
- if compression in ("gzip", "pigz"):
79
- return ".gz"
80
- if compression == "bz2":
81
- return ".bz2"
82
- if compression == "xz":
83
- return ".xz"
84
- if compression == "zip":
85
- return ".zip"
86
- if compression == "zstd":
87
- return ".zst"
88
- if compression is None:
89
- return ""
90
- raise ValueError(f"Unsupported compression type: {compression}")
91
-
92
63
  def _infer_mds_encoding(value):
93
64
  """Determine the MDS encoding for a given value."""
94
65
  if isinstance(value, str):
@@ -101,22 +72,16 @@ def _infer_mds_encoding(value):
101
72
  return 'bool'
102
73
  return 'pkl'
103
74
 
104
- def _infer_compression(file_path):
105
- """Infer the compression type from the file extension."""
106
- extension = os.path.splitext(file_path)[1]
107
- if extension.endswith('.gz'):
108
- if _pigz_available():
109
- return 'pigz'
110
- return 'gzip'
111
- if extension.endswith('.bz2'):
112
- return 'bz2'
113
- if extension.endswith('.xz'):
114
- return 'xz'
115
- if extension.endswith('.zip'):
116
- return 'zip'
117
- if extension.endswith('.zst'):
118
- return 'zstd'
119
- return None
75
+ def _streaming_jsonl(jsonl_files, compressions):
76
+ for jsonl_file, compression in tqdm(zip(jsonl_files, compressions), desc="Loading JSONL files", unit="file"):
77
+ for line in open_compression(jsonl_file, mode="rt", compression=compression):
78
+ yield json.loads(line)
79
+
80
+ def load_jsonl_files(jsonl_files):
81
+ compressions = [determine_compression("jsonl", jsonl_file) for jsonl_file in jsonl_files]
82
+ if "br" in compressions or "snappy" in compressions:
83
+ return _streaming_jsonl(jsonl_files, compressions)
84
+ return load_dataset("json", data_files=jsonl_files, split="train")
120
85
 
121
86
  def load_mds_directories(mds_directories, split='.', batch_size=2**16, bulk=True):
122
87
  if bulk:
@@ -141,50 +106,14 @@ def load_mds_directories(mds_directories, split='.', batch_size=2**16, bulk=True
141
106
  ds = concatenate_datasets(dsets=dss)
142
107
  return ds
143
108
 
144
- def _open_jsonl(file_path, mode="rt", compression="infer", processes=64):
145
- """Open a JSONL file, handling gzip compression if necessary."""
146
- compression = _determine_compression(file_path, compression)
147
- if compression == "gzip":
148
- return gzip.open(file_path, mode)
149
- if compression == "pigz":
150
- return pigz_open(file_path, mode, processes=processes) if mode[0] == "w" else gzip.open(file_path, mode)
151
- if compression == "bz2":
152
- return bz2.open(file_path, mode)
153
- if compression == "xz":
154
- return lzma.open(file_path, mode)
155
- if compression is None:
156
- return open(file_path, mode)
157
- raise ValueError(f"Unsupported compression type: {compression}")
158
-
159
- def _pigz_available():
160
- """Check if pigz is available on the system."""
161
- return shutil.which("pigz") is not None
162
-
163
- def _pigz_compress(input_file, output_file, processes=64, buf_size=2**24, keep=False, quiet=False):
164
- """Compress a file using pigz."""
165
- size = os.stat(input_file).st_size
166
- num_blocks = (size+buf_size-1) // buf_size
167
- with open(input_file, "rb") as f_in, pigz_open(output_file, "wb", processes=processes) as f_out:
168
- for _ in tqdm(range(num_blocks), desc="Compressing with pigz", unit="block", disable=quiet):
169
- buf = f_in.read(buf_size)
170
- assert buf
171
- f_out.write(buf)
172
- buf = f_in.read()
173
- assert not buf
174
- if not keep:
175
- os.remove(input_file)
176
- if not quiet:
177
- print(f"Removed {input_file}")
178
-
179
109
  def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True):
180
- compression = _determine_compression(output_file, compression)
181
110
  f = None
182
111
  part = 0
183
112
  for item in tqdm(iterable, desc="Writing to JSONL", unit="sample"):
184
113
  if f is None:
185
114
  part_file = output_file.format(part=part)
186
115
  check_arguments(part_file, overwrite, yes)
187
- f= _open_jsonl(part_file, mode="wb", compression=compression, processes=processes)
116
+ f = open_compression(part_file, mode="wb", compression=compression, processes=processes)
188
117
  f.write(f"{json.dumps(item)}\n".encode("utf-8"))
189
118
  if size_hint is not None and f.tell() >= size_hint:
190
119
  f.close()
@@ -193,11 +122,8 @@ def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=
193
122
  if f is not None:
194
123
  f.close()
195
124
 
196
- def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pigz=False, shard_size=None, size_hint=None, overwrite=True, yes=True):
197
- if compression == "none" or pigz:
198
- compression = None
199
- if compression == "gzip":
200
- compression = "gz"
125
+ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pigz=True, shard_size=None, size_hint=None, overwrite=True, yes=True):
126
+ compression = determine_compression("mds", output_dir, compression, no_pigz=not pigz)
201
127
  writer = None
202
128
  part = 0
203
129
  files = []
@@ -216,7 +142,8 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
216
142
  writer.finish()
217
143
  part += 1
218
144
  writer = None
219
- writer.finish()
145
+ if writer is not None:
146
+ writer.finish()
220
147
  if pigz:
221
148
  for output_dir in files:
222
149
  index_path = os.path.join(output_dir, "index.json")
@@ -228,7 +155,7 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
228
155
  compressed_file_name = file_name + ".gz"
229
156
  file_path = os.path.join(output_dir, file_name)
230
157
  compressed_file_path = os.path.join(output_dir, compressed_file_name)
231
- _pigz_compress(file_path, compressed_file_path, processes, buf_size=buf_size, keep=False, quiet=True)
158
+ pigz_compress(file_path, compressed_file_path, processes, buf_size=buf_size, keep=False, quiet=True)
232
159
  name2info[file_name]["compression"] = "gz"
233
160
  name2info[file_name]["zip_data"] = {
234
161
  "basename": compressed_file_name,
@@ -238,16 +165,23 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
238
165
  json.dump(index, open(index_path, "wt"))
239
166
  print(f"Compressed {output_dir} with pigz")
240
167
 
241
- def save_parquet(it, output_file, compression=None, batch_size=2**16):
168
+ def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=None, overwrite=True, yes=True):
169
+ compression = determine_compression("parquet", output_file, compression)
242
170
  writer = None
171
+ part = 0
243
172
  it = tqdm(it, desc="Writing to Parquet", unit="sample")
244
- for batch in batch_iterable(it, batch_size):
173
+ for batch in _batch_iterable(it, batch_size):
245
174
  table = pa.Table.from_pylist(batch)
246
175
  if writer is None:
247
- writer = pq.ParquetWriter(output_file, table.schema, compression=compression)
176
+ part_file = output_file.format(part=part)
177
+ check_arguments(part_file, overwrite, yes)
178
+ writer = pq.ParquetWriter(part_file, table.schema, compression=compression)
179
+ offset = 0
248
180
  writer.write_table(table)
249
- writer.close()
250
-
251
- def use_pigz(compression):
252
- """Determine if pigz should be used based on the compression type."""
253
- return compression == "pigz" or (compression == "gzip" and _pigz_available())
181
+ offset += table.nbytes
182
+ if size_hint is not None and offset >= size_hint:
183
+ writer.close()
184
+ part += 1
185
+ writer = None
186
+ if writer is not None:
187
+ writer.close()
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "mldataforge"
7
- version = "0.1.3"
7
+ version = "0.1.5"
8
8
  authors = [
9
9
  { name = "Peter Schneider-Kamp" }
10
10
  ]
@@ -19,10 +19,15 @@ classifiers = [
19
19
  ]
20
20
 
21
21
  dependencies = [
22
+ 'brotlicffi',
22
23
  'click',
23
24
  'datasets',
25
+ 'isal',
26
+ 'lz4',
24
27
  'mltiming',
25
- 'mosaicml-streaming'
28
+ 'mosaicml-streaming',
29
+ 'python-snappy',
30
+ 'zstandard'
26
31
  ]
27
32
 
28
33
  [project.optional-dependencies]
File without changes
File without changes
File without changes