mldataforge 0.0.1__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mldataforge
3
- Version: 0.0.1
3
+ Version: 0.0.3
4
4
  Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
5
5
  Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
6
6
  Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
@@ -13,7 +13,7 @@ Requires-Python: >=3.12
13
13
  Requires-Dist: click
14
14
  Requires-Dist: datasets
15
15
  Requires-Dist: mltiming
16
- Requires-Dist: pygz
16
+ Requires-Dist: mosaicml-streaming
17
17
  Description-Content-Type: text/markdown
18
18
 
19
19
  # mldatasets
@@ -1,5 +1,6 @@
1
1
  import click
2
2
 
3
+ from .jsonl import jsonl
3
4
  from .parquet import parquet
4
5
 
5
6
  __all__ = ["convert"]
@@ -9,3 +10,4 @@ def convert():
9
10
  pass
10
11
 
11
12
  convert.add_command(parquet)
13
+ convert.add_command(jsonl)
@@ -0,0 +1,11 @@
1
+ import click
2
+
3
+ from .mds import mds
4
+
5
+ __all__ = ["jsonl"]
6
+
7
+ @click.group()
8
+ def jsonl():
9
+ pass
10
+
11
+ jsonl.add_command(mds)
@@ -0,0 +1,50 @@
1
+ #!/usr/bin/env python
2
+ import click
3
+ import json
4
+ import os
5
+ from streaming import MDSWriter
6
+ from tqdm import tqdm
7
+
8
+ from ....utils import check_overwrite, infer_mds_encoding, open_jsonl, pigz_compress, use_pigz
9
+
10
+ @click.command()
11
+ @click.argument('output_dir', type=click.Path(exists=False))
12
+ @click.argument('jsonl_files', nargs=-1, type=click.Path(exists=True))
13
+ @click.option('--compression', type=click.Choice(['none', 'br', 'bz2', 'gzip', 'pigz', 'snappy', 'zstd'], case_sensitive=False), default=None, help='Compression type for the output dataset (default: None).')
14
+ @click.option("--processes", default=64, help="Number of processes to use for pigz compression (default: 64).")
15
+ @click.option("--overwrite", is_flag=True, help="Overwrite existing MDS directory.")
16
+ @click.option("--yes", is_flag=True, help="Assume yes to all prompts. Use with caution as it will remove entire directory trees without confirmation.")
17
+ @click.option("--buf-size", default=2**24, help=f"Buffer size for pigz compression (default: {2**24}).")
18
+ def mds(output_dir, jsonl_files, processes, compression, overwrite, yes, buf_size):
19
+ check_overwrite(output_dir, overwrite, yes)
20
+ if not jsonl_files:
21
+ raise click.BadArgumentUsage("No JSONL files provided.")
22
+ with open_jsonl(jsonl_files[0]) as f:
23
+ sample = json.loads(f.readline())
24
+ pigz = use_pigz(compression)
25
+ if compression == "none" or pigz:
26
+ compression = None
27
+ if compression == "gzip":
28
+ compression = "gz"
29
+ columns = {key: infer_mds_encoding(value) for key, value in sample.items()}
30
+ lines = 0
31
+ with MDSWriter(out=output_dir, columns=columns, compression=compression) as writer:
32
+ for jsonl_file in tqdm(jsonl_files, desc="Processing JSONL files", unit="file"):
33
+ with open_jsonl(jsonl_file, compression="infer") as f:
34
+ for line_num, line in enumerate(f, start=1):
35
+ try:
36
+ item = json.loads(line)
37
+ writer.write(item)
38
+ except json.JSONDecodeError as e:
39
+ print(f"Skipping line {line_num} in {jsonl_file} due to JSON error: {e}")
40
+ lines += 1
41
+ print(f"Wrote {lines} lines from {len(jsonl_files)} files to MDS files in {output_dir}")
42
+ if pigz:
43
+ file_paths = []
44
+ for file in os.listdir(output_dir):
45
+ if file.endswith(".mds"):
46
+ file_paths.append(os.path.join(output_dir, file))
47
+ for file_path in tqdm(file_paths, desc="Compressing with pigz", unit="file"):
48
+ pigz_compress(file_path, file_path + ".gz", processes, buf_size=buf_size, keep=False, quiet=True)
49
+ output_dir
50
+ print(f"Compressed {output_dir} with pigz")
@@ -1,6 +1,7 @@
1
1
  import click
