guidellm 0.4.0a21__py3-none-any.whl → 0.4.0a155__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guidellm might be problematic. Click here for more details.
- guidellm/__init__.py +5 -2
- guidellm/__main__.py +451 -252
- guidellm/backends/__init__.py +33 -0
- guidellm/backends/backend.py +110 -0
- guidellm/backends/openai.py +355 -0
- guidellm/backends/response_handlers.py +455 -0
- guidellm/benchmark/__init__.py +53 -39
- guidellm/benchmark/benchmarker.py +148 -317
- guidellm/benchmark/entrypoints.py +466 -128
- guidellm/benchmark/output.py +517 -771
- guidellm/benchmark/profile.py +580 -280
- guidellm/benchmark/progress.py +568 -549
- guidellm/benchmark/scenarios/__init__.py +40 -0
- guidellm/benchmark/scenarios/chat.json +6 -0
- guidellm/benchmark/scenarios/rag.json +6 -0
- guidellm/benchmark/schemas.py +2085 -0
- guidellm/data/__init__.py +28 -4
- guidellm/data/collators.py +16 -0
- guidellm/data/deserializers/__init__.py +53 -0
- guidellm/data/deserializers/deserializer.py +109 -0
- guidellm/data/deserializers/file.py +222 -0
- guidellm/data/deserializers/huggingface.py +94 -0
- guidellm/data/deserializers/memory.py +192 -0
- guidellm/data/deserializers/synthetic.py +346 -0
- guidellm/data/loaders.py +145 -0
- guidellm/data/preprocessors/__init__.py +25 -0
- guidellm/data/preprocessors/formatters.py +412 -0
- guidellm/data/preprocessors/mappers.py +198 -0
- guidellm/data/preprocessors/preprocessor.py +29 -0
- guidellm/data/processor.py +30 -0
- guidellm/data/schemas.py +13 -0
- guidellm/data/utils/__init__.py +10 -0
- guidellm/data/utils/dataset.py +94 -0
- guidellm/data/utils/functions.py +18 -0
- guidellm/extras/__init__.py +4 -0
- guidellm/extras/audio.py +215 -0
- guidellm/extras/vision.py +242 -0
- guidellm/logger.py +2 -2
- guidellm/mock_server/__init__.py +8 -0
- guidellm/mock_server/config.py +84 -0
- guidellm/mock_server/handlers/__init__.py +17 -0
- guidellm/mock_server/handlers/chat_completions.py +280 -0
- guidellm/mock_server/handlers/completions.py +280 -0
- guidellm/mock_server/handlers/tokenizer.py +142 -0
- guidellm/mock_server/models.py +510 -0
- guidellm/mock_server/server.py +168 -0
- guidellm/mock_server/utils.py +302 -0
- guidellm/preprocess/dataset.py +23 -26
- guidellm/presentation/builder.py +2 -2
- guidellm/presentation/data_models.py +25 -21
- guidellm/presentation/injector.py +2 -3
- guidellm/scheduler/__init__.py +65 -26
- guidellm/scheduler/constraints.py +1035 -0
- guidellm/scheduler/environments.py +252 -0
- guidellm/scheduler/scheduler.py +140 -368
- guidellm/scheduler/schemas.py +272 -0
- guidellm/scheduler/strategies.py +519 -0
- guidellm/scheduler/worker.py +391 -420
- guidellm/scheduler/worker_group.py +707 -0
- guidellm/schemas/__init__.py +31 -0
- guidellm/schemas/info.py +159 -0
- guidellm/schemas/request.py +216 -0
- guidellm/schemas/response.py +119 -0
- guidellm/schemas/stats.py +228 -0
- guidellm/{config.py → settings.py} +32 -21
- guidellm/utils/__init__.py +95 -8
- guidellm/utils/auto_importer.py +98 -0
- guidellm/utils/cli.py +46 -2
- guidellm/utils/console.py +183 -0
- guidellm/utils/encoding.py +778 -0
- guidellm/utils/functions.py +134 -0
- guidellm/utils/hf_datasets.py +1 -2
- guidellm/utils/hf_transformers.py +4 -4
- guidellm/utils/imports.py +9 -0
- guidellm/utils/messaging.py +1118 -0
- guidellm/utils/mixins.py +115 -0
- guidellm/utils/pydantic_utils.py +411 -0
- guidellm/utils/random.py +3 -4
- guidellm/utils/registry.py +220 -0
- guidellm/utils/singleton.py +133 -0
- guidellm/{objects → utils}/statistics.py +341 -247
- guidellm/utils/synchronous.py +159 -0
- guidellm/utils/text.py +163 -50
- guidellm/utils/typing.py +41 -0
- guidellm/version.py +1 -1
- {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a155.dist-info}/METADATA +33 -10
- guidellm-0.4.0a155.dist-info/RECORD +96 -0
- guidellm/backend/__init__.py +0 -23
- guidellm/backend/backend.py +0 -259
- guidellm/backend/openai.py +0 -705
- guidellm/backend/response.py +0 -136
- guidellm/benchmark/aggregator.py +0 -760
- guidellm/benchmark/benchmark.py +0 -837
- guidellm/benchmark/scenario.py +0 -104
- guidellm/data/prideandprejudice.txt.gz +0 -0
- guidellm/dataset/__init__.py +0 -22
- guidellm/dataset/creator.py +0 -213
- guidellm/dataset/entrypoints.py +0 -42
- guidellm/dataset/file.py +0 -92
- guidellm/dataset/hf_datasets.py +0 -62
- guidellm/dataset/in_memory.py +0 -132
- guidellm/dataset/synthetic.py +0 -287
- guidellm/objects/__init__.py +0 -18
- guidellm/objects/pydantic.py +0 -89
- guidellm/request/__init__.py +0 -18
- guidellm/request/loader.py +0 -284
- guidellm/request/request.py +0 -79
- guidellm/request/types.py +0 -10
- guidellm/scheduler/queues.py +0 -25
- guidellm/scheduler/result.py +0 -155
- guidellm/scheduler/strategy.py +0 -495
- guidellm-0.4.0a21.dist-info/RECORD +0 -62
- {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a155.dist-info}/WHEEL +0 -0
- {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a155.dist-info}/entry_points.txt +0 -0
- {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a155.dist-info}/licenses/LICENSE +0 -0
- {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a155.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import csv
|
|
5
|
+
import json
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from io import StringIO
|
|
8
|
+
from typing import Any, cast
|
|
9
|
+
|
|
10
|
+
from datasets import Dataset
|
|
11
|
+
from transformers import PreTrainedTokenizerBase
|
|
12
|
+
|
|
13
|
+
from guidellm.data.deserializers.deserializer import (
|
|
14
|
+
DataNotSupportedError,
|
|
15
|
+
DatasetDeserializer,
|
|
16
|
+
DatasetDeserializerFactory,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"InMemoryCsvDatasetDeserializer",
|
|
21
|
+
"InMemoryDictDatasetDeserializer",
|
|
22
|
+
"InMemoryDictListDatasetDeserializer",
|
|
23
|
+
"InMemoryItemListDatasetDeserializer",
|
|
24
|
+
"InMemoryJsonStrDatasetDeserializer",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@DatasetDeserializerFactory.register("in_memory_dict")
|
|
29
|
+
class InMemoryDictDatasetDeserializer(DatasetDeserializer):
|
|
30
|
+
def __call__(
|
|
31
|
+
self,
|
|
32
|
+
data: Any,
|
|
33
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
34
|
+
random_seed: int,
|
|
35
|
+
**data_kwargs: dict[str, Any],
|
|
36
|
+
) -> dict[str, list]:
|
|
37
|
+
_ = (processor_factory, random_seed) # Ignore unused args format errors
|
|
38
|
+
|
|
39
|
+
if (
|
|
40
|
+
not data
|
|
41
|
+
or not isinstance(data, dict)
|
|
42
|
+
or not all(
|
|
43
|
+
isinstance(key, str) and isinstance(val, list)
|
|
44
|
+
for key, val in data.items()
|
|
45
|
+
)
|
|
46
|
+
):
|
|
47
|
+
raise DataNotSupportedError(
|
|
48
|
+
f"Unsupported data for InMemoryDictDatasetDeserializer, "
|
|
49
|
+
f"expected dict[str, list], got {data}"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
rows = len(list(data.values())[0])
|
|
53
|
+
if not all(len(val) == rows for val in data.values()):
|
|
54
|
+
raise DataNotSupportedError(
|
|
55
|
+
"All lists in the data dictionary must have the same length, "
|
|
56
|
+
f"expected {rows} for all keys {list(data.keys())}"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
return Dataset.from_dict(data, **data_kwargs)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@DatasetDeserializerFactory.register("in_memory_dict_list")
|
|
63
|
+
class InMemoryDictListDatasetDeserializer(DatasetDeserializer):
|
|
64
|
+
def __call__(
|
|
65
|
+
self,
|
|
66
|
+
data: Any,
|
|
67
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
68
|
+
random_seed: int,
|
|
69
|
+
**data_kwargs: dict[str, Any],
|
|
70
|
+
) -> dict[str, list]:
|
|
71
|
+
_ = (processor_factory, random_seed) # Ignore unused args format errors
|
|
72
|
+
|
|
73
|
+
if (
|
|
74
|
+
not data
|
|
75
|
+
or not isinstance(data, list)
|
|
76
|
+
or not all(isinstance(item, dict) for item in data)
|
|
77
|
+
or not all(isinstance(key, str) for item in data for key in item)
|
|
78
|
+
):
|
|
79
|
+
raise DataNotSupportedError(
|
|
80
|
+
f"Unsupported data for InMemoryDictListDatasetDeserializer, "
|
|
81
|
+
f"expected list of dicts, got {data}"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data)
|
|
85
|
+
first_keys = set(data[0].keys())
|
|
86
|
+
for index, item in enumerate(data):
|
|
87
|
+
if set(item.keys()) != first_keys:
|
|
88
|
+
raise DataNotSupportedError(
|
|
89
|
+
f"All dictionaries must have the same keys. "
|
|
90
|
+
f"Expected keys: {first_keys}, "
|
|
91
|
+
f"got keys at index {index}: {set(item.keys())}"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Convert list of dicts to dict of lists
|
|
95
|
+
result_dict = {key: [] for key in first_keys}
|
|
96
|
+
for item in data:
|
|
97
|
+
for key, value in item.items():
|
|
98
|
+
result_dict[key].append(value)
|
|
99
|
+
|
|
100
|
+
return Dataset.from_dict(result_dict, **data_kwargs)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@DatasetDeserializerFactory.register("in_memory_item_list")
|
|
104
|
+
class InMemoryItemListDatasetDeserializer(DatasetDeserializer):
|
|
105
|
+
def __call__(
|
|
106
|
+
self,
|
|
107
|
+
data: Any,
|
|
108
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
109
|
+
random_seed: int,
|
|
110
|
+
**data_kwargs: dict[str, Any],
|
|
111
|
+
) -> dict[str, list]:
|
|
112
|
+
_ = (processor_factory, random_seed) # Ignore unused args format errors
|
|
113
|
+
|
|
114
|
+
primitive_types = (str, int, float, bool, type(None))
|
|
115
|
+
if (
|
|
116
|
+
not data
|
|
117
|
+
or not isinstance(data, list)
|
|
118
|
+
or not all(isinstance(item, primitive_types) for item in data)
|
|
119
|
+
):
|
|
120
|
+
raise DataNotSupportedError(
|
|
121
|
+
f"Unsupported data for InMemoryItemListDatasetDeserializer, "
|
|
122
|
+
f"expected list of primitive items, got {data}"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
column_name = data_kwargs.pop("column_name", "data")
|
|
126
|
+
|
|
127
|
+
return Dataset.from_dict({column_name: data}, **data_kwargs)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@DatasetDeserializerFactory.register("in_memory_json_str")
|
|
131
|
+
class InMemoryJsonStrDatasetDeserializer(DatasetDeserializer):
|
|
132
|
+
def __call__(
|
|
133
|
+
self,
|
|
134
|
+
data: Any,
|
|
135
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
136
|
+
random_seed: int,
|
|
137
|
+
**data_kwargs: dict[str, Any],
|
|
138
|
+
) -> dict[str, list]:
|
|
139
|
+
if (
|
|
140
|
+
isinstance(data, str)
|
|
141
|
+
and (json_str := data.strip())
|
|
142
|
+
and (
|
|
143
|
+
(json_str.startswith("{") and json_str.endswith("}"))
|
|
144
|
+
or (json_str.startswith("[") and json_str.endswith("]"))
|
|
145
|
+
)
|
|
146
|
+
):
|
|
147
|
+
with contextlib.suppress(Exception):
|
|
148
|
+
parsed = json.loads(data)
|
|
149
|
+
|
|
150
|
+
for deserializer in [
|
|
151
|
+
InMemoryDictDatasetDeserializer,
|
|
152
|
+
InMemoryDictListDatasetDeserializer,
|
|
153
|
+
InMemoryItemListDatasetDeserializer,
|
|
154
|
+
]:
|
|
155
|
+
with contextlib.suppress(DataNotSupportedError):
|
|
156
|
+
return deserializer()(
|
|
157
|
+
parsed, data_kwargs, processor_factory, random_seed
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
raise DataNotSupportedError(
|
|
161
|
+
f"Unsupported data for InMemoryJsonStrDatasetDeserializer, "
|
|
162
|
+
f"expected JSON string with a list or dict of items, got {data}"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@DatasetDeserializerFactory.register("in_memory_csv_str")
|
|
167
|
+
class InMemoryCsvDatasetDeserializer(DatasetDeserializer):
|
|
168
|
+
def __call__(
|
|
169
|
+
self,
|
|
170
|
+
data: Any,
|
|
171
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
172
|
+
random_seed: int,
|
|
173
|
+
**data_kwargs: dict[str, Any],
|
|
174
|
+
) -> dict[str, list]:
|
|
175
|
+
if (
|
|
176
|
+
isinstance(data, str)
|
|
177
|
+
and (csv_str := data.strip())
|
|
178
|
+
and len(csv_str.split("\n")) > 0
|
|
179
|
+
):
|
|
180
|
+
with contextlib.suppress(Exception):
|
|
181
|
+
csv_buffer = StringIO(data)
|
|
182
|
+
reader = csv.DictReader(csv_buffer)
|
|
183
|
+
rows = list(reader)
|
|
184
|
+
|
|
185
|
+
return InMemoryDictListDatasetDeserializer()(
|
|
186
|
+
rows, processor_factory, random_seed, **data_kwargs
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
raise DataNotSupportedError(
|
|
190
|
+
f"Unsupported data for InMemoryCsvDatasetDeserializer, "
|
|
191
|
+
f"expected CSV string, got {type(data)}"
|
|
192
|
+
)
|
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from collections.abc import Callable, Iterator
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from random import Random
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import yaml
|
|
10
|
+
from datasets import Features, IterableDataset, Value
|
|
11
|
+
from faker import Faker
|
|
12
|
+
from pydantic import ConfigDict, Field, model_validator
|
|
13
|
+
from transformers import PreTrainedTokenizerBase
|
|
14
|
+
|
|
15
|
+
from guidellm.data.deserializers.deserializer import (
|
|
16
|
+
DataNotSupportedError,
|
|
17
|
+
DatasetDeserializer,
|
|
18
|
+
DatasetDeserializerFactory,
|
|
19
|
+
)
|
|
20
|
+
from guidellm.utils import IntegerRangeSampler, StandardBaseModel
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"SyntheticTextDatasetConfig",
|
|
24
|
+
"SyntheticTextDatasetDeserializer",
|
|
25
|
+
"SyntheticTextGenerator",
|
|
26
|
+
"SyntheticTextPrefixBucketConfig",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SyntheticTextPrefixBucketConfig(StandardBaseModel):
|
|
31
|
+
bucket_weight: int = Field(
|
|
32
|
+
description="Weight of this bucket in the overall distribution.",
|
|
33
|
+
gt=0,
|
|
34
|
+
default=100,
|
|
35
|
+
)
|
|
36
|
+
prefix_count: int = Field(
|
|
37
|
+
description="The number of unique prefixes to generate for this bucket.",
|
|
38
|
+
ge=1,
|
|
39
|
+
default=1,
|
|
40
|
+
)
|
|
41
|
+
prefix_tokens: int = Field(
|
|
42
|
+
description="The number of prefix tokens per-prompt for this bucket.",
|
|
43
|
+
ge=0,
|
|
44
|
+
default=0,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SyntheticTextDatasetConfig(StandardBaseModel):
|
|
49
|
+
model_config = ConfigDict(
|
|
50
|
+
extra="allow",
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
prefix_buckets: list[SyntheticTextPrefixBucketConfig] | None = Field(
|
|
54
|
+
description="Buckets for the prefix tokens distribution.",
|
|
55
|
+
default=None,
|
|
56
|
+
)
|
|
57
|
+
prompt_tokens: int = Field(
|
|
58
|
+
description="The average number of text tokens generated for prompts.",
|
|
59
|
+
gt=0,
|
|
60
|
+
)
|
|
61
|
+
prompt_tokens_stdev: int | None = Field(
|
|
62
|
+
description="The standard deviation of the tokens generated for prompts.",
|
|
63
|
+
gt=0,
|
|
64
|
+
default=None,
|
|
65
|
+
)
|
|
66
|
+
prompt_tokens_min: int | None = Field(
|
|
67
|
+
description="The minimum number of text tokens generated for prompts.",
|
|
68
|
+
gt=0,
|
|
69
|
+
default=None,
|
|
70
|
+
)
|
|
71
|
+
prompt_tokens_max: int | None = Field(
|
|
72
|
+
description="The maximum number of text tokens generated for prompts.",
|
|
73
|
+
gt=0,
|
|
74
|
+
default=None,
|
|
75
|
+
)
|
|
76
|
+
output_tokens: int = Field(
|
|
77
|
+
description="The average number of text tokens generated for outputs.",
|
|
78
|
+
gt=0,
|
|
79
|
+
)
|
|
80
|
+
output_tokens_stdev: int | None = Field(
|
|
81
|
+
description="The standard deviation of the tokens generated for outputs.",
|
|
82
|
+
gt=0,
|
|
83
|
+
default=None,
|
|
84
|
+
)
|
|
85
|
+
output_tokens_min: int | None = Field(
|
|
86
|
+
description="The minimum number of text tokens generated for outputs.",
|
|
87
|
+
gt=0,
|
|
88
|
+
default=None,
|
|
89
|
+
)
|
|
90
|
+
output_tokens_max: int | None = Field(
|
|
91
|
+
description="The maximum number of text tokens generated for outputs.",
|
|
92
|
+
gt=0,
|
|
93
|
+
default=None,
|
|
94
|
+
)
|
|
95
|
+
source: str = Field(
|
|
96
|
+
description="The source of the text data to be used for generation.",
|
|
97
|
+
default="data:prideandprejudice.txt.gz",
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
@model_validator(mode="after")
|
|
101
|
+
def check_prefix_options(self) -> SyntheticTextDatasetConfig:
|
|
102
|
+
prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined]
|
|
103
|
+
prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined]
|
|
104
|
+
if prefix_count is not None or prefix_tokens is not None:
|
|
105
|
+
if self.prefix_buckets:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
"prefix_buckets is mutually exclusive"
|
|
108
|
+
" with prefix_count and prefix_tokens"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
self.prefix_buckets = [
|
|
112
|
+
SyntheticTextPrefixBucketConfig(
|
|
113
|
+
prefix_count=prefix_count or 1,
|
|
114
|
+
prefix_tokens=prefix_tokens or 0,
|
|
115
|
+
)
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
return self
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class SyntheticTextGenerator:
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
config: SyntheticTextDatasetConfig,
|
|
125
|
+
processor: PreTrainedTokenizerBase,
|
|
126
|
+
random_seed: int = 42,
|
|
127
|
+
):
|
|
128
|
+
self.config = config
|
|
129
|
+
self.processor = processor
|
|
130
|
+
self.random_seed = random_seed
|
|
131
|
+
|
|
132
|
+
def __iter__(self) -> Iterator[dict[str, Any]]:
|
|
133
|
+
samples_generated = 0
|
|
134
|
+
|
|
135
|
+
faker = Faker()
|
|
136
|
+
faker.seed_instance(self.random_seed)
|
|
137
|
+
prompt_tokens_sampler = iter(
|
|
138
|
+
IntegerRangeSampler(
|
|
139
|
+
average=self.config.prompt_tokens,
|
|
140
|
+
variance=self.config.prompt_tokens_stdev,
|
|
141
|
+
min_value=self.config.prompt_tokens_min,
|
|
142
|
+
max_value=self.config.prompt_tokens_max,
|
|
143
|
+
random_seed=self.random_seed,
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
output_tokens_sampler = iter(
|
|
147
|
+
IntegerRangeSampler(
|
|
148
|
+
average=self.config.output_tokens,
|
|
149
|
+
variance=self.config.output_tokens_stdev,
|
|
150
|
+
min_value=self.config.output_tokens_min,
|
|
151
|
+
max_value=self.config.output_tokens_max,
|
|
152
|
+
random_seed=self.random_seed + 1, # ensure diff dist from prompts
|
|
153
|
+
)
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Create a shared prefix if specified
|
|
157
|
+
rand = Random(self.random_seed + 3)
|
|
158
|
+
prefix_iter = self._create_prefix_iter(faker, rand)
|
|
159
|
+
|
|
160
|
+
while True:
|
|
161
|
+
prompt_tokens_count = next(prompt_tokens_sampler)
|
|
162
|
+
output_tokens_count = next(output_tokens_sampler)
|
|
163
|
+
|
|
164
|
+
yield {
|
|
165
|
+
"prefix": next(prefix_iter),
|
|
166
|
+
"prompt": self._create_prompt(
|
|
167
|
+
prompt_tokens_count, faker, f"{samples_generated} "
|
|
168
|
+
),
|
|
169
|
+
"prompt_tokens_count": prompt_tokens_count,
|
|
170
|
+
"output_tokens_count": output_tokens_count,
|
|
171
|
+
}
|
|
172
|
+
samples_generated += 1
|
|
173
|
+
|
|
174
|
+
def _create_prompt(
|
|
175
|
+
self, prompt_tokens_count: int, faker: Faker, unique: str = ""
|
|
176
|
+
) -> str:
|
|
177
|
+
prompt_token_ids = []
|
|
178
|
+
avg_chars_per_token = 5
|
|
179
|
+
margin_of_safety = 1.5
|
|
180
|
+
attempts = 0
|
|
181
|
+
|
|
182
|
+
while len(prompt_token_ids) < prompt_tokens_count:
|
|
183
|
+
attempts += 1
|
|
184
|
+
num_chars = (
|
|
185
|
+
prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts
|
|
186
|
+
)
|
|
187
|
+
text = unique + faker.text(max_nb_chars=num_chars)
|
|
188
|
+
prompt_token_ids = self.processor.encode(text)
|
|
189
|
+
|
|
190
|
+
return self.processor.decode(
|
|
191
|
+
prompt_token_ids[:prompt_tokens_count], skip_special_tokens=True
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]:
|
|
195
|
+
if not self.config.prefix_buckets:
|
|
196
|
+
while True:
|
|
197
|
+
yield ""
|
|
198
|
+
|
|
199
|
+
# Increase weights to ensure an integer number of samples per per-prefix
|
|
200
|
+
least_common_prefix_count = math.lcm(
|
|
201
|
+
*(bucket.prefix_count for bucket in self.config.prefix_buckets)
|
|
202
|
+
)
|
|
203
|
+
unnorm_weights = [
|
|
204
|
+
least_common_prefix_count * bucket.bucket_weight // bucket.prefix_count
|
|
205
|
+
for bucket in self.config.prefix_buckets
|
|
206
|
+
]
|
|
207
|
+
# Use GCD to reduce the weights to smallest integer ratio
|
|
208
|
+
common_divisor = math.gcd(*unnorm_weights)
|
|
209
|
+
|
|
210
|
+
# Create prefix list maintaining the correct distribution
|
|
211
|
+
prefixes = []
|
|
212
|
+
for bucket, weight in zip(
|
|
213
|
+
self.config.prefix_buckets, unnorm_weights, strict=False
|
|
214
|
+
):
|
|
215
|
+
bucket_prefixes = [
|
|
216
|
+
self._create_prompt(bucket.prefix_tokens, faker)
|
|
217
|
+
for _ in range(bucket.prefix_count)
|
|
218
|
+
]
|
|
219
|
+
sample_count = weight // common_divisor
|
|
220
|
+
prefixes.extend(bucket_prefixes * sample_count)
|
|
221
|
+
|
|
222
|
+
while True:
|
|
223
|
+
yield rand.choice(prefixes)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
@DatasetDeserializerFactory.register("synthetic_text")
|
|
227
|
+
class SyntheticTextDatasetDeserializer(DatasetDeserializer):
|
|
228
|
+
def __call__(
|
|
229
|
+
self,
|
|
230
|
+
data: Any,
|
|
231
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
232
|
+
random_seed: int,
|
|
233
|
+
**data_kwargs: dict[str, Any],
|
|
234
|
+
) -> IterableDataset:
|
|
235
|
+
# Config file pathways, deserialize and call self again
|
|
236
|
+
if (config := self._load_config_file(data)) is not None:
|
|
237
|
+
return self(config, processor_factory, random_seed, **data_kwargs)
|
|
238
|
+
|
|
239
|
+
# Config str pathways, deserialize and call self again
|
|
240
|
+
if (config := self._load_config_str(data)) is not None:
|
|
241
|
+
return self(config, processor_factory, random_seed, **data_kwargs)
|
|
242
|
+
|
|
243
|
+
if not isinstance(data, SyntheticTextDatasetConfig):
|
|
244
|
+
raise DataNotSupportedError(
|
|
245
|
+
"Unsupported data for SyntheticTextDatasetDeserializer, "
|
|
246
|
+
"expected SyntheticTextDatasetConfig, str or Path to a config file, "
|
|
247
|
+
f"got {data}"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
return IterableDataset.from_generator(
|
|
251
|
+
SyntheticTextGenerator,
|
|
252
|
+
gen_kwargs={
|
|
253
|
+
"config": data,
|
|
254
|
+
"processor": processor_factory(),
|
|
255
|
+
"random_seed": random_seed,
|
|
256
|
+
},
|
|
257
|
+
features=Features(
|
|
258
|
+
{
|
|
259
|
+
"prefix": Value("string"),
|
|
260
|
+
"prompt": Value("string"),
|
|
261
|
+
"prompt_tokens_count": Value("int32"),
|
|
262
|
+
"output_tokens_count": Value("int32"),
|
|
263
|
+
}
|
|
264
|
+
),
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
def _load_config_file(self, data: Any) -> SyntheticTextDatasetConfig | None:
|
|
268
|
+
if (not isinstance(data, str) and not isinstance(data, Path)) or (
|
|
269
|
+
not Path(data).is_file()
|
|
270
|
+
):
|
|
271
|
+
return None
|
|
272
|
+
|
|
273
|
+
data_path = Path(data) if isinstance(data, str) else data
|
|
274
|
+
error = None
|
|
275
|
+
|
|
276
|
+
if Path(data).is_file() and data_path.suffix.lower() == ".json":
|
|
277
|
+
try:
|
|
278
|
+
return SyntheticTextDatasetConfig.model_validate_json(
|
|
279
|
+
data_path.read_text()
|
|
280
|
+
)
|
|
281
|
+
except Exception as err: # noqa: BLE001
|
|
282
|
+
error = err
|
|
283
|
+
|
|
284
|
+
if Path(data).is_file() and data_path.suffix.lower() in {
|
|
285
|
+
".yaml",
|
|
286
|
+
".yml",
|
|
287
|
+
".config",
|
|
288
|
+
}:
|
|
289
|
+
try:
|
|
290
|
+
return SyntheticTextDatasetConfig.model_validate(
|
|
291
|
+
yaml.safe_load(data_path.read_text())
|
|
292
|
+
)
|
|
293
|
+
except Exception as err: # noqa: BLE001
|
|
294
|
+
error = err
|
|
295
|
+
|
|
296
|
+
err_message = (
|
|
297
|
+
f"Unsupported file {data_path} for "
|
|
298
|
+
f"SyntheticTextDatasetDeserializer, expected .json, "
|
|
299
|
+
f".yaml, .yml, or .config"
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
if error is not None:
|
|
303
|
+
err_message += f" with error: {error}"
|
|
304
|
+
raise DataNotSupportedError(err_message) from error
|
|
305
|
+
raise DataNotSupportedError(err_message)
|
|
306
|
+
|
|
307
|
+
def _load_config_str(self, data: str) -> SyntheticTextDatasetConfig | None:
|
|
308
|
+
if not isinstance(data, str):
|
|
309
|
+
return None
|
|
310
|
+
|
|
311
|
+
data_str = data.strip()
|
|
312
|
+
error = None
|
|
313
|
+
|
|
314
|
+
if (data_str.startswith("{") and data_str.endswith("}")) or (
|
|
315
|
+
data_str.startswith("[") and data_str.endswith("]")
|
|
316
|
+
):
|
|
317
|
+
try:
|
|
318
|
+
return SyntheticTextDatasetConfig.model_validate_json(data_str)
|
|
319
|
+
except Exception as err: # noqa: BLE001
|
|
320
|
+
error = err
|
|
321
|
+
|
|
322
|
+
if data_str.count("=") > 1:
|
|
323
|
+
# key=value pairs separated by commas
|
|
324
|
+
try:
|
|
325
|
+
config_dict = {}
|
|
326
|
+
items = data_str.split(",")
|
|
327
|
+
for item in items:
|
|
328
|
+
key, value = item.split("=")
|
|
329
|
+
config_dict[key.strip()] = (
|
|
330
|
+
int(value.strip())
|
|
331
|
+
if value.strip().isnumeric()
|
|
332
|
+
else value.strip()
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
return SyntheticTextDatasetConfig.model_validate(config_dict)
|
|
336
|
+
except Exception as err: # noqa: BLE001
|
|
337
|
+
error = err
|
|
338
|
+
|
|
339
|
+
err_message = (
|
|
340
|
+
"Unsupported string data for SyntheticTextDatasetDeserializer, "
|
|
341
|
+
f"expected JSON or key-value pairs, got {data}"
|
|
342
|
+
)
|
|
343
|
+
if error is not None:
|
|
344
|
+
err_message += f" with error: {error}"
|
|
345
|
+
raise DataNotSupportedError(err_message) from error
|
|
346
|
+
raise DataNotSupportedError(err_message)
|
guidellm/data/loaders.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
from collections.abc import Callable, Iterator
|
|
5
|
+
from typing import Any, Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch.utils.data import Sampler
|
|
9
|
+
from torch.utils.data.dataloader import DataLoader as PyTorchDataLoader
|
|
10
|
+
from torch.utils.data.dataset import IterableDataset as TorchIterableDataset
|
|
11
|
+
from transformers import PreTrainedTokenizerBase
|
|
12
|
+
|
|
13
|
+
from guidellm.data.deserializers import DatasetDeserializerFactory
|
|
14
|
+
from guidellm.data.preprocessors import DataDependentPreprocessor, DatasetPreprocessor
|
|
15
|
+
from guidellm.logger import logger
|
|
16
|
+
|
|
17
|
+
__all__ = ["DataLoader", "DatasetsIterator"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class DatasetsIterator(TorchIterableDataset):
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
data: list[Any],
|
|
24
|
+
data_args: list[dict[str, Any]] | None,
|
|
25
|
+
data_samples: int,
|
|
26
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
27
|
+
preprocessors: list[DatasetPreprocessor | DataDependentPreprocessor],
|
|
28
|
+
random_seed: int,
|
|
29
|
+
):
|
|
30
|
+
if not data or not isinstance(data, list):
|
|
31
|
+
raise ValueError(f"Data must be a non-empty list, got {data}.")
|
|
32
|
+
|
|
33
|
+
if not data_args:
|
|
34
|
+
data_args = [{} for _ in data]
|
|
35
|
+
|
|
36
|
+
if len(data) != len(data_args):
|
|
37
|
+
raise ValueError(
|
|
38
|
+
f"Length of data ({len(data)}) must match length of data_args "
|
|
39
|
+
f"({len(data_args)})."
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
self.datasets = []
|
|
43
|
+
for datum, data_kwargs in zip(data, data_args, strict=False):
|
|
44
|
+
self.datasets.append(
|
|
45
|
+
DatasetDeserializerFactory.deserialize(
|
|
46
|
+
data=datum,
|
|
47
|
+
processor_factory=processor_factory,
|
|
48
|
+
random_seed=random_seed,
|
|
49
|
+
**data_kwargs,
|
|
50
|
+
)
|
|
51
|
+
)
|
|
52
|
+
self.preprocessors = preprocessors
|
|
53
|
+
for preprocessor in self.preprocessors:
|
|
54
|
+
if isinstance(preprocessor, DataDependentPreprocessor):
|
|
55
|
+
preprocessor.setup_data(
|
|
56
|
+
datasets=self.datasets,
|
|
57
|
+
data_args=data_args,
|
|
58
|
+
)
|
|
59
|
+
self.precache: list[Any] | None = (
|
|
60
|
+
list(self.generator(data_samples)) if data_samples else None
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def __iter__(self):
|
|
64
|
+
worker_info = torch.utils.data.get_worker_info()
|
|
65
|
+
worker_modulus = worker_info.num_workers if worker_info is not None else 1
|
|
66
|
+
worker_index = worker_info.id if worker_info is not None else 0
|
|
67
|
+
|
|
68
|
+
if self.precache:
|
|
69
|
+
for index, item in enumerate(self.precache):
|
|
70
|
+
if (index + worker_index) % worker_modulus == 0:
|
|
71
|
+
yield item
|
|
72
|
+
else:
|
|
73
|
+
yield from self.generator(modulus=worker_modulus, offset=worker_index)
|
|
74
|
+
|
|
75
|
+
def generator(
|
|
76
|
+
self,
|
|
77
|
+
max_items: int | None = None,
|
|
78
|
+
modulus: int | None = None,
|
|
79
|
+
offset: int | None = None,
|
|
80
|
+
) -> Iterator[Any]:
|
|
81
|
+
gen_count = 0
|
|
82
|
+
|
|
83
|
+
with contextlib.suppress(StopIteration):
|
|
84
|
+
dataset_iters = [iter(dataset) for dataset in self.datasets]
|
|
85
|
+
|
|
86
|
+
while max_items is None or gen_count < max_items:
|
|
87
|
+
try:
|
|
88
|
+
row = {
|
|
89
|
+
"items": [next(dataset_iter) for dataset_iter in dataset_iters]
|
|
90
|
+
}
|
|
91
|
+
gen_count += 1
|
|
92
|
+
|
|
93
|
+
if (
|
|
94
|
+
modulus is not None
|
|
95
|
+
and offset is not None
|
|
96
|
+
and (gen_count % modulus) != offset
|
|
97
|
+
):
|
|
98
|
+
continue
|
|
99
|
+
|
|
100
|
+
for preprocessor in self.preprocessors:
|
|
101
|
+
row = preprocessor(row)
|
|
102
|
+
yield row
|
|
103
|
+
except Exception as err:
|
|
104
|
+
logger.error(f"Skipping data row due to error: {err}")
|
|
105
|
+
gen_count -= 1
|
|
106
|
+
|
|
107
|
+
if max_items is not None and gen_count < max_items:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
f"Requested {max_items} samples, but only {gen_count} "
|
|
110
|
+
"available from the provided datasets."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class DataLoader(PyTorchDataLoader):
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
data: list[Any],
|
|
118
|
+
data_args: list[dict[str, Any]] | None,
|
|
119
|
+
data_samples: int,
|
|
120
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
121
|
+
preprocessors: list[DatasetPreprocessor | DataDependentPreprocessor],
|
|
122
|
+
collator: Callable,
|
|
123
|
+
sampler: Sampler[int] | Literal["shuffle"] | None = None,
|
|
124
|
+
num_workers: int | None = 1,
|
|
125
|
+
random_seed: int = 42,
|
|
126
|
+
**kwargs: Any,
|
|
127
|
+
):
|
|
128
|
+
iterator = DatasetsIterator(
|
|
129
|
+
data=data,
|
|
130
|
+
data_args=data_args,
|
|
131
|
+
data_samples=data_samples,
|
|
132
|
+
processor_factory=processor_factory,
|
|
133
|
+
preprocessors=preprocessors,
|
|
134
|
+
random_seed=random_seed,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
super().__init__(
|
|
138
|
+
dataset=iterator,
|
|
139
|
+
batch_size=1,
|
|
140
|
+
shuffle=sampler == "shuffle",
|
|
141
|
+
sampler=sampler if sampler != "shuffle" else None,
|
|
142
|
+
collate_fn=collator,
|
|
143
|
+
num_workers=num_workers or 0,
|
|
144
|
+
**kwargs,
|
|
145
|
+
)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from .formatters import (
|
|
2
|
+
GenerativeAudioTranscriptionRequestFormatter,
|
|
3
|
+
GenerativeAudioTranslationRequestFormatter,
|
|
4
|
+
GenerativeChatCompletionsRequestFormatter,
|
|
5
|
+
GenerativeTextCompletionsRequestFormatter,
|
|
6
|
+
)
|
|
7
|
+
from .mappers import GenerativeColumnMapper
|
|
8
|
+
from .preprocessor import (
|
|
9
|
+
DataDependentPreprocessor,
|
|
10
|
+
DatasetPreprocessor,
|
|
11
|
+
PreprocessorRegistry,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"ColumnMapper",
|
|
16
|
+
"ColumnMapperRegistry",
|
|
17
|
+
"DataDependentPreprocessor",
|
|
18
|
+
"DatasetPreprocessor",
|
|
19
|
+
"GenerativeAudioTranscriptionRequestFormatter",
|
|
20
|
+
"GenerativeAudioTranslationRequestFormatter",
|
|
21
|
+
"GenerativeChatCompletionsRequestFormatter",
|
|
22
|
+
"GenerativeColumnMapper",
|
|
23
|
+
"GenerativeTextCompletionsRequestFormatter",
|
|
24
|
+
"PreprocessorRegistry",
|
|
25
|
+
]
|