mldataforge 0.0.2__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mldataforge
3
- Version: 0.0.2
3
+ Version: 0.0.4
4
4
  Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
5
5
  Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
6
6
  Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
@@ -1,6 +1,7 @@
1
1
  import click
2
2
 
3
3
  from .jsonl import jsonl
4
+ from .mds import mds
4
5
  from .parquet import parquet
5
6
 
6
7
  __all__ = ["convert"]
@@ -9,5 +10,6 @@ __all__ = ["convert"]
9
10
  def convert():
10
11
  pass
11
12
 
12
- convert.add_command(parquet)
13
13
  convert.add_command(jsonl)
14
+ convert.add_command(mds)
15
+ convert.add_command(parquet)
@@ -1,4 +1,3 @@
1
- #!/usr/bin/env python
2
1
  import click
3
2
  import json
4
3
  import os
@@ -40,11 +39,21 @@ def mds(output_dir, jsonl_files, processes, compression, overwrite, yes, buf_siz
40
39
  lines += 1
41
40
  print(f"Wrote {lines} lines from {len(jsonl_files)} files to MDS files in {output_dir}")
42
41
  if pigz:
43
- file_paths = []
44
- for file in os.listdir(output_dir):
45
- if file.endswith(".mds"):
46
- file_paths.append(os.path.join(output_dir, file))
47
- for file_path in tqdm(file_paths, desc="Compressing with pigz", unit="file"):
48
- pigz_compress(file_path, file_path + ".gz", processes, buf_size=buf_size, keep=False, quiet=True)
49
- output_dir
42
+ index_path = os.path.join(output_dir, "index.json")
43
+ index = json.load(open(index_path, "rt"))
44
+ name2info = {shard["raw_data"]["basename"]: shard for shard in index["shards"]}
45
+ file_names = [file for file in os.listdir(output_dir) if file.endswith(".mds")]
46
+ assert set(file_names) == set(name2info.keys())
47
+ for file_name in tqdm(file_names, desc="Compressing with pigz", unit="file"):
48
+ compressed_file_name = file_name + ".gz"
49
+ file_path = os.path.join(output_dir, file_name)
50
+ compressed_file_path = os.path.join(output_dir, compressed_file_name)
51
+ pigz_compress(file_path, compressed_file_path, processes, buf_size=buf_size, keep=False, quiet=True)
52
+ name2info[file_name]["compression"] = "gz"
53
+ name2info[file_name]["zip_data"] = {
54
+ "basename": compressed_file_name,
55
+ "bytes": os.stat(compressed_file_path).st_size,
56
+ "hashes": {},
57
+ }
58
+ json.dump(index, open(index_path, "wt"))
50
59
  print(f"Compressed {output_dir} with pigz")
@@ -0,0 +1,11 @@
1
+ import click
2
+
3
+ from .jsonl import jsonl
4
+
5
+ __all__ = ["mds"]
6
+
7
+ @click.group()
8
+ def mds():
9
+ pass
10
+
11
+ mds.add_command(jsonl)
@@ -0,0 +1,23 @@
1
+ import click
2
+ import json
3
+ from tqdm import tqdm
4
+
5
+ from ....utils import check_overwrite, create_temp_file, determine_compression, load_mds_directories, open_jsonl
6
+
7
+ @click.command()
8
+ @click.argument("output_file", type=click.Path(exists=False), required=True)
9
+ @click.argument("mds_directories", type=click.Path(exists=True), required=True, nargs=-1)
10
+ @click.option("--compression", default="infer", type=click.Choice(["none", "infer", "pigz", "gzip", "bz2", "xz"]), help="Compress the output JSONL file (default: infer; pigz for parallel gzip).")
11
+ @click.option("--processes", default=64, help="Number of processes to use for pigz compression (default: 64).")
12
+ @click.option("--overwrite", is_flag=True, help="Overwrite existing JSONL files.")
13
+ @click.option("--yes", is_flag=True, help="Assume yes to all prompts. Use with caution as it will remove files without confirmation.")
14
+ @click.option("--batch-size", default=2**16, help="Batch size for loading MDS directories (default: 65536).")
15
+ def jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size):
16
+ check_overwrite(output_file, overwrite, yes)
17
+ if not mds_directories:
18
+ raise click.BadArgumentUsage("No MDS files provided.")
19
+ ds = load_mds_directories(mds_directories, batch_size=batch_size)
20
+ compression = determine_compression(output_file, compression)
21
+ with open_jsonl(output_file, mode="wb", compression=compression, processes=processes) as f:
22
+ for item in tqdm(ds, desc="Writing to JSONL", unit="line"):
23
+ f.write(f"{json.dumps(item)}\n".encode("utf-8"))
@@ -1,6 +1,7 @@
1
1
  import click
