speedy-utils 1.1.46__py3-none-any.whl → 1.1.48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,203 @@
1
+ """
2
+ Dataset sharding utilities for parallel processing with merge.
3
+
4
+ This module provides utilities for processing large HuggingFace datasets in parallel
5
+ by sharding them across workers, processing each shard independently, and then
6
+ merging the results back together.
7
+ """
8
+
9
+ import os
10
+ import shutil
11
+ import logging
12
+ from pathlib import Path
13
+ from typing import Callable, Optional, Any, Dict
14
+ from datasets import Dataset, concatenate_datasets, load_from_disk
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ __all__ = ['multi_process_dataset']
19
+
20
+
21
+ def multi_process_dataset(
22
+ dataset: Dataset,
23
+ process_func: Callable,
24
+ output_path: str,
25
+ process_func_kwargs: Optional[Dict[str, Any]] = None,
26
+ num_workers: Optional[int] = None,
27
+ seed: Optional[int] = None,
28
+ debug: bool = False,
29
+ debug_size: int = 10000,
30
+ backend: str = 'ray'
31
+ ) -> str:
32
+ """
33
+ Process a dataset in parallel using sharding and multiprocessing.
34
+
35
+ This function implements the shard-process-merge pattern for large dataset processing:
36
+ 1. Optionally shuffle and truncate dataset (for debugging)
37
+ 2. Shard dataset across workers
38
+ 3. Process each shard in parallel
39
+ 4. Merge results and save to final location
40
+ 5. Clean up temporary shard files
41
+
42
+ Args:
43
+ dataset: The input dataset to process
44
+ process_func: Function to apply to each shard. Should take dataset as first argument
45
+ and output_path as second argument, plus any additional kwargs.
46
+ Must return the path to the processed shard.
47
+ output_path: Base path for output (without extension or size suffix)
48
+ process_func_kwargs: Additional keyword arguments to pass to process_func
49
+ num_workers: Number of parallel workers (default: CPU count - 2)
50
+ seed: Random seed for shuffling (if None, no shuffling)
51
+ debug: If True, process only a subset of data for debugging
52
+ debug_size: Number of examples to use in debug mode
53
+ backend: Backend for multiprocessing ('ray' or 'process')
54
+
55
+ Returns:
56
+ str: Path to the final merged dataset
57
+
58
+ Example:
59
+ def process_shard(shard_dataset, output_path, tokenizer_path, seq_len):
60
+ # Process the shard (tokenize, pack, etc.)
61
+ packer = SFTDatasetPacker(tokenizer_path, seq_len)
62
+ return packer.pack(shard_dataset, output_path)
63
+
64
+ final_path = multi_process_dataset(
65
+ dataset=dataset,
66
+ process_func=process_shard,
67
+ output_path="./data/processed",
68
+ process_func_kwargs={
69
+ 'tokenizer_path': 'tokenizers/Qwen3-32B',
70
+ 'seq_len': 12288
71
+ },
72
+ num_workers=32,
73
+ seed=42,
74
+ debug=True
75
+ )
76
+ """
77
+ from ..multi_worker import multi_process
78
+
79
+ # Determine number of workers
80
+ if num_workers is None:
81
+ num_workers = max(1, (os.cpu_count() or 1) - 2)
82
+
83
+ # Shuffle dataset if seed is provided
84
+ if seed is not None:
85
+ logger.info(f"Shuffling dataset with seed={seed}")
86
+ dataset = dataset.shuffle(seed=seed)
87
+
88
+ # Debug mode: truncate dataset
89
+ if debug:
90
+ dataset = dataset.select(range(min(debug_size, len(dataset))))
91
+ logger.info(f"Debug mode: using only {len(dataset)} examples")
92
+
93
+ # Prepare arguments for each shard
94
+ list_args = []
95
+ for shard_idx in range(num_workers):
96
+ out = f'{output_path}_shard{shard_idx}_of_{num_workers}'.replace('/', '_')
97
+ dst = f".cache/{out}"
98
+
99
+ # Prepare function arguments
100
+ func_args = {
101
+ 'dataset': dataset,
102
+ 'shard_idx': shard_idx,
103
+ 'total_shards': num_workers,
104
+ 'output_path': dst,
105
+ }
106
+
107
+ # Add the process function
108
+ func_args['process_func'] = process_func
109
+
110
+ # Add additional kwargs
111
+ if process_func_kwargs:
112
+ func_args.update(process_func_kwargs)
113
+
114
+ list_args.append(func_args)
115
+
116
+ # Process shards in parallel
117
+ total_items = len(dataset)
118
+ logger.info(f"Processing {total_items:,} examples using {num_workers} workers...")
119
+
120
+ # Enable item-level progress tracking for Ray backend
121
+ multi_process_kwargs = {
122
+ 'workers': num_workers,
123
+ 'backend': backend,
124
+ }
125
+ if backend == 'ray':
126
+ multi_process_kwargs['total_items'] = total_items
127
+ multi_process_kwargs['desc'] = f"Processing {total_items:,} items"
128
+ multi_process_kwargs['poll_interval'] = 0.3
129
+
130
+ output_paths = multi_process(
131
+ _process_shard_wrapper,
132
+ list_args,
133
+ **multi_process_kwargs
134
+ )
135
+
136
+ # Concatenate shards
137
+ logger.info("Merging shards...")
138
+ tmp_shards = [load_from_disk(p) for p in output_paths]
139
+ merged_dataset = concatenate_datasets(tmp_shards)
140
+
141
+ # Save final dataset
142
+ final_size = len(merged_dataset)
143
+ final_name = f"{Path(output_path).name}_size{final_size}"
144
+ final_dst = Path(output_path).parent / final_name
145
+
146
+ if final_dst.exists():
147
+ logger.warning(f"Removing existing dataset: {final_dst}")
148
+ shutil.rmtree(final_dst)
149
+
150
+ logger.info(f"Saving final merged dataset to: {final_dst}")
151
+ merged_dataset.save_to_disk(str(final_dst))
152
+
153
+ # Cleanup temporary shards
154
+ for p in output_paths:
155
+ logger.info(f'Removing temporary shard: {p}')
156
+ shutil.rmtree(p)
157
+
158
+ logger.info(f"✅ Successfully processed dataset: {final_dst}")
159
+ return str(final_dst)
160
+
161
+
162
+ def _process_shard_wrapper(args: Dict[str, Any]) -> str:
163
+ """
164
+ Wrapper function for processing a single shard.
165
+
166
+ This wrapper extracts the dataset, shards it, and passes it to the user-provided
167
+ process function along with any additional arguments.
168
+
169
+ Args:
170
+ args: Dictionary containing:
171
+ - dataset: The full dataset
172
+ - shard_idx: Index of the current shard
173
+ - total_shards: Total number of shards
174
+ - output_path: Path to save the processed shard
175
+ - process_func: The function to apply to the shard
176
+ - Additional arguments for the process function
177
+
178
+ Returns:
179
+ str: Path to the processed shard
180
+
181
+ Note:
182
+ Progress tracking is automatically available via report_progress()
183
+ when using Ray backend with item-level tracking enabled.
184
+ """
185
+ from datasets import Dataset
186
+
187
+ # Extract core parameters
188
+ dataset = args.pop('dataset')
189
+ shard_idx = args.pop('shard_idx')
190
+ total_shards = args.pop('total_shards')
191
+ output_path = args.pop('output_path')
192
+ process_func = args.pop('process_func')
193
+
194
+ # Remove progress_actor from args (it's in thread-local context now)
195
+ args.pop('progress_actor', None)
196
+
197
+ # Shard the dataset (HF datasets.shard() is memory-efficient)
198
+ shard = dataset.shard(num_shards=total_shards, index=shard_idx)
199
+ logger.info(f"Processing shard {shard_idx+1}/{total_shards} with {len(shard)} examples")
200
+
201
+ # Process the shard with remaining kwargs
202
+ # User code can call report_progress() directly for centralized tracking
203
+ return process_func(shard, output_path, **args)