guidellm 0.3.1__py3-none-any.whl → 0.6.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. guidellm/__init__.py +5 -2
  2. guidellm/__main__.py +524 -255
  3. guidellm/backends/__init__.py +33 -0
  4. guidellm/backends/backend.py +109 -0
  5. guidellm/backends/openai.py +340 -0
  6. guidellm/backends/response_handlers.py +428 -0
  7. guidellm/benchmark/__init__.py +69 -39
  8. guidellm/benchmark/benchmarker.py +160 -316
  9. guidellm/benchmark/entrypoints.py +560 -127
  10. guidellm/benchmark/outputs/__init__.py +24 -0
  11. guidellm/benchmark/outputs/console.py +633 -0
  12. guidellm/benchmark/outputs/csv.py +721 -0
  13. guidellm/benchmark/outputs/html.py +473 -0
  14. guidellm/benchmark/outputs/output.py +169 -0
  15. guidellm/benchmark/outputs/serialized.py +69 -0
  16. guidellm/benchmark/profiles.py +718 -0
  17. guidellm/benchmark/progress.py +553 -556
  18. guidellm/benchmark/scenarios/__init__.py +40 -0
  19. guidellm/benchmark/scenarios/chat.json +6 -0
  20. guidellm/benchmark/scenarios/rag.json +6 -0
  21. guidellm/benchmark/schemas/__init__.py +66 -0
  22. guidellm/benchmark/schemas/base.py +402 -0
  23. guidellm/benchmark/schemas/generative/__init__.py +55 -0
  24. guidellm/benchmark/schemas/generative/accumulator.py +841 -0
  25. guidellm/benchmark/schemas/generative/benchmark.py +163 -0
  26. guidellm/benchmark/schemas/generative/entrypoints.py +381 -0
  27. guidellm/benchmark/schemas/generative/metrics.py +927 -0
  28. guidellm/benchmark/schemas/generative/report.py +158 -0
  29. guidellm/data/__init__.py +34 -4
  30. guidellm/data/builders.py +541 -0
  31. guidellm/data/collators.py +16 -0
  32. guidellm/data/config.py +120 -0
  33. guidellm/data/deserializers/__init__.py +49 -0
  34. guidellm/data/deserializers/deserializer.py +141 -0
  35. guidellm/data/deserializers/file.py +223 -0
  36. guidellm/data/deserializers/huggingface.py +94 -0
  37. guidellm/data/deserializers/memory.py +194 -0
  38. guidellm/data/deserializers/synthetic.py +246 -0
  39. guidellm/data/entrypoints.py +52 -0
  40. guidellm/data/loaders.py +190 -0
  41. guidellm/data/preprocessors/__init__.py +27 -0
  42. guidellm/data/preprocessors/formatters.py +410 -0
  43. guidellm/data/preprocessors/mappers.py +196 -0
  44. guidellm/data/preprocessors/preprocessor.py +30 -0
  45. guidellm/data/processor.py +29 -0
  46. guidellm/data/schemas.py +175 -0
  47. guidellm/data/utils/__init__.py +6 -0
  48. guidellm/data/utils/dataset.py +94 -0
  49. guidellm/extras/__init__.py +4 -0
  50. guidellm/extras/audio.py +220 -0
  51. guidellm/extras/vision.py +242 -0
  52. guidellm/logger.py +2 -2
  53. guidellm/mock_server/__init__.py +8 -0
  54. guidellm/mock_server/config.py +84 -0
  55. guidellm/mock_server/handlers/__init__.py +17 -0
  56. guidellm/mock_server/handlers/chat_completions.py +280 -0
  57. guidellm/mock_server/handlers/completions.py +280 -0
  58. guidellm/mock_server/handlers/tokenizer.py +142 -0
  59. guidellm/mock_server/models.py +510 -0
  60. guidellm/mock_server/server.py +238 -0
  61. guidellm/mock_server/utils.py +302 -0
  62. guidellm/scheduler/__init__.py +69 -26
  63. guidellm/scheduler/constraints/__init__.py +49 -0
  64. guidellm/scheduler/constraints/constraint.py +325 -0
  65. guidellm/scheduler/constraints/error.py +411 -0
  66. guidellm/scheduler/constraints/factory.py +182 -0
  67. guidellm/scheduler/constraints/request.py +312 -0
  68. guidellm/scheduler/constraints/saturation.py +722 -0
  69. guidellm/scheduler/environments.py +252 -0
  70. guidellm/scheduler/scheduler.py +137 -368
  71. guidellm/scheduler/schemas.py +358 -0
  72. guidellm/scheduler/strategies.py +617 -0
  73. guidellm/scheduler/worker.py +413 -419
  74. guidellm/scheduler/worker_group.py +712 -0
  75. guidellm/schemas/__init__.py +65 -0
  76. guidellm/schemas/base.py +417 -0
  77. guidellm/schemas/info.py +188 -0
  78. guidellm/schemas/request.py +235 -0
  79. guidellm/schemas/request_stats.py +349 -0
  80. guidellm/schemas/response.py +124 -0
  81. guidellm/schemas/statistics.py +1018 -0
  82. guidellm/{config.py → settings.py} +31 -24
  83. guidellm/utils/__init__.py +71 -8
  84. guidellm/utils/auto_importer.py +98 -0
  85. guidellm/utils/cli.py +132 -5
  86. guidellm/utils/console.py +566 -0
  87. guidellm/utils/encoding.py +778 -0
  88. guidellm/utils/functions.py +159 -0
  89. guidellm/utils/hf_datasets.py +1 -2
  90. guidellm/utils/hf_transformers.py +4 -4
  91. guidellm/utils/imports.py +9 -0
  92. guidellm/utils/messaging.py +1118 -0
  93. guidellm/utils/mixins.py +115 -0
  94. guidellm/utils/random.py +3 -4
  95. guidellm/utils/registry.py +220 -0
  96. guidellm/utils/singleton.py +133 -0
  97. guidellm/utils/synchronous.py +159 -0
  98. guidellm/utils/text.py +163 -50
  99. guidellm/utils/typing.py +41 -0
  100. guidellm/version.py +2 -2
  101. guidellm-0.6.0a5.dist-info/METADATA +364 -0
  102. guidellm-0.6.0a5.dist-info/RECORD +109 -0
  103. guidellm/backend/__init__.py +0 -23
  104. guidellm/backend/backend.py +0 -259
  105. guidellm/backend/openai.py +0 -708
  106. guidellm/backend/response.py +0 -136
  107. guidellm/benchmark/aggregator.py +0 -760
  108. guidellm/benchmark/benchmark.py +0 -837
  109. guidellm/benchmark/output.py +0 -997
  110. guidellm/benchmark/profile.py +0 -409
  111. guidellm/benchmark/scenario.py +0 -104
  112. guidellm/data/prideandprejudice.txt.gz +0 -0
  113. guidellm/dataset/__init__.py +0 -22
  114. guidellm/dataset/creator.py +0 -213
  115. guidellm/dataset/entrypoints.py +0 -42
  116. guidellm/dataset/file.py +0 -92
  117. guidellm/dataset/hf_datasets.py +0 -62
  118. guidellm/dataset/in_memory.py +0 -132
  119. guidellm/dataset/synthetic.py +0 -287
  120. guidellm/objects/__init__.py +0 -18
  121. guidellm/objects/pydantic.py +0 -89
  122. guidellm/objects/statistics.py +0 -953
  123. guidellm/preprocess/__init__.py +0 -3
  124. guidellm/preprocess/dataset.py +0 -374
  125. guidellm/presentation/__init__.py +0 -28
  126. guidellm/presentation/builder.py +0 -27
  127. guidellm/presentation/data_models.py +0 -232
  128. guidellm/presentation/injector.py +0 -66
  129. guidellm/request/__init__.py +0 -18
  130. guidellm/request/loader.py +0 -284
  131. guidellm/request/request.py +0 -79
  132. guidellm/request/types.py +0 -10
  133. guidellm/scheduler/queues.py +0 -25
  134. guidellm/scheduler/result.py +0 -155
  135. guidellm/scheduler/strategy.py +0 -495
  136. guidellm-0.3.1.dist-info/METADATA +0 -329
  137. guidellm-0.3.1.dist-info/RECORD +0 -62
  138. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/WHEEL +0 -0
  139. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/entry_points.txt +0 -0
  140. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/licenses/LICENSE +0 -0
  141. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,541 @@
