mldataforge 0.1.6__tar.gz → 0.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mldataforge
3
- Version: 0.1.6
3
+ Version: 0.1.7
4
4
  Summary: swiss army knife of scripts for transforming and processing datasets for machine learning.
5
5
  Project-URL: Homepage, https://github.com/schneiderkamplab/mldataforge
6
6
  Project-URL: Bug Tracker, https://github.com/schneiderkamplab/mldataforge/issues
@@ -1,16 +1,26 @@
1
+ from copy import deepcopy
1
2
  import json
2
3
  import numpy as np
3
4
  import os
4
5
  import shutil
5
6
  from streaming.base.compression import compress, decompress, get_compression_extension, is_compression
7
+ from streaming.base.format import _readers
8
+ from streaming.base.format.base.reader import FileInfo, JointReader
6
9
  from streaming.base.format.index import get_index_basename
7
- from streaming.base.format.mds.encodings import mds_decode, mds_encode, is_mds_encoding, get_mds_encodings, get_mds_encoded_size
10
+ from streaming.base.format.mds.encodings import mds_decode, mds_encode, is_mds_encoding, is_mds_encoding_safe, get_mds_encodings, get_mds_encoded_size
8
11
  from streaming.base.hashing import get_hash, is_hash
9
12
  from streaming.base.util import bytes_to_int
10
13
  from typing import Any, Optional, Generator, Self, Union
11
14
 
12
15
  from .utils import open_compression
13
16
 
17
+ __all__ = [
18
+ "MDSBulkReader",
19
+ "MDSBulkShardReader",
20
+ "MDSReader",
21
+ "MDSWriter",
22
+ ]
23
+
14
24
  class MDSBulkReader:
