dalla-data-processing 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dalla/__init__.py +27 -0
- dalla/cli.py +453 -0
- dalla/core/__init__.py +6 -0
- dalla/core/dataset.py +387 -0
- dalla/core/parallel.py +279 -0
- dalla/deduplication/__init__.py +370 -0
- dalla/deduplication/bin/.gitignore +1 -0
- dalla/deduplication/bin/onion-linux-x86_64 +0 -0
- dalla/deduplication/onion/COPYING +24 -0
- dalla/deduplication/onion/Makefile +21 -0
- dalla/deduplication/onion/Makefile.config +3 -0
- dalla/deduplication/onion/README.md +21 -0
- dalla/deduplication/onion/src/Makefile +22 -0
- dalla/deduplication/onion/src/Makefile.g +23 -0
- dalla/deduplication/onion/src/buzhash.c +325 -0
- dalla/deduplication/onion/src/buzhash.h +30 -0
- dalla/deduplication/onion/src/hashdup.c +172 -0
- dalla/deduplication/onion/src/hashgen.c +206 -0
- dalla/deduplication/onion/src/onion +0 -0
- dalla/deduplication/onion/src/onion.c +799 -0
- dalla/deduplication/onion/src/onion_dup.c +824 -0
- dalla/deduplication/onion/src/version.c +17 -0
- dalla/deduplication/onion/src/version.h +10 -0
- dalla/deduplication/onion/src_sc/Makefile +22 -0
- dalla/deduplication/onion/src_sc/Makefile.g +23 -0
- dalla/deduplication/onion/src_sc/buzhash.c +325 -0
- dalla/deduplication/onion/src_sc/buzhash.h +30 -0
- dalla/deduplication/onion/src_sc/hashdup +0 -0
- dalla/deduplication/onion/src_sc/hashdup.c +172 -0
- dalla/deduplication/onion/src_sc/hashgen +0 -0
- dalla/deduplication/onion/src_sc/hashgen.c +206 -0
- dalla/deduplication/onion/src_sc/onion.c +854 -0
- dalla/deduplication/onion/src_sc/onion_dup.c +824 -0
- dalla/deduplication/onion/src_sc/version.c +17 -0
- dalla/deduplication/onion/src_sc/version.h +10 -0
- dalla/deduplication/onion_wrapper.py +223 -0
- dalla/deduplication/postprocessing.py +216 -0
- dalla/deduplication/preprocessing.py +120 -0
- dalla/quality/__init__.py +5 -0
- dalla/quality/checker.py +354 -0
- dalla/readability/__init__.py +197 -0
- dalla/readability/ranking.py +165 -0
- dalla/readability/scorer.py +148 -0
- dalla/stemming/__init__.py +551 -0
- dalla/stemming/data/words_al.txt +3414 -0
- dalla/stemming/data/words_al_t.txt +885 -0
- dalla/stemming/data/words_t.txt +7 -0
- dalla/utils/__init__.py +10 -0
- dalla/utils/logger.py +128 -0
- dalla/utils/tokenize.py +89 -0
- dalla_data_processing-0.0.1.dist-info/METADATA +393 -0
- dalla_data_processing-0.0.1.dist-info/RECORD +55 -0
- dalla_data_processing-0.0.1.dist-info/WHEEL +5 -0
- dalla_data_processing-0.0.1.dist-info/entry_points.txt +2 -0
- dalla_data_processing-0.0.1.dist-info/top_level.txt +1 -0
dalla/core/dataset.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dataset I/O utilities for unified HuggingFace dataset handling.
|
|
3
|
+
|
|
4
|
+
This module provides a consistent interface for loading, saving, and manipulating
|
|
5
|
+
HuggingFace datasets across all dalla-process components.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
|
|
13
|
+
|
|
14
|
+
from dalla.utils.logger import get_logger
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DatasetManager:
|
|
20
|
+
"""Unified manager for HuggingFace dataset operations."""
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def load(
|
|
24
|
+
path: str | Path,
|
|
25
|
+
split: str | None = None,
|
|
26
|
+
streaming: bool = False,
|
|
27
|
+
) -> Dataset | DatasetDict:
|
|
28
|
+
"""
|
|
29
|
+
Load a HuggingFace dataset from disk.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
path: Path to the dataset directory
|
|
33
|
+
split: Optional split name to load (e.g., 'train', 'test')
|
|
34
|
+
streaming: Whether to use streaming mode for large datasets
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Dataset or DatasetDict depending on the structure
|
|
38
|
+
|
|
39
|
+
Example:
|
|
40
|
+
>>> dm = DatasetManager()
|
|
41
|
+
>>> dataset = dm.load("./data/my_dataset")
|
|
42
|
+
>>> train_data = dm.load("./data/my_dataset", split="train")
|
|
43
|
+
"""
|
|
44
|
+
path = Path(path)
|
|
45
|
+
if not path.exists():
|
|
46
|
+
raise FileNotFoundError(f"Dataset path does not exist: {path}")
|
|
47
|
+
|
|
48
|
+
logger.info(f"Loading dataset from {path}")
|
|
49
|
+
dataset = load_from_disk(str(path))
|
|
50
|
+
|
|
51
|
+
if split is not None:
|
|
52
|
+
if isinstance(dataset, DatasetDict):
|
|
53
|
+
if split not in dataset:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"Split '{split}' not found. Available splits: {list(dataset.keys())}"
|
|
56
|
+
)
|
|
57
|
+
dataset = dataset[split]
|
|
58
|
+
else:
|
|
59
|
+
logger.warning(f"Split '{split}' specified but dataset has no splits")
|
|
60
|
+
|
|
61
|
+
logger.info(f"Loaded dataset with {DatasetManager.get_size(dataset)} examples")
|
|
62
|
+
return dataset
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def save(
|
|
66
|
+
dataset: Dataset | DatasetDict,
|
|
67
|
+
path: str | Path,
|
|
68
|
+
overwrite: bool = False,
|
|
69
|
+
) -> None:
|
|
70
|
+
"""
|
|
71
|
+
Save a HuggingFace dataset to disk.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
dataset: Dataset or DatasetDict to save
|
|
75
|
+
path: Path where the dataset will be saved
|
|
76
|
+
overwrite: Whether to overwrite existing dataset
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
>>> dm = DatasetManager()
|
|
80
|
+
>>> dm.save(processed_dataset, "./data/processed")
|
|
81
|
+
"""
|
|
82
|
+
path = Path(path)
|
|
83
|
+
|
|
84
|
+
if path.exists() and not overwrite:
|
|
85
|
+
raise FileExistsError(
|
|
86
|
+
f"Dataset path already exists: {path}. Use overwrite=True to replace."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
logger.info(f"Saving dataset to {path}")
|
|
90
|
+
dataset.save_to_disk(str(path))
|
|
91
|
+
logger.info("Dataset saved successfully")
|
|
92
|
+
|
|
93
|
+
@staticmethod
|
|
94
|
+
def get_size(dataset: Dataset | DatasetDict) -> int:
|
|
95
|
+
"""
|
|
96
|
+
Get the total number of examples in a dataset.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
dataset: Dataset or DatasetDict
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Total number of examples
|
|
103
|
+
"""
|
|
104
|
+
if isinstance(dataset, DatasetDict):
|
|
105
|
+
return sum(len(ds) for ds in dataset.values())
|
|
106
|
+
return len(dataset)
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def get_column_names(dataset: Dataset | DatasetDict) -> list[str]:
|
|
110
|
+
"""
|
|
111
|
+
Get column names from a dataset.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
dataset: Dataset or DatasetDict
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
List of column names
|
|
118
|
+
"""
|
|
119
|
+
if isinstance(dataset, DatasetDict):
|
|
120
|
+
# Get columns from first split
|
|
121
|
+
first_split = next(iter(dataset.values()))
|
|
122
|
+
return first_split.column_names
|
|
123
|
+
return dataset.column_names
|
|
124
|
+
|
|
125
|
+
@staticmethod
|
|
126
|
+
def add_column(
|
|
127
|
+
dataset: Dataset,
|
|
128
|
+
column_name: str,
|
|
129
|
+
data: list[Any],
|
|
130
|
+
) -> Dataset:
|
|
131
|
+
"""
|
|
132
|
+
Add a new column to a dataset.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
dataset: Dataset to modify
|
|
136
|
+
column_name: Name of the new column
|
|
137
|
+
data: List of values for the new column
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Dataset with the new column added
|
|
141
|
+
|
|
142
|
+
Example:
|
|
143
|
+
>>> scores = [0.95, 0.87, 0.92, ...]
|
|
144
|
+
>>> dataset = dm.add_column(dataset, "quality_score", scores)
|
|
145
|
+
"""
|
|
146
|
+
if len(data) != len(dataset):
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f"Data length ({len(data)}) must match dataset length ({len(dataset)})"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
logger.info(f"Adding column '{column_name}' to dataset")
|
|
152
|
+
return dataset.add_column(column_name, data)
|
|
153
|
+
|
|
154
|
+
@staticmethod
|
|
155
|
+
def map_column(
|
|
156
|
+
dataset: Dataset,
|
|
157
|
+
fn: Callable,
|
|
158
|
+
input_column: str,
|
|
159
|
+
output_column: str | None = None,
|
|
160
|
+
batched: bool = False,
|
|
161
|
+
batch_size: int = 1000,
|
|
162
|
+
num_proc: int | None = None,
|
|
163
|
+
desc: str | None = None,
|
|
164
|
+
) -> Dataset:
|
|
165
|
+
"""
|
|
166
|
+
Apply a function to a column in the dataset.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
dataset: Dataset to process
|
|
170
|
+
fn: Function to apply to each example
|
|
171
|
+
input_column: Name of the input column
|
|
172
|
+
output_column: Name of the output column (if None, replaces input_column)
|
|
173
|
+
batched: Whether to process in batches
|
|
174
|
+
batch_size: Size of batches when batched=True
|
|
175
|
+
num_proc: Number of processes for parallel processing
|
|
176
|
+
desc: Description for progress bar
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
Processed dataset
|
|
180
|
+
|
|
181
|
+
Example:
|
|
182
|
+
>>> def deduplicate_text(text):
|
|
183
|
+
... return text.strip().lower()
|
|
184
|
+
>>> dataset = dm.map_column(
|
|
185
|
+
... dataset,
|
|
186
|
+
... deduplicate_text,
|
|
187
|
+
... "text",
|
|
188
|
+
... "cleaned_text",
|
|
189
|
+
... num_proc=4
|
|
190
|
+
... )
|
|
191
|
+
"""
|
|
192
|
+
if input_column not in dataset.column_names:
|
|
193
|
+
raise ValueError(f"Column '{input_column}' not found in dataset")
|
|
194
|
+
|
|
195
|
+
output_col = output_column or input_column
|
|
196
|
+
|
|
197
|
+
def process_fn(examples):
|
|
198
|
+
if batched:
|
|
199
|
+
results = [fn(item) for item in examples[input_column]]
|
|
200
|
+
else:
|
|
201
|
+
results = fn(examples[input_column])
|
|
202
|
+
return {output_col: results}
|
|
203
|
+
|
|
204
|
+
logger.info(f"Mapping function to column '{input_column}'")
|
|
205
|
+
return dataset.map(
|
|
206
|
+
process_fn,
|
|
207
|
+
batched=batched,
|
|
208
|
+
batch_size=batch_size,
|
|
209
|
+
num_proc=num_proc,
|
|
210
|
+
desc=desc or f"Processing {input_column}",
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
@staticmethod
|
|
214
|
+
def filter_dataset(
|
|
215
|
+
dataset: Dataset,
|
|
216
|
+
fn: Callable,
|
|
217
|
+
num_proc: int | None = None,
|
|
218
|
+
desc: str | None = None,
|
|
219
|
+
) -> Dataset:
|
|
220
|
+
"""
|
|
221
|
+
Filter dataset based on a condition.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
dataset: Dataset to filter
|
|
225
|
+
fn: Function that returns True for examples to keep
|
|
226
|
+
num_proc: Number of processes for parallel processing
|
|
227
|
+
desc: Description for progress bar
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
Filtered dataset
|
|
231
|
+
|
|
232
|
+
Example:
|
|
233
|
+
>>> def is_high_quality(example):
|
|
234
|
+
... return example['quality_score'] > 0.8
|
|
235
|
+
>>> filtered = dm.filter_dataset(dataset, is_high_quality)
|
|
236
|
+
"""
|
|
237
|
+
logger.info(f"Filtering dataset with {len(dataset)} examples")
|
|
238
|
+
filtered = dataset.filter(fn, num_proc=num_proc, desc=desc or "Filtering dataset")
|
|
239
|
+
logger.info(
|
|
240
|
+
f"Filtered to {len(filtered)} examples ({len(filtered) / len(dataset) * 100:.1f}%)"
|
|
241
|
+
)
|
|
242
|
+
return filtered
|
|
243
|
+
|
|
244
|
+
@staticmethod
|
|
245
|
+
def select_columns(
|
|
246
|
+
dataset: Dataset | DatasetDict,
|
|
247
|
+
columns: list[str],
|
|
248
|
+
) -> Dataset | DatasetDict:
|
|
249
|
+
"""
|
|
250
|
+
Select specific columns from a dataset.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
dataset: Dataset or DatasetDict
|
|
254
|
+
columns: List of column names to keep
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
Dataset with only the specified columns
|
|
258
|
+
"""
|
|
259
|
+
available_columns = DatasetManager.get_column_names(dataset)
|
|
260
|
+
invalid_columns = set(columns) - set(available_columns)
|
|
261
|
+
if invalid_columns:
|
|
262
|
+
raise ValueError(f"Columns not found: {invalid_columns}")
|
|
263
|
+
|
|
264
|
+
logger.info(f"Selecting columns: {columns}")
|
|
265
|
+
if isinstance(dataset, DatasetDict):
|
|
266
|
+
return DatasetDict({split: ds.select_columns(columns) for split, ds in dataset.items()})
|
|
267
|
+
return dataset.select_columns(columns)
|
|
268
|
+
|
|
269
|
+
@staticmethod
|
|
270
|
+
def remove_columns(
|
|
271
|
+
dataset: Dataset | DatasetDict,
|
|
272
|
+
columns: list[str],
|
|
273
|
+
) -> Dataset | DatasetDict:
|
|
274
|
+
"""
|
|
275
|
+
Remove specific columns from a dataset.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
dataset: Dataset or DatasetDict
|
|
279
|
+
columns: List of column names to remove
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Dataset without the specified columns
|
|
283
|
+
"""
|
|
284
|
+
logger.info(f"Removing columns: {columns}")
|
|
285
|
+
if isinstance(dataset, DatasetDict):
|
|
286
|
+
return DatasetDict({split: ds.remove_columns(columns) for split, ds in dataset.items()})
|
|
287
|
+
return dataset.remove_columns(columns)
|
|
288
|
+
|
|
289
|
+
@staticmethod
|
|
290
|
+
def concatenate(datasets: list[Dataset]) -> Dataset:
|
|
291
|
+
"""
|
|
292
|
+
Concatenate multiple datasets.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
datasets: List of datasets to concatenate
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
Concatenated dataset
|
|
299
|
+
"""
|
|
300
|
+
if not datasets:
|
|
301
|
+
raise ValueError("Cannot concatenate empty list of datasets")
|
|
302
|
+
|
|
303
|
+
logger.info(f"Concatenating {len(datasets)} datasets")
|
|
304
|
+
return concatenate_datasets(datasets)
|
|
305
|
+
|
|
306
|
+
@staticmethod
|
|
307
|
+
def train_test_split(
|
|
308
|
+
dataset: Dataset,
|
|
309
|
+
test_size: float = 0.1,
|
|
310
|
+
seed: int = 42,
|
|
311
|
+
) -> DatasetDict:
|
|
312
|
+
"""
|
|
313
|
+
Split dataset into train and test sets.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
dataset: Dataset to split
|
|
317
|
+
test_size: Fraction of data to use for testing
|
|
318
|
+
seed: Random seed for reproducibility
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
DatasetDict with 'train' and 'test' splits
|
|
322
|
+
"""
|
|
323
|
+
logger.info(f"Splitting dataset into train/test (test_size={test_size})")
|
|
324
|
+
return dataset.train_test_split(test_size=test_size, seed=seed)
|
|
325
|
+
|
|
326
|
+
@staticmethod
|
|
327
|
+
def get_info(dataset: Dataset | DatasetDict) -> dict[str, Any]:
|
|
328
|
+
"""
|
|
329
|
+
Get information about a dataset.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
dataset: Dataset or DatasetDict
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
Dictionary with dataset information
|
|
336
|
+
"""
|
|
337
|
+
if isinstance(dataset, DatasetDict):
|
|
338
|
+
return {
|
|
339
|
+
"type": "DatasetDict",
|
|
340
|
+
"splits": list(dataset.keys()),
|
|
341
|
+
"total_examples": DatasetManager.get_size(dataset),
|
|
342
|
+
"split_info": {
|
|
343
|
+
split: {
|
|
344
|
+
"num_examples": len(ds),
|
|
345
|
+
"columns": ds.column_names,
|
|
346
|
+
"features": str(ds.features),
|
|
347
|
+
}
|
|
348
|
+
for split, ds in dataset.items()
|
|
349
|
+
},
|
|
350
|
+
}
|
|
351
|
+
else:
|
|
352
|
+
return {
|
|
353
|
+
"type": "Dataset",
|
|
354
|
+
"num_examples": len(dataset),
|
|
355
|
+
"columns": dataset.column_names,
|
|
356
|
+
"features": str(dataset.features),
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
@staticmethod
|
|
360
|
+
def print_info(dataset: Dataset | DatasetDict) -> None:
|
|
361
|
+
"""
|
|
362
|
+
Print dataset information in a readable format.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
dataset: Dataset or DatasetDict
|
|
366
|
+
"""
|
|
367
|
+
info = DatasetManager.get_info(dataset)
|
|
368
|
+
|
|
369
|
+
if info["type"] == "DatasetDict":
|
|
370
|
+
print(f"\n{'=' * 60}")
|
|
371
|
+
print("Dataset Dictionary")
|
|
372
|
+
print(f"{'=' * 60}")
|
|
373
|
+
print(f"Total examples: {info['total_examples']:,}")
|
|
374
|
+
print(f"Splits: {', '.join(info['splits'])}")
|
|
375
|
+
print()
|
|
376
|
+
for split, split_info in info["split_info"].items():
|
|
377
|
+
print(f" {split}:")
|
|
378
|
+
print(f" Examples: {split_info['num_examples']:,}")
|
|
379
|
+
print(f" Columns: {', '.join(split_info['columns'])}")
|
|
380
|
+
print(f"{'=' * 60}\n")
|
|
381
|
+
else:
|
|
382
|
+
print(f"\n{'=' * 60}")
|
|
383
|
+
print("Dataset")
|
|
384
|
+
print(f"{'=' * 60}")
|
|
385
|
+
print(f"Examples: {info['num_examples']:,}")
|
|
386
|
+
print(f"Columns: {', '.join(info['columns'])}")
|
|
387
|
+
print(f"{'=' * 60}\n")
|
dalla/core/parallel.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Parallel processing utilities for efficient dataset operations.
|
|
3
|
+
|
|
4
|
+
This module provides utilities for parallel processing of datasets,
|
|
5
|
+
including batch processing, multiprocessing, and progress tracking.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import multiprocessing
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from datasets import Dataset
|
|
13
|
+
from tqdm import tqdm
|
|
14
|
+
|
|
15
|
+
from dalla.utils.logger import get_logger
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ParallelProcessor:
|
|
21
|
+
"""Utility class for parallel dataset processing."""
|
|
22
|
+
|
|
23
|
+
@staticmethod
|
|
24
|
+
def get_optimal_num_workers(num_workers: int | None = None) -> int:
|
|
25
|
+
"""
|
|
26
|
+
Get optimal number of workers for parallel processing.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
num_workers: Requested number of workers (None for auto)
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Optimal number of workers
|
|
33
|
+
"""
|
|
34
|
+
cpu_count = multiprocessing.cpu_count()
|
|
35
|
+
if num_workers is None:
|
|
36
|
+
return max(1, cpu_count - 1)
|
|
37
|
+
return min(num_workers, cpu_count)
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def process_dataset_parallel(
|
|
41
|
+
dataset: Dataset,
|
|
42
|
+
process_fn: Callable,
|
|
43
|
+
num_proc: int | None = None,
|
|
44
|
+
batched: bool = False,
|
|
45
|
+
batch_size: int = 1000,
|
|
46
|
+
desc: str | None = None,
|
|
47
|
+
remove_columns: list[str] | None = None,
|
|
48
|
+
**map_kwargs,
|
|
49
|
+
) -> Dataset:
|
|
50
|
+
"""
|
|
51
|
+
Process a dataset in parallel using the map function.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
dataset: Dataset to process
|
|
55
|
+
process_fn: Function to apply to each example/batch
|
|
56
|
+
num_proc: Number of processes (None for auto)
|
|
57
|
+
batched: Whether to process in batches
|
|
58
|
+
batch_size: Batch size when batched=True
|
|
59
|
+
desc: Description for progress bar
|
|
60
|
+
remove_columns: Columns to remove after processing
|
|
61
|
+
**map_kwargs: Additional arguments for dataset.map()
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Processed dataset
|
|
65
|
+
|
|
66
|
+
Example:
|
|
67
|
+
>>> def process_text(example):
|
|
68
|
+
... example['processed'] = example['text'].lower()
|
|
69
|
+
... return example
|
|
70
|
+
>>> processed = ParallelProcessor.process_dataset_parallel(
|
|
71
|
+
... dataset, process_text, num_proc=4
|
|
72
|
+
... )
|
|
73
|
+
"""
|
|
74
|
+
num_workers = ParallelProcessor.get_optimal_num_workers(num_proc)
|
|
75
|
+
|
|
76
|
+
logger.info(f"Processing dataset with {num_workers} workers")
|
|
77
|
+
logger.info(f"Batched: {batched}, Batch size: {batch_size if batched else 'N/A'}")
|
|
78
|
+
|
|
79
|
+
return dataset.map(
|
|
80
|
+
process_fn,
|
|
81
|
+
num_proc=num_workers,
|
|
82
|
+
batched=batched,
|
|
83
|
+
batch_size=batch_size,
|
|
84
|
+
desc=desc or "Processing dataset",
|
|
85
|
+
remove_columns=remove_columns,
|
|
86
|
+
**map_kwargs,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
@staticmethod
|
|
90
|
+
def process_in_batches(
|
|
91
|
+
dataset: Dataset,
|
|
92
|
+
process_fn: Callable[[list[dict[str, Any]]], list[dict[str, Any]]],
|
|
93
|
+
batch_size: int = 1000,
|
|
94
|
+
desc: str | None = None,
|
|
95
|
+
) -> Dataset:
|
|
96
|
+
"""
|
|
97
|
+
Process dataset in batches with custom function.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
dataset: Dataset to process
|
|
101
|
+
process_fn: Function that takes a list of examples and returns processed list
|
|
102
|
+
batch_size: Size of batches
|
|
103
|
+
desc: Description for progress bar
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Processed dataset
|
|
107
|
+
|
|
108
|
+
Example:
|
|
109
|
+
>>> def batch_process(batch):
|
|
110
|
+
... # Process batch of examples
|
|
111
|
+
... return [{'text': ex['text'].upper()} for ex in batch]
|
|
112
|
+
>>> result = ParallelProcessor.process_in_batches(
|
|
113
|
+
... dataset, batch_process, batch_size=100
|
|
114
|
+
... )
|
|
115
|
+
"""
|
|
116
|
+
logger.info(f"Processing dataset in batches of {batch_size}")
|
|
117
|
+
|
|
118
|
+
processed_examples = []
|
|
119
|
+
total_batches = (len(dataset) + batch_size - 1) // batch_size
|
|
120
|
+
|
|
121
|
+
with tqdm(total=total_batches, desc=desc or "Processing batches") as pbar:
|
|
122
|
+
for i in range(0, len(dataset), batch_size):
|
|
123
|
+
batch = dataset[i : i + batch_size]
|
|
124
|
+
|
|
125
|
+
batch_list = [
|
|
126
|
+
{key: batch[key][j] for key in batch}
|
|
127
|
+
for j in range(len(batch[next(iter(batch))]))
|
|
128
|
+
]
|
|
129
|
+
|
|
130
|
+
processed_batch = process_fn(batch_list)
|
|
131
|
+
processed_examples.extend(processed_batch)
|
|
132
|
+
pbar.update(1)
|
|
133
|
+
|
|
134
|
+
return Dataset.from_list(processed_examples)
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def create_shards(
|
|
138
|
+
dataset: Dataset,
|
|
139
|
+
num_shards: int,
|
|
140
|
+
) -> list[Dataset]:
|
|
141
|
+
"""
|
|
142
|
+
Split dataset into shards for parallel processing.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
dataset: Dataset to shard
|
|
146
|
+
num_shards: Number of shards to create
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
List of dataset shards
|
|
150
|
+
|
|
151
|
+
Example:
|
|
152
|
+
>>> shards = ParallelProcessor.create_shards(dataset, 4)
|
|
153
|
+
>>> # Process each shard independently
|
|
154
|
+
"""
|
|
155
|
+
if num_shards <= 0:
|
|
156
|
+
raise ValueError("num_shards must be positive")
|
|
157
|
+
|
|
158
|
+
total_size = len(dataset)
|
|
159
|
+
shard_size = (total_size + num_shards - 1) // num_shards
|
|
160
|
+
|
|
161
|
+
shards = []
|
|
162
|
+
for i in range(num_shards):
|
|
163
|
+
start_idx = i * shard_size
|
|
164
|
+
end_idx = min(start_idx + shard_size, total_size)
|
|
165
|
+
if start_idx < total_size:
|
|
166
|
+
shard_indices = list(range(start_idx, end_idx))
|
|
167
|
+
shards.append(dataset.select(shard_indices))
|
|
168
|
+
|
|
169
|
+
logger.info(f"Created {len(shards)} shards from dataset of {total_size} examples")
|
|
170
|
+
return shards
|
|
171
|
+
|
|
172
|
+
@staticmethod
|
|
173
|
+
def process_with_multiprocessing(
|
|
174
|
+
items: list[Any],
|
|
175
|
+
process_fn: Callable,
|
|
176
|
+
num_workers: int | None = None,
|
|
177
|
+
desc: str | None = None,
|
|
178
|
+
) -> list[Any]:
|
|
179
|
+
"""
|
|
180
|
+
Process a list of items using multiprocessing.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
items: List of items to process
|
|
184
|
+
process_fn: Function to apply to each item
|
|
185
|
+
num_workers: Number of worker processes
|
|
186
|
+
desc: Description for progress bar
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
List of processed items
|
|
190
|
+
|
|
191
|
+
Example:
|
|
192
|
+
>>> def process_item(x):
|
|
193
|
+
... return x * 2
|
|
194
|
+
>>> results = ParallelProcessor.process_with_multiprocessing(
|
|
195
|
+
... [1, 2, 3, 4], process_item, num_workers=2
|
|
196
|
+
... )
|
|
197
|
+
"""
|
|
198
|
+
num_workers = ParallelProcessor.get_optimal_num_workers(num_workers)
|
|
199
|
+
|
|
200
|
+
logger.info(f"Processing {len(items)} items with {num_workers} workers")
|
|
201
|
+
|
|
202
|
+
if num_workers == 1:
|
|
203
|
+
return [process_fn(item) for item in tqdm(items, desc=desc or "Processing items")]
|
|
204
|
+
|
|
205
|
+
with multiprocessing.Pool(processes=num_workers) as pool:
|
|
206
|
+
results = list(
|
|
207
|
+
tqdm(
|
|
208
|
+
pool.imap(process_fn, items),
|
|
209
|
+
total=len(items),
|
|
210
|
+
desc=desc or "Processing items",
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return results
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class ProgressTracker:
|
|
218
|
+
"""Utility for tracking progress across multiple operations."""
|
|
219
|
+
|
|
220
|
+
def __init__(self, total: int, desc: str | None = None):
|
|
221
|
+
"""
|
|
222
|
+
Initialize progress tracker.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
total: Total number of items to track
|
|
226
|
+
desc: Description for progress bar
|
|
227
|
+
"""
|
|
228
|
+
self.pbar = tqdm(total=total, desc=desc or "Processing")
|
|
229
|
+
self.current = 0
|
|
230
|
+
|
|
231
|
+
def update(self, n: int = 1):
|
|
232
|
+
"""Update progress by n items."""
|
|
233
|
+
self.pbar.update(n)
|
|
234
|
+
self.current += n
|
|
235
|
+
|
|
236
|
+
def set_description(self, desc: str):
|
|
237
|
+
"""Update progress bar description."""
|
|
238
|
+
self.pbar.set_description(desc)
|
|
239
|
+
|
|
240
|
+
def close(self):
|
|
241
|
+
"""Close the progress bar."""
|
|
242
|
+
self.pbar.close()
|
|
243
|
+
|
|
244
|
+
def __enter__(self):
|
|
245
|
+
"""Context manager entry."""
|
|
246
|
+
return self
|
|
247
|
+
|
|
248
|
+
def __exit__(self, *args):
|
|
249
|
+
"""Context manager exit."""
|
|
250
|
+
self.close()
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def batch_iterator(iterable, batch_size: int):
|
|
254
|
+
"""
|
|
255
|
+
Yield batches from an iterable.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
iterable: Any iterable
|
|
259
|
+
batch_size: Size of each batch
|
|
260
|
+
|
|
261
|
+
Yields:
|
|
262
|
+
Batches of items
|
|
263
|
+
|
|
264
|
+
Example:
|
|
265
|
+
>>> for batch in batch_iterator(range(10), batch_size=3):
|
|
266
|
+
... print(batch)
|
|
267
|
+
[0, 1, 2]
|
|
268
|
+
[3, 4, 5]
|
|
269
|
+
[6, 7, 8]
|
|
270
|
+
[9]
|
|
271
|
+
"""
|
|
272
|
+
batch = []
|
|
273
|
+
for item in iterable:
|
|
274
|
+
batch.append(item)
|
|
275
|
+
if len(batch) == batch_size:
|
|
276
|
+
yield batch
|
|
277
|
+
batch = []
|
|
278
|
+
if batch:
|
|
279
|
+
yield batch
|