guidellm 0.4.0a155__py3-none-any.whl → 0.4.0a173__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guidellm might be problematic. Click here for more details.
- guidellm/__main__.py +4 -3
- guidellm/benchmark/benchmarker.py +2 -0
- guidellm/benchmark/entrypoints.py +1 -0
- guidellm/benchmark/output.py +3 -1
- guidellm/benchmark/schemas.py +2 -1
- guidellm/data/deserializers/deserializer.py +79 -44
- guidellm/data/deserializers/file.py +14 -14
- guidellm/data/deserializers/huggingface.py +1 -1
- guidellm/data/deserializers/memory.py +20 -18
- guidellm/data/deserializers/synthetic.py +18 -16
- guidellm/data/loaders.py +7 -3
- guidellm/data/preprocessors/formatters.py +24 -32
- guidellm/data/preprocessors/mappers.py +2 -2
- guidellm/data/preprocessors/preprocessor.py +5 -3
- guidellm/data/processor.py +3 -2
- guidellm/data/utils/__init__.py +0 -4
- guidellm/data/utils/dataset.py +2 -2
- guidellm/scheduler/constraints.py +1 -3
- guidellm/scheduler/environments.py +2 -2
- guidellm/scheduler/scheduler.py +1 -1
- guidellm/scheduler/strategies.py +31 -4
- guidellm/scheduler/worker.py +56 -30
- guidellm/scheduler/worker_group.py +33 -31
- guidellm/schemas/request.py +10 -0
- guidellm/utils/cli.py +26 -1
- {guidellm-0.4.0a155.dist-info → guidellm-0.4.0a173.dist-info}/METADATA +1 -1
- {guidellm-0.4.0a155.dist-info → guidellm-0.4.0a173.dist-info}/RECORD +31 -32
- guidellm/data/utils/functions.py +0 -18
- {guidellm-0.4.0a155.dist-info → guidellm-0.4.0a173.dist-info}/WHEEL +0 -0
- {guidellm-0.4.0a155.dist-info → guidellm-0.4.0a173.dist-info}/entry_points.txt +0 -0
- {guidellm-0.4.0a155.dist-info → guidellm-0.4.0a173.dist-info}/licenses/LICENSE +0 -0
- {guidellm-0.4.0a155.dist-info → guidellm-0.4.0a173.dist-info}/top_level.txt +0 -0
guidellm/__main__.py
CHANGED
|
@@ -156,8 +156,9 @@ def benchmark():
|
|
|
156
156
|
)
|
|
157
157
|
@click.option(
|
|
158
158
|
"--rate",
|
|
159
|
-
type=
|
|
160
|
-
|
|
159
|
+
type=str,
|
|
160
|
+
callback=cli_tools.parse_list_floats,
|
|
161
|
+
multiple=False,
|
|
161
162
|
default=BenchmarkGenerativeTextArgs.get_default("rate"),
|
|
162
163
|
help=(
|
|
163
164
|
"Benchmark rate(s) to test. Meaning depends on profile: "
|
|
@@ -383,7 +384,7 @@ def run(**kwargs):
|
|
|
383
384
|
kwargs.get("data_args"), default=[], simplify_single=False
|
|
384
385
|
)
|
|
385
386
|
kwargs["rate"] = cli_tools.format_list_arg(
|
|
386
|
-
kwargs.get("rate"), default=None, simplify_single=
|
|
387
|
+
kwargs.get("rate"), default=None, simplify_single=False
|
|
387
388
|
)
|
|
388
389
|
|
|
389
390
|
disable_console_outputs = kwargs.pop("disable_console_outputs", False)
|
|
@@ -57,6 +57,7 @@ class Benchmarker(
|
|
|
57
57
|
backend: BackendInterface[RequestT, ResponseT],
|
|
58
58
|
profile: Profile,
|
|
59
59
|
environment: Environment,
|
|
60
|
+
data: list[Any],
|
|
60
61
|
progress: BenchmarkerProgress[BenchmarkT] | None = None,
|
|
61
62
|
sample_requests: int | None = 20,
|
|
62
63
|
warmup: float | None = None,
|
|
@@ -149,6 +150,7 @@ class Benchmarker(
|
|
|
149
150
|
environment=environment,
|
|
150
151
|
strategy=strategy,
|
|
151
152
|
constraints=constraints,
|
|
153
|
+
data=data,
|
|
152
154
|
)
|
|
153
155
|
if progress:
|
|
154
156
|
await progress.on_benchmark_complete(benchmark)
|
guidellm/benchmark/output.py
CHANGED
|
@@ -649,6 +649,8 @@ class GenerativeBenchmarkerCSV(GenerativeBenchmarkerOutput):
|
|
|
649
649
|
status_dist_summary: StatusDistributionSummary = getattr(
|
|
650
650
|
benchmark.metrics, metric
|
|
651
651
|
)
|
|
652
|
+
if not hasattr(status_dist_summary, status):
|
|
653
|
+
return [], []
|
|
652
654
|
dist_summary: DistributionSummary = getattr(status_dist_summary, status)
|
|
653
655
|
|
|
654
656
|
headers = [
|
|
@@ -688,7 +690,7 @@ class GenerativeBenchmarkerCSV(GenerativeBenchmarkerOutput):
|
|
|
688
690
|
values: list[str] = [
|
|
689
691
|
benchmark.benchmarker.profile.model_dump_json(),
|
|
690
692
|
json.dumps(benchmark.benchmarker.backend),
|
|
691
|
-
json.dumps(benchmark.benchmarker.requests["
|
|
693
|
+
json.dumps(benchmark.benchmarker.requests["data"]),
|
|
692
694
|
]
|
|
693
695
|
|
|
694
696
|
if len(headers) != len(values):
|
guidellm/benchmark/schemas.py
CHANGED
|
@@ -1674,6 +1674,7 @@ class GenerativeBenchmark(Benchmark, StandardBaseDict):
|
|
|
1674
1674
|
environment: Environment,
|
|
1675
1675
|
strategy: SchedulingStrategy,
|
|
1676
1676
|
constraints: dict[str, dict[str, Any]],
|
|
1677
|
+
data: list[Any],
|
|
1677
1678
|
) -> GenerativeBenchmark:
|
|
1678
1679
|
"""
|
|
1679
1680
|
Compile final generative benchmark from accumulated state.
|
|
@@ -1702,7 +1703,7 @@ class GenerativeBenchmark(Benchmark, StandardBaseDict):
|
|
|
1702
1703
|
),
|
|
1703
1704
|
benchmarker=BenchmarkerDict(
|
|
1704
1705
|
profile=profile,
|
|
1705
|
-
requests=
|
|
1706
|
+
requests={"data": data},
|
|
1706
1707
|
backend=backend.info,
|
|
1707
1708
|
environment=environment.info,
|
|
1708
1709
|
),
|
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import contextlib
|
|
4
3
|
from collections.abc import Callable
|
|
5
4
|
from typing import Any, Protocol, Union, runtime_checkable
|
|
6
5
|
|
|
7
|
-
from datasets import Dataset, IterableDataset
|
|
6
|
+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
|
8
7
|
from transformers import PreTrainedTokenizerBase
|
|
9
8
|
|
|
10
9
|
from guidellm.data.utils import resolve_dataset_split
|
|
@@ -29,7 +28,7 @@ class DatasetDeserializer(Protocol):
|
|
|
29
28
|
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
30
29
|
random_seed: int,
|
|
31
30
|
**data_kwargs: dict[str, Any],
|
|
32
|
-
) ->
|
|
31
|
+
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: ...
|
|
33
32
|
|
|
34
33
|
|
|
35
34
|
class DatasetDeserializerFactory(
|
|
@@ -47,51 +46,16 @@ class DatasetDeserializerFactory(
|
|
|
47
46
|
remove_columns: list[str] | None = None,
|
|
48
47
|
**data_kwargs: dict[str, Any],
|
|
49
48
|
) -> Dataset | IterableDataset:
|
|
50
|
-
dataset
|
|
49
|
+
dataset: Dataset
|
|
51
50
|
|
|
52
51
|
if type_ is None:
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
# must be mutually exclusive to ensure deterministic behavior.
|
|
56
|
-
for name, deserializer in cls.registry.items():
|
|
57
|
-
deserializer_fn: DatasetDeserializer = (
|
|
58
|
-
deserializer() if isinstance(deserializer, type) else deserializer
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
try:
|
|
62
|
-
with contextlib.suppress(DataNotSupportedError):
|
|
63
|
-
dataset = deserializer_fn(
|
|
64
|
-
data=data,
|
|
65
|
-
processor_factory=processor_factory,
|
|
66
|
-
random_seed=random_seed,
|
|
67
|
-
**data_kwargs,
|
|
68
|
-
)
|
|
69
|
-
except Exception as e:
|
|
70
|
-
errors.append(e)
|
|
71
|
-
|
|
72
|
-
if dataset is not None:
|
|
73
|
-
break # Found one that works. Continuing could overwrite it.
|
|
74
|
-
|
|
75
|
-
if dataset is None and len(errors) > 0:
|
|
76
|
-
raise DataNotSupportedError(f"data deserialization failed; {len(errors)} errors occurred while "
|
|
77
|
-
f"attempting to deserialize data {data}: {errors}")
|
|
78
|
-
|
|
79
|
-
elif deserializer := cls.get_registered_object(type_) is not None:
|
|
80
|
-
deserializer_fn: DatasetDeserializer = (
|
|
81
|
-
deserializer() if isinstance(deserializer, type) else deserializer
|
|
52
|
+
dataset = cls._deserialize_with_registered_deserializers(
|
|
53
|
+
data, processor_factory, random_seed, **data_kwargs
|
|
82
54
|
)
|
|
83
55
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
processor_factory
|
|
87
|
-
random_seed=random_seed,
|
|
88
|
-
**data_kwargs,
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
if dataset is None:
|
|
92
|
-
raise DataNotSupportedError(
|
|
93
|
-
f"No suitable deserializer found for data {data} "
|
|
94
|
-
f"with kwargs {data_kwargs} and deserializer type {type_}."
|
|
56
|
+
else:
|
|
57
|
+
dataset = cls._deserialize_with_specified_deserializer(
|
|
58
|
+
data, type_, processor_factory, random_seed, **data_kwargs
|
|
95
59
|
)
|
|
96
60
|
|
|
97
61
|
if resolve_split:
|
|
@@ -107,3 +71,74 @@ class DatasetDeserializerFactory(
|
|
|
107
71
|
dataset = dataset.remove_columns(remove_columns)
|
|
108
72
|
|
|
109
73
|
return dataset
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def _deserialize_with_registered_deserializers(
|
|
77
|
+
cls,
|
|
78
|
+
data: Any,
|
|
79
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
80
|
+
random_seed: int = 42,
|
|
81
|
+
**data_kwargs: dict[str, Any],
|
|
82
|
+
) -> Dataset:
|
|
83
|
+
if cls.registry is None:
|
|
84
|
+
raise RuntimeError("registry is None; cannot deserialize dataset")
|
|
85
|
+
dataset: Dataset | None = None
|
|
86
|
+
|
|
87
|
+
errors: dict[str, Exception] = {}
|
|
88
|
+
# Note: There is no priority order for the deserializers, so all deserializers
|
|
89
|
+
# must be mutually exclusive to ensure deterministic behavior.
|
|
90
|
+
for _name, deserializer in cls.registry.items():
|
|
91
|
+
deserializer_fn: DatasetDeserializer = (
|
|
92
|
+
deserializer() if isinstance(deserializer, type) else deserializer
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
dataset = deserializer_fn(
|
|
97
|
+
data=data,
|
|
98
|
+
processor_factory=processor_factory,
|
|
99
|
+
random_seed=random_seed,
|
|
100
|
+
**data_kwargs,
|
|
101
|
+
)
|
|
102
|
+
except Exception as e: # noqa: BLE001 # The exceptions are saved.
|
|
103
|
+
errors[_name] = e
|
|
104
|
+
|
|
105
|
+
if dataset is not None:
|
|
106
|
+
return dataset # Success
|
|
107
|
+
|
|
108
|
+
if len(errors) > 0:
|
|
109
|
+
err_msgs = ""
|
|
110
|
+
def sort_key(item):
|
|
111
|
+
return (isinstance(item[1], DataNotSupportedError), item[0])
|
|
112
|
+
for key, err in sorted(errors.items(), key=sort_key):
|
|
113
|
+
err_msgs += f"\n - Deserializer '{key}': ({type(err).__name__}) {err}"
|
|
114
|
+
raise ValueError(
|
|
115
|
+
"Data deserialization failed, likely because the input doesn't "
|
|
116
|
+
f"match any of the input formats. See the {len(errors)} error(s) that "
|
|
117
|
+
f"occurred while attempting to deserialize the data {data}:{err_msgs}"
|
|
118
|
+
)
|
|
119
|
+
return dataset
|
|
120
|
+
|
|
121
|
+
@classmethod
|
|
122
|
+
def _deserialize_with_specified_deserializer(
|
|
123
|
+
cls,
|
|
124
|
+
data: Any,
|
|
125
|
+
type_: str,
|
|
126
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
127
|
+
random_seed: int = 42,
|
|
128
|
+
**data_kwargs: dict[str, Any],
|
|
129
|
+
) -> Dataset:
|
|
130
|
+
deserializer_from_type = cls.get_registered_object(type_)
|
|
131
|
+
if deserializer_from_type is None:
|
|
132
|
+
raise ValueError(f"Deserializer type '{type_}' is not registered.")
|
|
133
|
+
if isinstance(deserializer_from_type, type):
|
|
134
|
+
deserializer_fn = deserializer_from_type()
|
|
135
|
+
else:
|
|
136
|
+
deserializer_fn = deserializer_from_type
|
|
137
|
+
|
|
138
|
+
return deserializer_fn(
|
|
139
|
+
data=data,
|
|
140
|
+
processor_factory=processor_factory,
|
|
141
|
+
random_seed=random_seed,
|
|
142
|
+
**data_kwargs,
|
|
143
|
+
)
|
|
144
|
+
|
|
@@ -34,11 +34,11 @@ class TextFileDatasetDeserializer(DatasetDeserializer):
|
|
|
34
34
|
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
35
35
|
random_seed: int,
|
|
36
36
|
**data_kwargs: dict[str, Any],
|
|
37
|
-
) ->
|
|
37
|
+
) -> Dataset:
|
|
38
38
|
_ = (processor_factory, random_seed) # Ignore unused args format errors
|
|
39
39
|
|
|
40
40
|
if (
|
|
41
|
-
not isinstance(data,
|
|
41
|
+
not isinstance(data, str | Path)
|
|
42
42
|
or not (path := Path(data)).exists()
|
|
43
43
|
or not path.is_file()
|
|
44
44
|
or path.suffix.lower() not in {".txt", ".text"}
|
|
@@ -62,10 +62,10 @@ class CSVFileDatasetDeserializer(DatasetDeserializer):
|
|
|
62
62
|
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
63
63
|
random_seed: int,
|
|
64
64
|
**data_kwargs: dict[str, Any],
|
|
65
|
-
) ->
|
|
65
|
+
) -> Dataset:
|
|
66
66
|
_ = (processor_factory, random_seed)
|
|
67
67
|
if (
|
|
68
|
-
not isinstance(data,
|
|
68
|
+
not isinstance(data, str | Path)
|
|
69
69
|
or not (path := Path(data)).exists()
|
|
70
70
|
or not path.is_file()
|
|
71
71
|
or path.suffix.lower() != ".csv"
|
|
@@ -86,10 +86,10 @@ class JSONFileDatasetDeserializer(DatasetDeserializer):
|
|
|
86
86
|
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
87
87
|
random_seed: int,
|
|
88
88
|
**data_kwargs: dict[str, Any],
|
|
89
|
-
) ->
|
|
89
|
+
) -> Dataset:
|
|
90
90
|
_ = (processor_factory, random_seed)
|
|
91
91
|
if (
|
|
92
|
-
not isinstance(data,
|
|
92
|
+
not isinstance(data, str | Path)
|
|
93
93
|
or not (path := Path(data)).exists()
|
|
94
94
|
or not path.is_file()
|
|
95
95
|
or path.suffix.lower() not in {".json", ".jsonl"}
|
|
@@ -110,10 +110,10 @@ class ParquetFileDatasetDeserializer(DatasetDeserializer):
|
|
|
110
110
|
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
111
111
|
random_seed: int,
|
|
112
112
|
**data_kwargs: dict[str, Any],
|
|
113
|
-
) ->
|
|
113
|
+
) -> Dataset:
|
|
114
114
|
_ = (processor_factory, random_seed)
|
|
115
115
|
if (
|
|
116
|
-
not isinstance(data,
|
|
116
|
+
not isinstance(data, str | Path)
|
|
117
117
|
or not (path := Path(data)).exists()
|
|
118
118
|
or not path.is_file()
|
|
119
119
|
or path.suffix.lower() != ".parquet"
|
|
@@ -134,10 +134,10 @@ class ArrowFileDatasetDeserializer(DatasetDeserializer):
|
|
|
134
134
|
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
135
135
|
random_seed: int,
|
|
136
136
|
**data_kwargs: dict[str, Any],
|
|
137
|
-
) ->
|
|
137
|
+
) -> Dataset:
|
|
138
138
|
_ = (processor_factory, random_seed)
|
|
139
139
|
if (
|
|
140
|
-
not isinstance(data,
|
|
140
|
+
not isinstance(data, str | Path)
|
|
141
141
|
or not (path := Path(data)).exists()
|
|
142
142
|
or not path.is_file()
|
|
143
143
|
or path.suffix.lower() != ".arrow"
|
|
@@ -158,10 +158,10 @@ class HDF5FileDatasetDeserializer(DatasetDeserializer):
|
|
|
158
158
|
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
159
159
|
random_seed: int,
|
|
160
160
|
**data_kwargs: dict[str, Any],
|
|
161
|
-
) ->
|
|
161
|
+
) -> Dataset:
|
|
162
162
|
_ = (processor_factory, random_seed)
|
|
163
163
|
if (
|
|
164
|
-
not isinstance(data,
|
|
164
|
+
not isinstance(data, str | Path)
|
|
165
165
|
or not (path := Path(data)).exists()
|
|
166
166
|
or not path.is_file()
|
|
167
167
|
or path.suffix.lower() not in {".hdf5", ".h5"}
|
|
@@ -185,7 +185,7 @@ class DBFileDatasetDeserializer(DatasetDeserializer):
|
|
|
185
185
|
) -> dict[str, list]:
|
|
186
186
|
_ = (processor_factory, random_seed)
|
|
187
187
|
if (
|
|
188
|
-
not isinstance(data,
|
|
188
|
+
not isinstance(data, str | Path)
|
|
189
189
|
or not (path := Path(data)).exists()
|
|
190
190
|
or not path.is_file()
|
|
191
191
|
or path.suffix.lower() != ".db"
|
|
@@ -209,7 +209,7 @@ class TarFileDatasetDeserializer(DatasetDeserializer):
|
|
|
209
209
|
) -> dict[str, list]:
|
|
210
210
|
_ = (processor_factory, random_seed)
|
|
211
211
|
if (
|
|
212
|
-
not isinstance(data,
|
|
212
|
+
not isinstance(data, str | Path)
|
|
213
213
|
or not (path := Path(data)).exists()
|
|
214
214
|
or not path.is_file()
|
|
215
215
|
or path.suffix.lower() != ".tar"
|
|
@@ -36,7 +36,7 @@ class HuggingFaceDatasetDeserializer(DatasetDeserializer):
|
|
|
36
36
|
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
37
37
|
random_seed: int,
|
|
38
38
|
**data_kwargs: dict[str, Any],
|
|
39
|
-
) ->
|
|
39
|
+
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
|
40
40
|
_ = (processor_factory, random_seed)
|
|
41
41
|
|
|
42
42
|
if isinstance(
|
|
@@ -33,7 +33,7 @@ class InMemoryDictDatasetDeserializer(DatasetDeserializer):
|
|
|
33
33
|
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
34
34
|
random_seed: int,
|
|
35
35
|
**data_kwargs: dict[str, Any],
|
|
36
|
-
) ->
|
|
36
|
+
) -> Dataset:
|
|
37
37
|
_ = (processor_factory, random_seed) # Ignore unused args format errors
|
|
38
38
|
|
|
39
39
|
if (
|
|
@@ -67,7 +67,7 @@ class InMemoryDictListDatasetDeserializer(DatasetDeserializer):
|
|
|
67
67
|
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
68
68
|
random_seed: int,
|
|
69
69
|
**data_kwargs: dict[str, Any],
|
|
70
|
-
) ->
|
|
70
|
+
) -> Dataset:
|
|
71
71
|
_ = (processor_factory, random_seed) # Ignore unused args format errors
|
|
72
72
|
|
|
73
73
|
if (
|
|
@@ -81,9 +81,9 @@ class InMemoryDictListDatasetDeserializer(DatasetDeserializer):
|
|
|
81
81
|
f"expected list of dicts, got {data}"
|
|
82
82
|
)
|
|
83
83
|
|
|
84
|
-
|
|
85
|
-
first_keys = set(
|
|
86
|
-
for index, item in enumerate(
|
|
84
|
+
typed_data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data)
|
|
85
|
+
first_keys = set(typed_data[0].keys())
|
|
86
|
+
for index, item in enumerate(typed_data):
|
|
87
87
|
if set(item.keys()) != first_keys:
|
|
88
88
|
raise DataNotSupportedError(
|
|
89
89
|
f"All dictionaries must have the same keys. "
|
|
@@ -92,8 +92,8 @@ class InMemoryDictListDatasetDeserializer(DatasetDeserializer):
|
|
|
92
92
|
)
|
|
93
93
|
|
|
94
94
|
# Convert list of dicts to dict of lists
|
|
95
|
-
result_dict = {key: [] for key in first_keys}
|
|
96
|
-
for item in
|
|
95
|
+
result_dict: dict = {key: [] for key in first_keys}
|
|
96
|
+
for item in typed_data:
|
|
97
97
|
for key, value in item.items():
|
|
98
98
|
result_dict[key].append(value)
|
|
99
99
|
|
|
@@ -108,7 +108,7 @@ class InMemoryItemListDatasetDeserializer(DatasetDeserializer):
|
|
|
108
108
|
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
109
109
|
random_seed: int,
|
|
110
110
|
**data_kwargs: dict[str, Any],
|
|
111
|
-
) ->
|
|
111
|
+
) -> Dataset:
|
|
112
112
|
_ = (processor_factory, random_seed) # Ignore unused args format errors
|
|
113
113
|
|
|
114
114
|
primitive_types = (str, int, float, bool, type(None))
|
|
@@ -135,7 +135,7 @@ class InMemoryJsonStrDatasetDeserializer(DatasetDeserializer):
|
|
|
135
135
|
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
136
136
|
random_seed: int,
|
|
137
137
|
**data_kwargs: dict[str, Any],
|
|
138
|
-
) ->
|
|
138
|
+
) -> Dataset:
|
|
139
139
|
if (
|
|
140
140
|
isinstance(data, str)
|
|
141
141
|
and (json_str := data.strip())
|
|
@@ -145,16 +145,18 @@ class InMemoryJsonStrDatasetDeserializer(DatasetDeserializer):
|
|
|
145
145
|
)
|
|
146
146
|
):
|
|
147
147
|
with contextlib.suppress(Exception):
|
|
148
|
-
|
|
148
|
+
parsed_data = json.loads(data)
|
|
149
149
|
|
|
150
|
-
|
|
151
|
-
InMemoryDictDatasetDeserializer,
|
|
152
|
-
InMemoryDictListDatasetDeserializer,
|
|
153
|
-
InMemoryItemListDatasetDeserializer,
|
|
154
|
-
]
|
|
150
|
+
deserializers = [
|
|
151
|
+
InMemoryDictDatasetDeserializer(),
|
|
152
|
+
InMemoryDictListDatasetDeserializer(),
|
|
153
|
+
InMemoryItemListDatasetDeserializer(),
|
|
154
|
+
]
|
|
155
|
+
|
|
156
|
+
for deserializer in deserializers:
|
|
155
157
|
with contextlib.suppress(DataNotSupportedError):
|
|
156
|
-
return deserializer(
|
|
157
|
-
|
|
158
|
+
return deserializer(
|
|
159
|
+
parsed_data, processor_factory, random_seed, **data_kwargs
|
|
158
160
|
)
|
|
159
161
|
|
|
160
162
|
raise DataNotSupportedError(
|
|
@@ -171,7 +173,7 @@ class InMemoryCsvDatasetDeserializer(DatasetDeserializer):
|
|
|
171
173
|
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
172
174
|
random_seed: int,
|
|
173
175
|
**data_kwargs: dict[str, Any],
|
|
174
|
-
) ->
|
|
176
|
+
) -> Dataset:
|
|
175
177
|
if (
|
|
176
178
|
isinstance(data, str)
|
|
177
179
|
and (csv_str := data.strip())
|
|
@@ -99,21 +99,23 @@ class SyntheticTextDatasetConfig(StandardBaseModel):
|
|
|
99
99
|
|
|
100
100
|
@model_validator(mode="after")
|
|
101
101
|
def check_prefix_options(self) -> SyntheticTextDatasetConfig:
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
102
|
+
if self.__pydantic_extra__ is not None:
|
|
103
|
+
prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined]
|
|
104
|
+
prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined]
|
|
105
|
+
|
|
106
|
+
if prefix_count is not None or prefix_tokens is not None:
|
|
107
|
+
if self.prefix_buckets:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
"prefix_buckets is mutually exclusive"
|
|
110
|
+
" with prefix_count and prefix_tokens"
|
|
111
|
+
)
|
|
110
112
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
113
|
+
self.prefix_buckets = [
|
|
114
|
+
SyntheticTextPrefixBucketConfig(
|
|
115
|
+
prefix_count=prefix_count or 1,
|
|
116
|
+
prefix_tokens=prefix_tokens or 0,
|
|
117
|
+
)
|
|
118
|
+
]
|
|
117
119
|
|
|
118
120
|
return self
|
|
119
121
|
|
|
@@ -174,14 +176,14 @@ class SyntheticTextGenerator:
|
|
|
174
176
|
def _create_prompt(
|
|
175
177
|
self, prompt_tokens_count: int, faker: Faker, unique: str = ""
|
|
176
178
|
) -> str:
|
|
177
|
-
prompt_token_ids = []
|
|
179
|
+
prompt_token_ids: list[int] = []
|
|
178
180
|
avg_chars_per_token = 5
|
|
179
181
|
margin_of_safety = 1.5
|
|
180
182
|
attempts = 0
|
|
181
183
|
|
|
182
184
|
while len(prompt_token_ids) < prompt_tokens_count:
|
|
183
185
|
attempts += 1
|
|
184
|
-
num_chars = (
|
|
186
|
+
num_chars = int(
|
|
185
187
|
prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts
|
|
186
188
|
)
|
|
187
189
|
text = unique + faker.text(max_nb_chars=num_chars)
|
guidellm/data/loaders.py
CHANGED
|
@@ -17,6 +17,7 @@ from guidellm.logger import logger
|
|
|
17
17
|
__all__ = ["DataLoader", "DatasetsIterator"]
|
|
18
18
|
|
|
19
19
|
|
|
20
|
+
|
|
20
21
|
class DatasetsIterator(TorchIterableDataset):
|
|
21
22
|
def __init__(
|
|
22
23
|
self,
|
|
@@ -85,7 +86,7 @@ class DatasetsIterator(TorchIterableDataset):
|
|
|
85
86
|
|
|
86
87
|
while max_items is None or gen_count < max_items:
|
|
87
88
|
try:
|
|
88
|
-
row = {
|
|
89
|
+
row: dict[str, Any] = {
|
|
89
90
|
"items": [next(dataset_iter) for dataset_iter in dataset_iters]
|
|
90
91
|
}
|
|
91
92
|
gen_count += 1
|
|
@@ -98,9 +99,12 @@ class DatasetsIterator(TorchIterableDataset):
|
|
|
98
99
|
continue
|
|
99
100
|
|
|
100
101
|
for preprocessor in self.preprocessors:
|
|
101
|
-
|
|
102
|
+
# This can assign a GenerationRequest, which would then be
|
|
103
|
+
# passed into the preprocessor, which is a type violation.
|
|
104
|
+
# This should be fixed at some point.
|
|
105
|
+
row = preprocessor(row) # type: ignore[assignment]
|
|
102
106
|
yield row
|
|
103
|
-
except Exception as err:
|
|
107
|
+
except Exception as err: # noqa: BLE001 # Exception logged
|
|
104
108
|
logger.error(f"Skipping data row due to error: {err}")
|
|
105
109
|
gen_count -= 1
|
|
106
110
|
|
|
@@ -7,8 +7,6 @@ from guidellm.data.preprocessors.preprocessor import (
|
|
|
7
7
|
DatasetPreprocessor,
|
|
8
8
|
PreprocessorRegistry,
|
|
9
9
|
)
|
|
10
|
-
from guidellm.data.schemas import GenerativeDatasetColumnType
|
|
11
|
-
from guidellm.data.utils import text_stats
|
|
12
10
|
from guidellm.schemas import GenerationRequest, GenerationRequestArguments, UsageMetrics
|
|
13
11
|
|
|
14
12
|
__all__ = [
|
|
@@ -59,9 +57,13 @@ class GenerativeTextCompletionsRequestFormatter(RequestFormatter):
|
|
|
59
57
|
self.max_tokens: int | None = max_tokens or max_completion_tokens
|
|
60
58
|
|
|
61
59
|
def __call__(
|
|
62
|
-
self, columns: dict[
|
|
60
|
+
self, columns: dict[str, list[Any]]
|
|
63
61
|
) -> GenerationRequest:
|
|
64
|
-
|
|
62
|
+
"""
|
|
63
|
+
:param columns: A dict of GenerativeDatasetColumnType to Any
|
|
64
|
+
"""
|
|
65
|
+
arguments: GenerationRequestArguments = GenerationRequestArguments()
|
|
66
|
+
arguments.body = {} # The type checker works better setting this field here
|
|
65
67
|
input_metrics = UsageMetrics()
|
|
66
68
|
output_metrics = UsageMetrics()
|
|
67
69
|
|
|
@@ -99,10 +101,9 @@ class GenerativeTextCompletionsRequestFormatter(RequestFormatter):
|
|
|
99
101
|
prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre)
|
|
100
102
|
text = "".join(txt for txt in columns.get("text_column", []) if txt)
|
|
101
103
|
if prefix or text:
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
input_metrics.
|
|
105
|
-
input_metrics.text_words = stats.get("num_words")
|
|
104
|
+
prompt = prefix + text
|
|
105
|
+
arguments.body["prompt"] = prompt
|
|
106
|
+
input_metrics.add_text_metrics(prompt)
|
|
106
107
|
|
|
107
108
|
return GenerationRequest(
|
|
108
109
|
request_type="text_completions",
|
|
@@ -142,9 +143,13 @@ class GenerativeChatCompletionsRequestFormatter(RequestFormatter):
|
|
|
142
143
|
)
|
|
143
144
|
|
|
144
145
|
def __call__( # noqa: C901, PLR0912, PLR0915
|
|
145
|
-
self, columns: dict[
|
|
146
|
+
self, columns: dict[str, list[Any]]
|
|
146
147
|
) -> GenerationRequest:
|
|
147
|
-
|
|
148
|
+
"""
|
|
149
|
+
:param columns: A dict of GenerativeDatasetColumnType to Any
|
|
150
|
+
"""
|
|
151
|
+
arguments = GenerationRequestArguments()
|
|
152
|
+
arguments.body = {} # The type checker works best with body assigned here
|
|
148
153
|
input_metrics = UsageMetrics()
|
|
149
154
|
output_metrics = UsageMetrics()
|
|
150
155
|
|
|
@@ -191,27 +196,14 @@ class GenerativeChatCompletionsRequestFormatter(RequestFormatter):
|
|
|
191
196
|
if not prefix:
|
|
192
197
|
continue
|
|
193
198
|
|
|
194
|
-
|
|
195
|
-
if (num_chars := stats.get("num_chars")) is not None:
|
|
196
|
-
input_metrics.text_characters = (
|
|
197
|
-
input_metrics.text_characters or 0
|
|
198
|
-
) + num_chars
|
|
199
|
-
if (num_words := stats.get("num_words")) is not None:
|
|
200
|
-
input_metrics.text_words = (input_metrics.text_words or 0) + num_words
|
|
201
|
-
|
|
199
|
+
input_metrics.add_text_metrics(prefix)
|
|
202
200
|
arguments.body["messages"].append({"role": "system", "content": prefix})
|
|
203
201
|
|
|
204
202
|
for text in columns.get("text_column", []):
|
|
205
203
|
if not text:
|
|
206
204
|
continue
|
|
207
205
|
|
|
208
|
-
|
|
209
|
-
if (num_chars := stats.get("num_chars")) is not None:
|
|
210
|
-
input_metrics.text_characters = (
|
|
211
|
-
input_metrics.text_characters or 0
|
|
212
|
-
) + num_chars
|
|
213
|
-
if (num_words := stats.get("num_words")) is not None:
|
|
214
|
-
input_metrics.text_words = (input_metrics.text_words or 0) + num_words
|
|
206
|
+
input_metrics.add_text_metrics(text)
|
|
215
207
|
|
|
216
208
|
arguments.body["messages"].append(
|
|
217
209
|
{"role": "user", "content": [{"type": "text", "text": text}]}
|
|
@@ -329,9 +321,10 @@ class GenerativeAudioTranscriptionRequestFormatter(RequestFormatter):
|
|
|
329
321
|
self.encode_audio_kwargs = encode_kwargs or {}
|
|
330
322
|
|
|
331
323
|
def __call__( # noqa: C901
|
|
332
|
-
self, columns: dict[
|
|
324
|
+
self, columns: dict[str, list[Any]]
|
|
333
325
|
) -> GenerationRequest:
|
|
334
|
-
arguments = GenerationRequestArguments(
|
|
326
|
+
arguments = GenerationRequestArguments(files={})
|
|
327
|
+
arguments.body = {} # The type checker works best with body assigned here
|
|
335
328
|
input_metrics = UsageMetrics()
|
|
336
329
|
output_metrics = UsageMetrics()
|
|
337
330
|
|
|
@@ -387,10 +380,9 @@ class GenerativeAudioTranscriptionRequestFormatter(RequestFormatter):
|
|
|
387
380
|
prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre)
|
|
388
381
|
text = "".join(txt for txt in columns.get("text_column", []) if txt)
|
|
389
382
|
if prefix or text:
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
input_metrics.
|
|
393
|
-
input_metrics.text_words = stats.get("num_words")
|
|
383
|
+
prompt = prefix + text
|
|
384
|
+
arguments.body["prompt"] = prompt
|
|
385
|
+
input_metrics.add_text_metrics(prompt)
|
|
394
386
|
|
|
395
387
|
return GenerationRequest(
|
|
396
388
|
request_type="audio_transcriptions",
|
|
@@ -405,7 +397,7 @@ class GenerativeAudioTranslationRequestFormatter(
|
|
|
405
397
|
GenerativeAudioTranscriptionRequestFormatter
|
|
406
398
|
):
|
|
407
399
|
def __call__(
|
|
408
|
-
self, columns: dict[
|
|
400
|
+
self, columns: dict[str, list[Any]]
|
|
409
401
|
) -> GenerationRequest:
|
|
410
402
|
result = super().__call__(columns)
|
|
411
403
|
result.request_type = "audio_translations"
|
|
@@ -169,12 +169,12 @@ class GenerativeColumnMapper(DataDependentPreprocessor):
|
|
|
169
169
|
|
|
170
170
|
def __call__(
|
|
171
171
|
self, row: dict[str, Any]
|
|
172
|
-
) -> dict[
|
|
172
|
+
) -> dict[str, list[Any]]:
|
|
173
173
|
if self.datasets_column_mappings is None:
|
|
174
174
|
raise ValueError("DefaultGenerativeColumnMapper not setup with data.")
|
|
175
175
|
|
|
176
176
|
items = cast("dict[int, dict[str, Any]]", row.pop("items"))
|
|
177
|
-
mapped: dict[
|
|
177
|
+
mapped: dict[str, Any] = defaultdict(list)
|
|
178
178
|
|
|
179
179
|
for column_type, column_mappings in self.datasets_column_mappings.items():
|
|
180
180
|
for (
|