2
2
 
3
3
  from .jsonl import jsonl
4
+ from .mds import mds
4
5
 
5
6
  __all__ = ["parquet"]
6
7
 
@@ -9,3 +10,4 @@ def parquet():
9
10
  pass
10
11
 
11
12
  parquet.add_command(jsonl)
13
+ parquet.add_command(mds)
@@ -1,4 +1,3 @@
1
- #!/usr/bin/env python
2
1
  import click
3
2
  from mltiming import timing
4
3
 
@@ -0,0 +1,43 @@
1
+ import click
2
+ import json
3
+ import os
4
+ from streaming import MDSWriter
5
+ from tqdm import tqdm
6
+
7
+ from ....utils import check_overwrite, infer_mds_encoding, load_parquet_files, pigz_compress, use_pigz
8
+
9
+ @click.command()
10
+ @click.argument('output_dir', type=click.Path(exists=False))
11
+ @click.argument('parquet_files', nargs=-1, type=click.Path(exists=True))
12
+ @click.option('--compression', type=click.Choice(['none', 'br', 'bz2', 'gzip', 'pigz', 'snappy', 'zstd'], case_sensitive=False), default=None, help='Compression type for the output dataset (default: None).')
13
+ @click.option("--processes", default=64, help="Number of processes to use for pigz compression (default: 64).")
14
+ @click.option("--overwrite", is_flag=True, help="Overwrite existing MDS directory.")
15
+ @click.option("--yes", is_flag=True, help="Assume yes to all prompts. Use with caution as it will remove entire directory trees without confirmation.")
16
+ @click.option("--buf-size", default=2**24, help=f"Buffer size for pigz compression (default: {2**24}).")
17
+ def mds(output_dir, parquet_files, processes, compression, overwrite, yes, buf_size):
18
+ check_overwrite(output_dir, overwrite, yes)
19
+ if not parquet_files:
20
+ raise click.BadArgumentUsage("No parquet files provided.")
21
+ ds = load_parquet_files(parquet_files)
22
+ pigz = use_pigz(compression)
23
+ sample = ds[0]
24
+ if compression == "none" or pigz:
25
+ compression = None
26
+ if compression == "gzip":
27
+ compression = "gz"
28
+ columns = {key: infer_mds_encoding(value) for key, value in sample.items()}
29
+ lines = 0
30
+ with MDSWriter(out=output_dir, columns=columns, compression=compression) as writer:
31
+ for item in tqdm(ds, desc="Processing samples", unit="sample"):
32
+ writer.write(item)
33
+ lines += 1
34
+ print(f"Wrote {lines} lines from {len(parquet_files)} files to MDS files in {output_dir}")
35
+ if pigz:
36
+ file_paths = []
37
+ for file in os.listdir(output_dir):
38
+ if file.endswith(".mds"):
39
+ file_paths.append(os.path.join(output_dir, file))
40
+ for file_path in tqdm(file_paths, desc="Compressing with pigz", unit="file"):
41
+ pigz_compress(file_path, file_path + ".gz", processes, buf_size=buf_size, keep=False, quiet=True)
42
+ output_dir
43
+ print(f"Compressed {output_dir} with pigz")
@@ -0,0 +1,60 @@
1
+ import subprocess
2
+
3
+ __all__ = ["pigz_open"]
4
+
5
+ def pigz_open(path, mode="rt", processes=64, encoding=None):
6
+ return PigzFile(path, mode=mode, processes=processes, encoding=encoding)
7
+
8
+ class PigzFile(object):
9
+ """A wrapper for pigz to handle gzip compression and decompression."""
10
+ def __init__(self, path, mode="rt", processes=4, encoding="utf-8"):
11
+ assert mode in ("rt", "wt", "rb", "wb")
12
+ self.path = path
13
+ self.is_read = mode[0] == "r"
14
+ self.is_text = mode[1] == "t"
15
+ self.processes = processes
16
+ self.encoding = encoding if self.is_text else None
17
+ self._process = None
18
+ self._fw = None
19
+ args = ["pigz", "-p", str(self.processes), "-c"]
20
+ if self.is_read:
21
+ args.extend(("-d", self.path))
22
+ self._process = subprocess.Popen(args, stdout=subprocess.PIPE, encoding=self.encoding, text=self.is_text)
23
+ else:
24
+ self._fw = open(self.path, "w+")
25
+ self._process = subprocess.Popen(args, stdout=self._fw, stdin=subprocess.PIPE, encoding=self.encoding, text=self.is_text)
26
+
27
+ def __iter__(self):
28
+ assert self.is_read
29
+ for line in self._process.stdout:
30
+ assert isinstance(line, str) if self.is_text else isinstance(line, bytes)
31
+ yield line
32
+ self._process.wait()
33
+ assert self._process.returncode == 0
34
+ self._process.stdout.close()
35
+ self._process = None
36
+
37
+ def write(self, line):
38
+ assert not self.is_read
39
+ assert self._fw is not None
40
+ assert isinstance(line, str) if self.is_text else isinstance(line, bytes)
41
+ self._process.stdin.write(line)
42
+
43
+ def close(self):
44
+ if self._process:
45
+ if self.is_read:
46
+ self._process.kill()
47
+ self._process.stdout.close()
48
+ self._process = None
49
+ else:
50
+ self._process.stdin.close()
51
+ self._process.wait()
52
+ self._process = None
53
+ self._fw.close()
54
+ self._fw = None
55
+
56
+ def __enter__(self):
57
+ return self
58
+
59
+ def __exit__(self, exc_type, exc_value, traceback):
60
+ self.close()
@@ -7,10 +7,12 @@ import lzma
7
7
  from mltiming import timing