2
2
 
3
3
  from .jsonl import jsonl
4
+ from .mds import mds
4
5
 
5
6
  __all__ = ["parquet"]
6
7
 
@@ -9,3 +10,4 @@ def parquet():
9
10
  pass
10
11
 
11
12
  parquet.add_command(jsonl)
13
+ parquet.add_command(mds)
@@ -0,0 +1,26 @@
1
+ #!/usr/bin/env python
2
+ import click
3
+ from mltiming import timing
4
+
5
+ from ....utils import check_overwrite, create_temp_file, determine_compression, load_parquet_files, pigz_compress
6
+
7
+ @click.command()
8
+ @click.argument("output_file", type=click.Path(exists=False), required=True)
9
+ @click.argument("parquet_files", type=click.Path(exists=True), required=True, nargs=-1)
10
+ @click.option("--compression", default="infer", type=click.Choice(["none", "infer", "pigz", "gzip", "bz2", "xz"]), help="Compress the output JSONL file (default: infer; pigz for parallel gzip).")
11
+ @click.option("--processes", default=64, help="Number of processes to use for pigz compression (default: 64).")
12
+ @click.option("--overwrite", is_flag=True, help="Overwrite existing JSONL files.")
13
+ @click.option("--yes", is_flag=True, help="Assume yes to all prompts. Use with caution as it will remove files without confirmation.")
14
+ @click.option("--buf-size", default=2**24, help=f"Buffer size for pigz compression (default: {2**24}).")
15
+ def jsonl(output_file, parquet_files, compression, processes, overwrite, yes, buf_size):
16
+ check_overwrite(output_file, overwrite, yes)
17
+ if not parquet_files:
18
+ raise click.BadArgumentUsage("No parquet files provided.")
19
+ ds = load_parquet_files(parquet_files)
20
+ compression = determine_compression(output_file, compression)
21
+ compressed_file = None
22
+ if compression == "pigz":
23
+ compression, compressed_file, output_file = None, output_file, create_temp_file()
24
+ ds.to_json(output_file, num_proc=processes, orient="records", lines=True, compression=compression)
25
+ if compressed_file is not None:
26
+ pigz_compress(output_file, compressed_file, processes, buf_size, keep=False)
@@ -0,0 +1,44 @@
1
+ #!/usr/bin/env python
2
+ import click
3
+ import json
4
+ import os
5
+ from streaming import MDSWriter
6
+ from tqdm import tqdm
7
+
8
+ from ....utils import check_overwrite, infer_mds_encoding, load_parquet_files, pigz_compress, use_pigz
9
+
10
+ @click.command()
11
+ @click.argument('output_dir', type=click.Path(exists=False))
12
+ @click.argument('parquet_files', nargs=-1, type=click.Path(exists=True))
13
+ @click.option('--compression', type=click.Choice(['none', 'br', 'bz2', 'gzip', 'pigz', 'snappy', 'zstd'], case_sensitive=False), default=None, help='Compression type for the output dataset (default: None).')
14
+ @click.option("--processes", default=64, help="Number of processes to use for pigz compression (default: 64).")
15
+ @click.option("--overwrite", is_flag=True, help="Overwrite existing MDS directory.")
16
+ @click.option("--yes", is_flag=True, help="Assume yes to all prompts. Use with caution as it will remove entire directory trees without confirmation.")
17
+ @click.option("--buf-size", default=2**24, help=f"Buffer size for pigz compression (default: {2**24}).")
18
+ def mds(output_dir, parquet_files, processes, compression, overwrite, yes, buf_size):
19
+ check_overwrite(output_dir, overwrite, yes)
20
+ if not parquet_files:
21
+ raise click.BadArgumentUsage("No parquet files provided.")
22
+ ds = load_parquet_files(parquet_files)
23
+ pigz = use_pigz(compression)
24
+ sample = ds[0]
25
+ if compression == "none" or pigz:
26
+ compression = None
27
+ if compression == "gzip":
28
+ compression = "gz"
29
+ columns = {key: infer_mds_encoding(value) for key, value in sample.items()}
30
+ lines = 0
31
+ with MDSWriter(out=output_dir, columns=columns, compression=compression) as writer:
32
+ for item in tqdm(ds, desc="Processing samples", unit="sample"):
33
+ writer.write(item)
34
+ lines += 1
35
+ print(f"Wrote {lines} lines from {len(parquet_files)} files to MDS files in {output_dir}")
36
+ if pigz:
37
+ file_paths = []
38
+ for file in os.listdir(output_dir):
39
+ if file.endswith(".mds"):
40
+ file_paths.append(os.path.join(output_dir, file))
41
+ for file_path in tqdm(file_paths, desc="Compressing with pigz", unit="file"):
42
+ pigz_compress(file_path, file_path + ".gz", processes, buf_size=buf_size, keep=False, quiet=True)
43
+ output_dir
44
+ print(f"Compressed {output_dir} with pigz")
@@ -0,0 +1,212 @@
1
+ import atexit
2
+ import bz2
3
+ import click
4
+ from datasets import concatenate_datasets, load_dataset
5
+ import gzip
6
+ import lzma
7
+ from mltiming import timing
8
+ import os
9
+ import shutil
10
+ import subprocess
11
+ import tempfile
12
+ from tqdm import tqdm
13
+
14
+ __all__ = [
15
+ "check_overwrite",
16
+ "create_temp_file",
17
+ "determine_compression",
18
+ "infer_mds_encoding",
19
+ "infer_compression",
20
+ "load_parquet_files",
21
+ "open_jsonl",
22
+ "pigz_available",
23
+ "pigz_compress",
24
+ "use_pigz",
25
+ ]
26
+
27
+ class PigzFile(object):
28
+ """A wrapper for pigz to handle gzip compression and decompression."""
29
+ def __init__(self, path, mode="rt", processes=4, encoding=None):
30
+ if mode not in ("rt", "wt", "rb", "wb"):
31
+ raise ValueError("Mode must be one of rt, wt, rb, or wb.")
32
+ self.path = path
33
+ self.mode = mode
34
+ self.processes = processes
35
+ self.encoding = "latin1" if mode[1] == "b" else ("utf-8" if encoding is None else encoding)
36
+ self._process = None
37
+ self._fw = None
38
+ if self.mode[0] == "r":
39
+ args = ["pigz", "-d", "-c", "-p", str(self.processes), self.path]
40
+ self._process = subprocess.Popen(
41
+ args,
42
+ stdout=subprocess.PIPE,
43
+ encoding=encoding,
44
+ )
45
+ elif self.mode[0] == "w":
46
+ args = ["pigz", "-p", str(self.processes), "-c"]
47
+ self._fw = open(self.path, "w+")
48
+ self._process = subprocess.Popen(
49
+ args,
50
+ stdout=self._fw,
51
+ stdin=subprocess.PIPE,
52
+ encoding=encoding,
53
+ )
54
+
55
+ def __iter__(self):
56
+ assert self.mode[0] == "r"
57
+ for line in self._process.stdout:
58
+ yield line
59
+ self._process.wait()
60
+ assert self._process.returncode == 0
61
+ self._process.stdout.close()
62
+ self._process = None
63
+
64
+ def write(self, line):
65
+ assert self.mode[0] == "w"
66
+ self._process.stdin.write(line if self.mode[1] == "t" else line.encode(self.encoding))
67
+
68
+ def close(self):
69
+ if self._process:
70
+ if self.mode[0] == "r":
71
+ self._process.kill()
72
+ self._process.stdout.close()
73
+ self._process = None
74
+ elif self.mode[1] == "w":
75
+ self._process.stdin.close()
76
+ self._process.wait()
77
+ self._process = None
78
+ self._fw.close()
79
+ self._fw = None
80
+
81
+ def __enter__(self):
82
+ return self
83
+
84
+ def __exit__(self, exc_type, exc_value, traceback):
85
+ self.close()
86
+
87
+ def check_overwrite(output_path, overwrite, yes):
88
+ if os.path.exists(output_path):
89
+ if os.path.isfile(output_path):
90
+ if not overwrite:
91
+ raise click.BadParameter(f"Output file '{output_path}' already exists. Use --overwrite to overwrite.")
92
+ if not yes:
93
+ confirm_overwrite(f"Output file '{output_path}' already exists. Do you want to delete it?")
94
+ with timing(message=f"Deleting existing file '{output_path}'"):
95
+ os.remove(output_path)
96
+ elif os.path.isdir(output_path):
97
+ if not overwrite:
98
+ raise click.BadParameter(f"Output directory '{output_path}' already exists. Use --overwrite to overwrite.")
99
+ if not yes:
100
+ confirm_overwrite(f"Output directory '{output_path}' already exists. Do you want to delete this directory and all its contents?")
101
+ with timing(message=f"Deleting existing directory '{output_path}'"):
102
+ shutil.rmtree(output_path)
103
+ else:
104
+ raise click.BadParameter(f"Output path '{output_path}' exists but is neither a file nor a directory.")
105
+
106
+ def confirm_overwrite(message):
107
+ print(message)
108
+ response = input("Are you sure you want to proceed? (yes/no): ")
109
+ if response.lower() != 'yes':
110
+ raise click.Abort()
111
+
112
+ def create_temp_file():
113
+ def _cleanup_file(file_path):
114
+ try:
115
+ os.remove(file_path)
116
+ except OSError:
117
+ pass
118
+ # Create a named temp file, don't delete right away
119
+ temp = tempfile.NamedTemporaryFile(delete=False)
120
+ temp_name = temp.name
121
+ # Close so others can open it again without conflicts (especially on Windows)
122
+ temp.close()
123
+
124
+ # Schedule its deletion at exit
125
+ atexit.register(_cleanup_file, temp_name)
126
+
127
+ return temp_name
128
+
129
+ def determine_compression(file_path, compression="infer"):
130
+ if compression == "infer":
131
+ compression = infer_compression(file_path)
132
+ if compression == "none":
133
+ compression = None
134
+ return compression
135
+
136
+ def infer_mds_encoding(value):
137
+ """Determine the MDS encoding for a given value."""
138
+ if isinstance(value, str):
139
+ return 'str'
140
+ if isinstance(value, int):
141
+ return 'int'
142
+ if isinstance(value, float):
143
+ return 'float32'
144
+ if isinstance(value, bool):
145
+ return 'bool'
146
+ return 'pkl'
147
+
148
+ def infer_compression(file_path):
149
+ """Infer the compression type from the file extension."""
150
+ extension = os.path.splitext(file_path)[1]
151
+ if extension.endswith('.gz'):
152
+ if pigz_available():
153
+ return 'pigz'
154
+ return 'gzip'
155
+ if extension.endswith('.bz2'):
156
+ return 'bz2'
157
+ if extension.endswith('.xz'):
158
+ return 'xz'
159
+ if extension.endswith('.zip'):
160
+ return 'zip'
161
+ if extension.endswith('.zst'):
162
+ return 'zstd'
163
+ return None
164
+
165
+ def load_parquet_files(parquet_files):
166
+ dss = []
167
+ for parquet_file in tqdm(parquet_files, desc="Loading parquet files", unit="file"):
168
+ ds = load_dataset("parquet", data_files=parquet_file, split="train")
169
+ dss.append(ds)
170
+ if len(dss) == 1:
171
+ ds = dss[0]
172
+ else:
173
+ with timing(message=f"Concatenating {len(dss)} datasets"):
174
+ ds = concatenate_datasets(dsets=dss)
175
+ return ds
176
+
177
+ def open_jsonl(file_path, mode="rt", compression="infer"):
178
+ """Open a JSONL file, handling gzip compression if necessary."""
179
+ compression = determine_compression(file_path, compression)
180
+ if compression in ("gzip", "pigz"):
181
+ return gzip.open(file_path, mode)
182
+ if compression == "bz2":
183
+ return bz2.open(file_path, mode)
184
+ if compression == "xz":
185
+ return lzma.open(file_path, mode)
186
+ if compression is None:
187
+ return open(file_path, mode)
188
+ raise ValueError(f"Unsupported compression type: {compression}")
189
+
190
+ def pigz_available():
191
+ """Check if pigz is available on the system."""
192
+ return shutil.which("pigz") is not None
193
+
194
+ def pigz_compress(input_file, output_file, processes=64, buf_size=2**24, keep=False, quiet=False):
195
+ """Compress a file using pigz."""
196
+ size = os.stat(input_file).st_size
197
+ num_blocks = (size+buf_size-1) // buf_size
198
+ with open(input_file, "rt", encoding="latin1") as f_in, PigzFile(output_file, "wb", processes=processes) as f_out:
199
+ for _ in tqdm(range(num_blocks), desc="Compressing with pigz", unit="block", disable=quiet):
200
+ buf = f_in.read(buf_size)
201
+ assert buf
202
+ f_out.write(buf)
203
+ buf = f_in.read()
204
+ assert not buf
205
+ if not keep:
206
+ os.remove(input_file)
207
+ if not quiet:
208
+ print(f"Removed {input_file}")
209
+
210
+ def use_pigz(compression):
211
+ """Determine if pigz should be used based on the compression type."""
212
+ return compression == "pigz" or (compression == "gzip" and pigz_available())
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "mldataforge"
7
- version = "0.0.1"
7
+ version = "0.0.3"
8
8
  authors = [
9
9
  { name = "Peter Schneider-Kamp" }
10
10
  ]
