dalla-data-processing 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. dalla/__init__.py +27 -0
  2. dalla/cli.py +453 -0
  3. dalla/core/__init__.py +6 -0
  4. dalla/core/dataset.py +387 -0
  5. dalla/core/parallel.py +279 -0
  6. dalla/deduplication/__init__.py +370 -0
  7. dalla/deduplication/bin/.gitignore +1 -0
  8. dalla/deduplication/bin/onion-linux-x86_64 +0 -0
  9. dalla/deduplication/onion/COPYING +24 -0
  10. dalla/deduplication/onion/Makefile +21 -0
  11. dalla/deduplication/onion/Makefile.config +3 -0
  12. dalla/deduplication/onion/README.md +21 -0
  13. dalla/deduplication/onion/src/Makefile +22 -0
  14. dalla/deduplication/onion/src/Makefile.g +23 -0
  15. dalla/deduplication/onion/src/buzhash.c +325 -0
  16. dalla/deduplication/onion/src/buzhash.h +30 -0
  17. dalla/deduplication/onion/src/hashdup.c +172 -0
  18. dalla/deduplication/onion/src/hashgen.c +206 -0
  19. dalla/deduplication/onion/src/onion +0 -0
  20. dalla/deduplication/onion/src/onion.c +799 -0
  21. dalla/deduplication/onion/src/onion_dup.c +824 -0
  22. dalla/deduplication/onion/src/version.c +17 -0
  23. dalla/deduplication/onion/src/version.h +10 -0
  24. dalla/deduplication/onion/src_sc/Makefile +22 -0
  25. dalla/deduplication/onion/src_sc/Makefile.g +23 -0
  26. dalla/deduplication/onion/src_sc/buzhash.c +325 -0
  27. dalla/deduplication/onion/src_sc/buzhash.h +30 -0
  28. dalla/deduplication/onion/src_sc/hashdup +0 -0
  29. dalla/deduplication/onion/src_sc/hashdup.c +172 -0
  30. dalla/deduplication/onion/src_sc/hashgen +0 -0
  31. dalla/deduplication/onion/src_sc/hashgen.c +206 -0
  32. dalla/deduplication/onion/src_sc/onion.c +854 -0
  33. dalla/deduplication/onion/src_sc/onion_dup.c +824 -0
  34. dalla/deduplication/onion/src_sc/version.c +17 -0
  35. dalla/deduplication/onion/src_sc/version.h +10 -0
  36. dalla/deduplication/onion_wrapper.py +223 -0
  37. dalla/deduplication/postprocessing.py +216 -0
  38. dalla/deduplication/preprocessing.py +120 -0
  39. dalla/quality/__init__.py +5 -0
  40. dalla/quality/checker.py +354 -0
  41. dalla/readability/__init__.py +197 -0
  42. dalla/readability/ranking.py +165 -0
  43. dalla/readability/scorer.py +148 -0
  44. dalla/stemming/__init__.py +551 -0
  45. dalla/stemming/data/words_al.txt +3414 -0
  46. dalla/stemming/data/words_al_t.txt +885 -0
  47. dalla/stemming/data/words_t.txt +7 -0
  48. dalla/utils/__init__.py +10 -0
  49. dalla/utils/logger.py +128 -0
  50. dalla/utils/tokenize.py +89 -0
  51. dalla_data_processing-0.0.1.dist-info/METADATA +393 -0
  52. dalla_data_processing-0.0.1.dist-info/RECORD +55 -0
  53. dalla_data_processing-0.0.1.dist-info/WHEEL +5 -0
  54. dalla_data_processing-0.0.1.dist-info/entry_points.txt +2 -0
  55. dalla_data_processing-0.0.1.dist-info/top_level.txt +1 -0
