guidellm 0.3.0a15__py3-none-any.whl → 0.3.0a18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guidellm might be problematic. Click here for more details.
- guidellm/__main__.py +150 -0
- guidellm/preprocess/__init__.py +3 -0
- guidellm/preprocess/dataset.py +374 -0
- guidellm/utils/__init__.py +6 -0
- guidellm/utils/hf_datasets.py +36 -0
- {guidellm-0.3.0a15.dist-info → guidellm-0.3.0a18.dist-info}/METADATA +1 -1
- {guidellm-0.3.0a15.dist-info → guidellm-0.3.0a18.dist-info}/RECORD +11 -8
- {guidellm-0.3.0a15.dist-info → guidellm-0.3.0a18.dist-info}/WHEEL +0 -0
- {guidellm-0.3.0a15.dist-info → guidellm-0.3.0a18.dist-info}/entry_points.txt +0 -0
- {guidellm-0.3.0a15.dist-info → guidellm-0.3.0a18.dist-info}/licenses/LICENSE +0 -0
- {guidellm-0.3.0a15.dist-info → guidellm-0.3.0a18.dist-info}/top_level.txt +0 -0
guidellm/__main__.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import codecs
|
|
2
3
|
import json
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import get_args
|
|
@@ -8,6 +9,7 @@ import click
|
|
|
8
9
|
from guidellm.backend import BackendType
|
|
9
10
|
from guidellm.benchmark import ProfileType, benchmark_generative_text
|
|
10
11
|
from guidellm.config import print_config
|
|
12
|
+
from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset
|
|
11
13
|
from guidellm.scheduler import StrategyType
|
|
12
14
|
|
|
13
15
|
STRATEGY_PROFILE_CHOICES = set(
|
|
@@ -280,6 +282,20 @@ def benchmark(
|
|
|
280
282
|
)
|
|
281
283
|
|
|
282
284
|
|
|
285
|
+
def decode_escaped_str(_ctx, _param, value):
|
|
286
|
+
"""
|
|
287
|
+
Click auto adds characters. For example, when using --pad-char "\n",
|
|
288
|
+
it parses it as "\\n". This method decodes the string to handle escape
|
|
289
|
+
sequences correctly.
|
|
290
|
+
"""
|
|
291
|
+
if value is None:
|
|
292
|
+
return None
|
|
293
|
+
try:
|
|
294
|
+
return codecs.decode(value, "unicode_escape")
|
|
295
|
+
except Exception as e:
|
|
296
|
+
raise click.BadParameter(f"Could not decode escape sequences: {e}") from e
|
|
297
|
+
|
|
298
|
+
|
|
283
299
|
@cli.command(
|
|
284
300
|
help=(
|
|
285
301
|
"Print out the available configuration settings that can be set "
|
|
@@ -290,5 +306,139 @@ def config():
|
|
|
290
306
|
print_config()
|
|
291
307
|
|
|
292
308
|
|
|
309
|
+
@cli.group(help="General preprocessing tools and utilities.")
|
|
310
|
+
def preprocess():
|
|
311
|
+
pass
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
@preprocess.command(
|
|
315
|
+
help=(
|
|
316
|
+
"Convert a dataset to have specific prompt and output token sizes.\n"
|
|
317
|
+
"DATA: Path to the input dataset or dataset ID.\n"
|
|
318
|
+
"OUTPUT_PATH: Path to save the converted dataset, including file suffix."
|
|
319
|
+
)
|
|
320
|
+
)
|
|
321
|
+
@click.argument(
|
|
322
|
+
"data",
|
|
323
|
+
type=str,
|
|
324
|
+
required=True,
|
|
325
|
+
)
|
|
326
|
+
@click.argument(
|
|
327
|
+
"output_path",
|
|
328
|
+
type=click.Path(file_okay=True, dir_okay=False, writable=True, resolve_path=True),
|
|
329
|
+
required=True,
|
|
330
|
+
)
|
|
331
|
+
@click.option(
|
|
332
|
+
"--processor",
|
|
333
|
+
type=str,
|
|
334
|
+
required=True,
|
|
335
|
+
help=(
|
|
336
|
+
"The processor or tokenizer to use to calculate token counts for statistics "
|
|
337
|
+
"and synthetic data generation."
|
|
338
|
+
),
|
|
339
|
+
)
|
|
340
|
+
@click.option(
|
|
341
|
+
"--processor-args",
|
|
342
|
+
default=None,
|
|
343
|
+
callback=parse_json,
|
|
344
|
+
help=(
|
|
345
|
+
"A JSON string containing any arguments to pass to the processor constructor "
|
|
346
|
+
"as a dict with **kwargs."
|
|
347
|
+
),
|
|
348
|
+
)
|
|
349
|
+
@click.option(
|
|
350
|
+
"--data-args",
|
|
351
|
+
callback=parse_json,
|
|
352
|
+
help=(
|
|
353
|
+
"A JSON string containing any arguments to pass to the dataset creation "
|
|
354
|
+
"as a dict with **kwargs."
|
|
355
|
+
),
|
|
356
|
+
)
|
|
357
|
+
@click.option(
|
|
358
|
+
"--short-prompt-strategy",
|
|
359
|
+
type=click.Choice([s.value for s in ShortPromptStrategy]),
|
|
360
|
+
default=ShortPromptStrategy.IGNORE.value,
|
|
361
|
+
show_default=True,
|
|
362
|
+
help="Strategy to handle prompts shorter than the target length. ",
|
|
363
|
+
)
|
|
364
|
+
@click.option(
|
|
365
|
+
"--pad-char",
|
|
366
|
+
type=str,
|
|
367
|
+
default="",
|
|
368
|
+
callback=decode_escaped_str,
|
|
369
|
+
help="The token to pad short prompts with when using the 'pad' strategy.",
|
|
370
|
+
)
|
|
371
|
+
@click.option(
|
|
372
|
+
"--concat-delimiter",
|
|
373
|
+
type=str,
|
|
374
|
+
default="",
|
|
375
|
+
help=(
|
|
376
|
+
"The delimiter to use when concatenating prompts that are too short."
|
|
377
|
+
" Used when strategy is 'concatenate'."
|
|
378
|
+
),
|
|
379
|
+
)
|
|
380
|
+
@click.option(
|
|
381
|
+
"--prompt-tokens",
|
|
382
|
+
type=str,
|
|
383
|
+
default=None,
|
|
384
|
+
help="Prompt tokens config (JSON, YAML file or key=value string)",
|
|
385
|
+
)
|
|
386
|
+
@click.option(
|
|
387
|
+
"--output-tokens",
|
|
388
|
+
type=str,
|
|
389
|
+
default=None,
|
|
390
|
+
help="Output tokens config (JSON, YAML file or key=value string)",
|
|
391
|
+
)
|
|
392
|
+
@click.option(
|
|
393
|
+
"--push-to-hub",
|
|
394
|
+
is_flag=True,
|
|
395
|
+
help="Set this flag to push the converted dataset to the Hugging Face Hub.",
|
|
396
|
+
)
|
|
397
|
+
@click.option(
|
|
398
|
+
"--hub-dataset-id",
|
|
399
|
+
type=str,
|
|
400
|
+
default=None,
|
|
401
|
+
help="The Hugging Face Hub dataset ID to push to. "
|
|
402
|
+
"Required if --push-to-hub is used.",
|
|
403
|
+
)
|
|
404
|
+
@click.option(
|
|
405
|
+
"--random-seed",
|
|
406
|
+
type=int,
|
|
407
|
+
default=42,
|
|
408
|
+
show_default=True,
|
|
409
|
+
help="Random seed for prompt token sampling and output tokens sampling.",
|
|
410
|
+
)
|
|
411
|
+
def dataset(
|
|
412
|
+
data,
|
|
413
|
+
output_path,
|
|
414
|
+
processor,
|
|
415
|
+
processor_args,
|
|
416
|
+
data_args,
|
|
417
|
+
short_prompt_strategy,
|
|
418
|
+
pad_char,
|
|
419
|
+
concat_delimiter,
|
|
420
|
+
prompt_tokens,
|
|
421
|
+
output_tokens,
|
|
422
|
+
push_to_hub,
|
|
423
|
+
hub_dataset_id,
|
|
424
|
+
random_seed,
|
|
425
|
+
):
|
|
426
|
+
process_dataset(
|
|
427
|
+
data=data,
|
|
428
|
+
output_path=output_path,
|
|
429
|
+
processor=processor,
|
|
430
|
+
prompt_tokens=prompt_tokens,
|
|
431
|
+
output_tokens=output_tokens,
|
|
432
|
+
processor_args=processor_args,
|
|
433
|
+
data_args=data_args,
|
|
434
|
+
short_prompt_strategy=short_prompt_strategy,
|
|
435
|
+
pad_char=pad_char,
|
|
436
|
+
concat_delimiter=concat_delimiter,
|
|
437
|
+
push_to_hub=push_to_hub,
|
|
438
|
+
hub_dataset_id=hub_dataset_id,
|
|
439
|
+
random_seed=random_seed,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
|
|
293
443
|
if __name__ == "__main__":
|
|
294
444
|
cli()
|
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from collections.abc import Iterator
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Optional, Union
|
|
7
|
+
|
|
8
|
+
import yaml
|
|
9
|
+
from datasets import Dataset
|
|
10
|
+
from loguru import logger
|
|
11
|
+
from pydantic import BaseModel, Field
|
|
12
|
+
from transformers import PreTrainedTokenizerBase
|
|
13
|
+
|
|
14
|
+
from guidellm.dataset import load_dataset as guidellm_load_dataset
|
|
15
|
+
from guidellm.utils import IntegerRangeSampler, check_load_processor
|
|
16
|
+
from guidellm.utils.hf_datasets import SUPPORTED_TYPES, save_dataset_to_file
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PromptTooShortError(Exception):
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ShortPromptStrategy(str, Enum):
|
|
24
|
+
IGNORE = "ignore"
|
|
25
|
+
CONCATENATE = "concatenate"
|
|
26
|
+
PAD = "pad"
|
|
27
|
+
ERROR = "error"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def handle_ignore_strategy(
|
|
31
|
+
current_prompt: str,
|
|
32
|
+
min_prompt_tokens: int,
|
|
33
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
34
|
+
**_kwargs,
|
|
35
|
+
) -> Optional[str]:
|
|
36
|
+
"""
|
|
37
|
+
Ignores prompts that are shorter than the required minimum token length.
|
|
38
|
+
|
|
39
|
+
:param current_prompt: The input prompt string.
|
|
40
|
+
:param min_prompt_tokens: Minimum required token count.
|
|
41
|
+
:param tokenizer: Tokenizer used to count tokens.
|
|
42
|
+
:return: The prompt if it meets the length, otherwise None.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
if len(tokenizer.encode(current_prompt)) < min_prompt_tokens:
|
|
46
|
+
logger.warning("Prompt too short, ignoring")
|
|
47
|
+
return None
|
|
48
|
+
return current_prompt
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def handle_concatenate_strategy(
|
|
52
|
+
current_prompt: str,
|
|
53
|
+
min_prompt_tokens: int,
|
|
54
|
+
dataset_iterator: Iterator[dict[str, Any]],
|
|
55
|
+
prompt_column: str,
|
|
56
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
57
|
+
concat_delimiter: str,
|
|
58
|
+
**_kwargs,
|
|
59
|
+
) -> Optional[str]:
|
|
60
|
+
"""
|
|
61
|
+
Concatenates prompts until the minimum token requirement is met.
|
|
62
|
+
|
|
63
|
+
:param current_prompt: The initial prompt.
|
|
64
|
+
:param min_prompt_tokens: Target minimum token length.
|
|
65
|
+
:param dataset_iterator: Iterator to fetch more prompts.
|
|
66
|
+
:param prompt_column: Column key for prompt extraction.
|
|
67
|
+
:param tokenizer: Tokenizer used to count tokens.
|
|
68
|
+
:param concat_delimiter: Delimiter to use between prompts.
|
|
69
|
+
:return: Concatenated prompt or None if not enough data.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
tokens_len = len(tokenizer.encode(current_prompt))
|
|
73
|
+
while tokens_len < min_prompt_tokens:
|
|
74
|
+
try:
|
|
75
|
+
next_row = next(dataset_iterator)
|
|
76
|
+
except StopIteration:
|
|
77
|
+
logger.warning(
|
|
78
|
+
"Could not concatenate enough prompts to reach minimum length, ignoring"
|
|
79
|
+
)
|
|
80
|
+
return None
|
|
81
|
+
current_prompt += concat_delimiter + next_row[prompt_column]
|
|
82
|
+
tokens_len = len(tokenizer.encode(current_prompt))
|
|
83
|
+
return current_prompt
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def handle_pad_strategy(
|
|
87
|
+
current_prompt: str,
|
|
88
|
+
min_prompt_tokens: int,
|
|
89
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
90
|
+
pad_char: str,
|
|
91
|
+
pad_multiplier: int = 2,
|
|
92
|
+
**_kwargs,
|
|
93
|
+
) -> str:
|
|
94
|
+
"""
|
|
95
|
+
Pads the prompt with a character until it reaches the minimum token length.
|
|
96
|
+
|
|
97
|
+
:param current_prompt: The input prompt.
|
|
98
|
+
:param min_prompt_tokens: Desired minimum token count.
|
|
99
|
+
:param tokenizer: Tokenizer used to count tokens.
|
|
100
|
+
:param pad_char: Character used for padding.
|
|
101
|
+
:param pad_multiplier: Multiplier for padding character length.
|
|
102
|
+
:return: Padded prompt string.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
tokens = tokenizer.encode(current_prompt)
|
|
106
|
+
pad_count = 1
|
|
107
|
+
prompt = current_prompt
|
|
108
|
+
while len(tokens) < min_prompt_tokens:
|
|
109
|
+
prompt += pad_char * pad_count
|
|
110
|
+
tokens = tokenizer.encode(prompt)
|
|
111
|
+
pad_count *= pad_multiplier
|
|
112
|
+
return prompt
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def handle_error_strategy(
|
|
116
|
+
current_prompt: str,
|
|
117
|
+
min_prompt_tokens: int,
|
|
118
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
119
|
+
**_kwargs,
|
|
120
|
+
) -> Optional[str]:
|
|
121
|
+
"""
|
|
122
|
+
Raises an error if the prompt is too short.
|
|
123
|
+
|
|
124
|
+
:param current_prompt: The input prompt.
|
|
125
|
+
:param min_prompt_tokens: Required token count.
|
|
126
|
+
:param tokenizer: Tokenizer used to count tokens.
|
|
127
|
+
:return: The input prompt if valid.
|
|
128
|
+
:raises PromptTooShortError: If the prompt is too short.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
prompt_len = len(tokenizer.encode(current_prompt))
|
|
132
|
+
if prompt_len < min_prompt_tokens:
|
|
133
|
+
raise PromptTooShortError(
|
|
134
|
+
f"Found too short prompt: {current_prompt}, with length: {prompt_len}. "
|
|
135
|
+
f"Minimum length required: {min_prompt_tokens}.",
|
|
136
|
+
)
|
|
137
|
+
return current_prompt
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
STRATEGY_HANDLERS: dict[ShortPromptStrategy, Callable] = {
|
|
141
|
+
ShortPromptStrategy.IGNORE: handle_ignore_strategy,
|
|
142
|
+
ShortPromptStrategy.CONCATENATE: handle_concatenate_strategy,
|
|
143
|
+
ShortPromptStrategy.PAD: handle_pad_strategy,
|
|
144
|
+
ShortPromptStrategy.ERROR: handle_error_strategy,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class TokensConfig(BaseModel):
|
|
149
|
+
average: int = Field(
|
|
150
|
+
description="The average number of tokens.",
|
|
151
|
+
gt=0,
|
|
152
|
+
)
|
|
153
|
+
stdev: Optional[int] = Field(
|
|
154
|
+
description="The standard deviation of the tokens.",
|
|
155
|
+
gt=0,
|
|
156
|
+
default=None,
|
|
157
|
+
)
|
|
158
|
+
min: Optional[int] = Field(
|
|
159
|
+
description="The minimum number of tokens.",
|
|
160
|
+
gt=0,
|
|
161
|
+
default=None,
|
|
162
|
+
)
|
|
163
|
+
max: Optional[int] = Field(
|
|
164
|
+
description="The maximum number of tokens.",
|
|
165
|
+
gt=0,
|
|
166
|
+
default=None,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def parse_str(data: Union[str, Path]) -> "TokensConfig":
|
|
171
|
+
"""
|
|
172
|
+
Parses a string or path into a TokensConfig object. Supports:
|
|
173
|
+
- JSON string
|
|
174
|
+
- key=value pairs
|
|
175
|
+
- file path to .yaml/.config
|
|
176
|
+
|
|
177
|
+
:param data: String or path containing configuration.
|
|
178
|
+
:return: Parsed TokensConfig instance.
|
|
179
|
+
:raises ValueError: If the format is not recognized.
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
if (
|
|
183
|
+
isinstance(data, Path)
|
|
184
|
+
or data.strip().endswith(".config")
|
|
185
|
+
or data.strip().endswith(".yaml")
|
|
186
|
+
):
|
|
187
|
+
return TokensConfig.parse_config_file(data)
|
|
188
|
+
|
|
189
|
+
if data.strip().startswith("{"):
|
|
190
|
+
return TokensConfig.parse_json(data)
|
|
191
|
+
|
|
192
|
+
if data.count("=") > 1:
|
|
193
|
+
return TokensConfig.parse_key_value_pairs(data)
|
|
194
|
+
|
|
195
|
+
raise ValueError(
|
|
196
|
+
f"Unsupported data format. Expected JSON or key-value pairs, got {data}"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
@staticmethod
|
|
200
|
+
def parse_json(data: str) -> "TokensConfig":
|
|
201
|
+
config_dict = json.loads(data.strip())
|
|
202
|
+
|
|
203
|
+
return TokensConfig(**config_dict)
|
|
204
|
+
|
|
205
|
+
@staticmethod
|
|
206
|
+
def parse_key_value_pairs(data: str) -> "TokensConfig":
|
|
207
|
+
config_dict = {}
|
|
208
|
+
items = data.strip().split(",")
|
|
209
|
+
for item in items:
|
|
210
|
+
key, value = item.split("=")
|
|
211
|
+
config_dict[key.strip()] = (
|
|
212
|
+
int(value.strip()) if value.strip().isnumeric() else value.strip()
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
return TokensConfig(**config_dict) # type: ignore[arg-type]
|
|
216
|
+
|
|
217
|
+
@staticmethod
|
|
218
|
+
def parse_config_file(data: Union[str, Path]) -> "TokensConfig":
|
|
219
|
+
with Path(data).open("r") as file:
|
|
220
|
+
config_dict = yaml.safe_load(file)
|
|
221
|
+
|
|
222
|
+
return TokensConfig(**config_dict)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _validate_output_suffix(output_path: Union[str, Path]) -> None:
|
|
226
|
+
output_path = Path(output_path)
|
|
227
|
+
suffix = output_path.suffix.lower()
|
|
228
|
+
if suffix not in SUPPORTED_TYPES:
|
|
229
|
+
raise ValueError(
|
|
230
|
+
f"Unsupported file suffix '{suffix}' in output_path '{output_path}'. "
|
|
231
|
+
f"Only {SUPPORTED_TYPES} are supported."
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def process_dataset(
|
|
236
|
+
data: Union[str, Path],
|
|
237
|
+
output_path: Union[str, Path],
|
|
238
|
+
processor: Union[str, Path, PreTrainedTokenizerBase],
|
|
239
|
+
prompt_tokens: Union[str, Path],
|
|
240
|
+
output_tokens: Union[str, Path],
|
|
241
|
+
processor_args: Optional[dict[str, Any]] = None,
|
|
242
|
+
data_args: Optional[dict[str, Any]] = None,
|
|
243
|
+
short_prompt_strategy: ShortPromptStrategy = ShortPromptStrategy.IGNORE,
|
|
244
|
+
pad_char: Optional[str] = None,
|
|
245
|
+
concat_delimiter: Optional[str] = None,
|
|
246
|
+
push_to_hub: bool = False,
|
|
247
|
+
hub_dataset_id: Optional[str] = None,
|
|
248
|
+
random_seed: int = 42,
|
|
249
|
+
) -> None:
|
|
250
|
+
"""
|
|
251
|
+
Main method to process and save a dataset with sampled prompt/output token counts.
|
|
252
|
+
|
|
253
|
+
:param data: Path or identifier for dataset input.
|
|
254
|
+
:param output_path: File path to save the processed dataset.
|
|
255
|
+
:param processor: Tokenizer object or its config.
|
|
256
|
+
:param prompt_tokens: Prompt token config string or file.
|
|
257
|
+
:param output_tokens: Output token config string or file.
|
|
258
|
+
:param processor_args: Optional processor arguments.
|
|
259
|
+
:param data_args: Optional data loading arguments.
|
|
260
|
+
:param short_prompt_strategy: Strategy for handling short prompts.
|
|
261
|
+
:param pad_char: Character used when padding short prompts.
|
|
262
|
+
:param concat_delimiter: Delimiter for concatenation strategy.
|
|
263
|
+
:param push_to_hub: Whether to push to Hugging Face Hub.
|
|
264
|
+
:param hub_dataset_id: Dataset ID on Hugging Face Hub.
|
|
265
|
+
:param random_seed: Seed for random sampling.
|
|
266
|
+
:raises ValueError: If output path is invalid or pushing conditions unmet.
|
|
267
|
+
"""
|
|
268
|
+
|
|
269
|
+
_validate_output_suffix(output_path)
|
|
270
|
+
logger.info(
|
|
271
|
+
f"Starting dataset conversion | Input: {data} | Output directory: {output_path}"
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
dataset, column_mappings = guidellm_load_dataset(
|
|
275
|
+
data, data_args, processor, processor_args
|
|
276
|
+
)
|
|
277
|
+
tokenizer = check_load_processor(
|
|
278
|
+
processor,
|
|
279
|
+
processor_args,
|
|
280
|
+
"dataset conversion.",
|
|
281
|
+
)
|
|
282
|
+
prompt_column = column_mappings.get("prompt_column")
|
|
283
|
+
output_column = column_mappings.get(
|
|
284
|
+
"output_tokens_count_column", "output_tokens_count"
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
prompt_tokens_cfg = TokensConfig.parse_str(prompt_tokens)
|
|
288
|
+
output_tokens_cfg = TokensConfig.parse_str(output_tokens)
|
|
289
|
+
|
|
290
|
+
prompt_token_sampler = iter(
|
|
291
|
+
IntegerRangeSampler(
|
|
292
|
+
average=prompt_tokens_cfg.average,
|
|
293
|
+
variance=prompt_tokens_cfg.stdev,
|
|
294
|
+
min_value=prompt_tokens_cfg.min,
|
|
295
|
+
max_value=prompt_tokens_cfg.max,
|
|
296
|
+
random_seed=random_seed,
|
|
297
|
+
)
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
output_token_sampler = iter(
|
|
301
|
+
IntegerRangeSampler(
|
|
302
|
+
average=output_tokens_cfg.average,
|
|
303
|
+
variance=output_tokens_cfg.stdev,
|
|
304
|
+
min_value=output_tokens_cfg.min,
|
|
305
|
+
max_value=output_tokens_cfg.max,
|
|
306
|
+
random_seed=random_seed,
|
|
307
|
+
)
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
dataset_iterator = iter(dataset)
|
|
311
|
+
processed_prompts = []
|
|
312
|
+
prompt_handler = STRATEGY_HANDLERS[short_prompt_strategy]
|
|
313
|
+
|
|
314
|
+
for prompt_row in dataset_iterator:
|
|
315
|
+
prompt_text = prompt_row[prompt_column]
|
|
316
|
+
target_prompt_len = next(prompt_token_sampler)
|
|
317
|
+
|
|
318
|
+
prompt_text = prompt_handler(
|
|
319
|
+
current_prompt=prompt_text,
|
|
320
|
+
min_prompt_tokens=target_prompt_len,
|
|
321
|
+
dataset_iterator=dataset_iterator,
|
|
322
|
+
prompt_column=prompt_column,
|
|
323
|
+
tokenizer=tokenizer,
|
|
324
|
+
pad_char=pad_char,
|
|
325
|
+
concat_delimiter=concat_delimiter,
|
|
326
|
+
)
|
|
327
|
+
if prompt_text is None:
|
|
328
|
+
continue
|
|
329
|
+
|
|
330
|
+
tokens = tokenizer.encode(prompt_text)
|
|
331
|
+
if len(tokens) > target_prompt_len:
|
|
332
|
+
prompt_text = tokenizer.decode(tokens[:target_prompt_len])
|
|
333
|
+
|
|
334
|
+
processed_prompt = prompt_row.copy()
|
|
335
|
+
processed_prompt[prompt_column] = prompt_text
|
|
336
|
+
processed_prompt["prompt_tokens_count"] = target_prompt_len
|
|
337
|
+
processed_prompt[output_column] = next(output_token_sampler)
|
|
338
|
+
|
|
339
|
+
processed_prompts.append(processed_prompt)
|
|
340
|
+
|
|
341
|
+
if not processed_prompts:
|
|
342
|
+
logger.error("No prompts remained after processing")
|
|
343
|
+
return
|
|
344
|
+
|
|
345
|
+
logger.info(f"Generated processed dataset with {len(processed_prompts)} prompts")
|
|
346
|
+
|
|
347
|
+
processed_dataset = Dataset.from_list(processed_prompts)
|
|
348
|
+
save_dataset_to_file(processed_dataset, output_path)
|
|
349
|
+
logger.info(f"Conversion completed. Dataset saved to: {output_path}")
|
|
350
|
+
|
|
351
|
+
if push_to_hub:
|
|
352
|
+
push_dataset_to_hub(hub_dataset_id, processed_dataset)
|
|
353
|
+
logger.info(f"Pushed dataset to: {hub_dataset_id}")
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def push_dataset_to_hub(
|
|
357
|
+
hub_dataset_id: Optional[str],
|
|
358
|
+
processed_dataset: Dataset,
|
|
359
|
+
) -> None:
|
|
360
|
+
"""
|
|
361
|
+
Pushes the processed dataset to Hugging Face Hub using HF_TOKEN.
|
|
362
|
+
|
|
363
|
+
:param hub_dataset_id: Identifier on the Hub to push to.
|
|
364
|
+
:param processed_dataset: HuggingFace Dataset object.
|
|
365
|
+
:raises ValueError: If hub_dataset_id or HF_TOKEN is not available.
|
|
366
|
+
"""
|
|
367
|
+
|
|
368
|
+
hf_token = os.environ.get("HF_TOKEN")
|
|
369
|
+
if not hub_dataset_id or not hf_token:
|
|
370
|
+
raise ValueError(
|
|
371
|
+
"hub_dataset_id and HF_TOKEN env var must be provided when push_to_hub"
|
|
372
|
+
" is True"
|
|
373
|
+
)
|
|
374
|
+
processed_dataset.push_to_hub(hub_dataset_id, token=hf_token)
|
guidellm/utils/__init__.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
|
1
1
|
from .colors import Colors
|
|
2
|
+
from .hf_datasets import (
|
|
3
|
+
SUPPORTED_TYPES,
|
|
4
|
+
save_dataset_to_file,
|
|
5
|
+
)
|
|
2
6
|
from .hf_transformers import (
|
|
3
7
|
check_load_processor,
|
|
4
8
|
)
|
|
@@ -14,6 +18,7 @@ from .text import (
|
|
|
14
18
|
)
|
|
15
19
|
|
|
16
20
|
__all__ = [
|
|
21
|
+
"SUPPORTED_TYPES",
|
|
17
22
|
"Colors",
|
|
18
23
|
"EndlessTextCreator",
|
|
19
24
|
"IntegerRangeSampler",
|
|
@@ -22,6 +27,7 @@ __all__ = [
|
|
|
22
27
|
"filter_text",
|
|
23
28
|
"is_puncutation",
|
|
24
29
|
"load_text",
|
|
30
|
+
"save_dataset_to_file",
|
|
25
31
|
"split_text",
|
|
26
32
|
"split_text_list_by_length",
|
|
27
33
|
]
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
from datasets import Dataset
|
|
5
|
+
|
|
6
|
+
SUPPORTED_TYPES = {
|
|
7
|
+
".json",
|
|
8
|
+
".jsonl",
|
|
9
|
+
".csv",
|
|
10
|
+
".parquet",
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def save_dataset_to_file(dataset: Dataset, output_path: Union[str, Path]) -> None:
|
|
15
|
+
"""
|
|
16
|
+
Saves a HuggingFace Dataset to file in a supported format.
|
|
17
|
+
|
|
18
|
+
:param dataset: Dataset to save.
|
|
19
|
+
:param output_path: Output file path (.json, .jsonl, .csv, .parquet).
|
|
20
|
+
:raises ValueError: If the file extension is not supported.
|
|
21
|
+
"""
|
|
22
|
+
output_path = Path(output_path)
|
|
23
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
24
|
+
suffix = output_path.suffix.lower()
|
|
25
|
+
|
|
26
|
+
if suffix == ".csv":
|
|
27
|
+
dataset.to_csv(output_path)
|
|
28
|
+
elif suffix in {".json", ".jsonl"}:
|
|
29
|
+
dataset.to_json(output_path)
|
|
30
|
+
elif suffix == ".parquet":
|
|
31
|
+
dataset.to_parquet(output_path)
|
|
32
|
+
else:
|
|
33
|
+
raise ValueError(
|
|
34
|
+
f"Unsupported file suffix '{suffix}' in output_path'{output_path}'."
|
|
35
|
+
f" Only {SUPPORTED_TYPES} are supported."
|
|
36
|
+
)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
guidellm/__init__.py,sha256=qXCx-HonNByJ2PDKqOUnD7CcbxA7cazNKyqKigwyuyE,1139
|
|
2
|
-
guidellm/__main__.py,sha256=
|
|
2
|
+
guidellm/__main__.py,sha256=QwExTHhseN2JjHoclUPI1H_yspshOS_mtZ84466CnPE,11751
|
|
3
3
|
guidellm/config.py,sha256=-JuirSy1EDkQnXHfLKBV_PeQGFwi5nL088BKeCgh9Xo,6087
|
|
4
4
|
guidellm/logger.py,sha256=O4sU2QKHn_swJIEmayiEt6nIXzGHGmXqZ_Mg8CdIE5Q,2609
|
|
5
5
|
guidellm/version.py,sha256=XZeUwR24DzG1AjuKV0s8cRyc0Xv8cPiqXYSZTt7xQVg,127
|
|
@@ -27,6 +27,8 @@ guidellm/dataset/synthetic.py,sha256=PXYj06X9QuSJ98oGR7PmFVQYuqwacbq5OCnt45X2LTs
|
|
|
27
27
|
guidellm/objects/__init__.py,sha256=FJg5REEBg9hSuMAYnGAud9IxBcASpZuDwYDYEmwdnTY,384
|
|
28
28
|
guidellm/objects/pydantic.py,sha256=oF5SkXlNZLNXpnoGoJb3QHNvqDwFRLSsiPIecf8E-_g,1850
|
|
29
29
|
guidellm/objects/statistics.py,sha256=xf0PyCnJjIcvgOyD_YHgqtvgN4W8-crGi9u3v5u78Tw,36950
|
|
30
|
+
guidellm/preprocess/__init__.py,sha256=6mRs1atYwYkdX4txez_bEVk-_nCDsNt5Wo20eWZ24jA,112
|
|
31
|
+
guidellm/preprocess/dataset.py,sha256=P-TJEebcWnUFGczImc7AOi90vJ4sakZtU8NSYWhJXlM,12318
|
|
30
32
|
guidellm/request/__init__.py,sha256=Y7_O6lcOHwP-ld4VusCvHg_PfqMDyGPwjVHv_LesAL8,347
|
|
31
33
|
guidellm/request/loader.py,sha256=TN3S8gOOnseuTEztSYKRFiUy-13Wl3L64h2PLyxZKWs,9075
|
|
32
34
|
guidellm/request/request.py,sha256=BCf0Ua0h1oOj0y6HCyUh71r5tfFeS0fDrpaoU-UnMj4,3576
|
|
@@ -36,14 +38,15 @@ guidellm/scheduler/scheduler.py,sha256=xRrCkLtf_MeWB0iPYVBk1rhJUn-zb4vd2B7YkmnCs
|
|
|
36
38
|
guidellm/scheduler/strategy.py,sha256=MjSQvyBUK9-JIdPOHVEuTdNzzuIhVIfBeu2_Z98A7ak,18629
|
|
37
39
|
guidellm/scheduler/types.py,sha256=zHZ94-zEYo4LkU3qrfT3BRoZioicDMCQDiY8hYHnkfI,130
|
|
38
40
|
guidellm/scheduler/worker.py,sha256=f1FjI9JJRbz39rTIJRFJ-drcdavJGuPlJC_iQrUa4N0,17629
|
|
39
|
-
guidellm/utils/__init__.py,sha256=
|
|
41
|
+
guidellm/utils/__init__.py,sha256=l1PZxQvk6gYQshUCEPbB8CU42eB2mHn1TVDEKGYIA5c,651
|
|
40
42
|
guidellm/utils/colors.py,sha256=D0IGz8A346-Pt5qgnP3S5uV-VgngJoXbfToVCOna41k,175
|
|
43
|
+
guidellm/utils/hf_datasets.py,sha256=C99cB4StbhjC8XtnzLLGe6A0TYrs63EapQZJQmQr8dI,1023
|
|
41
44
|
guidellm/utils/hf_transformers.py,sha256=3iF40l02VEWOcS8kasO8TSws0Lp3cE-NyiqoB9GnHuA,1021
|
|
42
45
|
guidellm/utils/random.py,sha256=elA8HZ3AIN5T2pa7cgq35OVK__0SQmZVS4IzxJaOpvw,1310
|
|
43
46
|
guidellm/utils/text.py,sha256=Xn6JUWy3B7gi1l0UkBFLizV9fnZ_kM3OQDMLlIZqsgE,6347
|
|
44
|
-
guidellm-0.3.
|
|
45
|
-
guidellm-0.3.
|
|
46
|
-
guidellm-0.3.
|
|
47
|
-
guidellm-0.3.
|
|
48
|
-
guidellm-0.3.
|
|
49
|
-
guidellm-0.3.
|
|
47
|
+
guidellm-0.3.0a18.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
48
|
+
guidellm-0.3.0a18.dist-info/METADATA,sha256=Pf5eKOw0o-2KFVxAxVhm-HLfJ1w16-ewrn7idJV1sL4,18061
|
|
49
|
+
guidellm-0.3.0a18.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
50
|
+
guidellm-0.3.0a18.dist-info/entry_points.txt,sha256=DzLFEg47fF7qY1b-9laPz9jg0KSKJ1_D9TbF93kLz_E,51
|
|
51
|
+
guidellm-0.3.0a18.dist-info/top_level.txt,sha256=EXRGjnvFtL6MeZTe0tnHRMYcEWUW3vEqoG2zO7vFOtk,9
|
|
52
|
+
guidellm-0.3.0a18.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|