mldataforge 0.1.5__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mldataforge
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
5
5
  Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
6
6
  Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
@@ -53,7 +53,7 @@ def join_mds(output_dir, mds_directories, compression, processes, overwrite, yes
53
53
  compression=compression,
54
54
  buf_size=buf_size,
55
55
  shard_size=shard_size,
56
- pigz=use_pigz(compression, no_pigz)
56
+ pigz=use_pigz(compression, no_pigz),
57
57
  )
58
58
 
59
59
  @join.command()
@@ -30,7 +30,7 @@ JSONL_COMPRESSIONS = dict(
30
30
  )
31
31
  MDS_COMPRESSIONS = dict(
32
32
  default=None,
33
- choices=["none", "brotli", "bz2", "gzip", "pigz", "snappy", "zstd"],
33
+ choices=["none", "brotli", "bz2", "gzip", "pigz", "snappy", "zstd", "sample::brotli", "sample::bz2", "sample::gzip", "sample::snappy", "sample::zstd"],
34
34
  )
35
35
  PARQUET_COMPRESSIONS = dict(
36
36
  default="snappy",
@@ -55,6 +55,10 @@ def determine_compression(fmt, file_path, compression="infer", no_pigz=False):
55
55
  return "gz"
56
56
  if compression == "brotli":
57
57
  return "br"
58
+ if compression == "sample::gzip":
59
+ return "gz"
60
+ if compression == "sample::brotli":
61
+ return "br"
58
62
  return compression
59
63
  if fmt == "parquet":
60
64
  return compression
@@ -0,0 +1,317 @@
1
+ import json
2
+ import numpy as np
3
+ import os
4
+ import shutil
5
+ from streaming.base.compression import compress, decompress, get_compression_extension, is_compression
6
+ from streaming.base.format.index import get_index_basename
7
+ from streaming.base.format.mds.encodings import mds_decode, mds_encode, is_mds_encoding, get_mds_encodings, get_mds_encoded_size
8
+ from streaming.base.hashing import get_hash, is_hash
9
+ from streaming.base.util import bytes_to_int
10
+ from typing import Any, Optional, Generator, Self, Union
11
+
12
+ from .utils import open_compression
13
+
14
+ class MDSBulkReader:
15
+ def __init__(
16
+ self,
17
+ dirnames: list[str],
18
+ split: Optional[str],
19
+ ) -> None:
20
+ self.shards = []
21
+ self.samples = 0
22
+ for dirname in dirnames:
23
+ if split is not None:
24
+ dirname = os.path.join(dirname, split)
25
+ index = json.load(open(os.path.join(dirname, "index.json"), 'rt'))
26
+ for shard in index["shards"]:
27
+ basename = shard['raw_data']['basename'] if shard['zip_data'] is None else shard['zip_data']['basename']
28
+ filename = os.path.join(dirname, basename)
29
+ self.shards.append({
30
+ "filename": filename,
31
+ "compression": shard['compression'],
32
+ })
33
+ self.samples += shard['samples']
34
+
35
+ def __len__(self) -> int:
36
+ return self.samples
37
+
38
+ def __iter__(self) -> Generator[dict[str, Any], None, None]:
39
+ for shard in self.shards:
40
+ with MDSShardReader(**shard) as reader:
41
+ for sample in reader:
42
+ yield sample
43
+
44
+ class MDSShardReader:
45
+ def __init__(
46
+ self,
47
+ filename: str,
48
+ compression: Optional[str],
49
+ ) -> None:
50
+ self.sample_compression = None
51
+ if compression is not None and compression.startswith("sample::"):
52
+ compression, self.sample_compression = None, compression.removeprefix("sample::")
53
+ self.fp = open_compression(filename, "rb", compression=compression)
54
+ self.samples = np.frombuffer(self.fp.read(4), np.uint32)[0]
55
+ self.index = np.frombuffer(self.fp.read((1+self.samples)*4), np.uint32)
56
+ info = json.loads(self.fp.read(self.index[0]-self.fp.tell()))
57
+ self.column_encodings = info["column_encodings"]
58
+ self.column_names = info["column_names"]
59
+ self.column_sizes = info["column_sizes"]
60
+ assert self.fp.tell() == self.index[0]
61
+
62
+ def decode_sample(self, data: bytes) -> dict[str, Any]:
63
+ sizes = []
64
+ idx = 0
65
+ for key, size in zip(self.column_names, self.column_sizes):
66
+ if size:
67
+ sizes.append(size)
68
+ else:
69
+ size, = np.frombuffer(data[idx:idx + 4], np.uint32)
70
+ sizes.append(size)
71
+ idx += 4
72
+ sample = {}
73
+ for key, encoding, size in zip(self.column_names, self.column_encodings, sizes):
74
+ value = data[idx:idx + size]
75
+ sample[key] = mds_decode(encoding, value)
76
+ idx += size
77
+ return sample
78
+
79
+ def get_sample_data(self, idx: int) -> bytes:
80
+ begin, end = self.index[idx:idx+2]
81
+ assert self.fp.tell() == begin
82
+ data = self.fp.read(end - begin)
83
+ assert self.fp.tell() == end
84
+ assert data
85
+ return data
86
+
87
+ def get_item(self, idx: int) -> dict[str, Any]:
88
+ data = self.get_sample_data(idx)
89
+ if self.sample_compression is not None:
90
+ data = decompress(self.sample_compression, data)
91
+ return self.decode_sample(data)
92
+
93
+ def __iter__(self) -> Generator[dict[str, Any], None, None]:
94
+ for i in range(self.samples):
95
+ yield self.get_item(i)
96
+
97
+ def __enter__(self) -> "MDSShardReader":
98
+ return self
99
+
100
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
101
+ self.fp.close()
102
+
103
+
104
+ class MDSWriter:
105
+
106
+ format = 'mds'
107
+ extra_bytes_per_sample = 4
108
+
109
+ def __init__(self,
110
+ *,
111
+ columns: dict[str, str],
112
+ out: Union[str, tuple[str, str]],
113
+ compression: Optional[str] = None,
114
+ hashes: Optional[list[str]] = None,
115
+ size_limit: Optional[Union[int, str]] = 1 << 26,
116
+ **kwargs: Any) -> None:
117
+ compression = compression or None
118
+ sample_compression = None
119
+ if compression is not None and compression.startswith("sample::"):
120
+ compression, sample_compression = None, compression.removeprefix("sample::")
121
+ if compression:
122
+ if not is_compression(compression):
123
+ raise ValueError(f'Invalid compression: {compression}.')
124
+ if sample_compression:
125
+ if not is_compression(sample_compression):
126
+ raise ValueError(f'Invalid sample compression: {sample_compression}.')
127
+ hashes = hashes or []
128
+ if list(hashes) != sorted(hashes):
129
+ raise ValueError('Hashes must be unique and in sorted order.')
130
+ for algo in hashes:
131
+ if not is_hash(algo):
132
+ raise ValueError(f'Invalid hash: {algo}.')
133
+
134
+ size_limit_value = None
135
+ if size_limit:
136
+ size_limit_value = bytes_to_int(size_limit)
137
+ if size_limit_value < 0:
138
+ raise ValueError(f'`size_limit` must be greater than zero, instead, ' +
139
+ f'found as {size_limit_value}.')
140
+ if size_limit_value >= 2**32:
141
+ raise ValueError(f'`size_limit` must be less than 2**32, instead, ' +
142
+ f'found as {size_limit_value}. This is because sample ' +
143
+ f'byte offsets are stored with uint32.')
144
+
145
+ # Validate keyword arguments
146
+ invalid_kwargs = [
147
+ arg for arg in kwargs.keys()
148
+ if arg not in ('progress_bar', 'exist_ok')
149
+ ]
150
+ if invalid_kwargs:
151
+ raise ValueError(f'Invalid Writer argument(s): {invalid_kwargs} ')
152
+
153
+ self.compression = compression
154
+ self.sample_compression = sample_compression
155
+ self.hashes = hashes
156
+ self.size_limit = size_limit_value
157
+ self.new_samples: list[bytes]
158
+ self.new_shard_size: int
159
+
160
+ self.shards = []
161
+
162
+ # Remove local directory if requested prior to creating writer
163
+ self.local = os.path.expanduser(out)
164
+ if os.path.exists(self.local) and len(os.listdir(self.local)) != 0:
165
+ if kwargs.get('exist_ok', False):
166
+ raise FileExistsError(f'Directory is not empty: {self.local}')
167
+ shutil.rmtree(self.local)
168
+ os.makedirs(self.local, exist_ok=True)
169
+
170
+ self.columns = columns
171
+ self.column_names = []
172
+ self.column_encodings = []
173
+ self.column_sizes = []
174
+ for name in sorted(columns):
175
+ encoding = columns[name]
176
+ if not is_mds_encoding(encoding):
177
+ raise TypeError(f'MDSWriter passed column `{name}` with encoding `{encoding}` ' +
178
+ f'is unsupported. Supported encodings are {get_mds_encodings()}')
179
+ size = get_mds_encoded_size(encoding)
180
+ self.column_names.append(name)
181
+ self.column_encodings.append(encoding)
182
+ self.column_sizes.append(size)
183
+
184
+ obj = self.get_config()
185
+ text = json.dumps(obj, sort_keys=True)
186
+ self.config_data = text.encode('utf-8')
187
+ self.extra_bytes_per_shard = 4 + 4 + len(self.config_data)
188
+ self._reset_cache()
189
+
190
+ def encode_sample(self, sample: dict[str, Any]) -> bytes:
191
+ sizes = []
192
+ data = []
193
+ for key, encoding, size in zip(self.column_names, self.column_encodings,
194
+ self.column_sizes):
195
+ value = sample[key]
196
+ datum = mds_encode(encoding, value)
197
+ if size is None:
198
+ size = len(datum)
199
+ sizes.append(size)
200
+ else:
201
+ if size != len(datum):
202
+ raise KeyError(f'Unexpected data size; was this data typed with the correct ' +
203
+ f'encoding ({encoding})?')
204
+ data.append(datum)
205
+ head = np.array(sizes, np.uint32).tobytes()
206
+ body = b''.join(data)
207
+ sample_data = head + body
208
+ if self.sample_compression:
209
+ sample_data = compress(self.sample_compression, sample_data)
210
+ return sample_data
211
+
212
+ def encode_joint_shard(self) -> bytes:
213
+ num_samples = np.uint32(len(self.new_samples))
214
+ sizes = list(map(len, self.new_samples))
215
+ offsets = np.array([0] + sizes).cumsum().astype(np.uint32)
216
+ offsets += len(num_samples.tobytes()) + len(offsets.tobytes()) + len(self.config_data)
217
+ sample_data = b''.join(self.new_samples)
218
+ return num_samples.tobytes() + offsets.tobytes() + self.config_data + sample_data
219
+
220
+ def flush_shard(self) -> None:
221
+ raw_data_basename, zip_data_basename = self._name_next_shard()
222
+ raw_data = self.encode_joint_shard()
223
+ raw_data_info, zip_data_info = self._process_file(raw_data, raw_data_basename,
224
+ zip_data_basename)
225
+ obj = {
226
+ 'samples': len(self.new_samples),
227
+ 'raw_data': raw_data_info,
228
+ 'zip_data': zip_data_info
229
+ }
230
+ obj.update(self.get_config())
231
+ self.shards.append(obj)
232
+
233
+ def _reset_cache(self) -> None:
234
+ self.new_samples = []
235
+ self.new_shard_size = self.extra_bytes_per_shard
236
+
237
+ def _name_next_shard(self, extension: Optional[str] = None) -> tuple[str, Optional[str]]:
238
+ shard = len(self.shards)
239
+ parts = ['shard', f'{shard:05}', self.format]
240
+ if extension:
241
+ parts.append(extension)
242
+ raw_basename = '.'.join(parts)
243
+ if self.compression:
244
+ ext = get_compression_extension(self.compression)
245
+ parts.append(ext)
246
+ zip_basename = '.'.join(parts)
247
+ else:
248
+ zip_basename = None
249
+ return raw_basename, zip_basename
250
+
251
+ def _hash(self, data: bytes, basename: str) -> dict[str, Any]:
252
+ hashes = {}
253
+ for algo in self.hashes:
254
+ hashes[algo] = get_hash(algo, data)
255
+ return {'basename': basename, 'bytes': len(data), 'hashes': hashes}
256
+
257
+ def _process_file(self, raw_data: bytes, raw_basename: str,
258
+ zip_basename: Optional[str]) -> tuple[dict, Optional[dict]]:
259
+ raw_info = self._hash(raw_data, raw_basename)
260
+ if zip_basename:
261
+ zip_data = compress(self.compression, raw_data)
262
+ zip_info = self._hash(zip_data, zip_basename)
263
+ data = zip_data
264
+ basename = zip_basename
265
+ else:
266
+ zip_info = None
267
+ data = raw_data
268
+ basename = raw_basename
269
+ filename = os.path.join(self.local, basename)
270
+ with open(filename, 'wb') as out:
271
+ out.write(data)
272
+ return raw_info, zip_info
273
+
274
+ def get_config(self) -> dict[str, Any]:
275
+ return {
276
+ 'version': 2,
277
+ 'format': self.format,
278
+ 'compression': self.compression if self.sample_compression is None else f"sample::{self.sample_compression}",
279
+ 'hashes': self.hashes,
280
+ 'size_limit': self.size_limit,
281
+ 'column_names': self.column_names,
282
+ 'column_encodings': self.column_encodings,
283
+ 'column_sizes': self.column_sizes,
284
+ }
285
+
286
+ def write(self, sample: dict[str, Any]) -> None:
287
+ new_sample = self.encode_sample(sample)
288
+ new_sample_size = len(new_sample) + self.extra_bytes_per_sample
289
+ if self.size_limit and self.size_limit < self.new_shard_size + new_sample_size:
290
+ self.flush_shard()
291
+ self._reset_cache()
292
+ self.new_samples.append(new_sample)
293
+ self.new_shard_size += new_sample_size
294
+
295
+ def _write_index(self) -> None:
296
+ if self.new_samples:
297
+ raise RuntimeError('Internal error: not all samples have been written.')
298
+ basename = get_index_basename()
299
+ filename = os.path.join(self.local, basename)
300
+ obj = {
301
+ 'version': 2,
302
+ 'shards': self.shards,
303
+ }
304
+ with open(filename, 'w') as out:
305
+ json.dump(obj, out, sort_keys=True)
306
+
307
+ def finish(self) -> None:
308
+ if self.new_samples:
309
+ self.flush_shard()
310
+ self._reset_cache()
311
+ self._write_index()
312
+
313
+ def __enter__(self) -> Self:
314
+ return self
315
+
316
+ def __exit__(self, exc_type, exc, traceback):
317
+ self.finish()
@@ -36,60 +36,6 @@ class _SnappyWriteWrapper(io.RawIOBase):
36
36
  def writable(self):
37
37
  return True
38
38
 
39
-
40
- # class _SnappyReadWrapper(io.RawIOBase):
41
- # def __init__(self, fileobj):
42
- # self.fileobj = fileobj
43
- # self.buffer = io.BytesIO()
44
- # self.eof = False
45
-
46
- # def _fill_buffer_if_needed(self, min_bytes):
47
- # self.buffer.seek(0, io.SEEK_END)
48
- # while not self.eof and self.buffer.tell() < min_bytes:
49
- # length_bytes = self.fileobj.read(4)
50
- # if not length_bytes:
51
- # self.eof = True
52
- # break
53
- # if len(length_bytes) < 4:
54
- # self.eof = True # mark as EOF even if last chunk is malformed
55
- # break
56
-
57
- # try:
58
- # length = struct.unpack(">I", length_bytes)[0]
59
- # compressed = self.fileobj.read(length)
60
- # if len(compressed) < length:
61
- # self.eof = True
62
- # break
63
-
64
- # decompressed = snappy.decompress(compressed)
65
- # self.buffer.write(decompressed)
66
- # except Exception:
67
- # self.eof = True
68
- # break
69
-
70
- # self.buffer.seek(0)
71
-
72
- # def read(self, size=-1):
73
- # if size == -1:
74
- # while not self.eof:
75
- # self._fill_buffer_if_needed(_CHUNK_SIZE)
76
- # result = self.buffer.read()
77
- # self.buffer = io.BytesIO()
78
- # return result
79
-
80
- # self._fill_buffer_if_needed(size)
81
- # data = self.buffer.read(size)
82
- # rest = self.buffer.read()
83
- # self.buffer = io.BytesIO()
84
- # self.buffer.write(rest)
85
- # return data
86
-
87
- # def readable(self):
88
- # return True
89
-
90
- # def close(self):
91
- # self.fileobj.close()
92
-
93
39
  class _SnappyReadWrapper(io.RawIOBase):
94
40
  def __init__(self, fileobj):
95
41
  self.fileobj = fileobj
@@ -6,11 +6,11 @@ import pyarrow as pa
6
6
  import pyarrow.parquet as pq
7
7
  import os
8
8
  import shutil
9
- from streaming import MDSWriter, StreamingDataset
9
+ from streaming import StreamingDataset
10
10
  from tqdm import tqdm
11
11
 
12
12
  from .compression import determine_compression, open_compression, pigz_compress
13
- from .mds import MDSBulkReader
13
+ from .mds import MDSBulkReader, MDSWriter
14
14
  from .pigz import pigz_open
15
15
 
16
16
  __all__ = [
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "mldataforge"
7
- version = "0.1.5"
7
+ version = "0.1.6"
8
8
  authors = [
9
9
  { name = "Peter Schneider-Kamp" }
10
10
  ]
@@ -1,95 +0,0 @@
1
- import gzip
2
- import json
3
- from mltiming import timing
4
- import numpy as np
5
- import os
6
- import snappy
7
- from streaming.base.format.mds.encodings import mds_decode
8
- from typing import Any, Optional, Generator
9
-
10
- from .options import MDS_COMPRESSIONS
11
- from .utils import open_compression
12
-
13
- class MDSBulkReader:
14
- def __init__(
15
- self,
16
- dirnames: list[str],
17
- split: Optional[str],
18
- ) -> None:
19
- self.shards = []
20
- self.samples = 0
21
- for dirname in dirnames:
22
- if split is not None:
23
- dirname = os.path.join(dirname, split)
24
- index = json.load(open(os.path.join(dirname, "index.json"), 'rt'))
25
- for shard in index["shards"]:
26
- basename = shard['raw_data']['basename'] if shard['zip_data'] is None else shard['zip_data']['basename']
27
- filename = os.path.join(dirname, basename)
28
- self.shards.append({
29
- "filename": filename,
30
- "compression": shard['compression'],
31
- })
32
- self.samples += shard['samples']
33
-
34
- def __len__(self) -> int:
35
- return self.samples
36
-
37
- def __iter__(self) -> Generator[dict[str, Any], None, None]:
38
- for shard in self.shards:
39
- with MDSShardReader(**shard) as reader:
40
- for sample in reader:
41
- yield sample
42
-
43
- class MDSShardReader:
44
- def __init__(
45
- self,
46
- filename: str,
47
- compression: Optional[str],
48
- ) -> None:
49
- self.fp = open_compression(filename, "rb", compression=compression)
50
- self.samples = np.frombuffer(self.fp.read(4), np.uint32)[0]
51
- self.index = np.frombuffer(self.fp.read((1+self.samples)*4), np.uint32)
52
- info = json.loads(self.fp.read(self.index[0]-self.fp.tell()))
53
- self.column_encodings = info["column_encodings"]
54
- self.column_names = info["column_names"]
55
- self.column_sizes = info["column_sizes"]
56
- assert self.fp.tell() == self.index[0]
57
-
58
- def decode_sample(self, data: bytes) -> dict[str, Any]:
59
- sizes = []
60
- idx = 0
61
- for key, size in zip(self.column_names, self.column_sizes):
62
- if size:
63
- sizes.append(size)
64
- else:
65
- size, = np.frombuffer(data[idx:idx + 4], np.uint32)
66
- sizes.append(size)
67
- idx += 4
68
- sample = {}
69
- for key, encoding, size in zip(self.column_names, self.column_encodings, sizes):
70
- value = data[idx:idx + size]
71
- sample[key] = mds_decode(encoding, value)
72
- idx += size
73
- return sample
74
-
75
- def get_sample_data(self, idx: int) -> bytes:
76
- begin, end = self.index[idx:idx+2]
77
- assert self.fp.tell() == begin
78
- data = self.fp.read(end - begin)
79
- assert self.fp.tell() == end
80
- assert data
81
- return data
82
-
83
- def get_item(self, idx: int) -> dict[str, Any]:
84
- data = self.get_sample_data(idx)
85
- return self.decode_sample(data)
86
-
87
- def __iter__(self) -> Generator[dict[str, Any], None, None]:
88
- for i in range(self.samples):
89
- yield self.get_item(i)
90
-
91
- def __enter__(self) -> "MDSShardReader":
92
- return self
93
-
94
- def __exit__(self, exc_type, exc_value, traceback) -> None:
95
- self.fp.close()
File without changes
File without changes
File without changes