guidellm 0.4.0a18__py3-none-any.whl → 0.4.0a155__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guidellm might be problematic. Click here for more details.

Files changed (116) hide show
  1. guidellm/__init__.py +5 -2
  2. guidellm/__main__.py +451 -252
  3. guidellm/backends/__init__.py +33 -0
  4. guidellm/backends/backend.py +110 -0
  5. guidellm/backends/openai.py +355 -0
  6. guidellm/backends/response_handlers.py +455 -0
  7. guidellm/benchmark/__init__.py +53 -39
  8. guidellm/benchmark/benchmarker.py +148 -317
  9. guidellm/benchmark/entrypoints.py +466 -128
  10. guidellm/benchmark/output.py +517 -771
  11. guidellm/benchmark/profile.py +580 -280
  12. guidellm/benchmark/progress.py +568 -549
  13. guidellm/benchmark/scenarios/__init__.py +40 -0
  14. guidellm/benchmark/scenarios/chat.json +6 -0
  15. guidellm/benchmark/scenarios/rag.json +6 -0
  16. guidellm/benchmark/schemas.py +2085 -0
  17. guidellm/data/__init__.py +28 -4
  18. guidellm/data/collators.py +16 -0
  19. guidellm/data/deserializers/__init__.py +53 -0
  20. guidellm/data/deserializers/deserializer.py +109 -0
  21. guidellm/data/deserializers/file.py +222 -0
  22. guidellm/data/deserializers/huggingface.py +94 -0
  23. guidellm/data/deserializers/memory.py +192 -0
  24. guidellm/data/deserializers/synthetic.py +346 -0
  25. guidellm/data/loaders.py +145 -0
  26. guidellm/data/preprocessors/__init__.py +25 -0
  27. guidellm/data/preprocessors/formatters.py +412 -0
  28. guidellm/data/preprocessors/mappers.py +198 -0
  29. guidellm/data/preprocessors/preprocessor.py +29 -0
  30. guidellm/data/processor.py +30 -0
  31. guidellm/data/schemas.py +13 -0
  32. guidellm/data/utils/__init__.py +10 -0
  33. guidellm/data/utils/dataset.py +94 -0
  34. guidellm/data/utils/functions.py +18 -0
  35. guidellm/extras/__init__.py +4 -0
  36. guidellm/extras/audio.py +215 -0
  37. guidellm/extras/vision.py +242 -0
  38. guidellm/logger.py +2 -2
  39. guidellm/mock_server/__init__.py +8 -0
  40. guidellm/mock_server/config.py +84 -0
  41. guidellm/mock_server/handlers/__init__.py +17 -0
  42. guidellm/mock_server/handlers/chat_completions.py +280 -0
  43. guidellm/mock_server/handlers/completions.py +280 -0
  44. guidellm/mock_server/handlers/tokenizer.py +142 -0
  45. guidellm/mock_server/models.py +510 -0
  46. guidellm/mock_server/server.py +168 -0
  47. guidellm/mock_server/utils.py +302 -0
  48. guidellm/preprocess/dataset.py +23 -26
  49. guidellm/presentation/builder.py +2 -2
  50. guidellm/presentation/data_models.py +25 -21
  51. guidellm/presentation/injector.py +2 -3
  52. guidellm/scheduler/__init__.py +65 -26
  53. guidellm/scheduler/constraints.py +1035 -0
  54. guidellm/scheduler/environments.py +252 -0
  55. guidellm/scheduler/scheduler.py +140 -368
  56. guidellm/scheduler/schemas.py +272 -0
  57. guidellm/scheduler/strategies.py +519 -0
  58. guidellm/scheduler/worker.py +391 -420
  59. guidellm/scheduler/worker_group.py +707 -0
  60. guidellm/schemas/__init__.py +31 -0
  61. guidellm/schemas/info.py +159 -0
  62. guidellm/schemas/request.py +216 -0
  63. guidellm/schemas/response.py +119 -0
  64. guidellm/schemas/stats.py +228 -0
  65. guidellm/{config.py → settings.py} +32 -21
  66. guidellm/utils/__init__.py +95 -8
  67. guidellm/utils/auto_importer.py +98 -0
  68. guidellm/utils/cli.py +46 -2
  69. guidellm/utils/console.py +183 -0
  70. guidellm/utils/encoding.py +778 -0
  71. guidellm/utils/functions.py +134 -0
  72. guidellm/utils/hf_datasets.py +1 -2
  73. guidellm/utils/hf_transformers.py +4 -4
  74. guidellm/utils/imports.py +9 -0
  75. guidellm/utils/messaging.py +1118 -0
  76. guidellm/utils/mixins.py +115 -0
  77. guidellm/utils/pydantic_utils.py +411 -0
  78. guidellm/utils/random.py +3 -4
  79. guidellm/utils/registry.py +220 -0
  80. guidellm/utils/singleton.py +133 -0
  81. guidellm/{objects → utils}/statistics.py +341 -247
  82. guidellm/utils/synchronous.py +159 -0
  83. guidellm/utils/text.py +163 -50
  84. guidellm/utils/typing.py +41 -0
  85. guidellm/version.py +1 -1
  86. {guidellm-0.4.0a18.dist-info → guidellm-0.4.0a155.dist-info}/METADATA +33 -10
  87. guidellm-0.4.0a155.dist-info/RECORD +96 -0
  88. guidellm/backend/__init__.py +0 -23
  89. guidellm/backend/backend.py +0 -259
  90. guidellm/backend/openai.py +0 -705
  91. guidellm/backend/response.py +0 -136
  92. guidellm/benchmark/aggregator.py +0 -760
  93. guidellm/benchmark/benchmark.py +0 -837
  94. guidellm/benchmark/scenario.py +0 -104
  95. guidellm/data/prideandprejudice.txt.gz +0 -0
  96. guidellm/dataset/__init__.py +0 -22
  97. guidellm/dataset/creator.py +0 -213
  98. guidellm/dataset/entrypoints.py +0 -42
  99. guidellm/dataset/file.py +0 -92
  100. guidellm/dataset/hf_datasets.py +0 -62
  101. guidellm/dataset/in_memory.py +0 -132
  102. guidellm/dataset/synthetic.py +0 -287
  103. guidellm/objects/__init__.py +0 -18
  104. guidellm/objects/pydantic.py +0 -89
  105. guidellm/request/__init__.py +0 -18
  106. guidellm/request/loader.py +0 -284
  107. guidellm/request/request.py +0 -79
  108. guidellm/request/types.py +0 -10
  109. guidellm/scheduler/queues.py +0 -25
  110. guidellm/scheduler/result.py +0 -155
  111. guidellm/scheduler/strategy.py +0 -495
  112. guidellm-0.4.0a18.dist-info/RECORD +0 -62
  113. {guidellm-0.4.0a18.dist-info → guidellm-0.4.0a155.dist-info}/WHEEL +0 -0
  114. {guidellm-0.4.0a18.dist-info → guidellm-0.4.0a155.dist-info}/entry_points.txt +0 -0
  115. {guidellm-0.4.0a18.dist-info → guidellm-0.4.0a155.dist-info}/licenses/LICENSE +0 -0
  116. {guidellm-0.4.0a18.dist-info → guidellm-0.4.0a155.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,192 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import csv
5
+ import json
6
+ from collections.abc import Callable
7
+ from io import StringIO
8
+ from typing import Any, cast
9
+
10
+ from datasets import Dataset
11
+ from transformers import PreTrainedTokenizerBase
12
+
13
+ from guidellm.data.deserializers.deserializer import (
14
+ DataNotSupportedError,
15
+ DatasetDeserializer,
16
+ DatasetDeserializerFactory,
17
+ )
18
+
19
+ __all__ = [
20
+ "InMemoryCsvDatasetDeserializer",
21
+ "InMemoryDictDatasetDeserializer",
22
+ "InMemoryDictListDatasetDeserializer",
23
+ "InMemoryItemListDatasetDeserializer",
24
+ "InMemoryJsonStrDatasetDeserializer",
25
+ ]
26
+
27
+
28
+ @DatasetDeserializerFactory.register("in_memory_dict")
29
+ class InMemoryDictDatasetDeserializer(DatasetDeserializer):
30
+ def __call__(
31
+ self,
32
+ data: Any,
33
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
34
+ random_seed: int,
35
+ **data_kwargs: dict[str, Any],
36
+ ) -> dict[str, list]:
37
+ _ = (processor_factory, random_seed) # Ignore unused args format errors
38
+
39
+ if (
40
+ not data
41
+ or not isinstance(data, dict)
42
+ or not all(
43
+ isinstance(key, str) and isinstance(val, list)
44
+ for key, val in data.items()
45
+ )
46
+ ):
47
+ raise DataNotSupportedError(
48
+ f"Unsupported data for InMemoryDictDatasetDeserializer, "
49
+ f"expected dict[str, list], got {data}"
50
+ )
51
+
52
+ rows = len(list(data.values())[0])
53
+ if not all(len(val) == rows for val in data.values()):
54
+ raise DataNotSupportedError(
55
+ "All lists in the data dictionary must have the same length, "
56
+ f"expected {rows} for all keys {list(data.keys())}"
57
+ )
58
+
59
+ return Dataset.from_dict(data, **data_kwargs)
60
+
61
+
62
+ @DatasetDeserializerFactory.register("in_memory_dict_list")
63
+ class InMemoryDictListDatasetDeserializer(DatasetDeserializer):
64
+ def __call__(
65
+ self,
66
+ data: Any,
67
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
68
+ random_seed: int,
69
+ **data_kwargs: dict[str, Any],
70
+ ) -> dict[str, list]:
71
+ _ = (processor_factory, random_seed) # Ignore unused args format errors
72
+
73
+ if (
74
+ not data
75
+ or not isinstance(data, list)
76
+ or not all(isinstance(item, dict) for item in data)
77
+ or not all(isinstance(key, str) for item in data for key in item)
78
+ ):
79
+ raise DataNotSupportedError(
80
+ f"Unsupported data for InMemoryDictListDatasetDeserializer, "
81
+ f"expected list of dicts, got {data}"
82
+ )
83
+
84
+ data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data)
85
+ first_keys = set(data[0].keys())
86
+ for index, item in enumerate(data):
87
+ if set(item.keys()) != first_keys:
88
+ raise DataNotSupportedError(
89
+ f"All dictionaries must have the same keys. "
90
+ f"Expected keys: {first_keys}, "
91
+ f"got keys at index {index}: {set(item.keys())}"
92
+ )
93
+
94
+ # Convert list of dicts to dict of lists
95
+ result_dict = {key: [] for key in first_keys}
96
+ for item in data:
97
+ for key, value in item.items():
98
+ result_dict[key].append(value)
99
+
100
+ return Dataset.from_dict(result_dict, **data_kwargs)
101
+
102
+
103
+ @DatasetDeserializerFactory.register("in_memory_item_list")
104
+ class InMemoryItemListDatasetDeserializer(DatasetDeserializer):
105
+ def __call__(
106
+ self,
107
+ data: Any,
108
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
109
+ random_seed: int,
110
+ **data_kwargs: dict[str, Any],
111
+ ) -> dict[str, list]:
112
+ _ = (processor_factory, random_seed) # Ignore unused args format errors
113
+
114
+ primitive_types = (str, int, float, bool, type(None))
115
+ if (
116
+ not data
117
+ or not isinstance(data, list)
118
+ or not all(isinstance(item, primitive_types) for item in data)
119
+ ):
120
+ raise DataNotSupportedError(
121
+ f"Unsupported data for InMemoryItemListDatasetDeserializer, "
122
+ f"expected list of primitive items, got {data}"
123
+ )
124
+
125
+ column_name = data_kwargs.pop("column_name", "data")
126
+
127
+ return Dataset.from_dict({column_name: data}, **data_kwargs)
128
+
129
+
130
+ @DatasetDeserializerFactory.register("in_memory_json_str")
131
+ class InMemoryJsonStrDatasetDeserializer(DatasetDeserializer):
132
+ def __call__(
133
+ self,
134
+ data: Any,
135
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
136
+ random_seed: int,
137
+ **data_kwargs: dict[str, Any],
138
+ ) -> dict[str, list]:
139
+ if (
140
+ isinstance(data, str)
141
+ and (json_str := data.strip())
142
+ and (
143
+ (json_str.startswith("{") and json_str.endswith("}"))
144
+ or (json_str.startswith("[") and json_str.endswith("]"))
145
+ )
146
+ ):
147
+ with contextlib.suppress(Exception):
148
+ parsed = json.loads(data)
149
+
150
+ for deserializer in [
151
+ InMemoryDictDatasetDeserializer,
152
+ InMemoryDictListDatasetDeserializer,
153
+ InMemoryItemListDatasetDeserializer,
154
+ ]:
155
+ with contextlib.suppress(DataNotSupportedError):
156
+ return deserializer()(
157
+ parsed, data_kwargs, processor_factory, random_seed
158
+ )
159
+
160
+ raise DataNotSupportedError(
161
+ f"Unsupported data for InMemoryJsonStrDatasetDeserializer, "
162
+ f"expected JSON string with a list or dict of items, got {data}"
163
+ )
164
+
165
+
166
+ @DatasetDeserializerFactory.register("in_memory_csv_str")
167
+ class InMemoryCsvDatasetDeserializer(DatasetDeserializer):
168
+ def __call__(
169
+ self,
170
+ data: Any,
171
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
172
+ random_seed: int,
173
+ **data_kwargs: dict[str, Any],
174
+ ) -> dict[str, list]:
175
+ if (
176
+ isinstance(data, str)
177
+ and (csv_str := data.strip())
178
+ and len(csv_str.split("\n")) > 0
179
+ ):
180
+ with contextlib.suppress(Exception):
181
+ csv_buffer = StringIO(data)
182
+ reader = csv.DictReader(csv_buffer)
183
+ rows = list(reader)
184
+
185
+ return InMemoryDictListDatasetDeserializer()(
186
+ rows, processor_factory, random_seed, **data_kwargs
187
+ )
188
+
189
+ raise DataNotSupportedError(
190
+ f"Unsupported data for InMemoryCsvDatasetDeserializer, "
191
+ f"expected CSV string, got {type(data)}"
192
+ )
@@ -0,0 +1,346 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from collections.abc import Callable, Iterator
5
+ from pathlib import Path
6
+ from random import Random
7
+ from typing import Any
8
+
9
+ import yaml
10
+ from datasets import Features, IterableDataset, Value
11
+ from faker import Faker
12
+ from pydantic import ConfigDict, Field, model_validator
13
+ from transformers import PreTrainedTokenizerBase
14
+
15
+ from guidellm.data.deserializers.deserializer import (
16
+ DataNotSupportedError,
17
+ DatasetDeserializer,
18
+ DatasetDeserializerFactory,
19
+ )
20
+ from guidellm.utils import IntegerRangeSampler, StandardBaseModel
21
+
22
+ __all__ = [
23
+ "SyntheticTextDatasetConfig",
24
+ "SyntheticTextDatasetDeserializer",
25
+ "SyntheticTextGenerator",
26
+ "SyntheticTextPrefixBucketConfig",
27
+ ]
28
+
29
+
30
+ class SyntheticTextPrefixBucketConfig(StandardBaseModel):
31
+ bucket_weight: int = Field(
32
+ description="Weight of this bucket in the overall distribution.",
33
+ gt=0,
34
+ default=100,
35
+ )
36
+ prefix_count: int = Field(
37
+ description="The number of unique prefixes to generate for this bucket.",
38
+ ge=1,
39
+ default=1,
40
+ )
41
+ prefix_tokens: int = Field(
42
+ description="The number of prefix tokens per-prompt for this bucket.",
43
+ ge=0,
44
+ default=0,
45
+ )
46
+
47
+
48
+ class SyntheticTextDatasetConfig(StandardBaseModel):
49
+ model_config = ConfigDict(
50
+ extra="allow",
51
+ )
52
+
53
+ prefix_buckets: list[SyntheticTextPrefixBucketConfig] | None = Field(
54
+ description="Buckets for the prefix tokens distribution.",
55
+ default=None,
56
+ )
57
+ prompt_tokens: int = Field(
58
+ description="The average number of text tokens generated for prompts.",
59
+ gt=0,
60
+ )
61
+ prompt_tokens_stdev: int | None = Field(
62
+ description="The standard deviation of the tokens generated for prompts.",
63
+ gt=0,
64
+ default=None,
65
+ )
66
+ prompt_tokens_min: int | None = Field(
67
+ description="The minimum number of text tokens generated for prompts.",
68
+ gt=0,
69
+ default=None,
70
+ )
71
+ prompt_tokens_max: int | None = Field(
72
+ description="The maximum number of text tokens generated for prompts.",
73
+ gt=0,
74
+ default=None,
75
+ )
76
+ output_tokens: int = Field(
77
+ description="The average number of text tokens generated for outputs.",
78
+ gt=0,
79
+ )
80
+ output_tokens_stdev: int | None = Field(
81
+ description="The standard deviation of the tokens generated for outputs.",
82
+ gt=0,
83
+ default=None,
84
+ )
85
+ output_tokens_min: int | None = Field(
86
+ description="The minimum number of text tokens generated for outputs.",
87
+ gt=0,
88
+ default=None,
89
+ )
90
+ output_tokens_max: int | None = Field(
91
+ description="The maximum number of text tokens generated for outputs.",
92
+ gt=0,
93
+ default=None,
94
+ )
95
+ source: str = Field(
96
+ description="The source of the text data to be used for generation.",
97
+ default="data:prideandprejudice.txt.gz",
98
+ )
99
+
100
+ @model_validator(mode="after")
101
+ def check_prefix_options(self) -> SyntheticTextDatasetConfig:
102
+ prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined]
103
+ prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined]
104
+ if prefix_count is not None or prefix_tokens is not None:
105
+ if self.prefix_buckets:
106
+ raise ValueError(
107
+ "prefix_buckets is mutually exclusive"
108
+ " with prefix_count and prefix_tokens"
109
+ )
110
+
111
+ self.prefix_buckets = [
112
+ SyntheticTextPrefixBucketConfig(
113
+ prefix_count=prefix_count or 1,
114
+ prefix_tokens=prefix_tokens or 0,
115
+ )
116
+ ]
117
+
118
+ return self
119
+
120
+
121
+ class SyntheticTextGenerator:
122
+ def __init__(
123
+ self,
124
+ config: SyntheticTextDatasetConfig,
125
+ processor: PreTrainedTokenizerBase,
126
+ random_seed: int = 42,
127
+ ):
128
+ self.config = config
129
+ self.processor = processor
130
+ self.random_seed = random_seed
131
+
132
+ def __iter__(self) -> Iterator[dict[str, Any]]:
133
+ samples_generated = 0
134
+
135
+ faker = Faker()
136
+ faker.seed_instance(self.random_seed)
137
+ prompt_tokens_sampler = iter(
138
+ IntegerRangeSampler(
139
+ average=self.config.prompt_tokens,
140
+ variance=self.config.prompt_tokens_stdev,
141
+ min_value=self.config.prompt_tokens_min,
142
+ max_value=self.config.prompt_tokens_max,
143
+ random_seed=self.random_seed,
144
+ )
145
+ )
146
+ output_tokens_sampler = iter(
147
+ IntegerRangeSampler(
148
+ average=self.config.output_tokens,
149
+ variance=self.config.output_tokens_stdev,
150
+ min_value=self.config.output_tokens_min,
151
+ max_value=self.config.output_tokens_max,
152
+ random_seed=self.random_seed + 1, # ensure diff dist from prompts
153
+ )
154
+ )
155
+
156
+ # Create a shared prefix if specified
157
+ rand = Random(self.random_seed + 3)
158
+ prefix_iter = self._create_prefix_iter(faker, rand)
159
+
160
+ while True:
161
+ prompt_tokens_count = next(prompt_tokens_sampler)
162
+ output_tokens_count = next(output_tokens_sampler)
163
+
164
+ yield {
165
+ "prefix": next(prefix_iter),
166
+ "prompt": self._create_prompt(
167
+ prompt_tokens_count, faker, f"{samples_generated} "
168
+ ),
169
+ "prompt_tokens_count": prompt_tokens_count,
170
+ "output_tokens_count": output_tokens_count,
171
+ }
172
+ samples_generated += 1
173
+
174
+ def _create_prompt(
175
+ self, prompt_tokens_count: int, faker: Faker, unique: str = ""
176
+ ) -> str:
177
+ prompt_token_ids = []
178
+ avg_chars_per_token = 5
179
+ margin_of_safety = 1.5
180
+ attempts = 0
181
+
182
+ while len(prompt_token_ids) < prompt_tokens_count:
183
+ attempts += 1
184
+ num_chars = (
185
+ prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts
186
+ )
187
+ text = unique + faker.text(max_nb_chars=num_chars)
188
+ prompt_token_ids = self.processor.encode(text)
189
+
190
+ return self.processor.decode(
191
+ prompt_token_ids[:prompt_tokens_count], skip_special_tokens=True
192
+ )
193
+
194
+ def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]:
195
+ if not self.config.prefix_buckets:
196
+ while True:
197
+ yield ""
198
+
199
+ # Increase weights to ensure an integer number of samples per per-prefix
200
+ least_common_prefix_count = math.lcm(
201
+ *(bucket.prefix_count for bucket in self.config.prefix_buckets)
202
+ )
203
+ unnorm_weights = [
204
+ least_common_prefix_count * bucket.bucket_weight // bucket.prefix_count
205
+ for bucket in self.config.prefix_buckets
206
+ ]
207
+ # Use GCD to reduce the weights to smallest integer ratio
208
+ common_divisor = math.gcd(*unnorm_weights)
209
+
210
+ # Create prefix list maintaining the correct distribution
211
+ prefixes = []
212
+ for bucket, weight in zip(
213
+ self.config.prefix_buckets, unnorm_weights, strict=False
214
+ ):
215
+ bucket_prefixes = [
216
+ self._create_prompt(bucket.prefix_tokens, faker)
217
+ for _ in range(bucket.prefix_count)
218
+ ]
219
+ sample_count = weight // common_divisor
220
+ prefixes.extend(bucket_prefixes * sample_count)
221
+
222
+ while True:
223
+ yield rand.choice(prefixes)
224
+
225
+
226
+ @DatasetDeserializerFactory.register("synthetic_text")
227
+ class SyntheticTextDatasetDeserializer(DatasetDeserializer):
228
+ def __call__(
229
+ self,
230
+ data: Any,
231
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
232
+ random_seed: int,
233
+ **data_kwargs: dict[str, Any],
234
+ ) -> IterableDataset:
235
+ # Config file pathways, deserialize and call self again
236
+ if (config := self._load_config_file(data)) is not None:
237
+ return self(config, processor_factory, random_seed, **data_kwargs)
238
+
239
+ # Config str pathways, deserialize and call self again
240
+ if (config := self._load_config_str(data)) is not None:
241
+ return self(config, processor_factory, random_seed, **data_kwargs)
242
+
243
+ if not isinstance(data, SyntheticTextDatasetConfig):
244
+ raise DataNotSupportedError(
245
+ "Unsupported data for SyntheticTextDatasetDeserializer, "
246
+ "expected SyntheticTextDatasetConfig, str or Path to a config file, "
247
+ f"got {data}"
248
+ )
249
+
250
+ return IterableDataset.from_generator(
251
+ SyntheticTextGenerator,
252
+ gen_kwargs={
253
+ "config": data,
254
+ "processor": processor_factory(),
255
+ "random_seed": random_seed,
256
+ },
257
+ features=Features(
258
+ {
259
+ "prefix": Value("string"),
260
+ "prompt": Value("string"),
261
+ "prompt_tokens_count": Value("int32"),
262
+ "output_tokens_count": Value("int32"),
263
+ }
264
+ ),
265
+ )
266
+
267
+ def _load_config_file(self, data: Any) -> SyntheticTextDatasetConfig | None:
268
+ if (not isinstance(data, str) and not isinstance(data, Path)) or (
269
+ not Path(data).is_file()
270
+ ):
271
+ return None
272
+
273
+ data_path = Path(data) if isinstance(data, str) else data
274
+ error = None
275
+
276
+ if Path(data).is_file() and data_path.suffix.lower() == ".json":
277
+ try:
278
+ return SyntheticTextDatasetConfig.model_validate_json(
279
+ data_path.read_text()
280
+ )
281
+ except Exception as err: # noqa: BLE001
282
+ error = err
283
+
284
+ if Path(data).is_file() and data_path.suffix.lower() in {
285
+ ".yaml",
286
+ ".yml",
287
+ ".config",
288
+ }:
289
+ try:
290
+ return SyntheticTextDatasetConfig.model_validate(
291
+ yaml.safe_load(data_path.read_text())
292
+ )
293
+ except Exception as err: # noqa: BLE001
294
+ error = err
295
+
296
+ err_message = (
297
+ f"Unsupported file {data_path} for "
298
+ f"SyntheticTextDatasetDeserializer, expected .json, "
299
+ f".yaml, .yml, or .config"
300
+ )
301
+
302
+ if error is not None:
303
+ err_message += f" with error: {error}"
304
+ raise DataNotSupportedError(err_message) from error
305
+ raise DataNotSupportedError(err_message)
306
+
307
+ def _load_config_str(self, data: str) -> SyntheticTextDatasetConfig | None:
308
+ if not isinstance(data, str):
309
+ return None
310
+
311
+ data_str = data.strip()
312
+ error = None
313
+
314
+ if (data_str.startswith("{") and data_str.endswith("}")) or (
315
+ data_str.startswith("[") and data_str.endswith("]")
316
+ ):
317
+ try:
318
+ return SyntheticTextDatasetConfig.model_validate_json(data_str)
319
+ except Exception as err: # noqa: BLE001
320
+ error = err
321
+
322
+ if data_str.count("=") > 1:
323
+ # key=value pairs separated by commas
324
+ try:
325
+ config_dict = {}
326
+ items = data_str.split(",")
327
+ for item in items:
328
+ key, value = item.split("=")
329
+ config_dict[key.strip()] = (
330
+ int(value.strip())
331
+ if value.strip().isnumeric()
332
+ else value.strip()
333
+ )
334
+
335
+ return SyntheticTextDatasetConfig.model_validate(config_dict)
336
+ except Exception as err: # noqa: BLE001
337
+ error = err
338
+
339
+ err_message = (
340
+ "Unsupported string data for SyntheticTextDatasetDeserializer, "
341
+ f"expected JSON or key-value pairs, got {data}"
342
+ )
343
+ if error is not None:
344
+ err_message += f" with error: {error}"
345
+ raise DataNotSupportedError(err_message) from error
346
+ raise DataNotSupportedError(err_message)
@@ -0,0 +1,145 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ from collections.abc import Callable, Iterator
5
+ from typing import Any, Literal
6
+
7
+ import torch
8
+ from torch.utils.data import Sampler
9
+ from torch.utils.data.dataloader import DataLoader as PyTorchDataLoader
10
+ from torch.utils.data.dataset import IterableDataset as TorchIterableDataset
11
+ from transformers import PreTrainedTokenizerBase
12
+
13
+ from guidellm.data.deserializers import DatasetDeserializerFactory
14
+ from guidellm.data.preprocessors import DataDependentPreprocessor, DatasetPreprocessor
15
+ from guidellm.logger import logger
16
+
17
+ __all__ = ["DataLoader", "DatasetsIterator"]
18
+
19
+
20
+ class DatasetsIterator(TorchIterableDataset):
21
+ def __init__(
22
+ self,
23
+ data: list[Any],
24
+ data_args: list[dict[str, Any]] | None,
25
+ data_samples: int,
26
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
27
+ preprocessors: list[DatasetPreprocessor | DataDependentPreprocessor],
28
+ random_seed: int,
29
+ ):
30
+ if not data or not isinstance(data, list):
31
+ raise ValueError(f"Data must be a non-empty list, got {data}.")
32
+
33
+ if not data_args:
34
+ data_args = [{} for _ in data]
35
+
36
+ if len(data) != len(data_args):
37
+ raise ValueError(
38
+ f"Length of data ({len(data)}) must match length of data_args "
39
+ f"({len(data_args)})."
40
+ )
41
+
42
+ self.datasets = []
43
+ for datum, data_kwargs in zip(data, data_args, strict=False):
44
+ self.datasets.append(
45
+ DatasetDeserializerFactory.deserialize(
46
+ data=datum,
47
+ processor_factory=processor_factory,
48
+ random_seed=random_seed,
49
+ **data_kwargs,
50
+ )
51
+ )
52
+ self.preprocessors = preprocessors
53
+ for preprocessor in self.preprocessors:
54
+ if isinstance(preprocessor, DataDependentPreprocessor):
55
+ preprocessor.setup_data(
56
+ datasets=self.datasets,
57
+ data_args=data_args,
58
+ )
59
+ self.precache: list[Any] | None = (
60
+ list(self.generator(data_samples)) if data_samples else None
61
+ )
62
+
63
+ def __iter__(self):
64
+ worker_info = torch.utils.data.get_worker_info()
65
+ worker_modulus = worker_info.num_workers if worker_info is not None else 1
66
+ worker_index = worker_info.id if worker_info is not None else 0
67
+
68
+ if self.precache:
69
+ for index, item in enumerate(self.precache):
70
+ if (index + worker_index) % worker_modulus == 0:
71
+ yield item
72
+ else:
73
+ yield from self.generator(modulus=worker_modulus, offset=worker_index)
74
+
75
+ def generator(
76
+ self,
77
+ max_items: int | None = None,
78
+ modulus: int | None = None,
79
+ offset: int | None = None,
80
+ ) -> Iterator[Any]:
81
+ gen_count = 0
82
+
83
+ with contextlib.suppress(StopIteration):
84
+ dataset_iters = [iter(dataset) for dataset in self.datasets]
85
+
86
+ while max_items is None or gen_count < max_items:
87
+ try:
88
+ row = {
89
+ "items": [next(dataset_iter) for dataset_iter in dataset_iters]
90
+ }
91
+ gen_count += 1
92
+
93
+ if (
94
+ modulus is not None
95
+ and offset is not None
96
+ and (gen_count % modulus) != offset
97
+ ):
98
+ continue
99
+
100
+ for preprocessor in self.preprocessors:
101
+ row = preprocessor(row)
102
+ yield row
103
+ except Exception as err:
104
+ logger.error(f"Skipping data row due to error: {err}")
105
+ gen_count -= 1
106
+
107
+ if max_items is not None and gen_count < max_items:
108
+ raise ValueError(
109
+ f"Requested {max_items} samples, but only {gen_count} "
110
+ "available from the provided datasets."
111
+ )
112
+
113
+
114
+ class DataLoader(PyTorchDataLoader):
115
+ def __init__(
116
+ self,
117
+ data: list[Any],
118
+ data_args: list[dict[str, Any]] | None,
119
+ data_samples: int,
120
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
121
+ preprocessors: list[DatasetPreprocessor | DataDependentPreprocessor],
122
+ collator: Callable,
123
+ sampler: Sampler[int] | Literal["shuffle"] | None = None,
124
+ num_workers: int | None = 1,
125
+ random_seed: int = 42,
126
+ **kwargs: Any,
127
+ ):
128
+ iterator = DatasetsIterator(
129
+ data=data,
130
+ data_args=data_args,
131
+ data_samples=data_samples,
132
+ processor_factory=processor_factory,
133
+ preprocessors=preprocessors,
134
+ random_seed=random_seed,
135
+ )
136
+
137
+ super().__init__(
138
+ dataset=iterator,
139
+ batch_size=1,
140
+ shuffle=sampler == "shuffle",
141
+ sampler=sampler if sampler != "shuffle" else None,
142
+ collate_fn=collator,
143
+ num_workers=num_workers or 0,
144
+ **kwargs,
145
+ )
@@ -0,0 +1,25 @@
1
+ from .formatters import (
2
+ GenerativeAudioTranscriptionRequestFormatter,
3
+ GenerativeAudioTranslationRequestFormatter,
4
+ GenerativeChatCompletionsRequestFormatter,
5
+ GenerativeTextCompletionsRequestFormatter,
6
+ )
7
+ from .mappers import GenerativeColumnMapper
8
+ from .preprocessor import (
9
+ DataDependentPreprocessor,
10
+ DatasetPreprocessor,
11
+ PreprocessorRegistry,
12
+ )
13
+
14
+ __all__ = [
15
+ "ColumnMapper",
16
+ "ColumnMapperRegistry",
17
+ "DataDependentPreprocessor",
18
+ "DatasetPreprocessor",
19
+ "GenerativeAudioTranscriptionRequestFormatter",
20
+ "GenerativeAudioTranslationRequestFormatter",
21
+ "GenerativeChatCompletionsRequestFormatter",
22
+ "GenerativeColumnMapper",
23
+ "GenerativeTextCompletionsRequestFormatter",
24
+ "PreprocessorRegistry",
25
+ ]