mldataforge 0.0.3__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mldataforge
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
5
5
  Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
6
6
  Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
@@ -1,6 +1,7 @@
1
1
  import click
2
2
 
3
3
  from .jsonl import jsonl
4
+ from .mds import mds
4
5
  from .parquet import parquet
5
6
 
6
7
  __all__ = ["convert"]
@@ -9,5 +10,6 @@ __all__ = ["convert"]
9
10
  def convert():
10
11
  pass
11
12
 
12
- convert.add_command(parquet)
13
13
  convert.add_command(jsonl)
14
+ convert.add_command(mds)
15
+ convert.add_command(parquet)
@@ -1,6 +1,7 @@
1
1
  import click
2
2
 
3
3
  from .mds import mds
4
+ from .parquet import parquet
4
5
 
5
6
  __all__ = ["jsonl"]
6
7
 
@@ -9,3 +10,4 @@ def jsonl():
9
10
  pass
10
11
 
11
12
  jsonl.add_command(mds)
13
+ jsonl.add_command(parquet)
@@ -1,4 +1,3 @@
1
- #!/usr/bin/env python
2
1
  import click
3
2
  import json
4
3
  import os
@@ -40,11 +39,21 @@ def mds(output_dir, jsonl_files, processes, compression, overwrite, yes, buf_siz
40
39
  lines += 1
41
40
  print(f"Wrote {lines} lines from {len(jsonl_files)} files to MDS files in {output_dir}")
42
41
  if pigz:
43
- file_paths = []
44
- for file in os.listdir(output_dir):
45
- if file.endswith(".mds"):
46
- file_paths.append(os.path.join(output_dir, file))
47
- for file_path in tqdm(file_paths, desc="Compressing with pigz", unit="file"):
48
- pigz_compress(file_path, file_path + ".gz", processes, buf_size=buf_size, keep=False, quiet=True)
49
- output_dir
42
+ index_path = os.path.join(output_dir, "index.json")
43
+ index = json.load(open(index_path, "rt"))
44
+ name2info = {shard["raw_data"]["basename"]: shard for shard in index["shards"]}
45
+ file_names = [file for file in os.listdir(output_dir) if file.endswith(".mds")]
46
+ assert set(file_names) == set(name2info.keys())
47
+ for file_name in tqdm(file_names, desc="Compressing with pigz", unit="file"):
48
+ compressed_file_name = file_name + ".gz"
49
+ file_path = os.path.join(output_dir, file_name)
50
+ compressed_file_path = os.path.join(output_dir, compressed_file_name)
51
+ pigz_compress(file_path, compressed_file_path, processes, buf_size=buf_size, keep=False, quiet=True)
52
+ name2info[file_name]["compression"] = "gz"
53
+ name2info[file_name]["zip_data"] = {
54
+ "basename": compressed_file_name,
55
+ "bytes": os.stat(compressed_file_path).st_size,
56
+ "hashes": {},
57
+ }
58
+ json.dump(index, open(index_path, "wt"))
50
59
  print(f"Compressed {output_dir} with pigz")
@@ -0,0 +1,39 @@
1
+ import click
2
+ import json
3
+ import pyarrow as pa
4
+ import pyarrow.parquet as pq
5
+ from tqdm import tqdm
6
+
7
+ from ....utils import batch_iterable, check_overwrite, open_jsonl
8
+
9
+ def _iterate(jsonl_files):
10
+ lines = 0
11
+ for jsonl_file in tqdm(jsonl_files, desc="Processing JSONL files", unit="file"):
12
+ with open_jsonl(jsonl_file, compression="infer") as f:
13
+ for line_num, line in enumerate(f, start=1):
14
+ try:
15
+ item = json.loads(line)
16
+ yield item
17
+ except json.JSONDecodeError as e:
18
+ print(f"Skipping line {line_num} in {jsonl_file} due to JSON error: {e}")
19
+ lines += 1
20
+ print(f"Wrote {lines} lines from {len(jsonl_files)} files")
21
+
22
+ @click.command()
23
+ @click.argument('output_file', type=click.Path(exists=False))
24
+ @click.argument('jsonl_files', nargs=-1, type=click.Path(exists=True))
25
+ @click.option("--compression", default="snappy", type=click.Choice(["snappy", "gzip", "zstd"]), help="Compress the Parquet file (default: snappy).")
26
+ @click.option("--overwrite", is_flag=True, help="Overwrite existing MDS directory.")
27
+ @click.option("--yes", is_flag=True, help="Assume yes to all prompts. Use with caution as it will remove entire directory trees without confirmation.")
28
+ @click.option("--batch-size", default=2**16, help="Batch size for loading MDS directories and writing Parquet files (default: 65536).")
29
+ def parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size):
30
+ check_overwrite(output_file, overwrite, yes)
31
+ if not jsonl_files:
32
+ raise click.BadArgumentUsage("No JSONL files provided.")
33
+ writer = None
34
+ for batch in batch_iterable(_iterate(jsonl_files), batch_size):
35
+ table = pa.Table.from_pylist(batch)
36
+ if writer is None:
37
+ writer = pq.ParquetWriter(output_file, table.schema, compression=compression)
38
+ writer.write_table(table)
39
+ writer.close()
@@ -0,0 +1,13 @@
1
+ import click
2
+
3
+ from .jsonl import jsonl
4
+ from .parquet import parquet
5
+
6
+ __all__ = ["mds"]
7
+
8
+ @click.group()
9
+ def mds():
10
+ pass
11
+
12
+ mds.add_command(jsonl)
13
+ mds.add_command(parquet)
@@ -0,0 +1,23 @@
1
+ import click
2
+ import json
3
+ from tqdm import tqdm
4
+
5
+ from ....utils import check_overwrite, create_temp_file, determine_compression, load_mds_directories, open_jsonl
6
+
7
+ @click.command()
8
+ @click.argument("output_file", type=click.Path(exists=False), required=True)
9
+ @click.argument("mds_directories", type=click.Path(exists=True), required=True, nargs=-1)
10
+ @click.option("--compression", default="infer", type=click.Choice(["none", "infer", "pigz", "gzip", "bz2", "xz"]), help="Compress the output JSONL file (default: infer; pigz for parallel gzip).")
11
+ @click.option("--processes", default=64, help="Number of processes to use for pigz compression (default: 64).")
12
+ @click.option("--overwrite", is_flag=True, help="Overwrite existing JSONL files.")
13
+ @click.option("--yes", is_flag=True, help="Assume yes to all prompts. Use with caution as it will remove files without confirmation.")
14
+ @click.option("--batch-size", default=2**16, help="Batch size for loading MDS directories (default: 65536).")
15
+ def jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size):
16
+ check_overwrite(output_file, overwrite, yes)
17
+ if not mds_directories:
18
+ raise click.BadArgumentUsage("No MDS files provided.")
19
+ ds = load_mds_directories(mds_directories, batch_size=batch_size)
20
+ compression = determine_compression(output_file, compression)
21
+ with open_jsonl(output_file, mode="wb", compression=compression, processes=processes) as f:
22
+ for item in tqdm(ds, desc="Writing to JSONL", unit="line"):
23
+ f.write(f"{json.dumps(item)}\n".encode("utf-8"))
@@ -0,0 +1,26 @@
1
+ import click
2
+ import pyarrow as pa
3
+ import pyarrow.parquet as pq
4
+ from tqdm import tqdm
5
+
6
+ from ....utils import batch_iterable, check_overwrite, load_mds_directories
7
+
8
+ @click.command()
9
+ @click.argument("output_file", type=click.Path(exists=False), required=True)
10
+ @click.argument("mds_directories", type=click.Path(exists=True), required=True, nargs=-1)
11
+ @click.option("--compression", default="snappy", type=click.Choice(["snappy", "gzip", "zstd"]), help="Compress the Parquet file (default: snappy).")
12
+ @click.option("--overwrite", is_flag=True, help="Overwrite existing Parquet files.")
13
+ @click.option("--yes", is_flag=True, help="Assume yes to all prompts. Use with caution as it will remove files without confirmation.")
14
+ @click.option("--batch-size", default=2**16, help="Batch size for loading MDS directories and writing Parquet files (default: 65536).")
15
+ def parquet(output_file, mds_directories, compression, overwrite, yes, batch_size):
16
+ check_overwrite(output_file, overwrite, yes)
17
+ if not mds_directories:
18
+ raise click.BadArgumentUsage("No MDS files provided.")
19
+ ds = load_mds_directories(mds_directories, batch_size=batch_size)
20
+ writer = None
21
+ for batch in tqdm(batch_iterable(ds, batch_size), desc="Writing to Parquet", unit="batch", total=(len(ds)+batch_size-1) // batch_size):
22
+ table = pa.Table.from_pylist(batch)
23
+ if writer is None:
24
+ writer = pq.ParquetWriter(output_file, table.schema, compression=compression)
25
+ writer.write_table(table)
26
+ writer.close()
@@ -1,4 +1,3 @@
1
- #!/usr/bin/env python
2
1
  import click
3
2
  from mltiming import timing
4
3
 
@@ -1,4 +1,3 @@
1
- #!/usr/bin/env python
2
1
  import click
3
2
  import json
4
3
  import os
@@ -0,0 +1,60 @@
1
+ import subprocess
2
+
3
+ __all__ = ["pigz_open"]
4
+
5
+ def pigz_open(path, mode="rt", processes=64, encoding=None):
6
+ return PigzFile(path, mode=mode, processes=processes, encoding=encoding)
7
+
8
+ class PigzFile(object):
9
+ """A wrapper for pigz to handle gzip compression and decompression."""
10
+ def __init__(self, path, mode="rt", processes=4, encoding="utf-8"):
11
+ assert mode in ("rt", "wt", "rb", "wb")
12
+ self.path = path
13
+ self.is_read = mode[0] == "r"
14
+ self.is_text = mode[1] == "t"
15
+ self.processes = processes
16
+ self.encoding = encoding if self.is_text else None
17
+ self._process = None
18
+ self._fw = None
19
+ args = ["pigz", "-p", str(self.processes), "-c"]
20
+ if self.is_read:
21
+ args.extend(("-d", self.path))
22
+ self._process = subprocess.Popen(args, stdout=subprocess.PIPE, encoding=self.encoding, text=self.is_text)
23
+ else:
24
+ self._fw = open(self.path, "w+")
25
+ self._process = subprocess.Popen(args, stdout=self._fw, stdin=subprocess.PIPE, encoding=self.encoding, text=self.is_text)
26
+
27
+ def __iter__(self):
28
+ assert self.is_read
29
+ for line in self._process.stdout:
30
+ assert isinstance(line, str) if self.is_text else isinstance(line, bytes)
31
+ yield line
32
+ self._process.wait()
33
+ assert self._process.returncode == 0
34
+ self._process.stdout.close()
35
+ self._process = None
36
+
37
+ def write(self, line):
38
+ assert not self.is_read
39
+ assert self._fw is not None
40
+ assert isinstance(line, str) if self.is_text else isinstance(line, bytes)
41
+ self._process.stdin.write(line)
42
+
43
+ def close(self):
44
+ if self._process:
45
+ if self.is_read:
46
+ self._process.kill()
47
+ self._process.stdout.close()
48
+ self._process = None
49
+ else:
50
+ self._process.stdin.close()
51
+ self._process.wait()
52
+ self._process = None
53
+ self._fw.close()
54
+ self._fw = None
55
+
56
+ def __enter__(self):
57
+ return self
58
+
59
+ def __exit__(self, exc_type, exc_value, traceback):
60
+ self.close()
@@ -7,10 +7,12 @@ import lzma
7
7
  from mltiming import timing
8
8
  import os
9
9
  import shutil
10
- import subprocess
10
+ from streaming import StreamingDataset
11
11
  import tempfile
12
12
  from tqdm import tqdm
13
13
 
14
+ from .pigz import pigz_open
15
+
14
16
  __all__ = [
15
17
  "check_overwrite",
16
18
  "create_temp_file",
@@ -18,71 +20,22 @@ __all__ = [
18
20
  "infer_mds_encoding",
19
21
  "infer_compression",
20
22
  "load_parquet_files",
23
+ "load_mds_directories",
21
24
  "open_jsonl",
22
25
  "pigz_available",
23
26
  "pigz_compress",
24
27
  "use_pigz",
25
28
  ]
26
29
 
27
- class PigzFile(object):
28
- """A wrapper for pigz to handle gzip compression and decompression."""
29
- def __init__(self, path, mode="rt", processes=4, encoding=None):
30
- if mode not in ("rt", "wt", "rb", "wb"):
31
- raise ValueError("Mode must be one of rt, wt, rb, or wb.")
32
- self.path = path
33
- self.mode = mode
34
- self.processes = processes
35
- self.encoding = "latin1" if mode[1] == "b" else ("utf-8" if encoding is None else encoding)
36
- self._process = None
37
- self._fw = None
38
- if self.mode[0] == "r":
39
- args = ["pigz", "-d", "-c", "-p", str(self.processes), self.path]
40
- self._process = subprocess.Popen(
41
- args,
42
- stdout=subprocess.PIPE,
43
- encoding=encoding,
44
- )
45
- elif self.mode[0] == "w":
46
- args = ["pigz", "-p", str(self.processes), "-c"]
47
- self._fw = open(self.path, "w+")
48
- self._process = subprocess.Popen(
49
- args,
50
- stdout=self._fw,
51
- stdin=subprocess.PIPE,
52
- encoding=encoding,
53
- )
54
-
55
- def __iter__(self):
56
- assert self.mode[0] == "r"
57
- for line in self._process.stdout:
58
- yield line
59
- self._process.wait()
60
- assert self._process.returncode == 0
61
- self._process.stdout.close()
62
- self._process = None
63
-
64
- def write(self, line):
65
- assert self.mode[0] == "w"
66
- self._process.stdin.write(line if self.mode[1] == "t" else line.encode(self.encoding))
67
-
68
- def close(self):
69
- if self._process:
70
- if self.mode[0] == "r":
71
- self._process.kill()
72
- self._process.stdout.close()
73
- self._process = None
74
- elif self.mode[1] == "w":
75
- self._process.stdin.close()
76
- self._process.wait()
77
- self._process = None
78
- self._fw.close()
79
- self._fw = None
80
-
81
- def __enter__(self):
82
- return self
83
-
84
- def __exit__(self, exc_type, exc_value, traceback):
85
- self.close()
30
+ def batch_iterable(iterable, batch_size):
31
+ batch = []
32
+ for item in iterable:
33
+ batch.append(item)
34
+ if len(batch) == batch_size:
35
+ yield batch
36
+ batch.clear()
37
+ if batch:
38
+ yield batch
86
39
 
87
40
  def check_overwrite(output_path, overwrite, yes):
88
41
  if os.path.exists(output_path):
@@ -162,6 +115,27 @@ def infer_compression(file_path):
162
115
  return 'zstd'
163
116
  return None
164
117
 
118
+ def load_mds_directories(mds_directories, split='.', batch_size=2**16):
119
+ dss = []
120
+ for mds_directory in tqdm(mds_directories, desc="Loading MDS directories", unit="directory"):
121
+ ds = StreamingDataset(
122
+ local=mds_directory,
123
+ remote=None,
124
+ split=split,
125
+ shuffle=False,
126
+ allow_unsafe_types=True,
127
+ batch_size=batch_size,
128
+ download_retry=1,
129
+ validate_hash=False,
130
+ )
131
+ dss.append(ds)
132
+ if len(dss) == 1:
133
+ ds = dss[0]
134
+ else:
135
+ with timing(message=f"Concatenating {len(dss)} datasets"):
136
+ ds = concatenate_datasets(dsets=dss)
137
+ return ds
138
+
165
139
  def load_parquet_files(parquet_files):
166
140
  dss = []
167
141
  for parquet_file in tqdm(parquet_files, desc="Loading parquet files", unit="file"):
@@ -174,11 +148,13 @@ def load_parquet_files(parquet_files):
174
148
  ds = concatenate_datasets(dsets=dss)
175
149
  return ds
176
150
 
177
- def open_jsonl(file_path, mode="rt", compression="infer"):
151
+ def open_jsonl(file_path, mode="rt", compression="infer", processes=64):
178
152
  """Open a JSONL file, handling gzip compression if necessary."""
179
153
  compression = determine_compression(file_path, compression)
180
- if compression in ("gzip", "pigz"):
154
+ if compression == "gzip":
181
155
  return gzip.open(file_path, mode)
156
+ if compression == "pigz":
157
+ return pigz_open(file_path, mode, processes=processes) if mode[0] == "w" else gzip.open(file_path, mode)
182
158
  if compression == "bz2":
183
159
  return bz2.open(file_path, mode)
184
160
  if compression == "xz":
@@ -195,7 +171,7 @@ def pigz_compress(input_file, output_file, processes=64, buf_size=2**24, keep=Fa
195
171
  """Compress a file using pigz."""
196
172
  size = os.stat(input_file).st_size
197
173
  num_blocks = (size+buf_size-1) // buf_size
198
- with open(input_file, "rt", encoding="latin1") as f_in, PigzFile(output_file, "wb", processes=processes) as f_out:
174
+ with open(input_file, "rb") as f_in, pigz_open(output_file, "wb", processes=processes) as f_out:
199
175
  for _ in tqdm(range(num_blocks), desc="Compressing with pigz", unit="block", disable=quiet):
200
176
  buf = f_in.read(buf_size)
201
177
  assert buf
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "mldataforge"
7
- version = "0.0.3"
7
+ version = "0.0.5"
8
8
  authors = [
9
9
  { name = "Peter Schneider-Kamp" }
10
10
  ]
File without changes
File without changes
File without changes