guidellm 0.4.0a21__py3-none-any.whl → 0.4.0a169__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guidellm might be problematic. Click here for more details.

Files changed (115) hide show
  1. guidellm/__init__.py +5 -2
  2. guidellm/__main__.py +452 -252
  3. guidellm/backends/__init__.py +33 -0
  4. guidellm/backends/backend.py +110 -0
  5. guidellm/backends/openai.py +355 -0
  6. guidellm/backends/response_handlers.py +455 -0
  7. guidellm/benchmark/__init__.py +53 -39
  8. guidellm/benchmark/benchmarker.py +150 -317
  9. guidellm/benchmark/entrypoints.py +467 -128
  10. guidellm/benchmark/output.py +519 -771
  11. guidellm/benchmark/profile.py +580 -280
  12. guidellm/benchmark/progress.py +568 -549
  13. guidellm/benchmark/scenarios/__init__.py +40 -0
  14. guidellm/benchmark/scenarios/chat.json +6 -0
  15. guidellm/benchmark/scenarios/rag.json +6 -0
  16. guidellm/benchmark/schemas.py +2086 -0
  17. guidellm/data/__init__.py +28 -4
  18. guidellm/data/collators.py +16 -0
  19. guidellm/data/deserializers/__init__.py +53 -0
  20. guidellm/data/deserializers/deserializer.py +144 -0
  21. guidellm/data/deserializers/file.py +222 -0
  22. guidellm/data/deserializers/huggingface.py +94 -0
  23. guidellm/data/deserializers/memory.py +194 -0
  24. guidellm/data/deserializers/synthetic.py +348 -0
  25. guidellm/data/loaders.py +149 -0
  26. guidellm/data/preprocessors/__init__.py +25 -0
  27. guidellm/data/preprocessors/formatters.py +404 -0
  28. guidellm/data/preprocessors/mappers.py +198 -0
  29. guidellm/data/preprocessors/preprocessor.py +31 -0
  30. guidellm/data/processor.py +31 -0
  31. guidellm/data/schemas.py +13 -0
  32. guidellm/data/utils/__init__.py +6 -0
  33. guidellm/data/utils/dataset.py +94 -0
  34. guidellm/extras/__init__.py +4 -0
  35. guidellm/extras/audio.py +215 -0
  36. guidellm/extras/vision.py +242 -0
  37. guidellm/logger.py +2 -2
  38. guidellm/mock_server/__init__.py +8 -0
  39. guidellm/mock_server/config.py +84 -0
  40. guidellm/mock_server/handlers/__init__.py +17 -0
  41. guidellm/mock_server/handlers/chat_completions.py +280 -0
  42. guidellm/mock_server/handlers/completions.py +280 -0
  43. guidellm/mock_server/handlers/tokenizer.py +142 -0
  44. guidellm/mock_server/models.py +510 -0
  45. guidellm/mock_server/server.py +168 -0
  46. guidellm/mock_server/utils.py +302 -0
  47. guidellm/preprocess/dataset.py +23 -26
  48. guidellm/presentation/builder.py +2 -2
  49. guidellm/presentation/data_models.py +25 -21
  50. guidellm/presentation/injector.py +2 -3
  51. guidellm/scheduler/__init__.py +65 -26
  52. guidellm/scheduler/constraints.py +1035 -0
  53. guidellm/scheduler/environments.py +252 -0
  54. guidellm/scheduler/scheduler.py +140 -368
  55. guidellm/scheduler/schemas.py +272 -0
  56. guidellm/scheduler/strategies.py +519 -0
  57. guidellm/scheduler/worker.py +391 -420
  58. guidellm/scheduler/worker_group.py +707 -0
  59. guidellm/schemas/__init__.py +31 -0
  60. guidellm/schemas/info.py +159 -0
  61. guidellm/schemas/request.py +226 -0
  62. guidellm/schemas/response.py +119 -0
  63. guidellm/schemas/stats.py +228 -0
  64. guidellm/{config.py → settings.py} +32 -21
  65. guidellm/utils/__init__.py +95 -8
  66. guidellm/utils/auto_importer.py +98 -0
  67. guidellm/utils/cli.py +71 -2
  68. guidellm/utils/console.py +183 -0
  69. guidellm/utils/encoding.py +778 -0
  70. guidellm/utils/functions.py +134 -0
  71. guidellm/utils/hf_datasets.py +1 -2
  72. guidellm/utils/hf_transformers.py +4 -4
  73. guidellm/utils/imports.py +9 -0
  74. guidellm/utils/messaging.py +1118 -0
  75. guidellm/utils/mixins.py +115 -0
  76. guidellm/utils/pydantic_utils.py +411 -0
  77. guidellm/utils/random.py +3 -4
  78. guidellm/utils/registry.py +220 -0
  79. guidellm/utils/singleton.py +133 -0
  80. guidellm/{objects → utils}/statistics.py +341 -247
  81. guidellm/utils/synchronous.py +159 -0
  82. guidellm/utils/text.py +163 -50
  83. guidellm/utils/typing.py +41 -0
  84. guidellm/version.py +1 -1
  85. {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a169.dist-info}/METADATA +33 -10
  86. guidellm-0.4.0a169.dist-info/RECORD +95 -0
  87. guidellm/backend/__init__.py +0 -23
  88. guidellm/backend/backend.py +0 -259
  89. guidellm/backend/openai.py +0 -705
  90. guidellm/backend/response.py +0 -136
  91. guidellm/benchmark/aggregator.py +0 -760
  92. guidellm/benchmark/benchmark.py +0 -837
  93. guidellm/benchmark/scenario.py +0 -104
  94. guidellm/data/prideandprejudice.txt.gz +0 -0
  95. guidellm/dataset/__init__.py +0 -22
  96. guidellm/dataset/creator.py +0 -213
  97. guidellm/dataset/entrypoints.py +0 -42
  98. guidellm/dataset/file.py +0 -92
  99. guidellm/dataset/hf_datasets.py +0 -62
  100. guidellm/dataset/in_memory.py +0 -132
  101. guidellm/dataset/synthetic.py +0 -287
  102. guidellm/objects/__init__.py +0 -18
  103. guidellm/objects/pydantic.py +0 -89
  104. guidellm/request/__init__.py +0 -18
  105. guidellm/request/loader.py +0 -284
  106. guidellm/request/request.py +0 -79
  107. guidellm/request/types.py +0 -10
  108. guidellm/scheduler/queues.py +0 -25
  109. guidellm/scheduler/result.py +0 -155
  110. guidellm/scheduler/strategy.py +0 -495
  111. guidellm-0.4.0a21.dist-info/RECORD +0 -62
  112. {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a169.dist-info}/WHEEL +0 -0
  113. {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a169.dist-info}/entry_points.txt +0 -0
  114. {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a169.dist-info}/licenses/LICENSE +0 -0
  115. {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a169.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,194 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import csv
5
+ import json
6
+ from collections.abc import Callable
7
+ from io import StringIO
8
+ from typing import Any, cast
9
+
10
+ from datasets import Dataset
11
+ from transformers import PreTrainedTokenizerBase
12
+
13
+ from guidellm.data.deserializers.deserializer import (
14
+ DataNotSupportedError,
15
+ DatasetDeserializer,
16
+ DatasetDeserializerFactory,
17
+ )
18
+
19
+ __all__ = [
20
+ "InMemoryCsvDatasetDeserializer",
21
+ "InMemoryDictDatasetDeserializer",
22
+ "InMemoryDictListDatasetDeserializer",
23
+ "InMemoryItemListDatasetDeserializer",
24
+ "InMemoryJsonStrDatasetDeserializer",
25
+ ]
26
+
27
+
28
+ @DatasetDeserializerFactory.register("in_memory_dict")
29
+ class InMemoryDictDatasetDeserializer(DatasetDeserializer):
30
+ def __call__(
31
+ self,
32
+ data: Any,
33
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
34
+ random_seed: int,
35
+ **data_kwargs: dict[str, Any],
36
+ ) -> Dataset:
37
+ _ = (processor_factory, random_seed) # Ignore unused args format errors
38
+
39
+ if (
40
+ not data
41
+ or not isinstance(data, dict)
42
+ or not all(
43
+ isinstance(key, str) and isinstance(val, list)
44
+ for key, val in data.items()
45
+ )
46
+ ):
47
+ raise DataNotSupportedError(
48
+ f"Unsupported data for InMemoryDictDatasetDeserializer, "
49
+ f"expected dict[str, list], got {data}"
50
+ )
51
+
52
+ rows = len(list(data.values())[0])
53
+ if not all(len(val) == rows for val in data.values()):
54
+ raise DataNotSupportedError(
55
+ "All lists in the data dictionary must have the same length, "
56
+ f"expected {rows} for all keys {list(data.keys())}"
57
+ )
58
+
59
+ return Dataset.from_dict(data, **data_kwargs)
60
+
61
+
62
+ @DatasetDeserializerFactory.register("in_memory_dict_list")
63
+ class InMemoryDictListDatasetDeserializer(DatasetDeserializer):
64
+ def __call__(
65
+ self,
66
+ data: Any,
67
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
68
+ random_seed: int,
69
+ **data_kwargs: dict[str, Any],
70
+ ) -> Dataset:
71
+ _ = (processor_factory, random_seed) # Ignore unused args format errors
72
+
73
+ if (
74
+ not data
75
+ or not isinstance(data, list)
76
+ or not all(isinstance(item, dict) for item in data)
77
+ or not all(isinstance(key, str) for item in data for key in item)
78
+ ):
79
+ raise DataNotSupportedError(
80
+ f"Unsupported data for InMemoryDictListDatasetDeserializer, "
81
+ f"expected list of dicts, got {data}"
82
+ )
83
+
84
+ typed_data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data)
85
+ first_keys = set(typed_data[0].keys())
86
+ for index, item in enumerate(typed_data):
87
+ if set(item.keys()) != first_keys:
88
+ raise DataNotSupportedError(
89
+ f"All dictionaries must have the same keys. "
90
+ f"Expected keys: {first_keys}, "
91
+ f"got keys at index {index}: {set(item.keys())}"
92
+ )
93
+
94
+ # Convert list of dicts to dict of lists
95
+ result_dict: dict = {key: [] for key in first_keys}
96
+ for item in typed_data:
97
+ for key, value in item.items():
98
+ result_dict[key].append(value)
99
+
100
+ return Dataset.from_dict(result_dict, **data_kwargs)
101
+
102
+
103
+ @DatasetDeserializerFactory.register("in_memory_item_list")
104
+ class InMemoryItemListDatasetDeserializer(DatasetDeserializer):
105
+ def __call__(
106
+ self,
107
+ data: Any,
108
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
109
+ random_seed: int,
110
+ **data_kwargs: dict[str, Any],
111
+ ) -> Dataset:
112
+ _ = (processor_factory, random_seed) # Ignore unused args format errors
113
+
114
+ primitive_types = (str, int, float, bool, type(None))
115
+ if (
116
+ not data
117
+ or not isinstance(data, list)
118
+ or not all(isinstance(item, primitive_types) for item in data)
119
+ ):
120
+ raise DataNotSupportedError(
121
+ f"Unsupported data for InMemoryItemListDatasetDeserializer, "
122
+ f"expected list of primitive items, got {data}"
123
+ )
124
+
125
+ column_name = data_kwargs.pop("column_name", "data")
126
+
127
+ return Dataset.from_dict({column_name: data}, **data_kwargs)
128
+
129
+
130
+ @DatasetDeserializerFactory.register("in_memory_json_str")
131
+ class InMemoryJsonStrDatasetDeserializer(DatasetDeserializer):
132
+ def __call__(
133
+ self,
134
+ data: Any,
135
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
136
+ random_seed: int,
137
+ **data_kwargs: dict[str, Any],
138
+ ) -> Dataset:
139
+ if (
140
+ isinstance(data, str)
141
+ and (json_str := data.strip())
142
+ and (
143
+ (json_str.startswith("{") and json_str.endswith("}"))
144
+ or (json_str.startswith("[") and json_str.endswith("]"))
145
+ )
146
+ ):
147
+ with contextlib.suppress(Exception):
148
+ parsed_data = json.loads(data)
149
+
150
+ deserializers = [
151
+ InMemoryDictDatasetDeserializer(),
152
+ InMemoryDictListDatasetDeserializer(),
153
+ InMemoryItemListDatasetDeserializer(),
154
+ ]
155
+
156
+ for deserializer in deserializers:
157
+ with contextlib.suppress(DataNotSupportedError):
158
+ return deserializer(
159
+ parsed_data, processor_factory, random_seed, **data_kwargs
160
+ )
161
+
162
+ raise DataNotSupportedError(
163
+ f"Unsupported data for InMemoryJsonStrDatasetDeserializer, "
164
+ f"expected JSON string with a list or dict of items, got {data}"
165
+ )
166
+
167
+
168
+ @DatasetDeserializerFactory.register("in_memory_csv_str")
169
+ class InMemoryCsvDatasetDeserializer(DatasetDeserializer):
170
+ def __call__(
171
+ self,
172
+ data: Any,
173
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
174
+ random_seed: int,
175
+ **data_kwargs: dict[str, Any],
176
+ ) -> Dataset:
177
+ if (
178
+ isinstance(data, str)
179
+ and (csv_str := data.strip())
180
+ and len(csv_str.split("\n")) > 0
181
+ ):
182
+ with contextlib.suppress(Exception):
183
+ csv_buffer = StringIO(data)
184
+ reader = csv.DictReader(csv_buffer)
185
+ rows = list(reader)
186
+
187
+ return InMemoryDictListDatasetDeserializer()(
188
+ rows, processor_factory, random_seed, **data_kwargs
189
+ )
190
+
191
+ raise DataNotSupportedError(
192
+ f"Unsupported data for InMemoryCsvDatasetDeserializer, "
193
+ f"expected CSV string, got {type(data)}"
194
+ )
@@ -0,0 +1,348 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from collections.abc import Callable, Iterator
5
+ from pathlib import Path
6
+ from random import Random
7
+ from typing import Any
8
+
9
+ import yaml
10
+ from datasets import Features, IterableDataset, Value
11
+ from faker import Faker
12
+ from pydantic import ConfigDict, Field, model_validator
13
+ from transformers import PreTrainedTokenizerBase
14
+
15
+ from guidellm.data.deserializers.deserializer import (
16
+ DataNotSupportedError,
17
+ DatasetDeserializer,
18
+ DatasetDeserializerFactory,
19
+ )
20
+ from guidellm.utils import IntegerRangeSampler, StandardBaseModel
21
+
22
+ __all__ = [
23
+ "SyntheticTextDatasetConfig",
24
+ "SyntheticTextDatasetDeserializer",
25
+ "SyntheticTextGenerator",
26
+ "SyntheticTextPrefixBucketConfig",
27
+ ]
28
+
29
+
30
+ class SyntheticTextPrefixBucketConfig(StandardBaseModel):
31
+ bucket_weight: int = Field(
32
+ description="Weight of this bucket in the overall distribution.",
33
+ gt=0,
34
+ default=100,
35
+ )
36
+ prefix_count: int = Field(
37
+ description="The number of unique prefixes to generate for this bucket.",
38
+ ge=1,
39
+ default=1,
40
+ )
41
+ prefix_tokens: int = Field(
42
+ description="The number of prefix tokens per-prompt for this bucket.",
43
+ ge=0,
44
+ default=0,
45
+ )
46
+
47
+
48
+ class SyntheticTextDatasetConfig(StandardBaseModel):
49
+ model_config = ConfigDict(
50
+ extra="allow",
51
+ )
52
+
53
+ prefix_buckets: list[SyntheticTextPrefixBucketConfig] | None = Field(
54
+ description="Buckets for the prefix tokens distribution.",
55
+ default=None,
56
+ )
57
+ prompt_tokens: int = Field(
58
+ description="The average number of text tokens generated for prompts.",
59
+ gt=0,
60
+ )
61
+ prompt_tokens_stdev: int | None = Field(
62
+ description="The standard deviation of the tokens generated for prompts.",
63
+ gt=0,
64
+ default=None,
65
+ )
66
+ prompt_tokens_min: int | None = Field(
67
+ description="The minimum number of text tokens generated for prompts.",
68
+ gt=0,
69
+ default=None,
70
+ )
71
+ prompt_tokens_max: int | None = Field(
72
+ description="The maximum number of text tokens generated for prompts.",
73
+ gt=0,
74
+ default=None,
75
+ )
76
+ output_tokens: int = Field(
77
+ description="The average number of text tokens generated for outputs.",
78
+ gt=0,
79
+ )
80
+ output_tokens_stdev: int | None = Field(
81
+ description="The standard deviation of the tokens generated for outputs.",
82
+ gt=0,
83
+ default=None,
84
+ )
85
+ output_tokens_min: int | None = Field(
86
+ description="The minimum number of text tokens generated for outputs.",
87
+ gt=0,
88
+ default=None,
89
+ )
90
+ output_tokens_max: int | None = Field(
91
+ description="The maximum number of text tokens generated for outputs.",
92
+ gt=0,
93
+ default=None,
94
+ )
95
+ source: str = Field(
96
+ description="The source of the text data to be used for generation.",
97
+ default="data:prideandprejudice.txt.gz",
98
+ )
99
+
100
+ @model_validator(mode="after")
101
+ def check_prefix_options(self) -> SyntheticTextDatasetConfig:
102
+ if self.__pydantic_extra__ is not None:
103
+ prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined]
104
+ prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined]
105
+
106
+ if prefix_count is not None or prefix_tokens is not None:
107
+ if self.prefix_buckets:
108
+ raise ValueError(
109
+ "prefix_buckets is mutually exclusive"
110
+ " with prefix_count and prefix_tokens"
111
+ )
112
+
113
+ self.prefix_buckets = [
114
+ SyntheticTextPrefixBucketConfig(
115
+ prefix_count=prefix_count or 1,
116
+ prefix_tokens=prefix_tokens or 0,
117
+ )
118
+ ]
119
+
120
+ return self
121
+
122
+
123
+ class SyntheticTextGenerator:
124
+ def __init__(
125
+ self,
126
+ config: SyntheticTextDatasetConfig,
127
+ processor: PreTrainedTokenizerBase,
128
+ random_seed: int = 42,
129
+ ):
130
+ self.config = config
131
+ self.processor = processor
132
+ self.random_seed = random_seed
133
+
134
+ def __iter__(self) -> Iterator[dict[str, Any]]:
135
+ samples_generated = 0
136
+
137
+ faker = Faker()
138
+ faker.seed_instance(self.random_seed)
139
+ prompt_tokens_sampler = iter(
140
+ IntegerRangeSampler(
141
+ average=self.config.prompt_tokens,
142
+ variance=self.config.prompt_tokens_stdev,
143
+ min_value=self.config.prompt_tokens_min,
144
+ max_value=self.config.prompt_tokens_max,
145
+ random_seed=self.random_seed,
146
+ )
147
+ )
148
+ output_tokens_sampler = iter(
149
+ IntegerRangeSampler(
150
+ average=self.config.output_tokens,
151
+ variance=self.config.output_tokens_stdev,
152
+ min_value=self.config.output_tokens_min,
153
+ max_value=self.config.output_tokens_max,
154
+ random_seed=self.random_seed + 1, # ensure diff dist from prompts
155
+ )
156
+ )
157
+
158
+ # Create a shared prefix if specified
159
+ rand = Random(self.random_seed + 3)
160
+ prefix_iter = self._create_prefix_iter(faker, rand)
161
+
162
+ while True:
163
+ prompt_tokens_count = next(prompt_tokens_sampler)
164
+ output_tokens_count = next(output_tokens_sampler)
165
+
166
+ yield {
167
+ "prefix": next(prefix_iter),
168
+ "prompt": self._create_prompt(
169
+ prompt_tokens_count, faker, f"{samples_generated} "
170
+ ),
171
+ "prompt_tokens_count": prompt_tokens_count,
172
+ "output_tokens_count": output_tokens_count,
173
+ }
174
+ samples_generated += 1
175
+
176
+ def _create_prompt(
177
+ self, prompt_tokens_count: int, faker: Faker, unique: str = ""
178
+ ) -> str:
179
+ prompt_token_ids: list[int] = []
180
+ avg_chars_per_token = 5
181
+ margin_of_safety = 1.5
182
+ attempts = 0
183
+
184
+ while len(prompt_token_ids) < prompt_tokens_count:
185
+ attempts += 1
186
+ num_chars = int(
187
+ prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts
188
+ )
189
+ text = unique + faker.text(max_nb_chars=num_chars)
190
+ prompt_token_ids = self.processor.encode(text)
191
+
192
+ return self.processor.decode(
193
+ prompt_token_ids[:prompt_tokens_count], skip_special_tokens=True
194
+ )
195
+
196
+ def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]:
197
+ if not self.config.prefix_buckets:
198
+ while True:
199
+ yield ""
200
+
201
+ # Increase weights to ensure an integer number of samples per per-prefix
202
+ least_common_prefix_count = math.lcm(
203
+ *(bucket.prefix_count for bucket in self.config.prefix_buckets)
204
+ )
205
+ unnorm_weights = [
206
+ least_common_prefix_count * bucket.bucket_weight // bucket.prefix_count
207
+ for bucket in self.config.prefix_buckets
208
+ ]
209
+ # Use GCD to reduce the weights to smallest integer ratio
210
+ common_divisor = math.gcd(*unnorm_weights)
211
+
212
+ # Create prefix list maintaining the correct distribution
213
+ prefixes = []
214
+ for bucket, weight in zip(
215
+ self.config.prefix_buckets, unnorm_weights, strict=False
216
+ ):
217
+ bucket_prefixes = [
218
+ self._create_prompt(bucket.prefix_tokens, faker)
219
+ for _ in range(bucket.prefix_count)
220
+ ]
221
+ sample_count = weight // common_divisor
222
+ prefixes.extend(bucket_prefixes * sample_count)
223
+
224
+ while True:
225
+ yield rand.choice(prefixes)
226
+
227
+
228
+ @DatasetDeserializerFactory.register("synthetic_text")
229
+ class SyntheticTextDatasetDeserializer(DatasetDeserializer):
230
+ def __call__(
231
+ self,
232
+ data: Any,
233
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
234
+ random_seed: int,
235
+ **data_kwargs: dict[str, Any],
236
+ ) -> IterableDataset:
237
+ # Config file pathways, deserialize and call self again
238
+ if (config := self._load_config_file(data)) is not None:
239
+ return self(config, processor_factory, random_seed, **data_kwargs)
240
+
241
+ # Config str pathways, deserialize and call self again
242
+ if (config := self._load_config_str(data)) is not None:
243
+ return self(config, processor_factory, random_seed, **data_kwargs)
244
+
245
+ if not isinstance(data, SyntheticTextDatasetConfig):
246
+ raise DataNotSupportedError(
247
+ "Unsupported data for SyntheticTextDatasetDeserializer, "
248
+ "expected SyntheticTextDatasetConfig, str or Path to a config file, "
249
+ f"got {data}"
250
+ )
251
+
252
+ return IterableDataset.from_generator(
253
+ SyntheticTextGenerator,
254
+ gen_kwargs={
255
+ "config": data,
256
+ "processor": processor_factory(),
257
+ "random_seed": random_seed,
258
+ },
259
+ features=Features(
260
+ {
261
+ "prefix": Value("string"),
262
+ "prompt": Value("string"),
263
+ "prompt_tokens_count": Value("int32"),
264
+ "output_tokens_count": Value("int32"),
265
+ }
266
+ ),
267
+ )
268
+
269
+ def _load_config_file(self, data: Any) -> SyntheticTextDatasetConfig | None:
270
+ if (not isinstance(data, str) and not isinstance(data, Path)) or (
271
+ not Path(data).is_file()
272
+ ):
273
+ return None
274
+
275
+ data_path = Path(data) if isinstance(data, str) else data
276
+ error = None
277
+
278
+ if Path(data).is_file() and data_path.suffix.lower() == ".json":
279
+ try:
280
+ return SyntheticTextDatasetConfig.model_validate_json(
281
+ data_path.read_text()
282
+ )
283
+ except Exception as err: # noqa: BLE001
284
+ error = err
285
+
286
+ if Path(data).is_file() and data_path.suffix.lower() in {
287
+ ".yaml",
288
+ ".yml",
289
+ ".config",
290
+ }:
291
+ try:
292
+ return SyntheticTextDatasetConfig.model_validate(
293
+ yaml.safe_load(data_path.read_text())
294
+ )
295
+ except Exception as err: # noqa: BLE001
296
+ error = err
297
+
298
+ err_message = (
299
+ f"Unsupported file {data_path} for "
300
+ f"SyntheticTextDatasetDeserializer, expected .json, "
301
+ f".yaml, .yml, or .config"
302
+ )
303
+
304
+ if error is not None:
305
+ err_message += f" with error: {error}"
306
+ raise DataNotSupportedError(err_message) from error
307
+ raise DataNotSupportedError(err_message)
308
+
309
+ def _load_config_str(self, data: str) -> SyntheticTextDatasetConfig | None:
310
+ if not isinstance(data, str):
311
+ return None
312
+
313
+ data_str = data.strip()
314
+ error = None
315
+
316
+ if (data_str.startswith("{") and data_str.endswith("}")) or (
317
+ data_str.startswith("[") and data_str.endswith("]")
318
+ ):
319
+ try:
320
+ return SyntheticTextDatasetConfig.model_validate_json(data_str)
321
+ except Exception as err: # noqa: BLE001
322
+ error = err
323
+
324
+ if data_str.count("=") > 1:
325
+ # key=value pairs separated by commas
326
+ try:
327
+ config_dict = {}
328
+ items = data_str.split(",")
329
+ for item in items:
330
+ key, value = item.split("=")
331
+ config_dict[key.strip()] = (
332
+ int(value.strip())
333
+ if value.strip().isnumeric()
334
+ else value.strip()
335
+ )
336
+
337
+ return SyntheticTextDatasetConfig.model_validate(config_dict)
338
+ except Exception as err: # noqa: BLE001
339
+ error = err
340
+
341
+ err_message = (
342
+ "Unsupported string data for SyntheticTextDatasetDeserializer, "
343
+ f"expected JSON or key-value pairs, got {data}"
344
+ )
345
+ if error is not None:
346
+ err_message += f" with error: {error}"
347
+ raise DataNotSupportedError(err_message) from error
348
+ raise DataNotSupportedError(err_message)
@@ -0,0 +1,149 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ from collections.abc import Callable, Iterator
5
+ from typing import Any, Literal
6
+
7
+ import torch
8
+ from torch.utils.data import Sampler
9
+ from torch.utils.data.dataloader import DataLoader as PyTorchDataLoader
10
+ from torch.utils.data.dataset import IterableDataset as TorchIterableDataset
11
+ from transformers import PreTrainedTokenizerBase
12
+
13
+ from guidellm.data.deserializers import DatasetDeserializerFactory
14
+ from guidellm.data.preprocessors import DataDependentPreprocessor, DatasetPreprocessor
15
+ from guidellm.logger import logger
16
+
17
+ __all__ = ["DataLoader", "DatasetsIterator"]
18
+
19
+
20
+
21
+ class DatasetsIterator(TorchIterableDataset):
22
+ def __init__(
23
+ self,
24
+ data: list[Any],
25
+ data_args: list[dict[str, Any]] | None,
26
+ data_samples: int,
27
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
28
+ preprocessors: list[DatasetPreprocessor | DataDependentPreprocessor],
29
+ random_seed: int,
30
+ ):
31
+ if not data or not isinstance(data, list):
32
+ raise ValueError(f"Data must be a non-empty list, got {data}.")
33
+
34
+ if not data_args:
35
+ data_args = [{} for _ in data]
36
+
37
+ if len(data) != len(data_args):
38
+ raise ValueError(
39
+ f"Length of data ({len(data)}) must match length of data_args "
40
+ f"({len(data_args)})."
41
+ )
42
+
43
+ self.datasets = []
44
+ for datum, data_kwargs in zip(data, data_args, strict=False):
45
+ self.datasets.append(
46
+ DatasetDeserializerFactory.deserialize(
47
+ data=datum,
48
+ processor_factory=processor_factory,
49
+ random_seed=random_seed,
50
+ **data_kwargs,
51
+ )
52
+ )
53
+ self.preprocessors = preprocessors
54
+ for preprocessor in self.preprocessors:
55
+ if isinstance(preprocessor, DataDependentPreprocessor):
56
+ preprocessor.setup_data(
57
+ datasets=self.datasets,
58
+ data_args=data_args,
59
+ )
60
+ self.precache: list[Any] | None = (
61
+ list(self.generator(data_samples)) if data_samples else None
62
+ )
63
+
64
+ def __iter__(self):
65
+ worker_info = torch.utils.data.get_worker_info()
66
+ worker_modulus = worker_info.num_workers if worker_info is not None else 1
67
+ worker_index = worker_info.id if worker_info is not None else 0
68
+
69
+ if self.precache:
70
+ for index, item in enumerate(self.precache):
71
+ if (index + worker_index) % worker_modulus == 0:
72
+ yield item
73
+ else:
74
+ yield from self.generator(modulus=worker_modulus, offset=worker_index)
75
+
76
+ def generator(
77
+ self,
78
+ max_items: int | None = None,
79
+ modulus: int | None = None,
80
+ offset: int | None = None,
81
+ ) -> Iterator[Any]:
82
+ gen_count = 0
83
+
84
+ with contextlib.suppress(StopIteration):
85
+ dataset_iters = [iter(dataset) for dataset in self.datasets]
86
+
87
+ while max_items is None or gen_count < max_items:
88
+ try:
89
+ row: dict[str, Any] = {
90
+ "items": [next(dataset_iter) for dataset_iter in dataset_iters]
91
+ }
92
+ gen_count += 1
93
+
94
+ if (
95
+ modulus is not None
96
+ and offset is not None
97
+ and (gen_count % modulus) != offset
98
+ ):
99
+ continue
100
+
101
+ for preprocessor in self.preprocessors:
102
+ # This can assign a GenerationRequest, which would then be
103
+ # passed into the preprocessor, which is a type violation.
104
+ # This should be fixed at some point.
105
+ row = preprocessor(row) # type: ignore[assignment]
106
+ yield row
107
+ except Exception as err: # noqa: BLE001 # Exception logged
108
+ logger.error(f"Skipping data row due to error: {err}")
109
+ gen_count -= 1
110
+
111
+ if max_items is not None and gen_count < max_items:
112
+ raise ValueError(
113
+ f"Requested {max_items} samples, but only {gen_count} "
114
+ "available from the provided datasets."
115
+ )
116
+
117
+
118
+ class DataLoader(PyTorchDataLoader):
119
+ def __init__(
120
+ self,
121
+ data: list[Any],
122
+ data_args: list[dict[str, Any]] | None,
123
+ data_samples: int,
124
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
125
+ preprocessors: list[DatasetPreprocessor | DataDependentPreprocessor],
126
+ collator: Callable,
127
+ sampler: Sampler[int] | Literal["shuffle"] | None = None,
128
+ num_workers: int | None = 1,
129
+ random_seed: int = 42,
130
+ **kwargs: Any,
131
+ ):
132
+ iterator = DatasetsIterator(
133
+ data=data,
134
+ data_args=data_args,
135
+ data_samples=data_samples,
136
+ processor_factory=processor_factory,
137
+ preprocessors=preprocessors,
138
+ random_seed=random_seed,
139
+ )
140
+
141
+ super().__init__(
142
+ dataset=iterator,
143
+ batch_size=1,
144
+ shuffle=sampler == "shuffle",
145
+ sampler=sampler if sampler != "shuffle" else None,
146
+ collate_fn=collator,
147
+ num_workers=num_workers or 0,
148
+ **kwargs,
149
+ )