dalla/core/dataset.py ADDED
@@ -0,0 +1,387 @@
1
+ """
2
+ Dataset I/O utilities for unified HuggingFace dataset handling.
3
+
4
+ This module provides a consistent interface for loading, saving, and manipulating
5
+ HuggingFace datasets across all dalla-process components.
6
+ """
7
+
8
+ from collections.abc import Callable
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
13
+
14
+ from dalla.utils.logger import get_logger
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ class DatasetManager:
20
+ """Unified manager for HuggingFace dataset operations."""
21
+
22
+ @staticmethod
23
+ def load(
24
+ path: str | Path,
25
+ split: str | None = None,
26
+ streaming: bool = False,
27
+ ) -> Dataset | DatasetDict:
28
+ """
29
+ Load a HuggingFace dataset from disk.
30
+
31
+ Args:
32
+ path: Path to the dataset directory
33
+ split: Optional split name to load (e.g., 'train', 'test')
34
+ streaming: Whether to use streaming mode for large datasets
35
+
36
+ Returns:
37
+ Dataset or DatasetDict depending on the structure
38
+
39
+ Example:
40
+ >>> dm = DatasetManager()
41
+ >>> dataset = dm.load("./data/my_dataset")
42
+ >>> train_data = dm.load("./data/my_dataset", split="train")
43
+ """
44
+ path = Path(path)
45
+ if not path.exists():
46
+ raise FileNotFoundError(f"Dataset path does not exist: {path}")
47
+
48
+ logger.info(f"Loading dataset from {path}")
49
+ dataset = load_from_disk(str(path))
50
+
51
+ if split is not None:
52
+ if isinstance(dataset, DatasetDict):
53
+ if split not in dataset:
54
+ raise ValueError(
55
+ f"Split '{split}' not found. Available splits: {list(dataset.keys())}"
56
+ )
57
+ dataset = dataset[split]
58
+ else:
59
+ logger.warning(f"Split '{split}' specified but dataset has no splits")
60
+
61
+ logger.info(f"Loaded dataset with {DatasetManager.get_size(dataset)} examples")
62
+ return dataset
63
+
64
+ @staticmethod
65
+ def save(
66
+ dataset: Dataset | DatasetDict,
67
+ path: str | Path,
68
+ overwrite: bool = False,
69
+ ) -> None:
70
+ """
71
+ Save a HuggingFace dataset to disk.
72
+
73
+ Args:
74
+ dataset: Dataset or DatasetDict to save
75
+ path: Path where the dataset will be saved
76
+ overwrite: Whether to overwrite existing dataset
77
+
78
+ Example:
79
+ >>> dm = DatasetManager()
80
+ >>> dm.save(processed_dataset, "./data/processed")
81
+ """
82
+ path = Path(path)
83
+
84
+ if path.exists() and not overwrite:
85
+ raise FileExistsError(
86
+ f"Dataset path already exists: {path}. Use overwrite=True to replace."
87
+ )
88
+
89
+ logger.info(f"Saving dataset to {path}")
90
+ dataset.save_to_disk(str(path))
91
+ logger.info("Dataset saved successfully")
92
+
93
+ @staticmethod
94
+ def get_size(dataset: Dataset | DatasetDict) -> int:
95
+ """
96
+ Get the total number of examples in a dataset.
97
+
98
+ Args:
99
+ dataset: Dataset or DatasetDict
100
+
101
+ Returns:
102
+ Total number of examples
103
+ """
104
+ if isinstance(dataset, DatasetDict):
105
+ return sum(len(ds) for ds in dataset.values())
106
+ return len(dataset)
107
+
108
+ @staticmethod
109
+ def get_column_names(dataset: Dataset | DatasetDict) -> list[str]:
110
+ """
111
+ Get column names from a dataset.
112
+
113
+ Args:
114
+ dataset: Dataset or DatasetDict
115
+
116
+ Returns:
117
+ List of column names
118
+ """
119
+ if isinstance(dataset, DatasetDict):
120
+ # Get columns from first split
121
+ first_split = next(iter(dataset.values()))
122
+ return first_split.column_names
123
+ return dataset.column_names
124
+
125
+ @staticmethod
126
+ def add_column(
127
+ dataset: Dataset,
128
+ column_name: str,
129
+ data: list[Any],
130
+ ) -> Dataset:
131
+ """
132
+ Add a new column to a dataset.
133
+
134
+ Args:
135
+ dataset: Dataset to modify
136
+ column_name: Name of the new column
137
+ data: List of values for the new column
138
+
139
+ Returns:
140
+ Dataset with the new column added
141
+
142
+ Example:
143
+ >>> scores = [0.95, 0.87, 0.92, ...]
144
+ >>> dataset = dm.add_column(dataset, "quality_score", scores)
145
+ """
146
+ if len(data) != len(dataset):
147
+ raise ValueError(
148
+ f"Data length ({len(data)}) must match dataset length ({len(dataset)})"
149
+ )
150
+
151
+ logger.info(f"Adding column '{column_name}' to dataset")
152
+ return dataset.add_column(column_name, data)
153
+
154
+ @staticmethod
155
+ def map_column(
156
+ dataset: Dataset,
157
+ fn: Callable,
158
+ input_column: str,
159
+ output_column: str | None = None,
160
+ batched: bool = False,
161
+ batch_size: int = 1000,
162
+ num_proc: int | None = None,
163
+ desc: str | None = None,
164
+ ) -> Dataset:
165
+ """
166
+ Apply a function to a column in the dataset.
167
+
168
+ Args:
169
+ dataset: Dataset to process
170
+ fn: Function to apply to each example
171
+ input_column: Name of the input column
172
+ output_column: Name of the output column (if None, replaces input_column)
173
+ batched: Whether to process in batches
174
+ batch_size: Size of batches when batched=True
175
+ num_proc: Number of processes for parallel processing
176
+ desc: Description for progress bar
177
+
178
+ Returns:
179
+ Processed dataset
180
+
181
+ Example:
182
+ >>> def deduplicate_text(text):
183
+ ... return text.strip().lower()
184
+ >>> dataset = dm.map_column(
185
+ ... dataset,
186
+ ... deduplicate_text,
187
+ ... "text",
188
+ ... "cleaned_text",
189
+ ... num_proc=4
190
+ ... )
191
+ """
192
+ if input_column not in dataset.column_names:
193
+ raise ValueError(f"Column '{input_column}' not found in dataset")
194
+
195
+ output_col = output_column or input_column
196
+
197
+ def process_fn(examples):
198
+ if batched:
199
+ results = [fn(item) for item in examples[input_column]]
200
+ else:
201
+ results = fn(examples[input_column])
202
+ return {output_col: results}
203
+
204
+ logger.info(f"Mapping function to column '{input_column}'")
205
+ return dataset.map(
206
+ process_fn,
207
+ batched=batched,
208
+ batch_size=batch_size,
209
+ num_proc=num_proc,
210
+ desc=desc or f"Processing {input_column}",
211
+ )
212
+
213
+ @staticmethod
214
+ def filter_dataset(
215
+ dataset: Dataset,
216
+ fn: Callable,
217
+ num_proc: int | None = None,
218
+ desc: str | None = None,
219
+ ) -> Dataset:
220
+ """
221
+ Filter dataset based on a condition.
222
+
223
+ Args:
224
+ dataset: Dataset to filter
225
+ fn: Function that returns True for examples to keep
226
+ num_proc: Number of processes for parallel processing
227
+ desc: Description for progress bar
228
+
229
+ Returns:
230
+ Filtered dataset
231
+
232
+ Example:
233
+ >>> def is_high_quality(example):
234
+ ... return example['quality_score'] > 0.8
235
+ >>> filtered = dm.filter_dataset(dataset, is_high_quality)
236
+ """
237
+ logger.info(f"Filtering dataset with {len(dataset)} examples")
238
+ filtered = dataset.filter(fn, num_proc=num_proc, desc=desc or "Filtering dataset")
239
+ logger.info(
240
+ f"Filtered to {len(filtered)} examples ({len(filtered) / len(dataset) * 100:.1f}%)"
241
+ )
242
+ return filtered
243
+
244
+ @staticmethod
245
+ def select_columns(
246
+ dataset: Dataset | DatasetDict,
247
+ columns: list[str],
248
+ ) -> Dataset | DatasetDict:
249
+ """
250
+ Select specific columns from a dataset.
251
+
252
+ Args:
253
+ dataset: Dataset or DatasetDict
254
+ columns: List of column names to keep
255
+
256
+ Returns:
257
+ Dataset with only the specified columns
258
+ """
259
+ available_columns = DatasetManager.get_column_names(dataset)
260
+ invalid_columns = set(columns) - set(available_columns)
261
+ if invalid_columns:
262
+ raise ValueError(f"Columns not found: {invalid_columns}")
263
+
264
+ logger.info(f"Selecting columns: {columns}")
265
+ if isinstance(dataset, DatasetDict):
266
+ return DatasetDict({split: ds.select_columns(columns) for split, ds in dataset.items()})
267
+ return dataset.select_columns(columns)
268
+
269
+ @staticmethod
270
+ def remove_columns(
271
+ dataset: Dataset | DatasetDict,
272
+ columns: list[str],
273
+ ) -> Dataset | DatasetDict:
274
+ """
275
+ Remove specific columns from a dataset.
276
+
277
+ Args:
278
+ dataset: Dataset or DatasetDict
279
+ columns: List of column names to remove
280
+
281
+ Returns:
282
+ Dataset without the specified columns
283
+ """
284
+ logger.info(f"Removing columns: {columns}")
285
+ if isinstance(dataset, DatasetDict):
286
+ return DatasetDict({split: ds.remove_columns(columns) for split, ds in dataset.items()})
287
+ return dataset.remove_columns(columns)
288
+
289
+ @staticmethod
290
+ def concatenate(datasets: list[Dataset]) -> Dataset:
291
+ """
292
+ Concatenate multiple datasets.
293
+
294
+ Args:
295
+ datasets: List of datasets to concatenate
296
+
297
+ Returns:
298
+ Concatenated dataset
299
+ """
300
+ if not datasets:
301
+ raise ValueError("Cannot concatenate empty list of datasets")
302
+
303
+ logger.info(f"Concatenating {len(datasets)} datasets")
304
+ return concatenate_datasets(datasets)
305
+
306
+ @staticmethod
307
+ def train_test_split(
308
+ dataset: Dataset,
309
+ test_size: float = 0.1,
310
+ seed: int = 42,
311
+ ) -> DatasetDict:
312
+ """
313
+ Split dataset into train and test sets.
314
+
315
+ Args:
316
+ dataset: Dataset to split
317
+ test_size: Fraction of data to use for testing
318
+ seed: Random seed for reproducibility
319
+
320
+ Returns:
321
+ DatasetDict with 'train' and 'test' splits
322
+ """
323
+ logger.info(f"Splitting dataset into train/test (test_size={test_size})")
324
+ return dataset.train_test_split(test_size=test_size, seed=seed)
325
+
326
+ @staticmethod
327
+ def get_info(dataset: Dataset | DatasetDict) -> dict[str, Any]:
328
+ """
329
+ Get information about a dataset.
330
+
331
+ Args:
332
+ dataset: Dataset or DatasetDict
333
+
334
+ Returns:
335
+ Dictionary with dataset information
336
+ """
337
+ if isinstance(dataset, DatasetDict):
338
+ return {
339
+ "type": "DatasetDict",
340
+ "splits": list(dataset.keys()),
341
+ "total_examples": DatasetManager.get_size(dataset),
342
+ "split_info": {
343
+ split: {
344
+ "num_examples": len(ds),
345
+ "columns": ds.column_names,
346
+ "features": str(ds.features),
347
+ }
348
+ for split, ds in dataset.items()
349
+ },
350
+ }
351
+ else:
352
+ return {
353
+ "type": "Dataset",
354
+ "num_examples": len(dataset),
355
+ "columns": dataset.column_names,
356
+ "features": str(dataset.features),
357
+ }
358
+
359
+ @staticmethod
360
+ def print_info(dataset: Dataset | DatasetDict) -> None:
361
+ """
362
+ Print dataset information in a readable format.
363
+
364
+ Args:
365
+ dataset: Dataset or DatasetDict
366
+ """
367
+ info = DatasetManager.get_info(dataset)
368
+
369
+ if info["type"] == "DatasetDict":
370
+ print(f"\n{'=' * 60}")
371
+ print("Dataset Dictionary")
372
+ print(f"{'=' * 60}")
373
+ print(f"Total examples: {info['total_examples']:,}")
374
+ print(f"Splits: {', '.join(info['splits'])}")
375
+ print()
376
+ for split, split_info in info["split_info"].items():
377
+ print(f" {split}:")
378
+ print(f" Examples: {split_info['num_examples']:,}")
379
+ print(f" Columns: {', '.join(split_info['columns'])}")
380
+ print(f"{'=' * 60}\n")
381
+ else:
382
+ print(f"\n{'=' * 60}")
383
+ print("Dataset")
384
+ print(f"{'=' * 60}")
385
+ print(f"Examples: {info['num_examples']:,}")
386
+ print(f"Columns: {', '.join(info['columns'])}")
387
+ print(f"{'=' * 60}\n")
dalla/core/parallel.py ADDED
@@ -0,0 +1,279 @@
1
+ """
2
+ Parallel processing utilities for efficient dataset operations.
3
+
4
+ This module provides utilities for parallel processing of datasets,
5
+ including batch processing, multiprocessing, and progress tracking.
6
+ """
7
+
8
+ import multiprocessing
9
+ from collections.abc import Callable
10
+ from typing import Any
11
+
12
+ from datasets import Dataset
13
+ from tqdm import tqdm
14
+
15
+ from dalla.utils.logger import get_logger
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ class ParallelProcessor:
21
+ """Utility class for parallel dataset processing."""
22
+
23
+ @staticmethod
24
+ def get_optimal_num_workers(num_workers: int | None = None) -> int:
25
+ """
26
+ Get optimal number of workers for parallel processing.
27
+
28
+ Args:
29
+ num_workers: Requested number of workers (None for auto)
30
+
31
+ Returns:
32
+ Optimal number of workers
33
+ """
34
+ cpu_count = multiprocessing.cpu_count()
35
+ if num_workers is None:
36
+ return max(1, cpu_count - 1)
37
+ return min(num_workers, cpu_count)
38
+
39
+ @staticmethod
40
+ def process_dataset_parallel(
41
+ dataset: Dataset,
42
+ process_fn: Callable,
43
+ num_proc: int | None = None,
44
+ batched: bool = False,
45
+ batch_size: int = 1000,
46
+ desc: str | None = None,
47
+ remove_columns: list[str] | None = None,
48
+ **map_kwargs,
49
+ ) -> Dataset:
50
+ """
51
+ Process a dataset in parallel using the map function.
52
+
53
+ Args:
54
+ dataset: Dataset to process
55
+ process_fn: Function to apply to each example/batch
56
+ num_proc: Number of processes (None for auto)
57
+ batched: Whether to process in batches
58
+ batch_size: Batch size when batched=True
59
+ desc: Description for progress bar
60
+ remove_columns: Columns to remove after processing
61
+ **map_kwargs: Additional arguments for dataset.map()
62
+
63
+ Returns:
64
+ Processed dataset
65
+
66
+ Example:
67
+ >>> def process_text(example):
68
+ ... example['processed'] = example['text'].lower()
69
+ ... return example
70
+ >>> processed = ParallelProcessor.process_dataset_parallel(
71
+ ... dataset, process_text, num_proc=4
72
+ ... )
73
+ """
74
+ num_workers = ParallelProcessor.get_optimal_num_workers(num_proc)
75
+
76
+ logger.info(f"Processing dataset with {num_workers} workers")
77
+ logger.info(f"Batched: {batched}, Batch size: {batch_size if batched else 'N/A'}")
78
+
79
+ return dataset.map(
80
+ process_fn,
81
+ num_proc=num_workers,
82
+ batched=batched,
83
+ batch_size=batch_size,
84
+ desc=desc or "Processing dataset",
85
+ remove_columns=remove_columns,
86
+ **map_kwargs,
87
+ )
88
+
89
+ @staticmethod
90
+ def process_in_batches(
91
+ dataset: Dataset,
92
+ process_fn: Callable[[list[dict[str, Any]]], list[dict[str, Any]]],
93
+ batch_size: int = 1000,
94
+ desc: str | None = None,
95
+ ) -> Dataset:
96
+ """
97
+ Process dataset in batches with custom function.
98
+
99
+ Args:
100
+ dataset: Dataset to process
101
+ process_fn: Function that takes a list of examples and returns processed list
102
+ batch_size: Size of batches
103
+ desc: Description for progress bar
104
+
105
+ Returns:
106
+ Processed dataset
107
+
108
+ Example:
109
+ >>> def batch_process(batch):
110
+ ... # Process batch of examples
111
+ ... return [{'text': ex['text'].upper()} for ex in batch]
112
+ >>> result = ParallelProcessor.process_in_batches(
113
+ ... dataset, batch_process, batch_size=100
114
+ ... )
115
+ """
116
+ logger.info(f"Processing dataset in batches of {batch_size}")
117
+
118
+ processed_examples = []
119
+ total_batches = (len(dataset) + batch_size - 1) // batch_size
120
+
121
+ with tqdm(total=total_batches, desc=desc or "Processing batches") as pbar:
122
+ for i in range(0, len(dataset), batch_size):
123
+ batch = dataset[i : i + batch_size]
124
+
125
+ batch_list = [
126
+ {key: batch[key][j] for key in batch}
127
+ for j in range(len(batch[next(iter(batch))]))
128
+ ]
129
+
130
+ processed_batch = process_fn(batch_list)
131
+ processed_examples.extend(processed_batch)
132
+ pbar.update(1)
133
+
134
+ return Dataset.from_list(processed_examples)
135
+
136
+ @staticmethod
137
+ def create_shards(
138
+ dataset: Dataset,
139
+ num_shards: int,
140
+ ) -> list[Dataset]:
141
+ """
142
+ Split dataset into shards for parallel processing.
143
+
144
+ Args:
145
+ dataset: Dataset to shard
146
+ num_shards: Number of shards to create
147
+
148
+ Returns:
149
+ List of dataset shards
150
+
151
+ Example:
152
+ >>> shards = ParallelProcessor.create_shards(dataset, 4)
153
+ >>> # Process each shard independently
154
+ """
155
+ if num_shards <= 0:
156
+ raise ValueError("num_shards must be positive")
157
+
158
+ total_size = len(dataset)
159
+ shard_size = (total_size + num_shards - 1) // num_shards
160
+
161
+ shards = []
162
+ for i in range(num_shards):
163
+ start_idx = i * shard_size
164
+ end_idx = min(start_idx + shard_size, total_size)
165
+ if start_idx < total_size:
166
+ shard_indices = list(range(start_idx, end_idx))
167
+ shards.append(dataset.select(shard_indices))
168
+
169
+ logger.info(f"Created {len(shards)} shards from dataset of {total_size} examples")
170
+ return shards
171
+
172
+ @staticmethod
173
+ def process_with_multiprocessing(
174
+ items: list[Any],
175
+ process_fn: Callable,
176
+ num_workers: int | None = None,
177
+ desc: str | None = None,
178
+ ) -> list[Any]:
179
+ """
180
+ Process a list of items using multiprocessing.
181
+
182
+ Args:
183
+ items: List of items to process
184
+ process_fn: Function to apply to each item
185
+ num_workers: Number of worker processes
186
+ desc: Description for progress bar
187
+
188
+ Returns:
189
+ List of processed items
190
+
191
+ Example:
192
+ >>> def process_item(x):
193
+ ... return x * 2
194
+ >>> results = ParallelProcessor.process_with_multiprocessing(
195
+ ... [1, 2, 3, 4], process_item, num_workers=2
196
+ ... )
197
+ """
198
+ num_workers = ParallelProcessor.get_optimal_num_workers(num_workers)
199
+
200
+ logger.info(f"Processing {len(items)} items with {num_workers} workers")
201
+
202
+ if num_workers == 1:
203
+ return [process_fn(item) for item in tqdm(items, desc=desc or "Processing items")]
204
+
205
+ with multiprocessing.Pool(processes=num_workers) as pool:
206
+ results = list(
207
+ tqdm(
208
+ pool.imap(process_fn, items),
209
+ total=len(items),
210
+ desc=desc or "Processing items",
211
+ )
212
+ )
213
+
214
+ return results
215
+
216
+
217
+ class ProgressTracker:
218
+ """Utility for tracking progress across multiple operations."""
219
+
220
+ def __init__(self, total: int, desc: str | None = None):
221
+ """
222
+ Initialize progress tracker.
223
+
224
+ Args:
225
+ total: Total number of items to track
226
+ desc: Description for progress bar
227
+ """
228
+ self.pbar = tqdm(total=total, desc=desc or "Processing")
229
+ self.current = 0
230
+
231
+ def update(self, n: int = 1):
232
+ """Update progress by n items."""
233
+ self.pbar.update(n)
234
+ self.current += n
235
+
236
+ def set_description(self, desc: str):
237
+ """Update progress bar description."""
238
+ self.pbar.set_description(desc)
239
+
240
+ def close(self):
241
+ """Close the progress bar."""
242
+ self.pbar.close()
243
+
244
+ def __enter__(self):
245
+ """Context manager entry."""
246
+ return self
247
+
248
+ def __exit__(self, *args):
249
+ """Context manager exit."""
250
+ self.close()
251
+
252
+
253
+ def batch_iterator(iterable, batch_size: int):
254
+ """
255
+ Yield batches from an iterable.
256
+
257
+ Args:
258
+ iterable: Any iterable
259
+ batch_size: Size of each batch
260
+
261
+ Yields:
262
+ Batches of items
263
+
264
+ Example:
265
+ >>> for batch in batch_iterator(range(10), batch_size=3):
266
+ ... print(batch)
267
+ [0, 1, 2]
268
+ [3, 4, 5]
269
+ [6, 7, 8]
270
+ [9]
271
+ """
272
+ batch = []
273
+ for item in iterable:
274
+ batch.append(item)
275
+ if len(batch) == batch_size:
276
+ yield batch
277
+ batch = []
278
+ if batch:
279
+ yield batch