mldataforge 0.1.0__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mldataforge
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
5
5
  Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
6
6
  Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
@@ -17,10 +17,11 @@ def mds():
17
17
  @overwrite_option()
18
18
  @yes_option()
19
19
  @batch_size_option()
20
- def jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size):
20
+ @no_bulk_option()
21
+ def jsonl(output_file, mds_directories, compression, processes, overwrite, yes, batch_size, no_bulk):
21
22
  check_arguments(output_file, overwrite, yes, mds_directories)
22
23
  save_jsonl(
23
- load_mds_directories(mds_directories, batch_size=batch_size),
24
+ load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
24
25
  output_file,
25
26
  compression=compression,
26
27
  processes=processes,
@@ -28,15 +29,16 @@ def jsonl(output_file, mds_directories, compression, processes, overwrite, yes,
28
29
 
29
30
  @mds.command()
30
31
  @click.argument("output_file", type=click.Path(exists=False), required=True)
31
- @click.argument("parquet_files", type=click.Path(exists=True), required=True, nargs=-1)
32
+ @click.argument("mds_directories", type=click.Path(exists=True), required=True, nargs=-1)
32
33
  @compression_option("snappy", ["snappy", "gzip", "zstd"])
33
34
  @overwrite_option()
34
35
  @yes_option()
35
36
  @batch_size_option()
36
- def parquet(output_file, parquet_files, compression, overwrite, yes, batch_size):
37
- check_arguments(output_file, overwrite, yes, parquet_files)
37
+ @no_bulk_option()
38
+ def parquet(output_file, mds_directories, compression, overwrite, yes, batch_size, no_bulk):
39
+ check_arguments(output_file, overwrite, yes, mds_directories)
38
40
  save_parquet(
39
- load_mds_directories(parquet_files, batch_size=batch_size),
41
+ load_mds_directories(mds_directories, batch_size=batch_size, bulk=not no_bulk),
40
42
  output_file,
41
43
  compression=compression,
42
44
  batch_size=batch_size,
@@ -0,0 +1,97 @@
1
+ import gzip
2
+ import json
3
+ from mltiming import timing
4
+ import numpy as np
5
+ import os
6
+ from streaming.base.format.mds.encodings import mds_decode
7
+ from typing import Any, Optional, Generator
8
+
9
+ class MDSBulkReader:
10
+ def __init__(
11
+ self,
12
+ dirnames: list[str],
13
+ split: Optional[str],
14
+ ) -> None:
15
+ self.shards = []
16
+ self.samples = 0
17
+ for dirname in dirnames:
18
+ if split is not None:
19
+ dirname = os.path.join(dirname, split)
20
+ index = json.load(open(os.path.join(dirname, "index.json"), 'rt'))
21
+ for shard in index["shards"]:
22
+ basename = shard['raw_data']['basename'] if shard['zip_data'] is None else shard['zip_data']['basename']
23
+ filename = os.path.join(dirname, basename)
24
+ self.shards.append({
25
+ "filename": filename,
26
+ "compression": shard['compression'],
27
+ })
28
+ self.samples += shard['samples']
29
+
30
+ def __len__(self) -> int:
31
+ return self.samples
32
+
33
+ def __iter__(self) -> Generator[dict[str, Any], None, None]:
34
+ for shard in self.shards:
35
+ with MDSShardReader(**shard) as reader:
36
+ for sample in reader:
37
+ yield sample
38
+
39
+ class MDSShardReader:
40
+ def __init__(
41
+ self,
42
+ filename: str,
43
+ compression: Optional[str],
44
+ ) -> None:
45
+ if compression is None:
46
+ _open = open
47
+ elif compression == 'gz':
48
+ _open = gzip.open
49
+ else:
50
+ raise ValueError(f'Unsupported compression type: {compression}. Supported types: None, gzip.')
51
+ self.fp = _open(filename, "rb")
52
+ self.samples = np.frombuffer(self.fp.read(4), np.uint32)[0]
53
+ self.index = np.frombuffer(self.fp.read((1+self.samples)*4), np.uint32)
54
+ info = json.loads(self.fp.read(self.index[0]-self.fp.tell()))
55
+ self.column_encodings = info["column_encodings"]
56
+ self.column_names = info["column_names"]
57
+ self.column_sizes = info["column_sizes"]
58
+ assert self.fp.tell() == self.index[0]
59
+
60
+ def decode_sample(self, data: bytes) -> dict[str, Any]:
61
+ sizes = []
62
+ idx = 0
63
+ for key, size in zip(self.column_names, self.column_sizes):
64
+ if size:
65
+ sizes.append(size)
66
+ else:
67
+ size, = np.frombuffer(data[idx:idx + 4], np.uint32)
68
+ sizes.append(size)
69
+ idx += 4
70
+ sample = {}
71
+ for key, encoding, size in zip(self.column_names, self.column_encodings, sizes):
72
+ value = data[idx:idx + size]
73
+ sample[key] = mds_decode(encoding, value)
74
+ idx += size
75
+ return sample
76
+
77
+ def get_sample_data(self, idx: int) -> bytes:
78
+ begin, end = self.index[idx:idx+2]
79
+ assert self.fp.tell() == begin
80
+ data = self.fp.read(end - begin)
81
+ assert self.fp.tell() == end
82
+ assert data
83
+ return data
84
+
85
+ def get_item(self, idx: int) -> dict[str, Any]:
86
+ data = self.get_sample_data(idx)
87
+ return self.decode_sample(data)
88
+
89
+ def __iter__(self) -> Generator[dict[str, Any], None, None]:
90
+ for i in range(self.samples):
91
+ yield self.get_item(i)
92
+
93
+ def __enter__(self) -> "MDSShardReader":
94
+ return self
95
+
96
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
97
+ self.fp.close()
@@ -29,6 +29,16 @@ def buf_size_option(default=2**24):
29
29
  help=f"Buffer size for pigz compression (default: {default}).",
30
30
  )
31
31
 
32
+ def no_bulk_option():
33
+ """
34
+ Option for specifying whether to use a custom space and time-efficient bulk reader (only gzip and no compression).
35
+ """
36
+ return click.option(
37
+ "--no-bulk",
38
+ is_flag=True,
39
+ help="Use a custom space and time-efficient bulk reader (only gzip and no compression).",
40
+ )
41
+
32
42
  def compression_option(default, choices):
33
43
  """
34
44
  Option for specifying the compression type.
@@ -12,6 +12,7 @@ import shutil
12
12
  from streaming import MDSWriter, StreamingDataset
13
13
  from tqdm import tqdm
14
14
 
15
+ from .mds import MDSBulkReader
15
16
  from .pigz import pigz_open
16
17
 
17
18
  __all__ = [
@@ -98,7 +99,9 @@ def _infer_compression(file_path):
98
99
  return 'zstd'
99
100
  return None
100
101
 
101
- def load_mds_directories(mds_directories, split='.', batch_size=2**16):
102
+ def load_mds_directories(mds_directories, split='.', batch_size=2**16, bulk=True):
103
+ if bulk:
104
+ return MDSBulkReader(mds_directories, split=split)
102
105
  dss = []
103
106
  for mds_directory in mds_directories:
104
107
  ds = StreamingDataset(
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "mldataforge"
7
- version = "0.1.0"
7
+ version = "0.1.1"
8
8
  authors = [
9
9
  { name = "Peter Schneider-Kamp" }
10
10
  ]
File without changes
File without changes
File without changes