speedy-utils 1.1.46__py3-none-any.whl → 1.1.48__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +1 -3
- llm_utils/chat_format/__init__.py +0 -2
- llm_utils/chat_format/display.py +283 -364
- llm_utils/lm/llm.py +62 -22
- speedy_utils/__init__.py +4 -0
- speedy_utils/multi_worker/__init__.py +4 -0
- speedy_utils/multi_worker/_multi_process.py +425 -0
- speedy_utils/multi_worker/_multi_process_ray.py +308 -0
- speedy_utils/multi_worker/common.py +879 -0
- speedy_utils/multi_worker/dataset_sharding.py +203 -0
- speedy_utils/multi_worker/process.py +53 -1234
- speedy_utils/multi_worker/progress.py +71 -1
- speedy_utils/multi_worker/thread.py +45 -0
- speedy_utils/scripts/mpython.py +19 -12
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.48.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.48.dist-info}/RECORD +18 -14
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.48.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.48.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dataset sharding utilities for parallel processing with merge.
|
|
3
|
+
|
|
4
|
+
This module provides utilities for processing large HuggingFace datasets in parallel
|
|
5
|
+
by sharding them across workers, processing each shard independently, and then
|
|
6
|
+
merging the results back together.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import shutil
|
|
11
|
+
import logging
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Callable, Optional, Any, Dict
|
|
14
|
+
from datasets import Dataset, concatenate_datasets, load_from_disk
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
__all__ = ['multi_process_dataset']
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def multi_process_dataset(
|
|
22
|
+
dataset: Dataset,
|
|
23
|
+
process_func: Callable,
|
|
24
|
+
output_path: str,
|
|
25
|
+
process_func_kwargs: Optional[Dict[str, Any]] = None,
|
|
26
|
+
num_workers: Optional[int] = None,
|
|
27
|
+
seed: Optional[int] = None,
|
|
28
|
+
debug: bool = False,
|
|
29
|
+
debug_size: int = 10000,
|
|
30
|
+
backend: str = 'ray'
|
|
31
|
+
) -> str:
|
|
32
|
+
"""
|
|
33
|
+
Process a dataset in parallel using sharding and multiprocessing.
|
|
34
|
+
|
|
35
|
+
This function implements the shard-process-merge pattern for large dataset processing:
|
|
36
|
+
1. Optionally shuffle and truncate dataset (for debugging)
|
|
37
|
+
2. Shard dataset across workers
|
|
38
|
+
3. Process each shard in parallel
|
|
39
|
+
4. Merge results and save to final location
|
|
40
|
+
5. Clean up temporary shard files
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
dataset: The input dataset to process
|
|
44
|
+
process_func: Function to apply to each shard. Should take dataset as first argument
|
|
45
|
+
and output_path as second argument, plus any additional kwargs.
|
|
46
|
+
Must return the path to the processed shard.
|
|
47
|
+
output_path: Base path for output (without extension or size suffix)
|
|
48
|
+
process_func_kwargs: Additional keyword arguments to pass to process_func
|
|
49
|
+
num_workers: Number of parallel workers (default: CPU count - 2)
|
|
50
|
+
seed: Random seed for shuffling (if None, no shuffling)
|
|
51
|
+
debug: If True, process only a subset of data for debugging
|
|
52
|
+
debug_size: Number of examples to use in debug mode
|
|
53
|
+
backend: Backend for multiprocessing ('ray' or 'process')
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
str: Path to the final merged dataset
|
|
57
|
+
|
|
58
|
+
Example:
|
|
59
|
+
def process_shard(shard_dataset, output_path, tokenizer_path, seq_len):
|
|
60
|
+
# Process the shard (tokenize, pack, etc.)
|
|
61
|
+
packer = SFTDatasetPacker(tokenizer_path, seq_len)
|
|
62
|
+
return packer.pack(shard_dataset, output_path)
|
|
63
|
+
|
|
64
|
+
final_path = multi_process_dataset(
|
|
65
|
+
dataset=dataset,
|
|
66
|
+
process_func=process_shard,
|
|
67
|
+
output_path="./data/processed",
|
|
68
|
+
process_func_kwargs={
|
|
69
|
+
'tokenizer_path': 'tokenizers/Qwen3-32B',
|
|
70
|
+
'seq_len': 12288
|
|
71
|
+
},
|
|
72
|
+
num_workers=32,
|
|
73
|
+
seed=42,
|
|
74
|
+
debug=True
|
|
75
|
+
)
|
|
76
|
+
"""
|
|
77
|
+
from ..multi_worker import multi_process
|
|
78
|
+
|
|
79
|
+
# Determine number of workers
|
|
80
|
+
if num_workers is None:
|
|
81
|
+
num_workers = max(1, (os.cpu_count() or 1) - 2)
|
|
82
|
+
|
|
83
|
+
# Shuffle dataset if seed is provided
|
|
84
|
+
if seed is not None:
|
|
85
|
+
logger.info(f"Shuffling dataset with seed={seed}")
|
|
86
|
+
dataset = dataset.shuffle(seed=seed)
|
|
87
|
+
|
|
88
|
+
# Debug mode: truncate dataset
|
|
89
|
+
if debug:
|
|
90
|
+
dataset = dataset.select(range(min(debug_size, len(dataset))))
|
|
91
|
+
logger.info(f"Debug mode: using only {len(dataset)} examples")
|
|
92
|
+
|
|
93
|
+
# Prepare arguments for each shard
|
|
94
|
+
list_args = []
|
|
95
|
+
for shard_idx in range(num_workers):
|
|
96
|
+
out = f'{output_path}_shard{shard_idx}_of_{num_workers}'.replace('/', '_')
|
|
97
|
+
dst = f".cache/{out}"
|
|
98
|
+
|
|
99
|
+
# Prepare function arguments
|
|
100
|
+
func_args = {
|
|
101
|
+
'dataset': dataset,
|
|
102
|
+
'shard_idx': shard_idx,
|
|
103
|
+
'total_shards': num_workers,
|
|
104
|
+
'output_path': dst,
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
# Add the process function
|
|
108
|
+
func_args['process_func'] = process_func
|
|
109
|
+
|
|
110
|
+
# Add additional kwargs
|
|
111
|
+
if process_func_kwargs:
|
|
112
|
+
func_args.update(process_func_kwargs)
|
|
113
|
+
|
|
114
|
+
list_args.append(func_args)
|
|
115
|
+
|
|
116
|
+
# Process shards in parallel
|
|
117
|
+
total_items = len(dataset)
|
|
118
|
+
logger.info(f"Processing {total_items:,} examples using {num_workers} workers...")
|
|
119
|
+
|
|
120
|
+
# Enable item-level progress tracking for Ray backend
|
|
121
|
+
multi_process_kwargs = {
|
|
122
|
+
'workers': num_workers,
|
|
123
|
+
'backend': backend,
|
|
124
|
+
}
|
|
125
|
+
if backend == 'ray':
|
|
126
|
+
multi_process_kwargs['total_items'] = total_items
|
|
127
|
+
multi_process_kwargs['desc'] = f"Processing {total_items:,} items"
|
|
128
|
+
multi_process_kwargs['poll_interval'] = 0.3
|
|
129
|
+
|
|
130
|
+
output_paths = multi_process(
|
|
131
|
+
_process_shard_wrapper,
|
|
132
|
+
list_args,
|
|
133
|
+
**multi_process_kwargs
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Concatenate shards
|
|
137
|
+
logger.info("Merging shards...")
|
|
138
|
+
tmp_shards = [load_from_disk(p) for p in output_paths]
|
|
139
|
+
merged_dataset = concatenate_datasets(tmp_shards)
|
|
140
|
+
|
|
141
|
+
# Save final dataset
|
|
142
|
+
final_size = len(merged_dataset)
|
|
143
|
+
final_name = f"{Path(output_path).name}_size{final_size}"
|
|
144
|
+
final_dst = Path(output_path).parent / final_name
|
|
145
|
+
|
|
146
|
+
if final_dst.exists():
|
|
147
|
+
logger.warning(f"Removing existing dataset: {final_dst}")
|
|
148
|
+
shutil.rmtree(final_dst)
|
|
149
|
+
|
|
150
|
+
logger.info(f"Saving final merged dataset to: {final_dst}")
|
|
151
|
+
merged_dataset.save_to_disk(str(final_dst))
|
|
152
|
+
|
|
153
|
+
# Cleanup temporary shards
|
|
154
|
+
for p in output_paths:
|
|
155
|
+
logger.info(f'Removing temporary shard: {p}')
|
|
156
|
+
shutil.rmtree(p)
|
|
157
|
+
|
|
158
|
+
logger.info(f"✅ Successfully processed dataset: {final_dst}")
|
|
159
|
+
return str(final_dst)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _process_shard_wrapper(args: Dict[str, Any]) -> str:
|
|
163
|
+
"""
|
|
164
|
+
Wrapper function for processing a single shard.
|
|
165
|
+
|
|
166
|
+
This wrapper extracts the dataset, shards it, and passes it to the user-provided
|
|
167
|
+
process function along with any additional arguments.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
args: Dictionary containing:
|
|
171
|
+
- dataset: The full dataset
|
|
172
|
+
- shard_idx: Index of the current shard
|
|
173
|
+
- total_shards: Total number of shards
|
|
174
|
+
- output_path: Path to save the processed shard
|
|
175
|
+
- process_func: The function to apply to the shard
|
|
176
|
+
- Additional arguments for the process function
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
str: Path to the processed shard
|
|
180
|
+
|
|
181
|
+
Note:
|
|
182
|
+
Progress tracking is automatically available via report_progress()
|
|
183
|
+
when using Ray backend with item-level tracking enabled.
|
|
184
|
+
"""
|
|
185
|
+
from datasets import Dataset
|
|
186
|
+
|
|
187
|
+
# Extract core parameters
|
|
188
|
+
dataset = args.pop('dataset')
|
|
189
|
+
shard_idx = args.pop('shard_idx')
|
|
190
|
+
total_shards = args.pop('total_shards')
|
|
191
|
+
output_path = args.pop('output_path')
|
|
192
|
+
process_func = args.pop('process_func')
|
|
193
|
+
|
|
194
|
+
# Remove progress_actor from args (it's in thread-local context now)
|
|
195
|
+
args.pop('progress_actor', None)
|
|
196
|
+
|
|
197
|
+
# Shard the dataset (HF datasets.shard() is memory-efficient)
|
|
198
|
+
shard = dataset.shard(num_shards=total_shards, index=shard_idx)
|
|
199
|
+
logger.info(f"Processing shard {shard_idx+1}/{total_shards} with {len(shard)} examples")
|
|
200
|
+
|
|
201
|
+
# Process the shard with remaining kwargs
|
|
202
|
+
# User code can call report_progress() directly for centralized tracking
|
|
203
|
+
return process_func(shard, output_path, **args)
|