mldataforge 0.0.4__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mldataforge
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
5
5
  Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
6
6
  Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
@@ -1,6 +1,7 @@
1
1
  import click
2
2
 
3
3
  from .mds import mds
4
+ from .parquet import parquet
4
5
 
5
6
  __all__ = ["jsonl"]
6
7
 
@@ -9,3 +10,4 @@ def jsonl():
9
10
  pass
10
11
 
11
12
  jsonl.add_command(mds)
13
+ jsonl.add_command(parquet)
@@ -0,0 +1,39 @@
1
+ import click
2
+ import json
3
+ import pyarrow as pa
4
+ import pyarrow.parquet as pq
5
+ from tqdm import tqdm
6
+
7
+ from ....utils import batch_iterable, check_overwrite, open_jsonl
8
+
9
+ def _iterate(jsonl_files):
10
+ lines = 0
11
+ for jsonl_file in tqdm(jsonl_files, desc="Processing JSONL files", unit="file"):
12
+ with open_jsonl(jsonl_file, compression="infer") as f:
13
+ for line_num, line in enumerate(f, start=1):
14
+ try:
15
+ item = json.loads(line)
16
+ yield item
17
+ except json.JSONDecodeError as e:
18
+ print(f"Skipping line {line_num} in {jsonl_file} due to JSON error: {e}")
19
+ lines += 1
20
+ print(f"Wrote {lines} lines from {len(jsonl_files)} files")
21
+
22
+ @click.command()
23
+ @click.argument('output_file', type=click.Path(exists=False))
24
+ @click.argument('jsonl_files', nargs=-1, type=click.Path(exists=True))
25
+ @click.option("--compression", default="snappy", type=click.Choice(["snappy", "gzip", "zstd"]), help="Compress the Parquet file (default: snappy).")
26
+ @click.option("--overwrite", is_flag=True, help="Overwrite existing MDS directory.")
27
+ @click.option("--yes", is_flag=True, help="Assume yes to all prompts. Use with caution as it will remove entire directory trees without confirmation.")
28
+ @click.option("--batch-size", default=2**16, help="Batch size for loading MDS directories and writing Parquet files (default: 65536).")
29
+ def parquet(output_file, jsonl_files, compression, overwrite, yes, batch_size):
30
+ check_overwrite(output_file, overwrite, yes)
31
+ if not jsonl_files:
32
+ raise click.BadArgumentUsage("No JSONL files provided.")
33
+ writer = None
34
+ for batch in batch_iterable(_iterate(jsonl_files), batch_size):
35
+ table = pa.Table.from_pylist(batch)
36
+ if writer is None:
37
+ writer = pq.ParquetWriter(output_file, table.schema, compression=compression)
38
+ writer.write_table(table)
39
+ writer.close()
@@ -1,6 +1,7 @@
1
1
  import click
2
2
 
3
3
  from .jsonl import jsonl
4
+ from .parquet import parquet
4
5
 
5
6
  __all__ = ["mds"]
6
7
 
@@ -9,3 +10,4 @@ def mds():
9
10
  pass
10
11
 
11
12
  mds.add_command(jsonl)
13
+ mds.add_command(parquet)
@@ -0,0 +1,26 @@
1
+ import click
2
+ import pyarrow as pa
3
+ import pyarrow.parquet as pq
4
+ from tqdm import tqdm
5
+
6
+ from ....utils import batch_iterable, check_overwrite, load_mds_directories
7
+
8
+ @click.command()
9
+ @click.argument("output_file", type=click.Path(exists=False), required=True)
10
+ @click.argument("mds_directories", type=click.Path(exists=True), required=True, nargs=-1)
11
+ @click.option("--compression", default="snappy", type=click.Choice(["snappy", "gzip", "zstd"]), help="Compress the Parquet file (default: snappy).")
12
+ @click.option("--overwrite", is_flag=True, help="Overwrite existing Parquet files.")
13
+ @click.option("--yes", is_flag=True, help="Assume yes to all prompts. Use with caution as it will remove files without confirmation.")
14
+ @click.option("--batch-size", default=2**16, help="Batch size for loading MDS directories and writing Parquet files (default: 65536).")
15
+ def parquet(output_file, mds_directories, compression, overwrite, yes, batch_size):
16
+ check_overwrite(output_file, overwrite, yes)
17
+ if not mds_directories:
18
+ raise click.BadArgumentUsage("No MDS files provided.")
19
+ ds = load_mds_directories(mds_directories, batch_size=batch_size)
20
+ writer = None
21
+ for batch in tqdm(batch_iterable(ds, batch_size), desc="Writing to Parquet", unit="batch", total=(len(ds)+batch_size-1) // batch_size):
22
+ table = pa.Table.from_pylist(batch)
23
+ if writer is None:
24
+ writer = pq.ParquetWriter(output_file, table.schema, compression=compression)
25
+ writer.write_table(table)
26
+ writer.close()
@@ -27,6 +27,16 @@ __all__ = [
27
27
  "use_pigz",
28
28
  ]
29
29
 
30
+ def batch_iterable(iterable, batch_size):
31
+ batch = []
32
+ for item in iterable:
33
+ batch.append(item)
34
+ if len(batch) == batch_size:
35
+ yield batch
36
+ batch.clear()
37
+ if batch:
38
+ yield batch
39
+
30
40
  def check_overwrite(output_path, overwrite, yes):
31
41
  if os.path.exists(output_path):
32
42
  if os.path.isfile(output_path):
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "mldataforge"
7
- version = "0.0.4"
7
+ version = "0.0.5"
8
8
  authors = [
9
9
  { name = "Peter Schneider-Kamp" }
10
10
  ]
File without changes
File without changes
File without changes