guidellm 0.3.1__py3-none-any.whl → 0.6.0a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guidellm/__init__.py +5 -2
- guidellm/__main__.py +524 -255
- guidellm/backends/__init__.py +33 -0
- guidellm/backends/backend.py +109 -0
- guidellm/backends/openai.py +340 -0
- guidellm/backends/response_handlers.py +428 -0
- guidellm/benchmark/__init__.py +69 -39
- guidellm/benchmark/benchmarker.py +160 -316
- guidellm/benchmark/entrypoints.py +560 -127
- guidellm/benchmark/outputs/__init__.py +24 -0
- guidellm/benchmark/outputs/console.py +633 -0
- guidellm/benchmark/outputs/csv.py +721 -0
- guidellm/benchmark/outputs/html.py +473 -0
- guidellm/benchmark/outputs/output.py +169 -0
- guidellm/benchmark/outputs/serialized.py +69 -0
- guidellm/benchmark/profiles.py +718 -0
- guidellm/benchmark/progress.py +553 -556
- guidellm/benchmark/scenarios/__init__.py +40 -0
- guidellm/benchmark/scenarios/chat.json +6 -0
- guidellm/benchmark/scenarios/rag.json +6 -0
- guidellm/benchmark/schemas/__init__.py +66 -0
- guidellm/benchmark/schemas/base.py +402 -0
- guidellm/benchmark/schemas/generative/__init__.py +55 -0
- guidellm/benchmark/schemas/generative/accumulator.py +841 -0
- guidellm/benchmark/schemas/generative/benchmark.py +163 -0
- guidellm/benchmark/schemas/generative/entrypoints.py +381 -0
- guidellm/benchmark/schemas/generative/metrics.py +927 -0
- guidellm/benchmark/schemas/generative/report.py +158 -0
- guidellm/data/__init__.py +34 -4
- guidellm/data/builders.py +541 -0
- guidellm/data/collators.py +16 -0
- guidellm/data/config.py +120 -0
- guidellm/data/deserializers/__init__.py +49 -0
- guidellm/data/deserializers/deserializer.py +141 -0
- guidellm/data/deserializers/file.py +223 -0
- guidellm/data/deserializers/huggingface.py +94 -0
- guidellm/data/deserializers/memory.py +194 -0
- guidellm/data/deserializers/synthetic.py +246 -0
- guidellm/data/entrypoints.py +52 -0
- guidellm/data/loaders.py +190 -0
- guidellm/data/preprocessors/__init__.py +27 -0
- guidellm/data/preprocessors/formatters.py +410 -0
- guidellm/data/preprocessors/mappers.py +196 -0
- guidellm/data/preprocessors/preprocessor.py +30 -0
- guidellm/data/processor.py +29 -0
- guidellm/data/schemas.py +175 -0
- guidellm/data/utils/__init__.py +6 -0
- guidellm/data/utils/dataset.py +94 -0
- guidellm/extras/__init__.py +4 -0
- guidellm/extras/audio.py +220 -0
- guidellm/extras/vision.py +242 -0
- guidellm/logger.py +2 -2
- guidellm/mock_server/__init__.py +8 -0
- guidellm/mock_server/config.py +84 -0
- guidellm/mock_server/handlers/__init__.py +17 -0
- guidellm/mock_server/handlers/chat_completions.py +280 -0
- guidellm/mock_server/handlers/completions.py +280 -0
- guidellm/mock_server/handlers/tokenizer.py +142 -0
- guidellm/mock_server/models.py +510 -0
- guidellm/mock_server/server.py +238 -0
- guidellm/mock_server/utils.py +302 -0
- guidellm/scheduler/__init__.py +69 -26
- guidellm/scheduler/constraints/__init__.py +49 -0
- guidellm/scheduler/constraints/constraint.py +325 -0
- guidellm/scheduler/constraints/error.py +411 -0
- guidellm/scheduler/constraints/factory.py +182 -0
- guidellm/scheduler/constraints/request.py +312 -0
- guidellm/scheduler/constraints/saturation.py +722 -0
- guidellm/scheduler/environments.py +252 -0
- guidellm/scheduler/scheduler.py +137 -368
- guidellm/scheduler/schemas.py +358 -0
- guidellm/scheduler/strategies.py +617 -0
- guidellm/scheduler/worker.py +413 -419
- guidellm/scheduler/worker_group.py +712 -0
- guidellm/schemas/__init__.py +65 -0
- guidellm/schemas/base.py +417 -0
- guidellm/schemas/info.py +188 -0
- guidellm/schemas/request.py +235 -0
- guidellm/schemas/request_stats.py +349 -0
- guidellm/schemas/response.py +124 -0
- guidellm/schemas/statistics.py +1018 -0
- guidellm/{config.py → settings.py} +31 -24
- guidellm/utils/__init__.py +71 -8
- guidellm/utils/auto_importer.py +98 -0
- guidellm/utils/cli.py +132 -5
- guidellm/utils/console.py +566 -0
- guidellm/utils/encoding.py +778 -0
- guidellm/utils/functions.py +159 -0
- guidellm/utils/hf_datasets.py +1 -2
- guidellm/utils/hf_transformers.py +4 -4
- guidellm/utils/imports.py +9 -0
- guidellm/utils/messaging.py +1118 -0
- guidellm/utils/mixins.py +115 -0
- guidellm/utils/random.py +3 -4
- guidellm/utils/registry.py +220 -0
- guidellm/utils/singleton.py +133 -0
- guidellm/utils/synchronous.py +159 -0
- guidellm/utils/text.py +163 -50
- guidellm/utils/typing.py +41 -0
- guidellm/version.py +2 -2
- guidellm-0.6.0a5.dist-info/METADATA +364 -0
- guidellm-0.6.0a5.dist-info/RECORD +109 -0
- guidellm/backend/__init__.py +0 -23
- guidellm/backend/backend.py +0 -259
- guidellm/backend/openai.py +0 -708
- guidellm/backend/response.py +0 -136
- guidellm/benchmark/aggregator.py +0 -760
- guidellm/benchmark/benchmark.py +0 -837
- guidellm/benchmark/output.py +0 -997
- guidellm/benchmark/profile.py +0 -409
- guidellm/benchmark/scenario.py +0 -104
- guidellm/data/prideandprejudice.txt.gz +0 -0
- guidellm/dataset/__init__.py +0 -22
- guidellm/dataset/creator.py +0 -213
- guidellm/dataset/entrypoints.py +0 -42
- guidellm/dataset/file.py +0 -92
- guidellm/dataset/hf_datasets.py +0 -62
- guidellm/dataset/in_memory.py +0 -132
- guidellm/dataset/synthetic.py +0 -287
- guidellm/objects/__init__.py +0 -18
- guidellm/objects/pydantic.py +0 -89
- guidellm/objects/statistics.py +0 -953
- guidellm/preprocess/__init__.py +0 -3
- guidellm/preprocess/dataset.py +0 -374
- guidellm/presentation/__init__.py +0 -28
- guidellm/presentation/builder.py +0 -27
- guidellm/presentation/data_models.py +0 -232
- guidellm/presentation/injector.py +0 -66
- guidellm/request/__init__.py +0 -18
- guidellm/request/loader.py +0 -284
- guidellm/request/request.py +0 -79
- guidellm/request/types.py +0 -10
- guidellm/scheduler/queues.py +0 -25
- guidellm/scheduler/result.py +0 -155
- guidellm/scheduler/strategy.py +0 -495
- guidellm-0.3.1.dist-info/METADATA +0 -329
- guidellm-0.3.1.dist-info/RECORD +0 -62
- {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/WHEEL +0 -0
- {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/entry_points.txt +0 -0
- {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/licenses/LICENSE +0 -0
- {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,541 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from collections.abc import Callable, Iterator
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, cast
|
|
6
|
+
|
|
7
|
+
from datasets import Dataset
|
|
8
|
+
from loguru import logger
|
|
9
|
+
from transformers import PreTrainedTokenizerBase
|
|
10
|
+
|
|
11
|
+
from guidellm.data.config import load_config
|
|
12
|
+
from guidellm.data.deserializers import (
|
|
13
|
+
DatasetDeserializerFactory,
|
|
14
|
+
)
|
|
15
|
+
from guidellm.data.preprocessors import GenerativeColumnMapper
|
|
16
|
+
from guidellm.data.schemas import PreprocessDatasetConfig
|
|
17
|
+
from guidellm.utils import IntegerRangeSampler, check_load_processor
|
|
18
|
+
from guidellm.utils.hf_datasets import SUPPORTED_TYPES, save_dataset_to_file
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PromptTooShortError(Exception):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ShortPromptStrategy(str, Enum):
|
|
26
|
+
IGNORE = "ignore"
|
|
27
|
+
CONCATENATE = "concatenate"
|
|
28
|
+
PAD = "pad"
|
|
29
|
+
ERROR = "error"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ShortPromptStrategyHandler:
|
|
33
|
+
"""Handler class for short prompt strategies."""
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def handle_ignore(
|
|
37
|
+
current_prompt: str,
|
|
38
|
+
min_prompt_tokens: int,
|
|
39
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
40
|
+
**_kwargs,
|
|
41
|
+
) -> str | None:
|
|
42
|
+
"""
|
|
43
|
+
Ignores prompts that are shorter than the required minimum token length.
|
|
44
|
+
|
|
45
|
+
:param current_prompt: The input prompt string.
|
|
46
|
+
:param min_prompt_tokens: Minimum required token count.
|
|
47
|
+
:param tokenizer: Tokenizer used to count tokens.
|
|
48
|
+
:return: The prompt if it meets the length, otherwise None.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
if len(tokenizer.encode(current_prompt)) < min_prompt_tokens:
|
|
52
|
+
logger.warning("Prompt too short, ignoring")
|
|
53
|
+
return None
|
|
54
|
+
return current_prompt
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def handle_concatenate(
|
|
58
|
+
current_prompt: str,
|
|
59
|
+
min_prompt_tokens: int,
|
|
60
|
+
dataset_iterator: Iterator[dict[str, Any]],
|
|
61
|
+
prompt_column: str,
|
|
62
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
63
|
+
concat_delimiter: str,
|
|
64
|
+
**_kwargs,
|
|
65
|
+
) -> str | None:
|
|
66
|
+
"""
|
|
67
|
+
Concatenates prompts until the minimum token requirement is met.
|
|
68
|
+
|
|
69
|
+
:param current_prompt: The initial prompt.
|
|
70
|
+
:param min_prompt_tokens: Target minimum token length.
|
|
71
|
+
:param dataset_iterator: Iterator to fetch more prompts.
|
|
72
|
+
:param prompt_column: Column key for prompt extraction.
|
|
73
|
+
:param tokenizer: Tokenizer used to count tokens.
|
|
74
|
+
:param concat_delimiter: Delimiter to use between prompts.
|
|
75
|
+
:return: Concatenated prompt or None if not enough data.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
tokens_len = len(tokenizer.encode(current_prompt))
|
|
79
|
+
while tokens_len < min_prompt_tokens:
|
|
80
|
+
try:
|
|
81
|
+
next_row = next(dataset_iterator)
|
|
82
|
+
except StopIteration:
|
|
83
|
+
logger.warning(
|
|
84
|
+
"Could not concatenate enough prompts to reach minimum "
|
|
85
|
+
"length, ignoring"
|
|
86
|
+
)
|
|
87
|
+
return None
|
|
88
|
+
current_prompt += concat_delimiter + next_row[prompt_column]
|
|
89
|
+
tokens_len = len(tokenizer.encode(current_prompt))
|
|
90
|
+
return current_prompt
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
def handle_pad(
|
|
94
|
+
current_prompt: str,
|
|
95
|
+
min_prompt_tokens: int,
|
|
96
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
97
|
+
pad_char: str,
|
|
98
|
+
pad_multiplier: int = 2,
|
|
99
|
+
**_kwargs,
|
|
100
|
+
) -> str:
|
|
101
|
+
"""
|
|
102
|
+
Pads the prompt with a character until it reaches the minimum token length.
|
|
103
|
+
|
|
104
|
+
:param current_prompt: The input prompt.
|
|
105
|
+
:param min_prompt_tokens: Desired minimum token count.
|
|
106
|
+
:param tokenizer: Tokenizer used to count tokens.
|
|
107
|
+
:param pad_char: Character used for padding.
|
|
108
|
+
:param pad_multiplier: Multiplier for padding character length.
|
|
109
|
+
:return: Padded prompt string.
|
|
110
|
+
"""
|
|
111
|
+
tokens = tokenizer.encode(current_prompt)
|
|
112
|
+
pad_count = 1
|
|
113
|
+
prompt = current_prompt
|
|
114
|
+
while len(tokens) < min_prompt_tokens:
|
|
115
|
+
prompt += pad_char * pad_count
|
|
116
|
+
tokens = tokenizer.encode(prompt)
|
|
117
|
+
pad_count *= pad_multiplier
|
|
118
|
+
return prompt
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def handle_error(
|
|
122
|
+
current_prompt: str,
|
|
123
|
+
min_prompt_tokens: int,
|
|
124
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
125
|
+
**_kwargs,
|
|
126
|
+
) -> str | None:
|
|
127
|
+
"""
|
|
128
|
+
Raises an error if the prompt is too short.
|
|
129
|
+
|
|
130
|
+
:param current_prompt: The input prompt.
|
|
131
|
+
:param min_prompt_tokens: Required token count.
|
|
132
|
+
:param tokenizer: Tokenizer used to count tokens.
|
|
133
|
+
:return: The input prompt if valid.
|
|
134
|
+
:raises PromptTooShortError: If the prompt is too short.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
prompt_len = len(tokenizer.encode(current_prompt))
|
|
138
|
+
if prompt_len < min_prompt_tokens:
|
|
139
|
+
raise PromptTooShortError(
|
|
140
|
+
f"Found too short prompt: {current_prompt}, with length: {prompt_len}. "
|
|
141
|
+
f"Minimum length required: {min_prompt_tokens}.",
|
|
142
|
+
)
|
|
143
|
+
return current_prompt
|
|
144
|
+
|
|
145
|
+
@classmethod
|
|
146
|
+
def get_strategy_handler(cls, strategy: ShortPromptStrategy) -> Callable[..., Any]:
|
|
147
|
+
"""
|
|
148
|
+
Get the handler for a specific strategy.
|
|
149
|
+
|
|
150
|
+
:param strategy: The short prompt strategy to get the handler for.
|
|
151
|
+
:return: The handler callable for the specified strategy.
|
|
152
|
+
"""
|
|
153
|
+
return cast("Callable[..., Any]", STRATEGY_HANDLERS[strategy])
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# Initialize STRATEGY_HANDLERS after class definition to allow method references
|
|
157
|
+
STRATEGY_HANDLERS = {
|
|
158
|
+
ShortPromptStrategy.IGNORE: ShortPromptStrategyHandler.handle_ignore,
|
|
159
|
+
ShortPromptStrategy.CONCATENATE: ShortPromptStrategyHandler.handle_concatenate,
|
|
160
|
+
ShortPromptStrategy.PAD: ShortPromptStrategyHandler.handle_pad,
|
|
161
|
+
ShortPromptStrategy.ERROR: ShortPromptStrategyHandler.handle_error,
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _validate_output_suffix(output_path: str | Path) -> None:
|
|
166
|
+
output_path = Path(output_path)
|
|
167
|
+
suffix = output_path.suffix.lower()
|
|
168
|
+
if suffix not in SUPPORTED_TYPES:
|
|
169
|
+
raise ValueError(
|
|
170
|
+
f"Unsupported file suffix '{suffix}' in output_path '{output_path}'. "
|
|
171
|
+
f"Only {SUPPORTED_TYPES} are supported."
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def parse_synthetic_config(
|
|
176
|
+
config_input: str | Path,
|
|
177
|
+
) -> PreprocessDatasetConfig:
|
|
178
|
+
"""
|
|
179
|
+
Parse PreprocessDatasetConfig from string or file path.
|
|
180
|
+
|
|
181
|
+
Reuses SyntheticTextDatasetDeserializer's parsing logic to support:
|
|
182
|
+
- JSON strings
|
|
183
|
+
- Key=value pairs
|
|
184
|
+
- File paths (.json, .yaml, .yml, .config)
|
|
185
|
+
|
|
186
|
+
:param config_input: String or path to config.
|
|
187
|
+
:return: Parsed PreprocessDatasetConfig instance.
|
|
188
|
+
:raises ValueError: If the format is not recognized or parsing fails.
|
|
189
|
+
"""
|
|
190
|
+
config = load_config(config_input, PreprocessDatasetConfig)
|
|
191
|
+
|
|
192
|
+
if config is not None:
|
|
193
|
+
return config
|
|
194
|
+
|
|
195
|
+
raise ValueError(
|
|
196
|
+
f"Could not parse config from input: {config_input}. "
|
|
197
|
+
"Expected JSON string, key=value pairs, or file path "
|
|
198
|
+
"(.json, .yaml, .yml, .config)"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def process_dataset(
|
|
203
|
+
data: str | Path,
|
|
204
|
+
output_path: str | Path,
|
|
205
|
+
processor: str | Path | PreTrainedTokenizerBase,
|
|
206
|
+
config: str | Path,
|
|
207
|
+
processor_args: dict[str, Any] | None,
|
|
208
|
+
data_args: dict[str, Any] | None,
|
|
209
|
+
data_column_mapper: dict[str, str] | None,
|
|
210
|
+
short_prompt_strategy: ShortPromptStrategy,
|
|
211
|
+
pad_char: str | None,
|
|
212
|
+
concat_delimiter: str | None,
|
|
213
|
+
include_prefix_in_token_count: bool,
|
|
214
|
+
push_to_hub: bool,
|
|
215
|
+
hub_dataset_id: str | None,
|
|
216
|
+
random_seed: int,
|
|
217
|
+
) -> None:
|
|
218
|
+
"""
|
|
219
|
+
Main method to process and save a dataset with sampled prompt/output token counts.
|
|
220
|
+
"""
|
|
221
|
+
_validate_output_suffix(output_path)
|
|
222
|
+
logger.info(
|
|
223
|
+
f"Starting dataset conversion | Input: {data} | Output: {output_path}"
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Parse config
|
|
227
|
+
config_obj = parse_synthetic_config(config)
|
|
228
|
+
|
|
229
|
+
# Load tokenizer
|
|
230
|
+
tokenizer = check_load_processor(
|
|
231
|
+
processor,
|
|
232
|
+
processor_args,
|
|
233
|
+
"dataset conversion.",
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Load dataset
|
|
237
|
+
dataset = DatasetDeserializerFactory.deserialize(
|
|
238
|
+
data=data,
|
|
239
|
+
processor_factory=lambda: tokenizer,
|
|
240
|
+
random_seed=random_seed,
|
|
241
|
+
**(data_args or {}),
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# Setup column mapper
|
|
245
|
+
column_mapper = GenerativeColumnMapper(
|
|
246
|
+
column_mappings=data_column_mapper # type: ignore[arg-type]
|
|
247
|
+
)
|
|
248
|
+
column_mapper.setup_data(
|
|
249
|
+
datasets=[dataset],
|
|
250
|
+
data_args=[data_args or {}],
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# Extract column names from mapper
|
|
254
|
+
prompt_column, prefix_column, output_column = _extract_column_names(column_mapper)
|
|
255
|
+
|
|
256
|
+
# Create token samplers
|
|
257
|
+
prompt_token_sampler, output_token_sampler, prefix_tokens_max = (
|
|
258
|
+
_create_token_samplers(
|
|
259
|
+
config_obj,
|
|
260
|
+
random_seed,
|
|
261
|
+
)
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Process dataset
|
|
265
|
+
dataset_iterator = iter(dataset)
|
|
266
|
+
processed_prompts = []
|
|
267
|
+
prompt_handler = ShortPromptStrategyHandler.get_strategy_handler(
|
|
268
|
+
short_prompt_strategy
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
for row in dataset_iterator:
|
|
272
|
+
processed_row = _process_single_row(
|
|
273
|
+
row=row,
|
|
274
|
+
prompt_column=prompt_column,
|
|
275
|
+
prefix_column=prefix_column,
|
|
276
|
+
prompt_token_sampler=prompt_token_sampler,
|
|
277
|
+
output_token_sampler=output_token_sampler,
|
|
278
|
+
tokenizer=tokenizer,
|
|
279
|
+
prompt_handler=prompt_handler,
|
|
280
|
+
dataset_iterator=dataset_iterator,
|
|
281
|
+
include_prefix_in_token_count=include_prefix_in_token_count,
|
|
282
|
+
pad_char=pad_char,
|
|
283
|
+
concat_delimiter=concat_delimiter,
|
|
284
|
+
output_column=output_column,
|
|
285
|
+
prefix_tokens_max=prefix_tokens_max,
|
|
286
|
+
)
|
|
287
|
+
if processed_row is not None:
|
|
288
|
+
processed_prompts.append(processed_row)
|
|
289
|
+
|
|
290
|
+
# Finalize
|
|
291
|
+
_finalize_processed_dataset(
|
|
292
|
+
processed_prompts,
|
|
293
|
+
output_path,
|
|
294
|
+
push_to_hub,
|
|
295
|
+
hub_dataset_id,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _extract_column_names(
|
|
300
|
+
column_mapper: GenerativeColumnMapper,
|
|
301
|
+
) -> tuple[str, str | None, str]:
|
|
302
|
+
"""
|
|
303
|
+
Extract column names for prompt, prefix, and output from column mapper.
|
|
304
|
+
|
|
305
|
+
:param column_mapper: Initialized column mapper.
|
|
306
|
+
:return: Tuple of (prompt_column, prefix_column, output_column).
|
|
307
|
+
:raises ValueError: If column mapper is not properly initialized.
|
|
308
|
+
"""
|
|
309
|
+
if column_mapper.datasets_column_mappings is None:
|
|
310
|
+
raise ValueError("Column mapper not properly initialized")
|
|
311
|
+
|
|
312
|
+
text_mappings = column_mapper.datasets_column_mappings.get("text_column", [])
|
|
313
|
+
if not text_mappings:
|
|
314
|
+
raise ValueError("Could not find text column in dataset")
|
|
315
|
+
prompt_column = text_mappings[0][1]
|
|
316
|
+
|
|
317
|
+
prefix_mappings = column_mapper.datasets_column_mappings.get("prefix_column", [])
|
|
318
|
+
prefix_column = prefix_mappings[0][1] if prefix_mappings else None
|
|
319
|
+
|
|
320
|
+
output_mappings = column_mapper.datasets_column_mappings.get(
|
|
321
|
+
"output_tokens_count_column", []
|
|
322
|
+
)
|
|
323
|
+
output_column = (
|
|
324
|
+
output_mappings[0][1] if output_mappings else "output_tokens_count"
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
return prompt_column, prefix_column, output_column
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _create_token_samplers(
|
|
331
|
+
config_obj: PreprocessDatasetConfig,
|
|
332
|
+
random_seed: int,
|
|
333
|
+
) -> tuple[Iterator[int], Iterator[int], int | None]:
|
|
334
|
+
"""
|
|
335
|
+
Create token samplers for prompt, output, and prefix tokens.
|
|
336
|
+
|
|
337
|
+
:param config_obj: Configuration object with token settings.
|
|
338
|
+
:param prefix_tokens: Optional single prefix token count.
|
|
339
|
+
:param random_seed: Seed for random sampling.
|
|
340
|
+
:return: Tuple of (prompt_sampler, output_sampler, prefix_tokens_max).
|
|
341
|
+
prefix_sampler is None when prefix_tokens is not provided.
|
|
342
|
+
prefix_tokens_max is the maximum prefix token limit from config.
|
|
343
|
+
"""
|
|
344
|
+
prompt_token_sampler = iter(
|
|
345
|
+
IntegerRangeSampler(
|
|
346
|
+
average=config_obj.prompt_tokens,
|
|
347
|
+
variance=config_obj.prompt_tokens_stdev,
|
|
348
|
+
min_value=config_obj.prompt_tokens_min,
|
|
349
|
+
max_value=config_obj.prompt_tokens_max,
|
|
350
|
+
random_seed=random_seed,
|
|
351
|
+
)
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
output_token_sampler = iter(
|
|
355
|
+
IntegerRangeSampler(
|
|
356
|
+
average=config_obj.output_tokens,
|
|
357
|
+
variance=config_obj.output_tokens_stdev,
|
|
358
|
+
min_value=config_obj.output_tokens_min,
|
|
359
|
+
max_value=config_obj.output_tokens_max,
|
|
360
|
+
random_seed=random_seed,
|
|
361
|
+
)
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
return prompt_token_sampler, output_token_sampler, config_obj.prefix_tokens_max
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _process_dataset_row(
|
|
368
|
+
row: dict[str, Any],
|
|
369
|
+
prompt_column: str,
|
|
370
|
+
prefix_column: str | None,
|
|
371
|
+
output_column: str,
|
|
372
|
+
target_output_len: int,
|
|
373
|
+
prompt_text: str,
|
|
374
|
+
prefix_text: str | None,
|
|
375
|
+
tokens: list[int],
|
|
376
|
+
) -> dict[str, Any]:
|
|
377
|
+
"""
|
|
378
|
+
Create a processed row from the processed prompt/prefix data.
|
|
379
|
+
|
|
380
|
+
:param row: Original dataset row.
|
|
381
|
+
:param prompt_column: Name of prompt column.
|
|
382
|
+
:param prefix_column: Name of prefix column or None.
|
|
383
|
+
:param output_column: Name of output tokens count column.
|
|
384
|
+
:param target_prompt_len: Target prompt token length.
|
|
385
|
+
:param target_output_len: Target output token length.
|
|
386
|
+
:param prompt_text: Processed prompt text.
|
|
387
|
+
:param prefix_text: Processed prefix text or None.
|
|
388
|
+
:param tokens: Tokenized prompt.
|
|
389
|
+
:return: Processed row dictionary.
|
|
390
|
+
"""
|
|
391
|
+
processed_row = row.copy()
|
|
392
|
+
processed_row[prompt_column] = prompt_text
|
|
393
|
+
if prefix_column and prefix_text:
|
|
394
|
+
processed_row[prefix_column] = prefix_text
|
|
395
|
+
processed_row["prompt_tokens_count"] = len(tokens)
|
|
396
|
+
processed_row[output_column] = target_output_len
|
|
397
|
+
return processed_row
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def _process_single_row(
|
|
401
|
+
row: dict[str, Any],
|
|
402
|
+
prompt_column: str,
|
|
403
|
+
prefix_column: str | None,
|
|
404
|
+
prompt_token_sampler: Iterator[int],
|
|
405
|
+
output_token_sampler: Iterator[int],
|
|
406
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
407
|
+
prompt_handler: Callable,
|
|
408
|
+
dataset_iterator: Iterator[dict[str, Any]],
|
|
409
|
+
include_prefix_in_token_count: bool,
|
|
410
|
+
pad_char: str | None,
|
|
411
|
+
concat_delimiter: str | None,
|
|
412
|
+
output_column: str,
|
|
413
|
+
prefix_tokens_max: int | None,
|
|
414
|
+
) -> dict[str, Any] | None:
|
|
415
|
+
"""
|
|
416
|
+
Process a single row from the dataset.
|
|
417
|
+
|
|
418
|
+
:param include_prefix_in_token_count: When True, includes prefix tokens in the
|
|
419
|
+
prompt token count calculation. When False, prefix tokens are not counted
|
|
420
|
+
toward prompt tokens.
|
|
421
|
+
:param prefix_tokens_max: Maximum prefix token limit. If set, the prefix will be
|
|
422
|
+
trimmed if it exceeds this limit.
|
|
423
|
+
:return: Processed row dictionary or None if row should be skipped.
|
|
424
|
+
"""
|
|
425
|
+
# Extract prompt and prefix
|
|
426
|
+
prompt_text = row.get(prompt_column, "")
|
|
427
|
+
prefix_text = row.get(prefix_column) if prefix_column else None
|
|
428
|
+
|
|
429
|
+
# Sample target prompt token count
|
|
430
|
+
target_prompt_len = next(prompt_token_sampler)
|
|
431
|
+
count_adjustment = 0
|
|
432
|
+
|
|
433
|
+
# Handle prefix
|
|
434
|
+
if prefix_text:
|
|
435
|
+
# Apply prefix_tokens_max limit if set (strict maximum)
|
|
436
|
+
if prefix_tokens_max is not None:
|
|
437
|
+
prefix_tokens_list = tokenizer.encode(prefix_text)
|
|
438
|
+
if len(prefix_tokens_list) > prefix_tokens_max:
|
|
439
|
+
prefix_text = tokenizer.decode(
|
|
440
|
+
prefix_tokens_list[:prefix_tokens_max]
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# Count prefix tokens toward prompt if enabled
|
|
444
|
+
if include_prefix_in_token_count:
|
|
445
|
+
count_adjustment = len(tokenizer.encode(prefix_text))
|
|
446
|
+
|
|
447
|
+
if target_prompt_len == 0:
|
|
448
|
+
logger.warning("zero prompt size requested; skipping row")
|
|
449
|
+
return None
|
|
450
|
+
elif count_adjustment > 0:
|
|
451
|
+
adjusted_prompt_len = target_prompt_len - count_adjustment
|
|
452
|
+
if adjusted_prompt_len <= 0:
|
|
453
|
+
logger.warning("The prefix exceeds target output length with "
|
|
454
|
+
"--include-prefix-in-token-count enabled; Using prompt size"
|
|
455
|
+
"of 1; skipping row")
|
|
456
|
+
return None
|
|
457
|
+
target_prompt_len = adjusted_prompt_len
|
|
458
|
+
|
|
459
|
+
# Handle short prompts
|
|
460
|
+
prompt_text = prompt_handler(
|
|
461
|
+
current_prompt=prompt_text,
|
|
462
|
+
min_prompt_tokens=target_prompt_len,
|
|
463
|
+
dataset_iterator=dataset_iterator,
|
|
464
|
+
prompt_column=prompt_column,
|
|
465
|
+
tokenizer=tokenizer,
|
|
466
|
+
pad_char=pad_char,
|
|
467
|
+
concat_delimiter=concat_delimiter,
|
|
468
|
+
)
|
|
469
|
+
if prompt_text is None:
|
|
470
|
+
return None
|
|
471
|
+
|
|
472
|
+
# Trim long prompts
|
|
473
|
+
tokens = tokenizer.encode(prompt_text)
|
|
474
|
+
if len(tokens) > target_prompt_len:
|
|
475
|
+
prompt_text = tokenizer.decode(tokens[:target_prompt_len])
|
|
476
|
+
tokens = tokenizer.encode(prompt_text)
|
|
477
|
+
|
|
478
|
+
# Sample output token count
|
|
479
|
+
target_output_len = next(output_token_sampler)
|
|
480
|
+
|
|
481
|
+
# Create processed row
|
|
482
|
+
return _process_dataset_row(
|
|
483
|
+
row=row,
|
|
484
|
+
prompt_column=prompt_column,
|
|
485
|
+
prefix_column=prefix_column,
|
|
486
|
+
output_column=output_column,
|
|
487
|
+
target_output_len=target_output_len,
|
|
488
|
+
prompt_text=prompt_text,
|
|
489
|
+
prefix_text=prefix_text,
|
|
490
|
+
tokens=tokens,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def _finalize_processed_dataset(
|
|
495
|
+
processed_prompts: list[dict[str, Any]],
|
|
496
|
+
output_path: str | Path,
|
|
497
|
+
push_to_hub: bool,
|
|
498
|
+
hub_dataset_id: str | None,
|
|
499
|
+
) -> None:
|
|
500
|
+
"""
|
|
501
|
+
Finalize the processed dataset by saving and optionally pushing to hub.
|
|
502
|
+
|
|
503
|
+
:param processed_prompts: List of processed row dictionaries.
|
|
504
|
+
:param output_path: Path to save the dataset.
|
|
505
|
+
:param push_to_hub: Whether to push to Hugging Face Hub.
|
|
506
|
+
:param hub_dataset_id: Dataset ID on Hugging Face Hub.
|
|
507
|
+
"""
|
|
508
|
+
if not processed_prompts:
|
|
509
|
+
logger.error("No prompts remained after processing")
|
|
510
|
+
return
|
|
511
|
+
|
|
512
|
+
logger.info(f"Generated processed dataset with {len(processed_prompts)} prompts")
|
|
513
|
+
|
|
514
|
+
processed_dataset = Dataset.from_list(processed_prompts)
|
|
515
|
+
save_dataset_to_file(processed_dataset, output_path)
|
|
516
|
+
logger.info(f"Conversion completed. Dataset saved to: {output_path}")
|
|
517
|
+
|
|
518
|
+
if push_to_hub:
|
|
519
|
+
push_dataset_to_hub(hub_dataset_id, processed_dataset)
|
|
520
|
+
logger.info(f"Pushed dataset to: {hub_dataset_id}")
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def push_dataset_to_hub(
|
|
524
|
+
hub_dataset_id: str | None,
|
|
525
|
+
processed_dataset: Dataset,
|
|
526
|
+
) -> None:
|
|
527
|
+
"""
|
|
528
|
+
Pushes the processed dataset to Hugging Face Hub using HF_TOKEN.
|
|
529
|
+
|
|
530
|
+
:param hub_dataset_id: Identifier on the Hub to push to.
|
|
531
|
+
:param processed_dataset: HuggingFace Dataset object.
|
|
532
|
+
:raises ValueError: If hub_dataset_id or HF_TOKEN is not available.
|
|
533
|
+
"""
|
|
534
|
+
|
|
535
|
+
hf_token = os.environ.get("HF_TOKEN")
|
|
536
|
+
if not hub_dataset_id or not hf_token:
|
|
537
|
+
raise ValueError(
|
|
538
|
+
"hub_dataset_id and HF_TOKEN env var must be provided when push_to_hub"
|
|
539
|
+
" is True"
|
|
540
|
+
)
|
|
541
|
+
processed_dataset.push_to_hub(hub_dataset_id, token=hf_token)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from guidellm.schemas import GenerationRequest
|
|
4
|
+
|
|
5
|
+
__all__ = ["GenerativeRequestCollator"]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GenerativeRequestCollator:
|
|
9
|
+
def __call__(self, batch: list) -> GenerationRequest:
|
|
10
|
+
if len(batch) != 1:
|
|
11
|
+
raise NotImplementedError(
|
|
12
|
+
f"Batch size greater than 1 is not currently supported. "
|
|
13
|
+
f"Got batch size: {len(batch)}"
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
return batch[0]
|
guidellm/data/config.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, TypeVar
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
from pydantic import ValidationError
|
|
8
|
+
|
|
9
|
+
from guidellm.data.schemas import DataConfig, DataNotSupportedError
|
|
10
|
+
|
|
11
|
+
ConfigT = TypeVar("ConfigT", bound=DataConfig)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def load_config(config: Any, config_class: type[ConfigT]) -> ConfigT | None:
|
|
15
|
+
# Try file path first
|
|
16
|
+
if (loaded_config := _load_config_file(config, config_class)) is not None:
|
|
17
|
+
return loaded_config
|
|
18
|
+
|
|
19
|
+
# Try dict parsing next
|
|
20
|
+
if (loaded_config := _load_config_dict(config, config_class)) is not None:
|
|
21
|
+
return loaded_config
|
|
22
|
+
|
|
23
|
+
# Try string parsing
|
|
24
|
+
if (loaded_config := _load_config_str(config, config_class)) is not None:
|
|
25
|
+
return loaded_config
|
|
26
|
+
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _load_config_dict(data: Any, config_class: type[ConfigT]) -> ConfigT | None:
|
|
31
|
+
if not isinstance(data, dict | list):
|
|
32
|
+
return None
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
return config_class.model_validate(data)
|
|
36
|
+
except ValidationError:
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _load_config_file(data: Any, config_class: type[ConfigT]) -> ConfigT | None:
|
|
41
|
+
if (not isinstance(data, str) and not isinstance(data, Path)) or (
|
|
42
|
+
not Path(data).is_file()
|
|
43
|
+
):
|
|
44
|
+
return None
|
|
45
|
+
|
|
46
|
+
data_path = Path(data) if isinstance(data, str) else data
|
|
47
|
+
error = None
|
|
48
|
+
|
|
49
|
+
if Path(data).is_file() and data_path.suffix.lower() == ".json":
|
|
50
|
+
try:
|
|
51
|
+
return config_class.model_validate_json(
|
|
52
|
+
data_path.read_text()
|
|
53
|
+
)
|
|
54
|
+
except Exception as err: # noqa: BLE001
|
|
55
|
+
error = err
|
|
56
|
+
|
|
57
|
+
if Path(data).is_file() and data_path.suffix.lower() in {
|
|
58
|
+
".yaml",
|
|
59
|
+
".yml",
|
|
60
|
+
".config",
|
|
61
|
+
}:
|
|
62
|
+
try:
|
|
63
|
+
return config_class.model_validate(
|
|
64
|
+
yaml.safe_load(data_path.read_text())
|
|
65
|
+
)
|
|
66
|
+
except Exception as err: # noqa: BLE001
|
|
67
|
+
error = err
|
|
68
|
+
|
|
69
|
+
err_message = (
|
|
70
|
+
f"Unsupported file {data_path} for "
|
|
71
|
+
f"{config_class.__name__}, expected .json, "
|
|
72
|
+
f".yaml, .yml, or .config"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
if error is not None:
|
|
76
|
+
err_message += f" with error: {error}"
|
|
77
|
+
raise DataNotSupportedError(err_message) from error
|
|
78
|
+
raise DataNotSupportedError(err_message)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _load_config_str(data: str, config_class: type[ConfigT]) -> ConfigT | None:
|
|
82
|
+
if not isinstance(data, str):
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
data_str = data.strip()
|
|
86
|
+
error = None
|
|
87
|
+
|
|
88
|
+
if (data_str.startswith("{") and data_str.endswith("}")) or (
|
|
89
|
+
data_str.startswith("[") and data_str.endswith("]")
|
|
90
|
+
):
|
|
91
|
+
try:
|
|
92
|
+
return config_class.model_validate_json(data_str)
|
|
93
|
+
except Exception as err: # noqa: BLE001
|
|
94
|
+
error = err
|
|
95
|
+
|
|
96
|
+
if data_str.count("=") > 1:
|
|
97
|
+
# key=value pairs separated by commas
|
|
98
|
+
try:
|
|
99
|
+
config_dict = {}
|
|
100
|
+
items = data_str.split(",")
|
|
101
|
+
for item in items:
|
|
102
|
+
key, value = item.split("=")
|
|
103
|
+
config_dict[key.strip()] = (
|
|
104
|
+
int(value.strip())
|
|
105
|
+
if value.strip().isnumeric()
|
|
106
|
+
else value.strip()
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return config_class.model_validate(config_dict)
|
|
110
|
+
except Exception as err: # noqa: BLE001
|
|
111
|
+
error = err
|
|
112
|
+
|
|
113
|
+
err_message = (
|
|
114
|
+
f"Unsupported string data for {config_class.__name__}, "
|
|
115
|
+
f"expected JSON or key-value pairs, got {data}"
|
|
116
|
+
)
|
|
117
|
+
if error is not None:
|
|
118
|
+
err_message += f" with error: {error}"
|
|
119
|
+
raise DataNotSupportedError(err_message) from error
|
|
120
|
+
raise DataNotSupportedError(err_message)
|