1
+ import os
2
+ from collections.abc import Callable, Iterator
3
+ from enum import Enum
4
+ from pathlib import Path
5
+ from typing import Any, cast
6
+
7
+ from datasets import Dataset
8
+ from loguru import logger
9
+ from transformers import PreTrainedTokenizerBase
10
+
11
+ from guidellm.data.config import load_config
12
+ from guidellm.data.deserializers import (
13
+ DatasetDeserializerFactory,
14
+ )
15
+ from guidellm.data.preprocessors import GenerativeColumnMapper
16
+ from guidellm.data.schemas import PreprocessDatasetConfig
17
+ from guidellm.utils import IntegerRangeSampler, check_load_processor
18
+ from guidellm.utils.hf_datasets import SUPPORTED_TYPES, save_dataset_to_file
19
+
20
+
21
+ class PromptTooShortError(Exception):
22
+ pass
23
+
24
+
25
+ class ShortPromptStrategy(str, Enum):
26
+ IGNORE = "ignore"
27
+ CONCATENATE = "concatenate"
28
+ PAD = "pad"
29
+ ERROR = "error"
30
+
31
+
32
+ class ShortPromptStrategyHandler:
33
+ """Handler class for short prompt strategies."""
34
+
35
+ @staticmethod
36
+ def handle_ignore(
37
+ current_prompt: str,
38
+ min_prompt_tokens: int,
39
+ tokenizer: PreTrainedTokenizerBase,
40
+ **_kwargs,
41
+ ) -> str | None:
42
+ """
43
+ Ignores prompts that are shorter than the required minimum token length.
44
+
45
+ :param current_prompt: The input prompt string.
46
+ :param min_prompt_tokens: Minimum required token count.
47
+ :param tokenizer: Tokenizer used to count tokens.
48
+ :return: The prompt if it meets the length, otherwise None.
49
+ """
50
+
51
+ if len(tokenizer.encode(current_prompt)) < min_prompt_tokens:
52
+ logger.warning("Prompt too short, ignoring")
53
+ return None
54
+ return current_prompt
55
+
56
+ @staticmethod
57
+ def handle_concatenate(
58
+ current_prompt: str,
59
+ min_prompt_tokens: int,
60
+ dataset_iterator: Iterator[dict[str, Any]],
61
+ prompt_column: str,
62
+ tokenizer: PreTrainedTokenizerBase,
63
+ concat_delimiter: str,
64
+ **_kwargs,
65
+ ) -> str | None:
66
+ """
67
+ Concatenates prompts until the minimum token requirement is met.
68
+
69
+ :param current_prompt: The initial prompt.
70
+ :param min_prompt_tokens: Target minimum token length.
71
+ :param dataset_iterator: Iterator to fetch more prompts.
72
+ :param prompt_column: Column key for prompt extraction.
73
+ :param tokenizer: Tokenizer used to count tokens.
74
+ :param concat_delimiter: Delimiter to use between prompts.
75
+ :return: Concatenated prompt or None if not enough data.
76
+ """
77
+
78
+ tokens_len = len(tokenizer.encode(current_prompt))
79
+ while tokens_len < min_prompt_tokens:
80
+ try:
81
+ next_row = next(dataset_iterator)
82
+ except StopIteration:
83
+ logger.warning(
84
+ "Could not concatenate enough prompts to reach minimum "
85
+ "length, ignoring"
86
+ )
87
+ return None
88
+ current_prompt += concat_delimiter + next_row[prompt_column]
89
+ tokens_len = len(tokenizer.encode(current_prompt))
90
+ return current_prompt
91
+
92
+ @staticmethod
93
+ def handle_pad(
94
+ current_prompt: str,
95
+ min_prompt_tokens: int,
96
+ tokenizer: PreTrainedTokenizerBase,
97
+ pad_char: str,
98
+ pad_multiplier: int = 2,
99
+ **_kwargs,
100
+ ) -> str:
101
+ """
102
+ Pads the prompt with a character until it reaches the minimum token length.
103
+
104
+ :param current_prompt: The input prompt.
105
+ :param min_prompt_tokens: Desired minimum token count.
106
+ :param tokenizer: Tokenizer used to count tokens.
107
+ :param pad_char: Character used for padding.
108
+ :param pad_multiplier: Multiplier for padding character length.
109
+ :return: Padded prompt string.
110
+ """
111
+ tokens = tokenizer.encode(current_prompt)
112
+ pad_count = 1
113
+ prompt = current_prompt
114
+ while len(tokens) < min_prompt_tokens:
115
+ prompt += pad_char * pad_count
116
+ tokens = tokenizer.encode(prompt)
117
+ pad_count *= pad_multiplier
118
+ return prompt
119
+
120
+ @staticmethod
121
+ def handle_error(
122
+ current_prompt: str,
123
+ min_prompt_tokens: int,
124
+ tokenizer: PreTrainedTokenizerBase,
125
+ **_kwargs,
126
+ ) -> str | None:
127
+ """
128
+ Raises an error if the prompt is too short.
129
+
130
+ :param current_prompt: The input prompt.
131
+ :param min_prompt_tokens: Required token count.
132
+ :param tokenizer: Tokenizer used to count tokens.
133
+ :return: The input prompt if valid.
134
+ :raises PromptTooShortError: If the prompt is too short.
135
+ """
136
+
137
+ prompt_len = len(tokenizer.encode(current_prompt))
138
+ if prompt_len < min_prompt_tokens:
139
+ raise PromptTooShortError(
140
+ f"Found too short prompt: {current_prompt}, with length: {prompt_len}. "
141
+ f"Minimum length required: {min_prompt_tokens}.",
142
+ )
143
+ return current_prompt
144
+
145
+ @classmethod
146
+ def get_strategy_handler(cls, strategy: ShortPromptStrategy) -> Callable[..., Any]:
147
+ """
148
+ Get the handler for a specific strategy.
149
+
150
+ :param strategy: The short prompt strategy to get the handler for.
151
+ :return: The handler callable for the specified strategy.
152
+ """
153
+ return cast("Callable[..., Any]", STRATEGY_HANDLERS[strategy])
154
+
155
+
156
+ # Initialize STRATEGY_HANDLERS after class definition to allow method references
157
+ STRATEGY_HANDLERS = {
158
+ ShortPromptStrategy.IGNORE: ShortPromptStrategyHandler.handle_ignore,
159
+ ShortPromptStrategy.CONCATENATE: ShortPromptStrategyHandler.handle_concatenate,
160
+ ShortPromptStrategy.PAD: ShortPromptStrategyHandler.handle_pad,
161
+ ShortPromptStrategy.ERROR: ShortPromptStrategyHandler.handle_error,
162
+ }
163
+
164
+
165
+ def _validate_output_suffix(output_path: str | Path) -> None:
166
+ output_path = Path(output_path)
167
+ suffix = output_path.suffix.lower()
168
+ if suffix not in SUPPORTED_TYPES:
169
+ raise ValueError(
170
+ f"Unsupported file suffix '{suffix}' in output_path '{output_path}'. "
171
+ f"Only {SUPPORTED_TYPES} are supported."
172
+ )
173
+
174
+
175
+ def parse_synthetic_config(
176
+ config_input: str | Path,
177
+ ) -> PreprocessDatasetConfig:
178
+ """
179
+ Parse PreprocessDatasetConfig from string or file path.
180
+
181
+ Reuses SyntheticTextDatasetDeserializer's parsing logic to support:
182
+ - JSON strings
183
+ - Key=value pairs
184
+ - File paths (.json, .yaml, .yml, .config)
185
+
186
+ :param config_input: String or path to config.
187
+ :return: Parsed PreprocessDatasetConfig instance.
188
+ :raises ValueError: If the format is not recognized or parsing fails.
189
+ """
190
+ config = load_config(config_input, PreprocessDatasetConfig)
191
+
192
+ if config is not None:
193
+ return config
194
+
195
+ raise ValueError(
196
+ f"Could not parse config from input: {config_input}. "
197
+ "Expected JSON string, key=value pairs, or file path "
198
+ "(.json, .yaml, .yml, .config)"
199
+ )
200
+
201
+
202
+ def process_dataset(
203
+ data: str | Path,
204
+ output_path: str | Path,
205
+ processor: str | Path | PreTrainedTokenizerBase,
206
+ config: str | Path,
207
+ processor_args: dict[str, Any] | None,
208
+ data_args: dict[str, Any] | None,
209
+ data_column_mapper: dict[str, str] | None,
210
+ short_prompt_strategy: ShortPromptStrategy,
211
+ pad_char: str | None,
212
+ concat_delimiter: str | None,
213
+ include_prefix_in_token_count: bool,
214
+ push_to_hub: bool,
215
+ hub_dataset_id: str | None,
216
+ random_seed: int,
217
+ ) -> None:
218
+ """
219
+ Main method to process and save a dataset with sampled prompt/output token counts.
220
+ """
221
+ _validate_output_suffix(output_path)
222
+ logger.info(
223
+ f"Starting dataset conversion | Input: {data} | Output: {output_path}"
224
+ )
225
+
226
+ # Parse config
227
+ config_obj = parse_synthetic_config(config)
228
+
229
+ # Load tokenizer
230
+ tokenizer = check_load_processor(
231
+ processor,
232
+ processor_args,
233
+ "dataset conversion.",
234
+ )
235
+
236
+ # Load dataset
237
+ dataset = DatasetDeserializerFactory.deserialize(
238
+ data=data,
239
+ processor_factory=lambda: tokenizer,
240
+ random_seed=random_seed,
241
+ **(data_args or {}),
242
+ )
243
+
244
+ # Setup column mapper
245
+ column_mapper = GenerativeColumnMapper(
246
+ column_mappings=data_column_mapper # type: ignore[arg-type]
247
+ )
248
+ column_mapper.setup_data(
249
+ datasets=[dataset],
250
+ data_args=[data_args or {}],
251
+ )
252
+
253
+ # Extract column names from mapper
254
+ prompt_column, prefix_column, output_column = _extract_column_names(column_mapper)
255
+
256
+ # Create token samplers
257
+ prompt_token_sampler, output_token_sampler, prefix_tokens_max = (
258
+ _create_token_samplers(
259
+ config_obj,
260
+ random_seed,
261
+ )
262
+ )
263
+
264
+ # Process dataset
265
+ dataset_iterator = iter(dataset)
266
+ processed_prompts = []
267
+ prompt_handler = ShortPromptStrategyHandler.get_strategy_handler(
268
+ short_prompt_strategy
269
+ )
270
+
271
+ for row in dataset_iterator:
272
+ processed_row = _process_single_row(
273
+ row=row,
274
+ prompt_column=prompt_column,
275
+ prefix_column=prefix_column,
276
+ prompt_token_sampler=prompt_token_sampler,
277
+ output_token_sampler=output_token_sampler,
278
+ tokenizer=tokenizer,
279
+ prompt_handler=prompt_handler,
280
+ dataset_iterator=dataset_iterator,
281
+ include_prefix_in_token_count=include_prefix_in_token_count,
282
+ pad_char=pad_char,
283
+ concat_delimiter=concat_delimiter,
284
+ output_column=output_column,
285
+ prefix_tokens_max=prefix_tokens_max,
286
+ )
287
+ if processed_row is not None:
288
+ processed_prompts.append(processed_row)
289
+
290
+ # Finalize
291
+ _finalize_processed_dataset(
292
+ processed_prompts,
293
+ output_path,
294
+ push_to_hub,
295
+ hub_dataset_id,
296
+ )
297
+
298
+
299
+ def _extract_column_names(
300
+ column_mapper: GenerativeColumnMapper,
301
+ ) -> tuple[str, str | None, str]:
302
+ """
303
+ Extract column names for prompt, prefix, and output from column mapper.
304
+
305
+ :param column_mapper: Initialized column mapper.
306
+ :return: Tuple of (prompt_column, prefix_column, output_column).
307
+ :raises ValueError: If column mapper is not properly initialized.
308
+ """
309
+ if column_mapper.datasets_column_mappings is None:
310
+ raise ValueError("Column mapper not properly initialized")
311
+
312
+ text_mappings = column_mapper.datasets_column_mappings.get("text_column", [])
313
+ if not text_mappings:
314
+ raise ValueError("Could not find text column in dataset")
315
+ prompt_column = text_mappings[0][1]
316
+
317
+ prefix_mappings = column_mapper.datasets_column_mappings.get("prefix_column", [])
318
+ prefix_column = prefix_mappings[0][1] if prefix_mappings else None
319
+
320
+ output_mappings = column_mapper.datasets_column_mappings.get(
321
+ "output_tokens_count_column", []
322
+ )
323
+ output_column = (
324
+ output_mappings[0][1] if output_mappings else "output_tokens_count"
325
+ )
326
+
327
+ return prompt_column, prefix_column, output_column
328
+
329
+
330
+ def _create_token_samplers(
331
+ config_obj: PreprocessDatasetConfig,
332
+ random_seed: int,
333
+ ) -> tuple[Iterator[int], Iterator[int], int | None]:
334
+ """
335
+ Create token samplers for prompt, output, and prefix tokens.
336
+
337
+ :param config_obj: Configuration object with token settings.
338
+ :param prefix_tokens: Optional single prefix token count.
339
+ :param random_seed: Seed for random sampling.
340
+ :return: Tuple of (prompt_sampler, output_sampler, prefix_tokens_max).
341
+ prefix_sampler is None when prefix_tokens is not provided.
342
+ prefix_tokens_max is the maximum prefix token limit from config.
343
+ """
344
+ prompt_token_sampler = iter(
345
+ IntegerRangeSampler(
346
+ average=config_obj.prompt_tokens,
347
+ variance=config_obj.prompt_tokens_stdev,
348
+ min_value=config_obj.prompt_tokens_min,
349
+ max_value=config_obj.prompt_tokens_max,
350
+ random_seed=random_seed,
351
+ )
352
+ )
353
+
354
+ output_token_sampler = iter(
355
+ IntegerRangeSampler(
356
+ average=config_obj.output_tokens,
357
+ variance=config_obj.output_tokens_stdev,
358
+ min_value=config_obj.output_tokens_min,
359
+ max_value=config_obj.output_tokens_max,
360
+ random_seed=random_seed,
361
+ )
362
+ )
363
+
364
+ return prompt_token_sampler, output_token_sampler, config_obj.prefix_tokens_max
365
+
366
+
367
+ def _process_dataset_row(
368
+ row: dict[str, Any],
369
+ prompt_column: str,
370
+ prefix_column: str | None,
371
+ output_column: str,
372
+ target_output_len: int,
373
+ prompt_text: str,
374
+ prefix_text: str | None,
375
+ tokens: list[int],
376
+ ) -> dict[str, Any]:
377
+ """
378
+ Create a processed row from the processed prompt/prefix data.
379
+
380
+ :param row: Original dataset row.
381
+ :param prompt_column: Name of prompt column.
382
+ :param prefix_column: Name of prefix column or None.
383
+ :param output_column: Name of output tokens count column.
384
+ :param target_prompt_len: Target prompt token length.
385
+ :param target_output_len: Target output token length.
386
+ :param prompt_text: Processed prompt text.
387
+ :param prefix_text: Processed prefix text or None.
388
+ :param tokens: Tokenized prompt.
389
+ :return: Processed row dictionary.
390
+ """
391
+ processed_row = row.copy()
392
+ processed_row[prompt_column] = prompt_text
393
+ if prefix_column and prefix_text:
394
+ processed_row[prefix_column] = prefix_text
395
+ processed_row["prompt_tokens_count"] = len(tokens)
396
+ processed_row[output_column] = target_output_len
397
+ return processed_row
398
+
399
+
400
+ def _process_single_row(
401
+ row: dict[str, Any],
402
+ prompt_column: str,
403
+ prefix_column: str | None,
404
+ prompt_token_sampler: Iterator[int],
405
+ output_token_sampler: Iterator[int],
406
+ tokenizer: PreTrainedTokenizerBase,
407
+ prompt_handler: Callable,
408
+ dataset_iterator: Iterator[dict[str, Any]],
409
+ include_prefix_in_token_count: bool,
410
+ pad_char: str | None,
411
+ concat_delimiter: str | None,
412
+ output_column: str,
413
+ prefix_tokens_max: int | None,
414
+ ) -> dict[str, Any] | None:
415
+ """
416
+ Process a single row from the dataset.
417
+
418
+ :param include_prefix_in_token_count: When True, includes prefix tokens in the
419
+ prompt token count calculation. When False, prefix tokens are not counted
420
+ toward prompt tokens.
421
+ :param prefix_tokens_max: Maximum prefix token limit. If set, the prefix will be
422
+ trimmed if it exceeds this limit.
423
+ :return: Processed row dictionary or None if row should be skipped.
424
+ """
425
+ # Extract prompt and prefix
426
+ prompt_text = row.get(prompt_column, "")
427
+ prefix_text = row.get(prefix_column) if prefix_column else None
428
+
429
+ # Sample target prompt token count
430
+ target_prompt_len = next(prompt_token_sampler)
431
+ count_adjustment = 0
432
+
433
+ # Handle prefix
434
+ if prefix_text:
435
+ # Apply prefix_tokens_max limit if set (strict maximum)
436
+ if prefix_tokens_max is not None:
437
+ prefix_tokens_list = tokenizer.encode(prefix_text)
438
+ if len(prefix_tokens_list) > prefix_tokens_max:
439
+ prefix_text = tokenizer.decode(
440
+ prefix_tokens_list[:prefix_tokens_max]
441
+ )
442
+
443
+ # Count prefix tokens toward prompt if enabled
444
+ if include_prefix_in_token_count:
445
+ count_adjustment = len(tokenizer.encode(prefix_text))
446
+
447
+ if target_prompt_len == 0:
448
+ logger.warning("zero prompt size requested; skipping row")
449
+ return None
450
+ elif count_adjustment > 0:
451
+ adjusted_prompt_len = target_prompt_len - count_adjustment
452
+ if adjusted_prompt_len <= 0:
453
+ logger.warning("The prefix exceeds target output length with "
454
+ "--include-prefix-in-token-count enabled; Using prompt size"
455
+ "of 1; skipping row")
456
+ return None
457
+ target_prompt_len = adjusted_prompt_len
458
+
459
+ # Handle short prompts
460
+ prompt_text = prompt_handler(
461
+ current_prompt=prompt_text,
462
+ min_prompt_tokens=target_prompt_len,
463
+ dataset_iterator=dataset_iterator,
464
+ prompt_column=prompt_column,
465
+ tokenizer=tokenizer,
466
+ pad_char=pad_char,
467
+ concat_delimiter=concat_delimiter,
468
+ )
469
+ if prompt_text is None:
470
+ return None
471
+
472
+ # Trim long prompts
473
+ tokens = tokenizer.encode(prompt_text)
474
+ if len(tokens) > target_prompt_len:
475
+ prompt_text = tokenizer.decode(tokens[:target_prompt_len])
476
+ tokens = tokenizer.encode(prompt_text)
477
+
478
+ # Sample output token count
479
+ target_output_len = next(output_token_sampler)
480
+
481
+ # Create processed row
482
+ return _process_dataset_row(
483
+ row=row,
484
+ prompt_column=prompt_column,
485
+ prefix_column=prefix_column,
486
+ output_column=output_column,
487
+ target_output_len=target_output_len,
488
+ prompt_text=prompt_text,
489
+ prefix_text=prefix_text,
490
+ tokens=tokens,
491
+ )
492
+
493
+
494
+ def _finalize_processed_dataset(
495
+ processed_prompts: list[dict[str, Any]],
496
+ output_path: str | Path,
497
+ push_to_hub: bool,
498
+ hub_dataset_id: str | None,
499
+ ) -> None:
500
+ """
501
+ Finalize the processed dataset by saving and optionally pushing to hub.
502
+
503
+ :param processed_prompts: List of processed row dictionaries.
504
+ :param output_path: Path to save the dataset.
505
+ :param push_to_hub: Whether to push to Hugging Face Hub.
506
+ :param hub_dataset_id: Dataset ID on Hugging Face Hub.
507
+ """
508
+ if not processed_prompts:
509
+ logger.error("No prompts remained after processing")
510
+ return
511
+
512
+ logger.info(f"Generated processed dataset with {len(processed_prompts)} prompts")
513
+
514
+ processed_dataset = Dataset.from_list(processed_prompts)
515
+ save_dataset_to_file(processed_dataset, output_path)
516
+ logger.info(f"Conversion completed. Dataset saved to: {output_path}")
517
+
518
+ if push_to_hub:
519
+ push_dataset_to_hub(hub_dataset_id, processed_dataset)
520
+ logger.info(f"Pushed dataset to: {hub_dataset_id}")
521
+
522
+
523
+ def push_dataset_to_hub(
524
+ hub_dataset_id: str | None,
525
+ processed_dataset: Dataset,
526
+ ) -> None:
527
+ """
528
+ Pushes the processed dataset to Hugging Face Hub using HF_TOKEN.
529
+
530
+ :param hub_dataset_id: Identifier on the Hub to push to.
531
+ :param processed_dataset: HuggingFace Dataset object.
532
+ :raises ValueError: If hub_dataset_id or HF_TOKEN is not available.
533
+ """
534
+
535
+ hf_token = os.environ.get("HF_TOKEN")
536
+ if not hub_dataset_id or not hf_token:
537
+ raise ValueError(
538
+ "hub_dataset_id and HF_TOKEN env var must be provided when push_to_hub"
539
+ " is True"
540
+ )
541
+ processed_dataset.push_to_hub(hub_dataset_id, token=hf_token)
@@ -0,0 +1,16 @@
1
+ from __future__ import annotations
2
+
3
+ from guidellm.schemas import GenerationRequest
4
+
5
+ __all__ = ["GenerativeRequestCollator"]
6
+
7
+
8
+ class GenerativeRequestCollator:
9
+ def __call__(self, batch: list) -> GenerationRequest:
10
+ if len(batch) != 1:
11
+ raise NotImplementedError(
12
+ f"Batch size greater than 1 is not currently supported. "
13
+ f"Got batch size: {len(batch)}"
14
+ )
15
+
16
+ return batch[0]
@@ -0,0 +1,120 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, TypeVar
5
+
6
+ import yaml
7
+ from pydantic import ValidationError
8
+
9
+ from guidellm.data.schemas import DataConfig, DataNotSupportedError
10
+
11
+ ConfigT = TypeVar("ConfigT", bound=DataConfig)
12
+
13
+
14
+ def load_config(config: Any, config_class: type[ConfigT]) -> ConfigT | None:
15
+ # Try file path first
16
+ if (loaded_config := _load_config_file(config, config_class)) is not None:
17
+ return loaded_config
18
+
19
+ # Try dict parsing next
20
+ if (loaded_config := _load_config_dict(config, config_class)) is not None:
21
+ return loaded_config
22
+
23
+ # Try string parsing
24
+ if (loaded_config := _load_config_str(config, config_class)) is not None:
25
+ return loaded_config
26
+
27
+ return None
28
+
29
+
30
+ def _load_config_dict(data: Any, config_class: type[ConfigT]) -> ConfigT | None:
31
+ if not isinstance(data, dict | list):
32
+ return None
33
+
34
+ try:
35
+ return config_class.model_validate(data)
36
+ except ValidationError:
37
+ return None
38
+
39
+
40
+ def _load_config_file(data: Any, config_class: type[ConfigT]) -> ConfigT | None:
41
+ if (not isinstance(data, str) and not isinstance(data, Path)) or (
42
+ not Path(data).is_file()
43
+ ):
44
+ return None
45
+
46
+ data_path = Path(data) if isinstance(data, str) else data
47
+ error = None
48
+
49
+ if Path(data).is_file() and data_path.suffix.lower() == ".json":
50
+ try:
51
+ return config_class.model_validate_json(
52
+ data_path.read_text()
53
+ )
54
+ except Exception as err: # noqa: BLE001
55
+ error = err
56
+
57
+ if Path(data).is_file() and data_path.suffix.lower() in {
58
+ ".yaml",
59
+ ".yml",
60
+ ".config",
61
+ }:
62
+ try:
63
+ return config_class.model_validate(
64
+ yaml.safe_load(data_path.read_text())
65
+ )
66
+ except Exception as err: # noqa: BLE001
67
+ error = err
68
+
69
+ err_message = (
70
+ f"Unsupported file {data_path} for "
71
+ f"{config_class.__name__}, expected .json, "
72
+ f".yaml, .yml, or .config"
73
+ )
74
+
75
+ if error is not None:
76
+ err_message += f" with error: {error}"
77
+ raise DataNotSupportedError(err_message) from error
78
+ raise DataNotSupportedError(err_message)
79
+
80
+
81
+ def _load_config_str(data: str, config_class: type[ConfigT]) -> ConfigT | None:
82
+ if not isinstance(data, str):
83
+ return None
84
+
85
+ data_str = data.strip()
86
+ error = None
87
+
88
+ if (data_str.startswith("{") and data_str.endswith("}")) or (
89
+ data_str.startswith("[") and data_str.endswith("]")
90
+ ):
91
+ try:
92
+ return config_class.model_validate_json(data_str)
93
+ except Exception as err: # noqa: BLE001
94
+ error = err
95
+
96
+ if data_str.count("=") > 1:
97
+ # key=value pairs separated by commas
98
+ try:
99
+ config_dict = {}
100
+ items = data_str.split(",")
101
+ for item in items:
102
+ key, value = item.split("=")
103
+ config_dict[key.strip()] = (
104
+ int(value.strip())
105
+ if value.strip().isnumeric()
106
+ else value.strip()
107
+ )
108
+
109
+ return config_class.model_validate(config_dict)
110
+ except Exception as err: # noqa: BLE001
111
+ error = err
112
+
113
+ err_message = (
114
+ f"Unsupported string data for {config_class.__name__}, "
115
+ f"expected JSON or key-value pairs, got {data}"
116
+ )
117
+ if error is not None:
118
+ err_message += f" with error: {error}"
119
+ raise DataNotSupportedError(err_message) from error
120
+ raise DataNotSupportedError(err_message)