guidellm 0.3.1__py3-none-any.whl → 0.6.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. guidellm/__init__.py +5 -2
  2. guidellm/__main__.py +524 -255
  3. guidellm/backends/__init__.py +33 -0
  4. guidellm/backends/backend.py +109 -0
  5. guidellm/backends/openai.py +340 -0
  6. guidellm/backends/response_handlers.py +428 -0
  7. guidellm/benchmark/__init__.py +69 -39
  8. guidellm/benchmark/benchmarker.py +160 -316
  9. guidellm/benchmark/entrypoints.py +560 -127
  10. guidellm/benchmark/outputs/__init__.py +24 -0
  11. guidellm/benchmark/outputs/console.py +633 -0
  12. guidellm/benchmark/outputs/csv.py +721 -0
  13. guidellm/benchmark/outputs/html.py +473 -0
  14. guidellm/benchmark/outputs/output.py +169 -0
  15. guidellm/benchmark/outputs/serialized.py +69 -0
  16. guidellm/benchmark/profiles.py +718 -0
  17. guidellm/benchmark/progress.py +553 -556
  18. guidellm/benchmark/scenarios/__init__.py +40 -0
  19. guidellm/benchmark/scenarios/chat.json +6 -0
  20. guidellm/benchmark/scenarios/rag.json +6 -0
  21. guidellm/benchmark/schemas/__init__.py +66 -0
  22. guidellm/benchmark/schemas/base.py +402 -0
  23. guidellm/benchmark/schemas/generative/__init__.py +55 -0
  24. guidellm/benchmark/schemas/generative/accumulator.py +841 -0
  25. guidellm/benchmark/schemas/generative/benchmark.py +163 -0
  26. guidellm/benchmark/schemas/generative/entrypoints.py +381 -0
  27. guidellm/benchmark/schemas/generative/metrics.py +927 -0
  28. guidellm/benchmark/schemas/generative/report.py +158 -0
  29. guidellm/data/__init__.py +34 -4
  30. guidellm/data/builders.py +541 -0
  31. guidellm/data/collators.py +16 -0
  32. guidellm/data/config.py +120 -0
  33. guidellm/data/deserializers/__init__.py +49 -0
  34. guidellm/data/deserializers/deserializer.py +141 -0
  35. guidellm/data/deserializers/file.py +223 -0
  36. guidellm/data/deserializers/huggingface.py +94 -0
  37. guidellm/data/deserializers/memory.py +194 -0
  38. guidellm/data/deserializers/synthetic.py +246 -0
  39. guidellm/data/entrypoints.py +52 -0
  40. guidellm/data/loaders.py +190 -0
  41. guidellm/data/preprocessors/__init__.py +27 -0
  42. guidellm/data/preprocessors/formatters.py +410 -0
  43. guidellm/data/preprocessors/mappers.py +196 -0
  44. guidellm/data/preprocessors/preprocessor.py +30 -0
  45. guidellm/data/processor.py +29 -0
  46. guidellm/data/schemas.py +175 -0
  47. guidellm/data/utils/__init__.py +6 -0
  48. guidellm/data/utils/dataset.py +94 -0
  49. guidellm/extras/__init__.py +4 -0
  50. guidellm/extras/audio.py +220 -0
  51. guidellm/extras/vision.py +242 -0
  52. guidellm/logger.py +2 -2
  53. guidellm/mock_server/__init__.py +8 -0
  54. guidellm/mock_server/config.py +84 -0
  55. guidellm/mock_server/handlers/__init__.py +17 -0
  56. guidellm/mock_server/handlers/chat_completions.py +280 -0
  57. guidellm/mock_server/handlers/completions.py +280 -0
  58. guidellm/mock_server/handlers/tokenizer.py +142 -0
  59. guidellm/mock_server/models.py +510 -0
  60. guidellm/mock_server/server.py +238 -0
  61. guidellm/mock_server/utils.py +302 -0
  62. guidellm/scheduler/__init__.py +69 -26
  63. guidellm/scheduler/constraints/__init__.py +49 -0
  64. guidellm/scheduler/constraints/constraint.py +325 -0
  65. guidellm/scheduler/constraints/error.py +411 -0
  66. guidellm/scheduler/constraints/factory.py +182 -0
  67. guidellm/scheduler/constraints/request.py +312 -0
  68. guidellm/scheduler/constraints/saturation.py +722 -0
  69. guidellm/scheduler/environments.py +252 -0
  70. guidellm/scheduler/scheduler.py +137 -368
  71. guidellm/scheduler/schemas.py +358 -0
  72. guidellm/scheduler/strategies.py +617 -0
  73. guidellm/scheduler/worker.py +413 -419
  74. guidellm/scheduler/worker_group.py +712 -0
  75. guidellm/schemas/__init__.py +65 -0
  76. guidellm/schemas/base.py +417 -0
  77. guidellm/schemas/info.py +188 -0
  78. guidellm/schemas/request.py +235 -0
  79. guidellm/schemas/request_stats.py +349 -0
  80. guidellm/schemas/response.py +124 -0
  81. guidellm/schemas/statistics.py +1018 -0
  82. guidellm/{config.py → settings.py} +31 -24
  83. guidellm/utils/__init__.py +71 -8
  84. guidellm/utils/auto_importer.py +98 -0
  85. guidellm/utils/cli.py +132 -5
  86. guidellm/utils/console.py +566 -0
  87. guidellm/utils/encoding.py +778 -0
  88. guidellm/utils/functions.py +159 -0
  89. guidellm/utils/hf_datasets.py +1 -2
  90. guidellm/utils/hf_transformers.py +4 -4
  91. guidellm/utils/imports.py +9 -0
  92. guidellm/utils/messaging.py +1118 -0
  93. guidellm/utils/mixins.py +115 -0
  94. guidellm/utils/random.py +3 -4
  95. guidellm/utils/registry.py +220 -0
  96. guidellm/utils/singleton.py +133 -0
  97. guidellm/utils/synchronous.py +159 -0
  98. guidellm/utils/text.py +163 -50
  99. guidellm/utils/typing.py +41 -0
  100. guidellm/version.py +2 -2
  101. guidellm-0.6.0a5.dist-info/METADATA +364 -0
  102. guidellm-0.6.0a5.dist-info/RECORD +109 -0
  103. guidellm/backend/__init__.py +0 -23
  104. guidellm/backend/backend.py +0 -259
  105. guidellm/backend/openai.py +0 -708
  106. guidellm/backend/response.py +0 -136
  107. guidellm/benchmark/aggregator.py +0 -760
  108. guidellm/benchmark/benchmark.py +0 -837
  109. guidellm/benchmark/output.py +0 -997
  110. guidellm/benchmark/profile.py +0 -409
  111. guidellm/benchmark/scenario.py +0 -104
  112. guidellm/data/prideandprejudice.txt.gz +0 -0
  113. guidellm/dataset/__init__.py +0 -22
  114. guidellm/dataset/creator.py +0 -213
  115. guidellm/dataset/entrypoints.py +0 -42
  116. guidellm/dataset/file.py +0 -92
  117. guidellm/dataset/hf_datasets.py +0 -62
  118. guidellm/dataset/in_memory.py +0 -132
  119. guidellm/dataset/synthetic.py +0 -287
  120. guidellm/objects/__init__.py +0 -18
  121. guidellm/objects/pydantic.py +0 -89
  122. guidellm/objects/statistics.py +0 -953
  123. guidellm/preprocess/__init__.py +0 -3
  124. guidellm/preprocess/dataset.py +0 -374
  125. guidellm/presentation/__init__.py +0 -28
  126. guidellm/presentation/builder.py +0 -27
  127. guidellm/presentation/data_models.py +0 -232
  128. guidellm/presentation/injector.py +0 -66
  129. guidellm/request/__init__.py +0 -18
  130. guidellm/request/loader.py +0 -284
  131. guidellm/request/request.py +0 -79
  132. guidellm/request/types.py +0 -10
  133. guidellm/scheduler/queues.py +0 -25
  134. guidellm/scheduler/result.py +0 -155
  135. guidellm/scheduler/strategy.py +0 -495
  136. guidellm-0.3.1.dist-info/METADATA +0 -329
  137. guidellm-0.3.1.dist-info/RECORD +0 -62
  138. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/WHEEL +0 -0
  139. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/entry_points.txt +0 -0
  140. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/licenses/LICENSE +0 -0
  141. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,410 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from guidellm.data.preprocessors.preprocessor import (
6
+ DatasetPreprocessor,
7
+ PreprocessorRegistry,
8
+ )
9
+ from guidellm.schemas import GenerationRequest, GenerationRequestArguments, UsageMetrics
10
+
11
+ __all__ = [
12
+ "GenerativeAudioTranscriptionRequestFormatter",
13
+ "GenerativeAudioTranslationRequestFormatter",
14
+ "GenerativeChatCompletionsRequestFormatter",
15
+ "GenerativeTextCompletionsRequestFormatter",
16
+ "RequestFormatter",
17
+ ]
18
+
19
+
20
+ class RequestFormatter(DatasetPreprocessor):
21
+ def __init__(self, model: str, **_kwargs):
22
+ self.model = model
23
+
24
+ @staticmethod
25
+ def encode_audio(*args, **kwargs):
26
+ from guidellm.extras.audio import encode_audio
27
+
28
+ return encode_audio(*args, **kwargs)
29
+
30
+ @staticmethod
31
+ def encode_image(*args, **kwargs):
32
+ from guidellm.extras.vision import encode_image
33
+
34
+ return encode_image(*args, **kwargs)
35
+
36
+ @staticmethod
37
+ def encode_video(*args, **kwargs):
38
+ from guidellm.extras.vision import encode_video
39
+
40
+ return encode_video(*args, **kwargs)
41
+
42
+
43
+ @PreprocessorRegistry.register("text_completions")
44
+ class GenerativeTextCompletionsRequestFormatter(RequestFormatter):
45
+ def __init__(
46
+ self,
47
+ model: str,
48
+ extras: dict[str, Any] | GenerationRequestArguments | None = None,
49
+ stream: bool = True,
50
+ max_tokens: int | None = None,
51
+ max_completion_tokens: int | None = None,
52
+ ):
53
+ self.model: str = model
54
+ self.extras = (
55
+ GenerationRequestArguments(**extras)
56
+ if extras and isinstance(extras, dict)
57
+ else extras
58
+ )
59
+ self.stream: bool = stream
60
+ self.max_tokens: int | None = max_tokens or max_completion_tokens
61
+
62
+ def __call__(self, columns: dict[str, list[Any]]) -> GenerationRequest:
63
+ """
64
+ :param columns: A dict of GenerativeDatasetColumnType to Any
65
+ """
66
+ arguments: GenerationRequestArguments = GenerationRequestArguments()
67
+ arguments.body = {} # The type checker works better setting this field here
68
+ input_metrics = UsageMetrics()
69
+ output_metrics = UsageMetrics()
70
+
71
+ # Add model
72
+ if self.model is not None:
73
+ arguments.body["model"] = self.model
74
+
75
+ # Configure streaming
76
+ if self.stream:
77
+ arguments.stream = True
78
+ arguments.body["stream"] = True
79
+ arguments.body["stream_options"] = {
80
+ "include_usage": True,
81
+ "continuous_usage_stats": True,
82
+ }
83
+
84
+ # Handle output tokens
85
+ if output_tokens := sum(
86
+ count for count in columns.get("output_tokens_count_column", []) if count
87
+ ):
88
+ output_metrics.text_tokens = output_tokens
89
+ arguments.body["max_tokens"] = output_tokens
90
+ arguments.body["stop"] = None
91
+ arguments.body["ignore_eos"] = True
92
+ elif self.max_tokens is not None:
93
+ arguments.body["max_tokens"] = self.max_tokens
94
+
95
+ # Handle prompt tokens
96
+ if prompt_tokens := sum(
97
+ count for count in columns.get("prompt_tokens_count_column", []) if count
98
+ ):
99
+ input_metrics.text_tokens = prompt_tokens
100
+
101
+ # Apply extra arguments
102
+ if self.extras:
103
+ arguments.model_combine(self.extras)
104
+
105
+ # Build prompt
106
+ prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre)
107
+ text = "".join(txt for txt in columns.get("text_column", []) if txt)
108
+ if prefix or text:
109
+ prompt = prefix + text
110
+ arguments.body["prompt"] = prompt
111
+ input_metrics.add_text_metrics(prompt)
112
+
113
+ return GenerationRequest(
114
+ request_type="text_completions",
115
+ arguments=arguments,
116
+ input_metrics=input_metrics,
117
+ output_metrics=output_metrics,
118
+ )
119
+
120
+
121
+ @PreprocessorRegistry.register("chat_completions")
122
+ class GenerativeChatCompletionsRequestFormatter(RequestFormatter):
123
+ def __init__(
124
+ self,
125
+ model: str,
126
+ extras: dict[str, Any] | GenerationRequestArguments | None = None,
127
+ stream: bool = True,
128
+ max_tokens: int | None = None,
129
+ max_completion_tokens: int | None = None,
130
+ encode_kwargs: dict[str, Any] | None = None,
131
+ ):
132
+ self.model = model
133
+ self.extras = (
134
+ GenerationRequestArguments(**extras)
135
+ if extras and isinstance(extras, dict)
136
+ else extras
137
+ )
138
+ self.stream = stream
139
+ self.max_completion_tokens = max_tokens or max_completion_tokens
140
+ self.encode_image_kwargs = (
141
+ encode_kwargs.get("image", {}) if encode_kwargs else {}
142
+ )
143
+ self.encode_video_kwargs = (
144
+ encode_kwargs.get("video", {}) if encode_kwargs else {}
145
+ )
146
+ self.encode_audio_kwargs = (
147
+ encode_kwargs.get("audio", {}) if encode_kwargs else {}
148
+ )
149
+
150
+ def __call__( # noqa: C901, PLR0912, PLR0915
151
+ self, columns: dict[str, list[Any]]
152
+ ) -> GenerationRequest:
153
+ """
154
+ :param columns: A dict of GenerativeDatasetColumnType to Any
155
+ """
156
+ arguments = GenerationRequestArguments()
157
+ arguments.body = {} # The type checker works best with body assigned here
158
+ input_metrics = UsageMetrics()
159
+ output_metrics = UsageMetrics()
160
+
161
+ # Add model
162
+ if self.model is not None:
163
+ arguments.body["model"] = self.model
164
+
165
+ # Configure streaming
166
+ if self.stream:
167
+ arguments.stream = True
168
+ arguments.body["stream"] = True
169
+ arguments.body["stream_options"] = {
170
+ "include_usage": True,
171
+ "continuous_usage_stats": True,
172
+ }
173
+
174
+ # Handle output tokens
175
+ if output_tokens := sum(
176
+ count for count in columns.get("output_tokens_count_column", []) if count
177
+ ):
178
+ output_metrics.text_tokens = output_tokens
179
+ arguments.body.update(
180
+ {
181
+ "max_completion_tokens": output_tokens,
182
+ "stop": None,
183
+ "ignore_eos": True,
184
+ }
185
+ )
186
+ elif self.max_completion_tokens is not None:
187
+ arguments.body["max_completion_tokens"] = self.max_completion_tokens
188
+
189
+ # Handle prompt tokens
190
+ if prompt_tokens := sum(
191
+ count for count in columns.get("prompt_tokens_count_column", []) if count
192
+ ):
193
+ input_metrics.text_tokens = prompt_tokens
194
+
195
+ # Apply extra arguments
196
+ if self.extras:
197
+ arguments.model_combine(self.extras)
198
+
199
+ # Build messages
200
+ arguments.body["messages"] = []
201
+
202
+ for prefix in columns.get("prefix_column", []):
203
+ if not prefix:
204
+ continue
205
+
206
+ input_metrics.add_text_metrics(prefix)
207
+ arguments.body["messages"].append({"role": "system", "content": prefix})
208
+
209
+ for text in columns.get("text_column", []):
210
+ if not text:
211
+ continue
212
+
213
+ input_metrics.add_text_metrics(text)
214
+
215
+ arguments.body["messages"].append(
216
+ {"role": "user", "content": [{"type": "text", "text": text}]}
217
+ )
218
+
219
+ for image in columns.get("image_column", []):
220
+ if not image:
221
+ continue
222
+
223
+ image_dict = self.encode_image(image, **self.encode_image_kwargs)
224
+ if (image_pixels := image_dict.get("image_pixels")) is not None:
225
+ input_metrics.image_pixels = (
226
+ input_metrics.image_pixels or 0
227
+ ) + image_pixels
228
+ if (image_bytes := image_dict.get("image_bytes")) is not None:
229
+ input_metrics.image_bytes = (
230
+ input_metrics.image_bytes or 0
231
+ ) + image_bytes
232
+
233
+ arguments.body["messages"].append(
234
+ {
235
+ "role": "user",
236
+ "content": [
237
+ {
238
+ "type": "image_url",
239
+ "image_url": {"url": image_dict.get("image")},
240
+ }
241
+ ],
242
+ }
243
+ )
244
+
245
+ for video in columns.get("video_column", []):
246
+ if not video:
247
+ continue
248
+
249
+ video_dict = self.encode_video(video, **self.encode_video_kwargs)
250
+ if (video_frames := video_dict.get("video_frames")) is not None:
251
+ input_metrics.video_frames = (
252
+ input_metrics.video_frames or 0
253
+ ) + video_frames
254
+ if (video_seconds := video_dict.get("video_seconds")) is not None:
255
+ input_metrics.video_seconds = (
256
+ input_metrics.video_seconds or 0.0
257
+ ) + video_seconds
258
+ if (video_bytes := video_dict.get("video_bytes")) is not None:
259
+ input_metrics.video_bytes = (
260
+ input_metrics.video_bytes or 0
261
+ ) + video_bytes
262
+
263
+ arguments.body["messages"].append(
264
+ {
265
+ "role": "user",
266
+ "content": [
267
+ {
268
+ "type": "video_url",
269
+ "video_url": {"url": video_dict.get("video")},
270
+ }
271
+ ],
272
+ }
273
+ )
274
+
275
+ for audio in columns.get("audio_column", []):
276
+ if not audio:
277
+ continue
278
+
279
+ audio_dict = self.encode_audio(
280
+ audio, b64encode=True, **self.encode_audio_kwargs
281
+ )
282
+ if (audio_samples := audio_dict.get("audio_samples")) is not None:
283
+ input_metrics.audio_samples = (
284
+ input_metrics.audio_samples or 0
285
+ ) + audio_samples
286
+ if (audio_seconds := audio_dict.get("audio_seconds")) is not None:
287
+ input_metrics.audio_seconds = (
288
+ input_metrics.audio_seconds or 0.0
289
+ ) + audio_seconds
290
+ if (audio_bytes := audio_dict.get("audio_bytes")) is not None:
291
+ input_metrics.audio_bytes = (
292
+ input_metrics.audio_bytes or 0
293
+ ) + audio_bytes
294
+
295
+ arguments.body["messages"].append(
296
+ {
297
+ "role": "user",
298
+ "content": [
299
+ {
300
+ "type": "input_audio",
301
+ "input_audio": {
302
+ "data": audio_dict.get("audio"),
303
+ "format": audio_dict.get("format"),
304
+ },
305
+ }
306
+ ],
307
+ }
308
+ )
309
+
310
+ return GenerationRequest(
311
+ request_type="chat_completions",
312
+ arguments=arguments,
313
+ input_metrics=input_metrics,
314
+ output_metrics=output_metrics,
315
+ )
316
+
317
+
318
+ @PreprocessorRegistry.register("audio_transcriptions")
319
+ class GenerativeAudioTranscriptionRequestFormatter(RequestFormatter):
320
+ def __init__(
321
+ self,
322
+ model: str,
323
+ extras: dict[str, Any] | GenerationRequestArguments | None = None,
324
+ stream: bool = True,
325
+ encode_kwargs: dict[str, Any] | None = None,
326
+ ):
327
+ self.model = model
328
+ self.extras = (
329
+ GenerationRequestArguments(**extras)
330
+ if extras and isinstance(extras, dict)
331
+ else extras
332
+ )
333
+ self.stream = stream
334
+ self.encode_audio_kwargs = encode_kwargs or {}
335
+
336
+ def __call__( # noqa: C901
337
+ self, columns: dict[str, list[Any]]
338
+ ) -> GenerationRequest:
339
+ arguments = GenerationRequestArguments(files={})
340
+ arguments.body = {} # The type checker works best with body assigned here
341
+ input_metrics = UsageMetrics()
342
+ output_metrics = UsageMetrics()
343
+
344
+ # Add model
345
+ if self.model is not None:
346
+ arguments.body["model"] = self.model
347
+
348
+ # Configure streaming
349
+ if self.stream:
350
+ arguments.stream = True
351
+ arguments.body["stream"] = True
352
+ # NOTE: File upload endpoints use flattened stream options
353
+ arguments.body["stream_include_usage"] = True
354
+ arguments.body["stream_continuous_usage_stats"] = True
355
+
356
+ # Handle output tokens
357
+ if output_tokens := sum(
358
+ count for count in columns.get("output_tokens_count_column", []) if count
359
+ ):
360
+ output_metrics.text_tokens = output_tokens
361
+
362
+ # Handle prompt tokens (for audio duration tracking)
363
+ if prompt_tokens := sum(
364
+ count for count in columns.get("prompt_tokens_count_column", []) if count
365
+ ):
366
+ input_metrics.text_tokens = prompt_tokens
367
+
368
+ # Apply extra arguments
369
+ if self.extras:
370
+ arguments.model_combine(self.extras)
371
+
372
+ # Build audio input
373
+ audio_columns = columns.get("audio_column", [])
374
+ if len(audio_columns) != 1:
375
+ raise ValueError(
376
+ f"GenerativeAudioTranscriptionRequestFormatter expects exactly "
377
+ f"one audio column, but got {len(audio_columns)}."
378
+ )
379
+
380
+ audio_dict = self.encode_audio(
381
+ audio_columns[0], b64encode=False, **self.encode_audio_kwargs
382
+ )
383
+ input_metrics.audio_samples = audio_dict.get("audio_samples")
384
+ input_metrics.audio_seconds = audio_dict.get("audio_seconds")
385
+ input_metrics.audio_bytes = audio_dict.get("audio_bytes")
386
+
387
+ arguments.files = {
388
+ "file": (
389
+ audio_dict.get("file_name", "audio_input"),
390
+ audio_dict.get("audio"),
391
+ audio_dict.get("mimetype"),
392
+ )
393
+ }
394
+
395
+ return GenerationRequest(
396
+ request_type="audio_transcriptions",
397
+ arguments=arguments,
398
+ input_metrics=input_metrics,
399
+ output_metrics=output_metrics,
400
+ )
401
+
402
+
403
+ @PreprocessorRegistry.register("audio_translations")
404
+ class GenerativeAudioTranslationRequestFormatter(
405
+ GenerativeAudioTranscriptionRequestFormatter
406
+ ):
407
+ def __call__(self, columns: dict[str, list[Any]]) -> GenerationRequest:
408
+ result = super().__call__(columns)
409
+ result.request_type = "audio_translations"
410
+ return result
@@ -0,0 +1,196 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from typing import Any, ClassVar, cast
5
+
6
+ from datasets import Dataset, IterableDataset
7
+
8
+ from guidellm.data.preprocessors.preprocessor import (
9
+ DataDependentPreprocessor,
10
+ PreprocessorRegistry,
11
+ )
12
+ from guidellm.data.schemas import GenerativeDatasetColumnType
13
+
14
+ __all__ = ["GenerativeColumnMapper"]
15
+
16
+
17
+ @PreprocessorRegistry.register("generative_column_mapper")
18
+ class GenerativeColumnMapper(DataDependentPreprocessor):
19
+ defaults: ClassVar[dict[str, list[str]]] = {
20
+ "prompt_tokens_count_column": ["prompt_tokens_count", "input_tokens_count"],
21
+ "output_tokens_count_column": [
22
+ "output_tokens_count",
23
+ "completion_tokens_count",
24
+ ],
25
+ "prefix_column": [
26
+ "system_prompt",
27
+ "system",
28
+ "prefix",
29
+ ],
30
+ "text_column": [
31
+ "prompt",
32
+ "instruction",
33
+ "question",
34
+ "input",
35
+ "context",
36
+ "content",
37
+ "conversation",
38
+ "turn",
39
+ "text",
40
+ ],
41
+ "image_column": [
42
+ "image",
43
+ "picture",
44
+ "photo",
45
+ "img",
46
+ ],
47
+ "video_column": [
48
+ "video",
49
+ "clip",
50
+ "movie",
51
+ "footage",
52
+ "mp4",
53
+ "mov",
54
+ "avi",
55
+ ],
56
+ "audio_column": [
57
+ "audio",
58
+ "sound",
59
+ "voice",
60
+ "speech",
61
+ "wav",
62
+ "mp3",
63
+ ],
64
+ }
65
+
66
+ @classmethod
67
+ def datasets_default_mappings(
68
+ cls, datasets: list[Dataset | IterableDataset]
69
+ ) -> dict[GenerativeDatasetColumnType, list[tuple[int, str]]]:
70
+ mappings: dict[GenerativeDatasetColumnType, list[tuple[int, str]]] = (
71
+ defaultdict(list)
72
+ )
73
+
74
+ for index, dataset in enumerate(datasets):
75
+ dataset_columns = dataset.column_names or list(next(iter(dataset)).keys())
76
+
77
+ for column_type in cls.defaults:
78
+ if column_type in mappings:
79
+ continue
80
+
81
+ type_names = [
82
+ variant
83
+ for name in cls.defaults.get(column_type, [])
84
+ for plural in [name, f"{name}s", f"{name}es"]
85
+ for variant in [
86
+ plural,
87
+ plural.lower(),
88
+ plural.upper(),
89
+ plural.capitalize(),
90
+ ]
91
+ ]
92
+
93
+ for name in type_names:
94
+ if name in dataset_columns:
95
+ key = cast("GenerativeDatasetColumnType", column_type)
96
+ mappings[key].append((index, name))
97
+ break
98
+
99
+ return mappings
100
+
101
+ @classmethod
102
+ def datasets_mappings(
103
+ cls,
104
+ datasets: list[Dataset | IterableDataset],
105
+ input_mappings: dict[GenerativeDatasetColumnType, str | list[str]],
106
+ ) -> dict[GenerativeDatasetColumnType, list[tuple[int, str]]]:
107
+ mappings: dict[GenerativeDatasetColumnType, list[tuple[int, str]]] = (
108
+ defaultdict(list)
109
+ )
110
+ datasets_named_indices = {
111
+ (
112
+ dataset.info.dataset_name
113
+ if dataset.info and dataset.info.dataset_name
114
+ else index
115
+ ): index
116
+ for index, dataset in enumerate(datasets)
117
+ }
118
+ datasets_columns = {
119
+ index: dataset.column_names or list(next(iter(dataset)).keys())
120
+ for index, dataset in enumerate(datasets)
121
+ }
122
+
123
+ # Parse out user mappings that were passed in and validate them
124
+ # Must be in the format of:
125
+ # {<column_type>: [<column_names>]}
126
+ # where <column_names> can be a single string or list of strings
127
+ # and each string can be any of:
128
+ # - a column name (assumes the first dataset was intended)
129
+ # - <int>.<column_name> where <int> is the dataset index
130
+ # - <str>.<column_name> where <str> is the dataset name
131
+ for column_type, names in input_mappings.items():
132
+ mappings[column_type] = []
133
+ for name in names if isinstance(names, list) else [names]:
134
+ if "." in name:
135
+ dataset, column_name = name.split(".", 1)
136
+ dataset_index = (
137
+ int(dataset)
138
+ if dataset.isdigit()
139
+ else datasets_named_indices.get(dataset)
140
+ )
141
+ else:
142
+ dataset_index = 0
143
+ column_name = name
144
+
145
+ if dataset_index is None or dataset_index >= len(datasets):
146
+ raise ValueError(
147
+ f"Dataset '{name}' not found in datasets: "
148
+ f"{datasets_named_indices}."
149
+ )
150
+ if column_name not in datasets_columns[dataset_index]:
151
+ raise ValueError(
152
+ f"Column '{column_name}' not found in dataset "
153
+ f"'{datasets[dataset_index]}' "
154
+ f"columns: {datasets_columns[dataset_index]}."
155
+ )
156
+ mappings[column_type].append((dataset_index, column_name))
157
+
158
+ return mappings
159
+
160
+ def __init__(
161
+ self,
162
+ column_mappings: dict[GenerativeDatasetColumnType, str | list[str]]
163
+ | None = None,
164
+ ):
165
+ self.input_mappings = column_mappings
166
+ self.datasets_column_mappings: (
167
+ dict[GenerativeDatasetColumnType, list[tuple[int, str]]] | None
168
+ )
169
+
170
+ def __call__(self, row: dict[str, Any]) -> dict[str, list[Any]]:
171
+ if self.datasets_column_mappings is None:
172
+ raise ValueError("DefaultGenerativeColumnMapper not setup with data.")
173
+
174
+ items = cast("dict[int, dict[str, Any]]", row.pop("items"))
175
+ mapped: dict[str, Any] = defaultdict(list)
176
+
177
+ for column_type, column_mappings in self.datasets_column_mappings.items():
178
+ for (
179
+ dataset_index,
180
+ dataset_column,
181
+ ) in column_mappings:
182
+ mapped[column_type].append(items[dataset_index][dataset_column])
183
+
184
+ return dict(mapped)
185
+
186
+ def setup_data(
187
+ self,
188
+ datasets: list[Dataset | IterableDataset],
189
+ data_args: list[dict[str, Any]],
190
+ ):
191
+ _ = data_args # Unused for this mapper
192
+ self.datasets_column_mappings = (
193
+ self.datasets_default_mappings(datasets)
194
+ if self.input_mappings is None
195
+ else self.datasets_mappings(datasets, self.input_mappings)
196
+ )
@@ -0,0 +1,30 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Protocol, runtime_checkable
4
+
5
+ from datasets import Dataset, IterableDataset
6
+
7
+ from guidellm.schemas import GenerationRequest
8
+ from guidellm.utils import RegistryMixin
9
+
10
+ __all__ = ["DataDependentPreprocessor", "DatasetPreprocessor", "PreprocessorRegistry"]
11
+
12
+
13
+ @runtime_checkable
14
+ class DatasetPreprocessor(Protocol):
15
+ def __call__(self, item: dict[str, Any]) -> GenerationRequest | dict[str, Any]: ...
16
+
17
+
18
+ @runtime_checkable
19
+ class DataDependentPreprocessor(DatasetPreprocessor, Protocol):
20
+ def setup_data(
21
+ self,
22
+ datasets: list[Dataset | IterableDataset],
23
+ data_args: list[dict[str, Any]],
24
+ ): ...
25
+
26
+
27
+ class PreprocessorRegistry(
28
+ RegistryMixin[type[DatasetPreprocessor] | type[DataDependentPreprocessor]]
29
+ ):
30
+ pass
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ from transformers import AutoTokenizer, PreTrainedTokenizerBase # type: ignore[import]
7
+
8
+ __all__ = ["ProcessorFactory"]
9
+
10
+
11
+ class ProcessorFactory:
12
+ def __init__(
13
+ self,
14
+ processor: str | Path | PreTrainedTokenizerBase,
15
+ processor_args: dict[str, Any] | None = None,
16
+ ) -> None:
17
+ self.processor = processor
18
+ self.processor_args = processor_args or {}
19
+
20
+ def __call__(self) -> PreTrainedTokenizerBase:
21
+ if isinstance(self.processor, PreTrainedTokenizerBase):
22
+ return self.processor
23
+ else:
24
+ from_pretrained = AutoTokenizer.from_pretrained(
25
+ self.processor,
26
+ **(self.processor_args or {}),
27
+ )
28
+ self.processor = from_pretrained
29
+ return from_pretrained