guidellm 0.3.1__py3-none-any.whl → 0.6.0a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guidellm/__init__.py +5 -2
- guidellm/__main__.py +524 -255
- guidellm/backends/__init__.py +33 -0
- guidellm/backends/backend.py +109 -0
- guidellm/backends/openai.py +340 -0
- guidellm/backends/response_handlers.py +428 -0
- guidellm/benchmark/__init__.py +69 -39
- guidellm/benchmark/benchmarker.py +160 -316
- guidellm/benchmark/entrypoints.py +560 -127
- guidellm/benchmark/outputs/__init__.py +24 -0
- guidellm/benchmark/outputs/console.py +633 -0
- guidellm/benchmark/outputs/csv.py +721 -0
- guidellm/benchmark/outputs/html.py +473 -0
- guidellm/benchmark/outputs/output.py +169 -0
- guidellm/benchmark/outputs/serialized.py +69 -0
- guidellm/benchmark/profiles.py +718 -0
- guidellm/benchmark/progress.py +553 -556
- guidellm/benchmark/scenarios/__init__.py +40 -0
- guidellm/benchmark/scenarios/chat.json +6 -0
- guidellm/benchmark/scenarios/rag.json +6 -0
- guidellm/benchmark/schemas/__init__.py +66 -0
- guidellm/benchmark/schemas/base.py +402 -0
- guidellm/benchmark/schemas/generative/__init__.py +55 -0
- guidellm/benchmark/schemas/generative/accumulator.py +841 -0
- guidellm/benchmark/schemas/generative/benchmark.py +163 -0
- guidellm/benchmark/schemas/generative/entrypoints.py +381 -0
- guidellm/benchmark/schemas/generative/metrics.py +927 -0
- guidellm/benchmark/schemas/generative/report.py +158 -0
- guidellm/data/__init__.py +34 -4
- guidellm/data/builders.py +541 -0
- guidellm/data/collators.py +16 -0
- guidellm/data/config.py +120 -0
- guidellm/data/deserializers/__init__.py +49 -0
- guidellm/data/deserializers/deserializer.py +141 -0
- guidellm/data/deserializers/file.py +223 -0
- guidellm/data/deserializers/huggingface.py +94 -0
- guidellm/data/deserializers/memory.py +194 -0
- guidellm/data/deserializers/synthetic.py +246 -0
- guidellm/data/entrypoints.py +52 -0
- guidellm/data/loaders.py +190 -0
- guidellm/data/preprocessors/__init__.py +27 -0
- guidellm/data/preprocessors/formatters.py +410 -0
- guidellm/data/preprocessors/mappers.py +196 -0
- guidellm/data/preprocessors/preprocessor.py +30 -0
- guidellm/data/processor.py +29 -0
- guidellm/data/schemas.py +175 -0
- guidellm/data/utils/__init__.py +6 -0
- guidellm/data/utils/dataset.py +94 -0
- guidellm/extras/__init__.py +4 -0
- guidellm/extras/audio.py +220 -0
- guidellm/extras/vision.py +242 -0
- guidellm/logger.py +2 -2
- guidellm/mock_server/__init__.py +8 -0
- guidellm/mock_server/config.py +84 -0
- guidellm/mock_server/handlers/__init__.py +17 -0
- guidellm/mock_server/handlers/chat_completions.py +280 -0
- guidellm/mock_server/handlers/completions.py +280 -0
- guidellm/mock_server/handlers/tokenizer.py +142 -0
- guidellm/mock_server/models.py +510 -0
- guidellm/mock_server/server.py +238 -0
- guidellm/mock_server/utils.py +302 -0
- guidellm/scheduler/__init__.py +69 -26
- guidellm/scheduler/constraints/__init__.py +49 -0
- guidellm/scheduler/constraints/constraint.py +325 -0
- guidellm/scheduler/constraints/error.py +411 -0
- guidellm/scheduler/constraints/factory.py +182 -0
- guidellm/scheduler/constraints/request.py +312 -0
- guidellm/scheduler/constraints/saturation.py +722 -0
- guidellm/scheduler/environments.py +252 -0
- guidellm/scheduler/scheduler.py +137 -368
- guidellm/scheduler/schemas.py +358 -0
- guidellm/scheduler/strategies.py +617 -0
- guidellm/scheduler/worker.py +413 -419
- guidellm/scheduler/worker_group.py +712 -0
- guidellm/schemas/__init__.py +65 -0
- guidellm/schemas/base.py +417 -0
- guidellm/schemas/info.py +188 -0
- guidellm/schemas/request.py +235 -0
- guidellm/schemas/request_stats.py +349 -0
- guidellm/schemas/response.py +124 -0
- guidellm/schemas/statistics.py +1018 -0
- guidellm/{config.py → settings.py} +31 -24
- guidellm/utils/__init__.py +71 -8
- guidellm/utils/auto_importer.py +98 -0
- guidellm/utils/cli.py +132 -5
- guidellm/utils/console.py +566 -0
- guidellm/utils/encoding.py +778 -0
- guidellm/utils/functions.py +159 -0
- guidellm/utils/hf_datasets.py +1 -2
- guidellm/utils/hf_transformers.py +4 -4
- guidellm/utils/imports.py +9 -0
- guidellm/utils/messaging.py +1118 -0
- guidellm/utils/mixins.py +115 -0
- guidellm/utils/random.py +3 -4
- guidellm/utils/registry.py +220 -0
- guidellm/utils/singleton.py +133 -0
- guidellm/utils/synchronous.py +159 -0
- guidellm/utils/text.py +163 -50
- guidellm/utils/typing.py +41 -0
- guidellm/version.py +2 -2
- guidellm-0.6.0a5.dist-info/METADATA +364 -0
- guidellm-0.6.0a5.dist-info/RECORD +109 -0
- guidellm/backend/__init__.py +0 -23
- guidellm/backend/backend.py +0 -259
- guidellm/backend/openai.py +0 -708
- guidellm/backend/response.py +0 -136
- guidellm/benchmark/aggregator.py +0 -760
- guidellm/benchmark/benchmark.py +0 -837
- guidellm/benchmark/output.py +0 -997
- guidellm/benchmark/profile.py +0 -409
- guidellm/benchmark/scenario.py +0 -104
- guidellm/data/prideandprejudice.txt.gz +0 -0
- guidellm/dataset/__init__.py +0 -22
- guidellm/dataset/creator.py +0 -213
- guidellm/dataset/entrypoints.py +0 -42
- guidellm/dataset/file.py +0 -92
- guidellm/dataset/hf_datasets.py +0 -62
- guidellm/dataset/in_memory.py +0 -132
- guidellm/dataset/synthetic.py +0 -287
- guidellm/objects/__init__.py +0 -18
- guidellm/objects/pydantic.py +0 -89
- guidellm/objects/statistics.py +0 -953
- guidellm/preprocess/__init__.py +0 -3
- guidellm/preprocess/dataset.py +0 -374
- guidellm/presentation/__init__.py +0 -28
- guidellm/presentation/builder.py +0 -27
- guidellm/presentation/data_models.py +0 -232
- guidellm/presentation/injector.py +0 -66
- guidellm/request/__init__.py +0 -18
- guidellm/request/loader.py +0 -284
- guidellm/request/request.py +0 -79
- guidellm/request/types.py +0 -10
- guidellm/scheduler/queues.py +0 -25
- guidellm/scheduler/result.py +0 -155
- guidellm/scheduler/strategy.py +0 -495
- guidellm-0.3.1.dist-info/METADATA +0 -329
- guidellm-0.3.1.dist-info/RECORD +0 -62
- {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/WHEEL +0 -0
- {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/entry_points.txt +0 -0
- {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/licenses/LICENSE +0 -0
- {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from collections.abc import Callable, Iterator
|
|
5
|
+
from random import Random
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from datasets import DatasetInfo, Features, IterableDataset, Value
|
|
10
|
+
from datasets.iterable_dataset import _BaseExamplesIterable
|
|
11
|
+
from faker import Faker
|
|
12
|
+
from transformers import PreTrainedTokenizerBase
|
|
13
|
+
|
|
14
|
+
from guidellm.data.config import load_config
|
|
15
|
+
from guidellm.data.deserializers.deserializer import (
|
|
16
|
+
DataNotSupportedError,
|
|
17
|
+
DatasetDeserializer,
|
|
18
|
+
DatasetDeserializerFactory,
|
|
19
|
+
)
|
|
20
|
+
from guidellm.data.schemas import SyntheticTextDatasetConfig
|
|
21
|
+
from guidellm.utils import IntegerRangeSampler
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"SyntheticTextDataset",
|
|
25
|
+
"SyntheticTextDatasetDeserializer",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class _SyntheticTextExamplesIterable(_BaseExamplesIterable):
|
|
30
|
+
"""Custom examples iterable for synthetic text generation."""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
config: SyntheticTextDatasetConfig,
|
|
35
|
+
processor: PreTrainedTokenizerBase,
|
|
36
|
+
random_seed: int,
|
|
37
|
+
):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.config = config
|
|
40
|
+
self.processor = processor
|
|
41
|
+
self.random_seed = random_seed
|
|
42
|
+
self.iteration_count = 0
|
|
43
|
+
|
|
44
|
+
def __iter__(self) -> Iterator[tuple[int, dict[str, Any]]]:
|
|
45
|
+
iter_random_seed = self.random_seed + self.iteration_count
|
|
46
|
+
self.iteration_count += 1
|
|
47
|
+
|
|
48
|
+
faker = Faker()
|
|
49
|
+
faker.seed_instance(iter_random_seed)
|
|
50
|
+
prompt_tokens_sampler = iter(
|
|
51
|
+
IntegerRangeSampler(
|
|
52
|
+
average=self.config.prompt_tokens,
|
|
53
|
+
variance=self.config.prompt_tokens_stdev,
|
|
54
|
+
min_value=self.config.prompt_tokens_min,
|
|
55
|
+
max_value=self.config.prompt_tokens_max,
|
|
56
|
+
random_seed=iter_random_seed,
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
output_tokens_sampler = iter(
|
|
60
|
+
IntegerRangeSampler(
|
|
61
|
+
average=self.config.output_tokens,
|
|
62
|
+
variance=self.config.output_tokens_stdev,
|
|
63
|
+
min_value=self.config.output_tokens_min,
|
|
64
|
+
max_value=self.config.output_tokens_max,
|
|
65
|
+
random_seed=iter_random_seed + 1, # ensure diff dist from prompts
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Create a shared prefix if specified
|
|
70
|
+
rand = Random(iter_random_seed + 3)
|
|
71
|
+
prefix_iter = self._create_prefix_iter(faker, rand)
|
|
72
|
+
samples_count = 0
|
|
73
|
+
|
|
74
|
+
while True:
|
|
75
|
+
prompt_tokens_count = next(prompt_tokens_sampler)
|
|
76
|
+
output_tokens_count = next(output_tokens_sampler)
|
|
77
|
+
|
|
78
|
+
yield (
|
|
79
|
+
samples_count,
|
|
80
|
+
{
|
|
81
|
+
"prefix": next(prefix_iter),
|
|
82
|
+
"prompt": self._create_prompt(
|
|
83
|
+
prompt_tokens_count,
|
|
84
|
+
faker,
|
|
85
|
+
f"{self.iteration_count} {samples_count} ",
|
|
86
|
+
),
|
|
87
|
+
"prompt_tokens_count": prompt_tokens_count,
|
|
88
|
+
"output_tokens_count": output_tokens_count,
|
|
89
|
+
},
|
|
90
|
+
)
|
|
91
|
+
samples_count += 1
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def is_typed(self) -> bool:
|
|
95
|
+
return True
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def features(self) -> Features:
|
|
99
|
+
return Features(
|
|
100
|
+
{
|
|
101
|
+
"prefix": Value("string"),
|
|
102
|
+
"prompt": Value("string"),
|
|
103
|
+
"prompt_tokens_count": Value("int32"),
|
|
104
|
+
"output_tokens_count": Value("int32"),
|
|
105
|
+
}
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def num_shards(self) -> int:
|
|
110
|
+
return 1
|
|
111
|
+
|
|
112
|
+
def shuffle_data_sources(
|
|
113
|
+
self,
|
|
114
|
+
generator: np.random.Generator, # noqa: ARG002
|
|
115
|
+
) -> _SyntheticTextExamplesIterable:
|
|
116
|
+
"""Return self since synthetic data doesn't have fixed sources to shuffle."""
|
|
117
|
+
return self
|
|
118
|
+
|
|
119
|
+
def shard_data_sources(
|
|
120
|
+
self,
|
|
121
|
+
num_shards: int, # noqa: ARG002
|
|
122
|
+
index: int, # noqa: ARG002
|
|
123
|
+
contiguous: bool = True, # noqa: ARG002
|
|
124
|
+
) -> _SyntheticTextExamplesIterable:
|
|
125
|
+
"""Return self since synthetic data generation is infinite and stateless."""
|
|
126
|
+
return self
|
|
127
|
+
|
|
128
|
+
def load_state_dict(self, state_dict: dict) -> None:
|
|
129
|
+
"""Load the state from a state dict."""
|
|
130
|
+
self.iteration_count = state_dict.get("iteration_count", 0)
|
|
131
|
+
|
|
132
|
+
def _init_state_dict(self) -> dict:
|
|
133
|
+
"""Initialize the state dict for the iterable."""
|
|
134
|
+
self._state_dict = {"iteration_count": self.iteration_count}
|
|
135
|
+
return self._state_dict
|
|
136
|
+
|
|
137
|
+
def _create_prompt(
|
|
138
|
+
self, prompt_tokens_count: int, faker: Faker, unique: str = ""
|
|
139
|
+
) -> str:
|
|
140
|
+
prompt_token_ids: list[int] = []
|
|
141
|
+
avg_chars_per_token = 5
|
|
142
|
+
margin_of_safety = 1.5
|
|
143
|
+
attempts = 0
|
|
144
|
+
|
|
145
|
+
while len(prompt_token_ids) < prompt_tokens_count:
|
|
146
|
+
attempts += 1
|
|
147
|
+
num_chars = int(
|
|
148
|
+
prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts
|
|
149
|
+
)
|
|
150
|
+
text = unique + faker.text(max_nb_chars=num_chars)
|
|
151
|
+
prompt_token_ids = self.processor.encode(text)
|
|
152
|
+
|
|
153
|
+
return self.processor.decode(
|
|
154
|
+
prompt_token_ids[:prompt_tokens_count], skip_special_tokens=True
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]:
|
|
158
|
+
if not self.config.prefix_buckets:
|
|
159
|
+
while True:
|
|
160
|
+
yield ""
|
|
161
|
+
|
|
162
|
+
# Increase weights to ensure an integer number of samples per per-prefix
|
|
163
|
+
least_common_prefix_count = math.lcm(
|
|
164
|
+
*(bucket.prefix_count for bucket in self.config.prefix_buckets)
|
|
165
|
+
)
|
|
166
|
+
unnorm_weights = [
|
|
167
|
+
least_common_prefix_count * bucket.bucket_weight // bucket.prefix_count
|
|
168
|
+
for bucket in self.config.prefix_buckets
|
|
169
|
+
]
|
|
170
|
+
# Use GCD to reduce the weights to smallest integer ratio
|
|
171
|
+
common_divisor = math.gcd(*unnorm_weights)
|
|
172
|
+
|
|
173
|
+
# Create prefix list maintaining the correct distribution
|
|
174
|
+
prefixes = []
|
|
175
|
+
for bucket, weight in zip(
|
|
176
|
+
self.config.prefix_buckets, unnorm_weights, strict=False
|
|
177
|
+
):
|
|
178
|
+
bucket_prefixes = [
|
|
179
|
+
self._create_prompt(bucket.prefix_tokens, faker)
|
|
180
|
+
for _ in range(bucket.prefix_count)
|
|
181
|
+
]
|
|
182
|
+
sample_count = weight // common_divisor
|
|
183
|
+
prefixes.extend(bucket_prefixes * sample_count)
|
|
184
|
+
|
|
185
|
+
while True:
|
|
186
|
+
yield rand.choice(prefixes)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class SyntheticTextDataset(IterableDataset):
|
|
190
|
+
def __init__(
|
|
191
|
+
self,
|
|
192
|
+
config: SyntheticTextDatasetConfig,
|
|
193
|
+
processor: PreTrainedTokenizerBase,
|
|
194
|
+
random_seed: int = 42,
|
|
195
|
+
):
|
|
196
|
+
self.config = config
|
|
197
|
+
self.processor = processor
|
|
198
|
+
self.random_seed = random_seed
|
|
199
|
+
|
|
200
|
+
# Create the examples iterable
|
|
201
|
+
ex_iterable = _SyntheticTextExamplesIterable(
|
|
202
|
+
config=config,
|
|
203
|
+
processor=processor,
|
|
204
|
+
random_seed=random_seed,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Initialize parent with proper ex_iterable
|
|
208
|
+
super().__init__(
|
|
209
|
+
ex_iterable=ex_iterable,
|
|
210
|
+
info=DatasetInfo(
|
|
211
|
+
description="Synthetic text dataset generator",
|
|
212
|
+
features=ex_iterable.features,
|
|
213
|
+
),
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
def set_epoch(self, epoch: int):
|
|
217
|
+
"""Set the epoch for the dataset iteration."""
|
|
218
|
+
if isinstance(self._ex_iterable, _SyntheticTextExamplesIterable):
|
|
219
|
+
self._ex_iterable.iteration_count = epoch
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@DatasetDeserializerFactory.register("synthetic_text")
|
|
223
|
+
class SyntheticTextDatasetDeserializer(DatasetDeserializer):
|
|
224
|
+
def __call__(
|
|
225
|
+
self,
|
|
226
|
+
data: Any,
|
|
227
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
228
|
+
random_seed: int,
|
|
229
|
+
**data_kwargs: dict[str, Any],
|
|
230
|
+
) -> IterableDataset:
|
|
231
|
+
# Config file and string pathways; deserialize and call self again
|
|
232
|
+
if (config := load_config(data, SyntheticTextDatasetConfig)) is not None:
|
|
233
|
+
return self(config, processor_factory, random_seed, **data_kwargs)
|
|
234
|
+
|
|
235
|
+
if not isinstance(data, SyntheticTextDatasetConfig):
|
|
236
|
+
raise DataNotSupportedError(
|
|
237
|
+
"Unsupported data for SyntheticTextDatasetDeserializer, "
|
|
238
|
+
"expected SyntheticTextDatasetConfig, str or Path to a config file, "
|
|
239
|
+
f"got {data}"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
return SyntheticTextDataset(
|
|
243
|
+
config=data,
|
|
244
|
+
processor=processor_factory(),
|
|
245
|
+
random_seed=random_seed,
|
|
246
|
+
)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from transformers import PreTrainedTokenizerBase
|
|
5
|
+
|
|
6
|
+
from guidellm.data import builders
|
|
7
|
+
from guidellm.data.builders import ShortPromptStrategy
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def process_dataset(
|
|
11
|
+
data: str | Path,
|
|
12
|
+
output_path: str | Path,
|
|
13
|
+
processor: str | Path | PreTrainedTokenizerBase,
|
|
14
|
+
config: str | Path,
|
|
15
|
+
processor_args: dict[str, Any] | None = None,
|
|
16
|
+
data_args: dict[str, Any] | None = None,
|
|
17
|
+
data_column_mapper: dict[str, str] | None = None,
|
|
18
|
+
short_prompt_strategy: ShortPromptStrategy = ShortPromptStrategy.IGNORE,
|
|
19
|
+
pad_char: str | None = None,
|
|
20
|
+
concat_delimiter: str | None = None,
|
|
21
|
+
include_prefix_in_token_count: bool = False,
|
|
22
|
+
push_to_hub: bool = False,
|
|
23
|
+
hub_dataset_id: str | None = None,
|
|
24
|
+
random_seed: int = 42,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Main method to process and save a dataset with sampled prompt/output token counts.
|
|
28
|
+
|
|
29
|
+
:param data: Path or identifier for dataset input.
|
|
30
|
+
:param output_path: File path to save the processed dataset.
|
|
31
|
+
:param processor: Tokenizer object or its config.
|
|
32
|
+
:param config: PreprocessDatasetConfig string or file path.
|
|
33
|
+
:param processor_args: Optional processor arguments.
|
|
34
|
+
:param data_args: Optional data loading arguments.
|
|
35
|
+
:param data_column_mapper: Optional column mapping dictionary.
|
|
36
|
+
:param short_prompt_strategy: Strategy for handling short prompts.
|
|
37
|
+
:param pad_char: Character used when padding short prompts.
|
|
38
|
+
:param concat_delimiter: Delimiter for concatenation strategy.
|
|
39
|
+
:param include_prefix_in_token_count:
|
|
40
|
+
Whether to include prefix in prompt token count, simplifying the token counts.
|
|
41
|
+
When True, prefix trimming is disabled and the prefix is kept as-is. The prefix
|
|
42
|
+
token count is subtracted from the prompt token budget instead.
|
|
43
|
+
:param push_to_hub: Whether to push to Hugging Face Hub.
|
|
44
|
+
:param hub_dataset_id: Dataset ID on Hugging Face Hub.
|
|
45
|
+
:param random_seed: Seed for random sampling.
|
|
46
|
+
:raises ValueError: If the output path is invalid or pushing conditions unmet.
|
|
47
|
+
"""
|
|
48
|
+
builders.process_dataset(
|
|
49
|
+
data, output_path, processor, config, processor_args, data_args,
|
|
50
|
+
data_column_mapper, short_prompt_strategy, pad_char, concat_delimiter,
|
|
51
|
+
include_prefix_in_token_count, push_to_hub, hub_dataset_id, random_seed,
|
|
52
|
+
)
|
guidellm/data/loaders.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
from collections.abc import Callable, Iterator
|
|
5
|
+
from typing import Any, Literal, TypeVar
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch.utils.data import Sampler
|
|
9
|
+
from torch.utils.data.dataloader import DataLoader as PyTorchDataLoader
|
|
10
|
+
from torch.utils.data.dataset import IterableDataset as TorchIterableDataset
|
|
11
|
+
from transformers import PreTrainedTokenizerBase
|
|
12
|
+
|
|
13
|
+
from guidellm.data.deserializers import DatasetDeserializerFactory
|
|
14
|
+
from guidellm.data.preprocessors import DataDependentPreprocessor, DatasetPreprocessor
|
|
15
|
+
from guidellm.logger import logger
|
|
16
|
+
from guidellm.utils import InfoMixin
|
|
17
|
+
|
|
18
|
+
__all__ = ["DataLoader", "DatasetsIterator"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
DataT = TypeVar("DataT")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class DatasetsIterator(TorchIterableDataset[DataT]):
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
data: list[Any],
|
|
28
|
+
data_args: list[dict[str, Any]] | None,
|
|
29
|
+
data_samples: int,
|
|
30
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
31
|
+
preprocessors: list[DatasetPreprocessor | DataDependentPreprocessor],
|
|
32
|
+
random_seed: int,
|
|
33
|
+
):
|
|
34
|
+
if not data or not isinstance(data, list):
|
|
35
|
+
raise ValueError(f"Data must be a non-empty list, got {data}.")
|
|
36
|
+
|
|
37
|
+
if not data_args:
|
|
38
|
+
data_args = [{} for _ in data]
|
|
39
|
+
|
|
40
|
+
if len(data) != len(data_args):
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"Length of data ({len(data)}) must match length of data_args "
|
|
43
|
+
f"({len(data_args)})."
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
self.datasets = []
|
|
47
|
+
for datum, data_kwargs in zip(data, data_args, strict=False):
|
|
48
|
+
self.datasets.append(
|
|
49
|
+
DatasetDeserializerFactory.deserialize(
|
|
50
|
+
data=datum,
|
|
51
|
+
processor_factory=processor_factory,
|
|
52
|
+
random_seed=random_seed,
|
|
53
|
+
**data_kwargs,
|
|
54
|
+
)
|
|
55
|
+
)
|
|
56
|
+
self.preprocessors = preprocessors
|
|
57
|
+
for preprocessor in self.preprocessors:
|
|
58
|
+
if isinstance(preprocessor, DataDependentPreprocessor):
|
|
59
|
+
preprocessor.setup_data(
|
|
60
|
+
datasets=self.datasets,
|
|
61
|
+
data_args=data_args,
|
|
62
|
+
)
|
|
63
|
+
self.precache: list[Any] | None = (
|
|
64
|
+
list(self.generator(data_samples)) if data_samples else None
|
|
65
|
+
)
|
|
66
|
+
self.epoch = 0
|
|
67
|
+
|
|
68
|
+
def __iter__(self) -> Iterator[DataT]:
|
|
69
|
+
worker_info = torch.utils.data.get_worker_info()
|
|
70
|
+
worker_modulus = worker_info.num_workers if worker_info is not None else 1
|
|
71
|
+
worker_index = worker_info.id if worker_info is not None else 0
|
|
72
|
+
|
|
73
|
+
if self.precache:
|
|
74
|
+
for index, item in enumerate(self.precache):
|
|
75
|
+
if (index + worker_index) % worker_modulus == 0:
|
|
76
|
+
yield item
|
|
77
|
+
else:
|
|
78
|
+
yield from self.generator(
|
|
79
|
+
modulus=worker_modulus, offset=worker_index, epoch=self.epoch
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def set_epoch(self, epoch: int):
|
|
83
|
+
self.epoch = epoch
|
|
84
|
+
|
|
85
|
+
def generator(
|
|
86
|
+
self,
|
|
87
|
+
max_items: int | None = None,
|
|
88
|
+
modulus: int | None = None,
|
|
89
|
+
offset: int | None = None,
|
|
90
|
+
epoch: int = 0,
|
|
91
|
+
) -> Iterator[DataT]:
|
|
92
|
+
gen_count = 0
|
|
93
|
+
|
|
94
|
+
with contextlib.suppress(StopIteration):
|
|
95
|
+
dataset_iters = []
|
|
96
|
+
for dataset in self.datasets:
|
|
97
|
+
if hasattr(dataset, "set_epoch"):
|
|
98
|
+
with contextlib.suppress(Exception):
|
|
99
|
+
dataset.set_epoch(epoch)
|
|
100
|
+
dataset_iters.append(iter(dataset))
|
|
101
|
+
|
|
102
|
+
while max_items is None or gen_count < max_items:
|
|
103
|
+
try:
|
|
104
|
+
row: dict[str, Any] = {
|
|
105
|
+
"items": [next(dataset_iter) for dataset_iter in dataset_iters]
|
|
106
|
+
}
|
|
107
|
+
gen_count += 1
|
|
108
|
+
|
|
109
|
+
if (
|
|
110
|
+
modulus is not None
|
|
111
|
+
and offset is not None
|
|
112
|
+
and (gen_count % modulus) != offset
|
|
113
|
+
):
|
|
114
|
+
continue
|
|
115
|
+
|
|
116
|
+
for preprocessor in self.preprocessors:
|
|
117
|
+
# This can assign a GenerationRequest, which would then be
|
|
118
|
+
# passed into the preprocessor, which is a type violation.
|
|
119
|
+
# This should be fixed at some point.
|
|
120
|
+
row = preprocessor(row) # type: ignore[assignment]
|
|
121
|
+
yield row # type: ignore[misc]
|
|
122
|
+
except StopIteration:
|
|
123
|
+
raise # Stop iteration when any dataset is exhausted
|
|
124
|
+
except Exception as err: # noqa: BLE001 # Exception logged
|
|
125
|
+
logger.error(f"Skipping data row due to error: {err}")
|
|
126
|
+
gen_count -= 1
|
|
127
|
+
|
|
128
|
+
if max_items is not None and gen_count < max_items:
|
|
129
|
+
raise ValueError(
|
|
130
|
+
f"Requested {max_items} samples, but only {gen_count} "
|
|
131
|
+
"available from the provided datasets."
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class DataLoader(PyTorchDataLoader[DataT], InfoMixin):
|
|
136
|
+
def __init__(
|
|
137
|
+
self,
|
|
138
|
+
data: list[Any],
|
|
139
|
+
data_args: list[dict[str, Any]] | None,
|
|
140
|
+
data_samples: int,
|
|
141
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
142
|
+
preprocessors: list[DatasetPreprocessor | DataDependentPreprocessor],
|
|
143
|
+
collator: Callable,
|
|
144
|
+
sampler: Sampler[int] | Literal["shuffle"] | None = None,
|
|
145
|
+
num_workers: int | None = 1,
|
|
146
|
+
random_seed: int = 42,
|
|
147
|
+
**kwargs: Any,
|
|
148
|
+
):
|
|
149
|
+
iterator: DatasetsIterator[DataT] = DatasetsIterator(
|
|
150
|
+
data=data,
|
|
151
|
+
data_args=data_args,
|
|
152
|
+
data_samples=data_samples,
|
|
153
|
+
processor_factory=processor_factory,
|
|
154
|
+
preprocessors=preprocessors,
|
|
155
|
+
random_seed=random_seed,
|
|
156
|
+
)
|
|
157
|
+
self._info: dict[str, Any] = {
|
|
158
|
+
"data": str(data),
|
|
159
|
+
"data_args": str(data_args),
|
|
160
|
+
"data_samples": data_samples,
|
|
161
|
+
"preprocessors": [
|
|
162
|
+
preprocessor.__class__.__name__ for preprocessor in preprocessors
|
|
163
|
+
],
|
|
164
|
+
"collator": collator.__class__.__name__,
|
|
165
|
+
"sampler": str(sampler),
|
|
166
|
+
"num_workers": num_workers,
|
|
167
|
+
"random_seed": random_seed,
|
|
168
|
+
}
|
|
169
|
+
self.epoch = 0
|
|
170
|
+
|
|
171
|
+
super().__init__(
|
|
172
|
+
dataset=iterator,
|
|
173
|
+
batch_size=1,
|
|
174
|
+
shuffle=sampler == "shuffle",
|
|
175
|
+
sampler=sampler if sampler != "shuffle" else None,
|
|
176
|
+
collate_fn=collator,
|
|
177
|
+
num_workers=num_workers or 0,
|
|
178
|
+
**kwargs,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
def __iter__(self):
|
|
182
|
+
if isinstance(self.dataset, DatasetsIterator):
|
|
183
|
+
self.dataset.set_epoch(self.epoch)
|
|
184
|
+
self.epoch += 1
|
|
185
|
+
|
|
186
|
+
return super().__iter__()
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def info(self) -> dict[str, Any]:
|
|
190
|
+
return self._info
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from .formatters import (
|
|
2
|
+
GenerativeAudioTranscriptionRequestFormatter,
|
|
3
|
+
GenerativeAudioTranslationRequestFormatter,
|
|
4
|
+
GenerativeChatCompletionsRequestFormatter,
|
|
5
|
+
GenerativeTextCompletionsRequestFormatter,
|
|
6
|
+
RequestFormatter,
|
|
7
|
+
)
|
|
8
|
+
from .mappers import GenerativeColumnMapper
|
|
9
|
+
from .preprocessor import (
|
|
10
|
+
DataDependentPreprocessor,
|
|
11
|
+
DatasetPreprocessor,
|
|
12
|
+
PreprocessorRegistry,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"ColumnMapper",
|
|
17
|
+
"ColumnMapperRegistry",
|
|
18
|
+
"DataDependentPreprocessor",
|
|
19
|
+
"DatasetPreprocessor",
|
|
20
|
+
"GenerativeAudioTranscriptionRequestFormatter",
|
|
21
|
+
"GenerativeAudioTranslationRequestFormatter",
|
|
22
|
+
"GenerativeChatCompletionsRequestFormatter",
|
|
23
|
+
"GenerativeColumnMapper",
|
|
24
|
+
"GenerativeTextCompletionsRequestFormatter",
|
|
25
|
+
"PreprocessorRegistry",
|
|
26
|
+
"RequestFormatter",
|
|
27
|
+
]
|