@@ -22,7 +22,7 @@ dependencies = [
22
22
  'click',
23
23
  'datasets',
24
24
  'mltiming',
25
- 'pygz'
25
+ 'mosaicml-streaming'
26
26
  ]
27
27
 
28
28
  [project.urls]
@@ -1,55 +0,0 @@
1
- #!/usr/bin/env python
2
- import atexit
3
- import click
4
- from datasets import load_dataset
5
- from mltiming import timing
6
- import os
7
- from pygz import PigzFile
8
- from shutil import which
9
- import tempfile
10
-
11
- def create_temp_file():
12
- # Create a named temp file, don't delete right away
13
- temp = tempfile.NamedTemporaryFile(delete=False)
14
- temp_name = temp.name
15
- # Close so others can open it again without conflicts (especially on Windows)
16
- temp.close()
17
-
18
- # Schedule its deletion at exit
19
- atexit.register(_cleanup_file, temp_name)
20
-
21
- return temp_name
22
-
23
- def _cleanup_file(file_path):
24
- try:
25
- os.remove(file_path)
26
- except OSError:
27
- pass
28
-
29
- @click.command()
30
- @click.argument("parquet_file", type=click.Path(exists=True))
31
- @click.argument("output_file", type=click.Path(exists=False), required=False)
32
- @click.option("--compression", default="infer", type=click.Choice(["none", "infer", "pigz", "gzip", "bz2", "xz"]), help="Compress the output JSONL file (default: infer; pigz for parallel gzip).")
33
- @click.option("--threads", default=64, help="Number of processes to use for pigz compression (default: 64).")
34
- @click.option("--overwrite", is_flag=True, help="Overwrite existing JSONL files.")
35
- def jsonl(parquet_file, output_file, compression, threads, overwrite):
36
- if os.path.exists(output_file) and not overwrite:
37
- raise click.ClickException(f"Output file {output_file} already exists. Use --overwrite to overwrite.")
38
- with timing(message=f"Loading from {parquet_file}"):
39
- ds = load_dataset("parquet", data_files=parquet_file)
40
- orig_output_file = None
41
- if compression == "none":
42
- compression = None
43
- elif compression == "infer":
44
- if output_file.endswith(".gz") and which("pigz") is not None:
45
- compression = "pigz"
46
- if compression == "pigz":
47
- compression = None
48
- orig_output_file = output_file
49
- output_file = create_temp_file()
50
- with timing(message=f"Saving to {output_file} with compression {compression}"):
51
- ds["train"].to_json(output_file, orient="records", lines=True, compression=compression)
52
- if orig_output_file is not None:
53
- with timing(message=f"Compressing {output_file} to {orig_output_file} with pigz using {threads} threads"):
54
- with open(output_file, "rt") as f_in, PigzFile(f"{orig_output_file}.gz", "wt", threads=threads) as f_out:
55
- f_out.write(f_in.read())
File without changes
File without changes
File without changes