15
25
  def __init__(
16
26
  self,
@@ -37,11 +47,11 @@ class MDSBulkReader:
37
47
 
38
48
  def __iter__(self) -> Generator[dict[str, Any], None, None]:
39
49
  for shard in self.shards:
40
- with MDSShardReader(**shard) as reader:
50
+ with MDSBulkShardReader(**shard) as reader:
41
51
  for sample in reader:
42
52
  yield sample
43
53
 
44
- class MDSShardReader:
54
+ class MDSBulkShardReader:
45
55
  def __init__(
46
56
  self,
47
57
  filename: str,
@@ -94,7 +104,7 @@ class MDSShardReader:
94
104
  for i in range(self.samples):
95
105
  yield self.get_item(i)
96
106
 
97
- def __enter__(self) -> "MDSShardReader":
107
+ def __enter__(self) -> "MDSBulkShardReader":
98
108
  return self
99
109
 
100
110
  def __exit__(self, exc_type, exc_value, traceback) -> None:
@@ -315,3 +325,124 @@ class MDSWriter:
315
325
 
316
326
  def __exit__(self, exc_type, exc, traceback):
317
327
  self.finish()
328
+
329
+ class MDSReader(JointReader):
330
+
331
+ def __init__(
332
+ self,
333
+ dirname: str,
334
+ split: Optional[str],
335
+ column_encodings: list[str],
336
+ column_names: list[str],
337
+ column_sizes: list[Optional[int]],
338
+ compression: Optional[str],
339
+ hashes: list[str],
340
+ raw_data: FileInfo,
341
+ samples: int,
342
+ size_limit: Optional[Union[int, str]],
343
+ zip_data: Optional[FileInfo],
344
+ ) -> None:
345
+ self.sample_compression = None
346
+ if compression and compression.startswith("sample::"):
347
+ compression, self.sample_compression = None, compression.removeprefix("sample::")
348
+ super().__init__(dirname, split, compression, hashes, raw_data, samples, size_limit,
349
+ zip_data)
350
+ self.column_encodings = column_encodings
351
+ self.column_names = column_names
352
+ self.column_sizes = column_sizes
353
+
354
+ @classmethod
355
+ def from_json(cls, dirname: str, split: Optional[str], obj: dict[str, Any]) -> Self:
356
+ """Initialize from JSON object.
357
+
358
+ Args:
359
+ dirname (str): Local directory containing shards.
360
+ split (str, optional): Which dataset split to use, if any.
361
+ obj (Dict[str, Any]): JSON object to load.
362
+
363
+ Returns:
364
+ Self: Loaded MDSReader.
365
+ """
366
+ args = deepcopy(obj)
367
+ args_version = args['version']
368
+ if args_version != 2:
369
+ raise ValueError(
370
+ f'Unsupported streaming data version: {args_version}. Expected version 2.')
371
+ del args['version']
372
+ args_format = args['format']
373
+ if args_format != 'mds':
374
+ raise ValueError(f'Unsupported data format: {args_format}. Expected to be `mds`.')
375
+ del args['format']
376
+ args['dirname'] = dirname
377
+ args['split'] = split
378
+ for key in ['raw_data', 'zip_data']:
379
+ arg = args[key]
380
+ args[key] = FileInfo(**arg) if arg else None
381
+ return cls(**args)
382
+
383
+ def validate(self, allow_unsafe_types: bool) -> None:
384
+ """Check whether this shard is acceptable to be part of some Stream.
385
+
386
+ Args:
387
+ allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
388
+ execution during deserialization, whether to keep going if ``True`` or raise an
389
+ error if ``False``.
390
+ """
391
+ if not allow_unsafe_types:
392
+ for column_id, encoding in enumerate(self.column_encodings):
393
+ if not is_mds_encoding_safe(encoding):
394
+ name = self.column_names[column_id]
395
+ raise ValueError(f'Column {name} contains an unsafe type: {encoding}. To ' +
396
+ f'proceed anyway, set ``allow_unsafe_types=True``.')
397
+
398
+ def decode_sample(self, data: bytes) -> dict[str, Any]:
399
+ """Decode a sample dict from bytes.
400
+
401
+ Args:
402
+ data (bytes): The sample encoded as bytes.
403
+
404
+ Returns:
405
+ Dict[str, Any]: Sample dict.
406
+ """
407
+ sizes = []
408
+ idx = 0
409
+ for key, size in zip(self.column_names, self.column_sizes):
410
+ if size:
411
+ sizes.append(size)
412
+ else:
413
+ size, = np.frombuffer(data[idx:idx + 4], np.uint32)
414
+ sizes.append(size)
415
+ idx += 4
416
+ sample = {}
417
+ for key, encoding, size in zip(self.column_names, self.column_encodings, sizes):
418
+ value = data[idx:idx + size]
419
+ sample[key] = mds_decode(encoding, value)
420
+ idx += size
421
+ return sample
422
+
423
+ def get_sample_data(self, idx: int) -> bytes:
424
+ """Get the raw sample data at the index.
425
+
426
+ Args:
427
+ idx (int): Sample index.
428
+
429
+ Returns:
430
+ bytes: Sample data.
431
+ """
432
+ filename = os.path.join(self.dirname, self.split, self.raw_data.basename)
433
+ offset = (1 + idx) * 4
434
+ with open(filename, 'rb', 0) as fp:
435
+ fp.seek(offset)
436
+ pair = fp.read(8)
437
+ begin, end = np.frombuffer(pair, np.uint32)
438
+ fp.seek(begin)
439
+ data = fp.read(end - begin)
440
+ if not data:
441
+ raise IndexError(
442
+ f'Relative sample index {idx} is not present in the {self.raw_data.basename} file.'
443
+ )
444
+ if self.sample_compression:
445
+ data = decompress(self.sample_compression, data)
446
+ return data
447
+
448
+ _readers["mds"] = MDSReader
@@ -23,6 +23,11 @@ __all__ = [
23
23
  "save_parquet",
24
24
  ]
25
25
 
26
+ _NO_PROGESS = False
27
+ def set_progress(value):
28
+ global _NO_PROGESS
29
+ _NO_PROGESS = value
30
+
26
31
  def _batch_iterable(iterable, batch_size):
27
32
  batch = []
28
33
  for item in iterable:
@@ -73,7 +78,7 @@ def _infer_mds_encoding(value):
73
78
  return 'pkl'
74
79
 
75
80
  def _streaming_jsonl(jsonl_files, compressions):
76
- for jsonl_file, compression in tqdm(zip(jsonl_files, compressions), desc="Loading JSONL files", unit="file"):
81
+ for jsonl_file, compression in tqdm(zip(jsonl_files, compressions), desc="Loading JSONL files", unit="file", disable=_NO_PROGESS):
77
82
  for line in open_compression(jsonl_file, mode="rt", compression=compression):
78
83
  yield json.loads(line)
79
84
 
@@ -109,7 +114,7 @@ def load_mds_directories(mds_directories, split='.', batch_size=2**16, bulk=True
109
114
  def save_jsonl(iterable, output_file, compression=None, processes=64, size_hint=None, overwrite=True, yes=True):
110
115
  f = None
111
116
  part = 0
112
- for item in tqdm(iterable, desc="Writing to JSONL", unit="sample"):
117
+ for item in tqdm(iterable, desc="Writing to JSONL", unit="sample", disable=_NO_PROGESS):
113
118
  if f is None:
114
119
  part_file = output_file.format(part=part)
115
120
  check_arguments(part_file, overwrite, yes)
@@ -127,7 +132,7 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
127
132
  writer = None
128
133
  part = 0
129
134
  files = []
130
- for sample in tqdm(it, desc="Writing to MDS", unit="sample"):
135
+ for sample in tqdm(it, desc="Writing to MDS", unit="sample", disable=_NO_PROGESS):
131
136
  if writer is None:
132
137
  part_dir = output_dir.format(part=part)
133
138
  check_arguments(part_dir, overwrite, yes)
@@ -151,7 +156,7 @@ def save_mds(it, output_dir, processes=64, compression=None, buf_size=2**24, pig
151
156
  name2info = {shard["raw_data"]["basename"]: shard for shard in index["shards"]}
152
157
  file_names = [file for file in os.listdir(output_dir) if file.endswith(".mds")]
153
158
  assert set(file_names) == set(name2info.keys())
154
- for file_name in tqdm(file_names, desc="Compressing with pigz", unit="file"):
159
+ for file_name in tqdm(file_names, desc="Compressing with pigz", unit="file", disable=_NO_PROGESS):
155
160
  compressed_file_name = file_name + ".gz"
156
161
  file_path = os.path.join(output_dir, file_name)
157
162
  compressed_file_path = os.path.join(output_dir, compressed_file_name)
@@ -169,7 +174,7 @@ def save_parquet(it, output_file, compression=None, batch_size=2**16, size_hint=
169
174
  compression = determine_compression("parquet", output_file, compression)
170
175
  writer = None
171
176
  part = 0
172
- it = tqdm(it, desc="Writing to Parquet", unit="sample")
177
+ it = tqdm(it, desc="Writing to Parquet", unit="sample", disable=_NO_PROGESS)
173
178
  for batch in _batch_iterable(it, batch_size):
174
179
  table = pa.Table.from_pylist(batch)
175
180
  if writer is None:
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "mldataforge"
7
- version = "0.1.6"
7
+ version = "0.1.7"
8
8
  authors = [
9
9
  { name = "Peter Schneider-Kamp" }
10
10
  ]
File without changes
File without changes
File without changes