8
8
  import os
9
9
  import shutil
10
- import subprocess
10
+ from streaming import StreamingDataset
11
11
  import tempfile
12
12
  from tqdm import tqdm
13
13
 
14
+ from .pigz import pigz_open
15
+
14
16
  __all__ = [
15
17
  "check_overwrite",
16
18
  "create_temp_file",
@@ -18,72 +20,13 @@ __all__ = [
18
20
  "infer_mds_encoding",
19
21
  "infer_compression",
20
22
  "load_parquet_files",
23
+ "load_mds_directories",
21
24
  "open_jsonl",
22
25
  "pigz_available",
23
26
  "pigz_compress",
24
27
  "use_pigz",
25
28
  ]
26
29
 
27
- class PigzFile(object):
28
- """A wrapper for pigz to handle gzip compression and decompression."""
29
- def __init__(self, path, mode="rt", processes=4, encoding=None):
30
- if mode not in ("rt", "wt", "rb", "wb"):
31
- raise ValueError("Mode must be one of rt, wt, rb, or wb.")
32
- self.path = path
33
- self.mode = mode
34
- self.processes = processes
35
- self.encoding = "latin1" if mode[1] == "b" else ("utf-8" if encoding is None else encoding)
36
- self._process = None
37
- self._fw = None
38
- if self.mode[0] == "r":
39
- args = ["pigz", "-d", "-c", "-p", str(self.processes), self.path]
40
- self._process = subprocess.Popen(
41
- args,
42
- stdout=subprocess.PIPE,
43
- encoding=encoding,
44
- )
45
- elif self.mode[0] == "w":
46
- args = ["pigz", "-p", str(self.processes), "-c"]
47
- self._fw = open(self.path, "w+")
48
- self._process = subprocess.Popen(
49
- args,
50
- stdout=self._fw,
51
- stdin=subprocess.PIPE,
52
- encoding=encoding,
53
- )
54
-
55
- def __iter__(self):
56
- assert self.mode[0] == "r"
57
- for line in self._process.stdout:
58
- yield line
59
- self._process.wait()
60
- assert self._process.returncode == 0
61
- self._process.stdout.close()
62
- self._process = None
63
-
64
- def write(self, line):
65
- assert self.mode[0] == "w"
66
- self._process.stdin.write(line if self.mode[1] == "t" else line.encode(self.encoding))
67
-
68
- def close(self):
69
- if self._process:
70
- if self.mode[0] == "r":
71
- self._process.kill()
72
- self._process.stdout.close()
73
- self._process = None
74
- elif self.mode[1] == "w":
75
- self._process.stdin.close()
76
- self._process.wait()
77
- self._process = None
78
- self._fw.close()
79
- self._fw = None
80
-
81
- def __enter__(self):
82
- return self
83
-
84
- def __exit__(self, exc_type, exc_value, traceback):
85
- self.close()
86
-
87
30
  def check_overwrite(output_path, overwrite, yes):
88
31
  if os.path.exists(output_path):
89
32
  if os.path.isfile(output_path):
@@ -162,6 +105,27 @@ def infer_compression(file_path):
162
105
  return 'zstd'
163
106
  return None
164
107
 
108
+ def load_mds_directories(mds_directories, split='.', batch_size=2**16):
109
+ dss = []
110
+ for mds_directory in tqdm(mds_directories, desc="Loading MDS directories", unit="directory"):
111
+ ds = StreamingDataset(
112
+ local=mds_directory,
113
+ remote=None,
114
+ split=split,
115
+ shuffle=False,
116
+ allow_unsafe_types=True,
117
+ batch_size=batch_size,
118
+ download_retry=1,
119
+ validate_hash=False,
120
+ )
121
+ dss.append(ds)
122
+ if len(dss) == 1:
123
+ ds = dss[0]
124
+ else:
125
+ with timing(message=f"Concatenating {len(dss)} datasets"):
126
+ ds = concatenate_datasets(dsets=dss)
127
+ return ds
128
+
165
129
  def load_parquet_files(parquet_files):
166
130
  dss = []
167
131
  for parquet_file in tqdm(parquet_files, desc="Loading parquet files", unit="file"):
@@ -174,11 +138,13 @@ def load_parquet_files(parquet_files):
174
138
  ds = concatenate_datasets(dsets=dss)
175
139
  return ds
176
140
 
177
- def open_jsonl(file_path, mode="rt", compression="infer"):
141
+ def open_jsonl(file_path, mode="rt", compression="infer", processes=64):
178
142
  """Open a JSONL file, handling gzip compression if necessary."""
179
143
  compression = determine_compression(file_path, compression)
180
- if compression in ("gzip", "pigz"):
144
+ if compression == "gzip":
181
145
  return gzip.open(file_path, mode)
146
+ if compression == "pigz":
147
+ return pigz_open(file_path, mode, processes=processes) if mode[0] == "w" else gzip.open(file_path, mode)
182
148
  if compression == "bz2":
183
149
  return bz2.open(file_path, mode)
184
150
  if compression == "xz":
@@ -195,7 +161,7 @@ def pigz_compress(input_file, output_file, processes=64, buf_size=2**24, keep=Fa
195
161
  """Compress a file using pigz."""
196
162
  size = os.stat(input_file).st_size
197
163
  num_blocks = (size+buf_size-1) // buf_size
198
- with open(input_file, "rt", encoding="latin1") as f_in, PigzFile(output_file, "wb", processes=processes) as f_out:
164
+ with open(input_file, "rb") as f_in, pigz_open(output_file, "wb", processes=processes) as f_out:
199
165
  for _ in tqdm(range(num_blocks), desc="Compressing with pigz", unit="block", disable=quiet):
200
166
  buf = f_in.read(buf_size)
201
167
  assert buf
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "mldataforge"
7
- version = "0.0.2"
7
+ version = "0.0.4"
8
8
  authors = [
9
9
  { name = "Peter Schneider-Kamp" }
10
10
  ]
File without changes
File without changes
File without changes