guidellm 0.4.0a21__py3-none-any.whl → 0.4.0a169__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guidellm might be problematic. Click here for more details.
- guidellm/__init__.py +5 -2
- guidellm/__main__.py +452 -252
- guidellm/backends/__init__.py +33 -0
- guidellm/backends/backend.py +110 -0
- guidellm/backends/openai.py +355 -0
- guidellm/backends/response_handlers.py +455 -0
- guidellm/benchmark/__init__.py +53 -39
- guidellm/benchmark/benchmarker.py +150 -317
- guidellm/benchmark/entrypoints.py +467 -128
- guidellm/benchmark/output.py +519 -771
- guidellm/benchmark/profile.py +580 -280
- guidellm/benchmark/progress.py +568 -549
- guidellm/benchmark/scenarios/__init__.py +40 -0
- guidellm/benchmark/scenarios/chat.json +6 -0
- guidellm/benchmark/scenarios/rag.json +6 -0
- guidellm/benchmark/schemas.py +2086 -0
- guidellm/data/__init__.py +28 -4
- guidellm/data/collators.py +16 -0
- guidellm/data/deserializers/__init__.py +53 -0
- guidellm/data/deserializers/deserializer.py +144 -0
- guidellm/data/deserializers/file.py +222 -0
- guidellm/data/deserializers/huggingface.py +94 -0
- guidellm/data/deserializers/memory.py +194 -0
- guidellm/data/deserializers/synthetic.py +348 -0
- guidellm/data/loaders.py +149 -0
- guidellm/data/preprocessors/__init__.py +25 -0
- guidellm/data/preprocessors/formatters.py +404 -0
- guidellm/data/preprocessors/mappers.py +198 -0
- guidellm/data/preprocessors/preprocessor.py +31 -0
- guidellm/data/processor.py +31 -0
- guidellm/data/schemas.py +13 -0
- guidellm/data/utils/__init__.py +6 -0
- guidellm/data/utils/dataset.py +94 -0
- guidellm/extras/__init__.py +4 -0
- guidellm/extras/audio.py +215 -0
- guidellm/extras/vision.py +242 -0
- guidellm/logger.py +2 -2
- guidellm/mock_server/__init__.py +8 -0
- guidellm/mock_server/config.py +84 -0
- guidellm/mock_server/handlers/__init__.py +17 -0
- guidellm/mock_server/handlers/chat_completions.py +280 -0
- guidellm/mock_server/handlers/completions.py +280 -0
- guidellm/mock_server/handlers/tokenizer.py +142 -0
- guidellm/mock_server/models.py +510 -0
- guidellm/mock_server/server.py +168 -0
- guidellm/mock_server/utils.py +302 -0
- guidellm/preprocess/dataset.py +23 -26
- guidellm/presentation/builder.py +2 -2
- guidellm/presentation/data_models.py +25 -21
- guidellm/presentation/injector.py +2 -3
- guidellm/scheduler/__init__.py +65 -26
- guidellm/scheduler/constraints.py +1035 -0
- guidellm/scheduler/environments.py +252 -0
- guidellm/scheduler/scheduler.py +140 -368
- guidellm/scheduler/schemas.py +272 -0
- guidellm/scheduler/strategies.py +519 -0
- guidellm/scheduler/worker.py +391 -420
- guidellm/scheduler/worker_group.py +707 -0
- guidellm/schemas/__init__.py +31 -0
- guidellm/schemas/info.py +159 -0
- guidellm/schemas/request.py +226 -0
- guidellm/schemas/response.py +119 -0
- guidellm/schemas/stats.py +228 -0
- guidellm/{config.py → settings.py} +32 -21
- guidellm/utils/__init__.py +95 -8
- guidellm/utils/auto_importer.py +98 -0
- guidellm/utils/cli.py +71 -2
- guidellm/utils/console.py +183 -0
- guidellm/utils/encoding.py +778 -0
- guidellm/utils/functions.py +134 -0
- guidellm/utils/hf_datasets.py +1 -2
- guidellm/utils/hf_transformers.py +4 -4
- guidellm/utils/imports.py +9 -0
- guidellm/utils/messaging.py +1118 -0
- guidellm/utils/mixins.py +115 -0
- guidellm/utils/pydantic_utils.py +411 -0
- guidellm/utils/random.py +3 -4
- guidellm/utils/registry.py +220 -0
- guidellm/utils/singleton.py +133 -0
- guidellm/{objects → utils}/statistics.py +341 -247
- guidellm/utils/synchronous.py +159 -0
- guidellm/utils/text.py +163 -50
- guidellm/utils/typing.py +41 -0
- guidellm/version.py +1 -1
- {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a169.dist-info}/METADATA +33 -10
- guidellm-0.4.0a169.dist-info/RECORD +95 -0
- guidellm/backend/__init__.py +0 -23
- guidellm/backend/backend.py +0 -259
- guidellm/backend/openai.py +0 -705
- guidellm/backend/response.py +0 -136
- guidellm/benchmark/aggregator.py +0 -760
- guidellm/benchmark/benchmark.py +0 -837
- guidellm/benchmark/scenario.py +0 -104
- guidellm/data/prideandprejudice.txt.gz +0 -0
- guidellm/dataset/__init__.py +0 -22
- guidellm/dataset/creator.py +0 -213
- guidellm/dataset/entrypoints.py +0 -42
- guidellm/dataset/file.py +0 -92
- guidellm/dataset/hf_datasets.py +0 -62
- guidellm/dataset/in_memory.py +0 -132
- guidellm/dataset/synthetic.py +0 -287
- guidellm/objects/__init__.py +0 -18
- guidellm/objects/pydantic.py +0 -89
- guidellm/request/__init__.py +0 -18
- guidellm/request/loader.py +0 -284
- guidellm/request/request.py +0 -79
- guidellm/request/types.py +0 -10
- guidellm/scheduler/queues.py +0 -25
- guidellm/scheduler/result.py +0 -155
- guidellm/scheduler/strategy.py +0 -495
- guidellm-0.4.0a21.dist-info/RECORD +0 -62
- {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a169.dist-info}/WHEEL +0 -0
- {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a169.dist-info}/entry_points.txt +0 -0
- {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a169.dist-info}/licenses/LICENSE +0 -0
- {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a169.dist-info}/top_level.txt +0 -0
guidellm/data/__init__.py
CHANGED
|
@@ -1,4 +1,28 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
1
|
+
from .collators import GenerativeRequestCollator
|
|
2
|
+
from .deserializers import (
|
|
3
|
+
DataNotSupportedError,
|
|
4
|
+
DatasetDeserializer,
|
|
5
|
+
DatasetDeserializerFactory,
|
|
6
|
+
)
|
|
7
|
+
from .loaders import DataLoader, DatasetsIterator
|
|
8
|
+
from .preprocessors import (
|
|
9
|
+
DataDependentPreprocessor,
|
|
10
|
+
DatasetPreprocessor,
|
|
11
|
+
PreprocessorRegistry,
|
|
12
|
+
)
|
|
13
|
+
from .processor import ProcessorFactory
|
|
14
|
+
from .schemas import GenerativeDatasetColumnType
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"DataDependentPreprocessor",
|
|
18
|
+
"DataLoader",
|
|
19
|
+
"DataNotSupportedError",
|
|
20
|
+
"DatasetDeserializer",
|
|
21
|
+
"DatasetDeserializerFactory",
|
|
22
|
+
"DatasetPreprocessor",
|
|
23
|
+
"DatasetsIterator",
|
|
24
|
+
"GenerativeDatasetColumnType",
|
|
25
|
+
"GenerativeRequestCollator",
|
|
26
|
+
"PreprocessorRegistry",
|
|
27
|
+
"ProcessorFactory",
|
|
28
|
+
]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from guidellm.schemas import GenerationRequest
|
|
4
|
+
|
|
5
|
+
__all__ = ["GenerativeRequestCollator"]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GenerativeRequestCollator:
|
|
9
|
+
def __call__(self, batch: list) -> GenerationRequest:
|
|
10
|
+
if len(batch) != 1:
|
|
11
|
+
raise NotImplementedError(
|
|
12
|
+
f"Batch size greater than 1 is not currently supported. "
|
|
13
|
+
f"Got batch size: {len(batch)}"
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
return batch[0]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from .deserializer import (
|
|
2
|
+
DataNotSupportedError,
|
|
3
|
+
DatasetDeserializer,
|
|
4
|
+
DatasetDeserializerFactory,
|
|
5
|
+
)
|
|
6
|
+
from .file import (
|
|
7
|
+
ArrowFileDatasetDeserializer,
|
|
8
|
+
CSVFileDatasetDeserializer,
|
|
9
|
+
DBFileDatasetDeserializer,
|
|
10
|
+
HDF5FileDatasetDeserializer,
|
|
11
|
+
JSONFileDatasetDeserializer,
|
|
12
|
+
ParquetFileDatasetDeserializer,
|
|
13
|
+
TarFileDatasetDeserializer,
|
|
14
|
+
TextFileDatasetDeserializer,
|
|
15
|
+
)
|
|
16
|
+
from .huggingface import HuggingFaceDatasetDeserializer
|
|
17
|
+
from .memory import (
|
|
18
|
+
InMemoryCsvDatasetDeserializer,
|
|
19
|
+
InMemoryDictDatasetDeserializer,
|
|
20
|
+
InMemoryDictListDatasetDeserializer,
|
|
21
|
+
InMemoryItemListDatasetDeserializer,
|
|
22
|
+
InMemoryJsonStrDatasetDeserializer,
|
|
23
|
+
)
|
|
24
|
+
from .synthetic import (
|
|
25
|
+
SyntheticTextDatasetConfig,
|
|
26
|
+
SyntheticTextDatasetDeserializer,
|
|
27
|
+
SyntheticTextGenerator,
|
|
28
|
+
SyntheticTextPrefixBucketConfig,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
"ArrowFileDatasetDeserializer",
|
|
33
|
+
"CSVFileDatasetDeserializer",
|
|
34
|
+
"DBFileDatasetDeserializer",
|
|
35
|
+
"DataNotSupportedError",
|
|
36
|
+
"DatasetDeserializer",
|
|
37
|
+
"DatasetDeserializerFactory",
|
|
38
|
+
"HDF5FileDatasetDeserializer",
|
|
39
|
+
"HuggingFaceDatasetDeserializer",
|
|
40
|
+
"InMemoryCsvDatasetDeserializer",
|
|
41
|
+
"InMemoryDictDatasetDeserializer",
|
|
42
|
+
"InMemoryDictListDatasetDeserializer",
|
|
43
|
+
"InMemoryItemListDatasetDeserializer",
|
|
44
|
+
"InMemoryJsonStrDatasetDeserializer",
|
|
45
|
+
"JSONFileDatasetDeserializer",
|
|
46
|
+
"ParquetFileDatasetDeserializer",
|
|
47
|
+
"SyntheticTextDatasetConfig",
|
|
48
|
+
"SyntheticTextDatasetDeserializer",
|
|
49
|
+
"SyntheticTextGenerator",
|
|
50
|
+
"SyntheticTextPrefixBucketConfig",
|
|
51
|
+
"TarFileDatasetDeserializer",
|
|
52
|
+
"TextFileDatasetDeserializer",
|
|
53
|
+
]
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, Protocol, Union, runtime_checkable
|
|
5
|
+
|
|
6
|
+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
|
7
|
+
from transformers import PreTrainedTokenizerBase
|
|
8
|
+
|
|
9
|
+
from guidellm.data.utils import resolve_dataset_split
|
|
10
|
+
from guidellm.utils import RegistryMixin
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"DataNotSupportedError",
|
|
14
|
+
"DatasetDeserializer",
|
|
15
|
+
"DatasetDeserializerFactory",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DataNotSupportedError(Exception):
|
|
20
|
+
"""Exception raised when data format is not supported by deserializer."""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@runtime_checkable
|
|
24
|
+
class DatasetDeserializer(Protocol):
|
|
25
|
+
def __call__(
|
|
26
|
+
self,
|
|
27
|
+
data: Any,
|
|
28
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
29
|
+
random_seed: int,
|
|
30
|
+
**data_kwargs: dict[str, Any],
|
|
31
|
+
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: ...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class DatasetDeserializerFactory(
|
|
35
|
+
RegistryMixin[Union["type[DatasetDeserializer]", DatasetDeserializer]],
|
|
36
|
+
):
|
|
37
|
+
@classmethod
|
|
38
|
+
def deserialize(
|
|
39
|
+
cls,
|
|
40
|
+
data: Any,
|
|
41
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
42
|
+
random_seed: int = 42,
|
|
43
|
+
type_: str | None = None,
|
|
44
|
+
resolve_split: bool = True,
|
|
45
|
+
select_columns: list[str] | None = None,
|
|
46
|
+
remove_columns: list[str] | None = None,
|
|
47
|
+
**data_kwargs: dict[str, Any],
|
|
48
|
+
) -> Dataset | IterableDataset:
|
|
49
|
+
dataset: Dataset
|
|
50
|
+
|
|
51
|
+
if type_ is None:
|
|
52
|
+
dataset = cls._deserialize_with_registered_deserializers(
|
|
53
|
+
data, processor_factory, random_seed, **data_kwargs
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
else:
|
|
57
|
+
dataset = cls._deserialize_with_specified_deserializer(
|
|
58
|
+
data, type_, processor_factory, random_seed, **data_kwargs
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if resolve_split:
|
|
62
|
+
dataset = resolve_dataset_split(dataset)
|
|
63
|
+
|
|
64
|
+
if select_columns is not None or remove_columns is not None:
|
|
65
|
+
column_names = dataset.column_names or list(next(iter(dataset)).keys())
|
|
66
|
+
if select_columns is not None:
|
|
67
|
+
remove_columns = [
|
|
68
|
+
col for col in column_names if col not in select_columns
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
dataset = dataset.remove_columns(remove_columns)
|
|
72
|
+
|
|
73
|
+
return dataset
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def _deserialize_with_registered_deserializers(
|
|
77
|
+
cls,
|
|
78
|
+
data: Any,
|
|
79
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
80
|
+
random_seed: int = 42,
|
|
81
|
+
**data_kwargs: dict[str, Any],
|
|
82
|
+
) -> Dataset:
|
|
83
|
+
if cls.registry is None:
|
|
84
|
+
raise RuntimeError("registry is None; cannot deserialize dataset")
|
|
85
|
+
dataset: Dataset | None = None
|
|
86
|
+
|
|
87
|
+
errors: dict[str, Exception] = {}
|
|
88
|
+
# Note: There is no priority order for the deserializers, so all deserializers
|
|
89
|
+
# must be mutually exclusive to ensure deterministic behavior.
|
|
90
|
+
for _name, deserializer in cls.registry.items():
|
|
91
|
+
deserializer_fn: DatasetDeserializer = (
|
|
92
|
+
deserializer() if isinstance(deserializer, type) else deserializer
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
dataset = deserializer_fn(
|
|
97
|
+
data=data,
|
|
98
|
+
processor_factory=processor_factory,
|
|
99
|
+
random_seed=random_seed,
|
|
100
|
+
**data_kwargs,
|
|
101
|
+
)
|
|
102
|
+
except Exception as e: # noqa: BLE001 # The exceptions are saved.
|
|
103
|
+
errors[_name] = e
|
|
104
|
+
|
|
105
|
+
if dataset is not None:
|
|
106
|
+
return dataset # Success
|
|
107
|
+
|
|
108
|
+
if len(errors) > 0:
|
|
109
|
+
err_msgs = ""
|
|
110
|
+
def sort_key(item):
|
|
111
|
+
return (isinstance(item[1], DataNotSupportedError), item[0])
|
|
112
|
+
for key, err in sorted(errors.items(), key=sort_key):
|
|
113
|
+
err_msgs += f"\n - Deserializer '{key}': ({type(err).__name__}) {err}"
|
|
114
|
+
raise ValueError(
|
|
115
|
+
"Data deserialization failed, likely because the input doesn't "
|
|
116
|
+
f"match any of the input formats. See the {len(errors)} error(s) that "
|
|
117
|
+
f"occurred while attempting to deserialize the data {data}:{err_msgs}"
|
|
118
|
+
)
|
|
119
|
+
return dataset
|
|
120
|
+
|
|
121
|
+
@classmethod
|
|
122
|
+
def _deserialize_with_specified_deserializer(
|
|
123
|
+
cls,
|
|
124
|
+
data: Any,
|
|
125
|
+
type_: str,
|
|
126
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
127
|
+
random_seed: int = 42,
|
|
128
|
+
**data_kwargs: dict[str, Any],
|
|
129
|
+
) -> Dataset:
|
|
130
|
+
deserializer_from_type = cls.get_registered_object(type_)
|
|
131
|
+
if deserializer_from_type is None:
|
|
132
|
+
raise ValueError(f"Deserializer type '{type_}' is not registered.")
|
|
133
|
+
if isinstance(deserializer_from_type, type):
|
|
134
|
+
deserializer_fn = deserializer_from_type()
|
|
135
|
+
else:
|
|
136
|
+
deserializer_fn = deserializer_from_type
|
|
137
|
+
|
|
138
|
+
return deserializer_fn(
|
|
139
|
+
data=data,
|
|
140
|
+
processor_factory=processor_factory,
|
|
141
|
+
random_seed=random_seed,
|
|
142
|
+
**data_kwargs,
|
|
143
|
+
)
|
|
144
|
+
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from datasets import Dataset, load_dataset
|
|
9
|
+
from transformers import PreTrainedTokenizerBase
|
|
10
|
+
|
|
11
|
+
from guidellm.data.deserializers.deserializer import (
|
|
12
|
+
DataNotSupportedError,
|
|
13
|
+
DatasetDeserializer,
|
|
14
|
+
DatasetDeserializerFactory,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"ArrowFileDatasetDeserializer",
|
|
19
|
+
"CSVFileDatasetDeserializer",
|
|
20
|
+
"DBFileDatasetDeserializer",
|
|
21
|
+
"HDF5FileDatasetDeserializer",
|
|
22
|
+
"JSONFileDatasetDeserializer",
|
|
23
|
+
"ParquetFileDatasetDeserializer",
|
|
24
|
+
"TarFileDatasetDeserializer",
|
|
25
|
+
"TextFileDatasetDeserializer",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@DatasetDeserializerFactory.register("text_file")
|
|
30
|
+
class TextFileDatasetDeserializer(DatasetDeserializer):
|
|
31
|
+
def __call__(
|
|
32
|
+
self,
|
|
33
|
+
data: Any,
|
|
34
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
35
|
+
random_seed: int,
|
|
36
|
+
**data_kwargs: dict[str, Any],
|
|
37
|
+
) -> Dataset:
|
|
38
|
+
_ = (processor_factory, random_seed) # Ignore unused args format errors
|
|
39
|
+
|
|
40
|
+
if (
|
|
41
|
+
not isinstance(data, str | Path)
|
|
42
|
+
or not (path := Path(data)).exists()
|
|
43
|
+
or not path.is_file()
|
|
44
|
+
or path.suffix.lower() not in {".txt", ".text"}
|
|
45
|
+
):
|
|
46
|
+
raise DataNotSupportedError(
|
|
47
|
+
"Unsupported data for TextFileDatasetDeserializer, "
|
|
48
|
+
f"expected str or Path to a local .txt or .text file, got {data}"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
with path.open() as file:
|
|
52
|
+
lines = file.readlines()
|
|
53
|
+
|
|
54
|
+
return Dataset.from_dict({"text": lines}, **data_kwargs)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@DatasetDeserializerFactory.register("csv_file")
|
|
58
|
+
class CSVFileDatasetDeserializer(DatasetDeserializer):
|
|
59
|
+
def __call__(
|
|
60
|
+
self,
|
|
61
|
+
data: Any,
|
|
62
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
63
|
+
random_seed: int,
|
|
64
|
+
**data_kwargs: dict[str, Any],
|
|
65
|
+
) -> Dataset:
|
|
66
|
+
_ = (processor_factory, random_seed)
|
|
67
|
+
if (
|
|
68
|
+
not isinstance(data, str | Path)
|
|
69
|
+
or not (path := Path(data)).exists()
|
|
70
|
+
or not path.is_file()
|
|
71
|
+
or path.suffix.lower() != ".csv"
|
|
72
|
+
):
|
|
73
|
+
raise DataNotSupportedError(
|
|
74
|
+
"Unsupported data for CSVFileDatasetDeserializer, "
|
|
75
|
+
f"expected str or Path to a local .csv file, got {data}"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return load_dataset("csv", data_files=str(path), **data_kwargs)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@DatasetDeserializerFactory.register("json_file")
|
|
82
|
+
class JSONFileDatasetDeserializer(DatasetDeserializer):
|
|
83
|
+
def __call__(
|
|
84
|
+
self,
|
|
85
|
+
data: Any,
|
|
86
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
87
|
+
random_seed: int,
|
|
88
|
+
**data_kwargs: dict[str, Any],
|
|
89
|
+
) -> Dataset:
|
|
90
|
+
_ = (processor_factory, random_seed)
|
|
91
|
+
if (
|
|
92
|
+
not isinstance(data, str | Path)
|
|
93
|
+
or not (path := Path(data)).exists()
|
|
94
|
+
or not path.is_file()
|
|
95
|
+
or path.suffix.lower() not in {".json", ".jsonl"}
|
|
96
|
+
):
|
|
97
|
+
raise DataNotSupportedError(
|
|
98
|
+
f"Unsupported data for JSONFileDatasetDeserializer, "
|
|
99
|
+
f"expected str or Path to a local .json or .jsonl file, got {data}"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return load_dataset("json", data_files=str(path), **data_kwargs)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@DatasetDeserializerFactory.register("parquet_file")
|
|
106
|
+
class ParquetFileDatasetDeserializer(DatasetDeserializer):
|
|
107
|
+
def __call__(
|
|
108
|
+
self,
|
|
109
|
+
data: Any,
|
|
110
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
111
|
+
random_seed: int,
|
|
112
|
+
**data_kwargs: dict[str, Any],
|
|
113
|
+
) -> Dataset:
|
|
114
|
+
_ = (processor_factory, random_seed)
|
|
115
|
+
if (
|
|
116
|
+
not isinstance(data, str | Path)
|
|
117
|
+
or not (path := Path(data)).exists()
|
|
118
|
+
or not path.is_file()
|
|
119
|
+
or path.suffix.lower() != ".parquet"
|
|
120
|
+
):
|
|
121
|
+
raise DataNotSupportedError(
|
|
122
|
+
f"Unsupported data for ParquetFileDatasetDeserializer, "
|
|
123
|
+
f"expected str or Path to a local .parquet file, got {data}"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
return load_dataset("parquet", data_files=str(path), **data_kwargs)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@DatasetDeserializerFactory.register("arrow_file")
|
|
130
|
+
class ArrowFileDatasetDeserializer(DatasetDeserializer):
|
|
131
|
+
def __call__(
|
|
132
|
+
self,
|
|
133
|
+
data: Any,
|
|
134
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
135
|
+
random_seed: int,
|
|
136
|
+
**data_kwargs: dict[str, Any],
|
|
137
|
+
) -> Dataset:
|
|
138
|
+
_ = (processor_factory, random_seed)
|
|
139
|
+
if (
|
|
140
|
+
not isinstance(data, str | Path)
|
|
141
|
+
or not (path := Path(data)).exists()
|
|
142
|
+
or not path.is_file()
|
|
143
|
+
or path.suffix.lower() != ".arrow"
|
|
144
|
+
):
|
|
145
|
+
raise DataNotSupportedError(
|
|
146
|
+
f"Unsupported data for ArrowFileDatasetDeserializer, "
|
|
147
|
+
f"expected str or Path to a local .arrow file, got {data}"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
return load_dataset("arrow", data_files=str(path), **data_kwargs)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@DatasetDeserializerFactory.register("hdf5_file")
|
|
154
|
+
class HDF5FileDatasetDeserializer(DatasetDeserializer):
|
|
155
|
+
def __call__(
|
|
156
|
+
self,
|
|
157
|
+
data: Any,
|
|
158
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
159
|
+
random_seed: int,
|
|
160
|
+
**data_kwargs: dict[str, Any],
|
|
161
|
+
) -> Dataset:
|
|
162
|
+
_ = (processor_factory, random_seed)
|
|
163
|
+
if (
|
|
164
|
+
not isinstance(data, str | Path)
|
|
165
|
+
or not (path := Path(data)).exists()
|
|
166
|
+
or not path.is_file()
|
|
167
|
+
or path.suffix.lower() not in {".hdf5", ".h5"}
|
|
168
|
+
):
|
|
169
|
+
raise DataNotSupportedError(
|
|
170
|
+
f"Unsupported data for HDF5FileDatasetDeserializer, "
|
|
171
|
+
f"expected str or Path to a local .hdf5 or .h5 file, got {data}"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
return Dataset.from_pandas(pd.read_hdf(str(path)), **data_kwargs)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@DatasetDeserializerFactory.register("db_file")
|
|
178
|
+
class DBFileDatasetDeserializer(DatasetDeserializer):
|
|
179
|
+
def __call__(
|
|
180
|
+
self,
|
|
181
|
+
data: Any,
|
|
182
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
183
|
+
random_seed: int,
|
|
184
|
+
**data_kwargs: dict[str, Any],
|
|
185
|
+
) -> dict[str, list]:
|
|
186
|
+
_ = (processor_factory, random_seed)
|
|
187
|
+
if (
|
|
188
|
+
not isinstance(data, str | Path)
|
|
189
|
+
or not (path := Path(data)).exists()
|
|
190
|
+
or not path.is_file()
|
|
191
|
+
or path.suffix.lower() != ".db"
|
|
192
|
+
):
|
|
193
|
+
raise DataNotSupportedError(
|
|
194
|
+
f"Unsupported data for DBFileDatasetDeserializer, "
|
|
195
|
+
f"expected str or Path to a local .db file, got {data}"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
return Dataset.from_sql(con=str(path), **data_kwargs)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@DatasetDeserializerFactory.register("tar_file")
|
|
202
|
+
class TarFileDatasetDeserializer(DatasetDeserializer):
|
|
203
|
+
def __call__(
|
|
204
|
+
self,
|
|
205
|
+
data: Any,
|
|
206
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
207
|
+
random_seed: int,
|
|
208
|
+
**data_kwargs: dict[str, Any],
|
|
209
|
+
) -> dict[str, list]:
|
|
210
|
+
_ = (processor_factory, random_seed)
|
|
211
|
+
if (
|
|
212
|
+
not isinstance(data, str | Path)
|
|
213
|
+
or not (path := Path(data)).exists()
|
|
214
|
+
or not path.is_file()
|
|
215
|
+
or path.suffix.lower() != ".tar"
|
|
216
|
+
):
|
|
217
|
+
raise DataNotSupportedError(
|
|
218
|
+
f"Unsupported data for TarFileDatasetDeserializer, "
|
|
219
|
+
f"expected str or Path to a local .tar file, got {data}"
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
return load_dataset("webdataset", data_files=str(path), **data_kwargs)
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from datasets import (
|
|
8
|
+
Dataset,
|
|
9
|
+
DatasetDict,
|
|
10
|
+
IterableDataset,
|
|
11
|
+
IterableDatasetDict,
|
|
12
|
+
load_dataset,
|
|
13
|
+
load_from_disk,
|
|
14
|
+
)
|
|
15
|
+
from datasets.exceptions import (
|
|
16
|
+
DataFilesNotFoundError,
|
|
17
|
+
DatasetNotFoundError,
|
|
18
|
+
FileNotFoundDatasetsError,
|
|
19
|
+
)
|
|
20
|
+
from transformers import PreTrainedTokenizerBase
|
|
21
|
+
|
|
22
|
+
from guidellm.data.deserializers.deserializer import (
|
|
23
|
+
DataNotSupportedError,
|
|
24
|
+
DatasetDeserializer,
|
|
25
|
+
DatasetDeserializerFactory,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
__all__ = ["HuggingFaceDatasetDeserializer"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@DatasetDeserializerFactory.register("huggingface")
|
|
32
|
+
class HuggingFaceDatasetDeserializer(DatasetDeserializer):
|
|
33
|
+
def __call__(
|
|
34
|
+
self,
|
|
35
|
+
data: Any,
|
|
36
|
+
processor_factory: Callable[[], PreTrainedTokenizerBase],
|
|
37
|
+
random_seed: int,
|
|
38
|
+
**data_kwargs: dict[str, Any],
|
|
39
|
+
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
|
40
|
+
_ = (processor_factory, random_seed)
|
|
41
|
+
|
|
42
|
+
if isinstance(
|
|
43
|
+
data, Dataset | IterableDataset | DatasetDict | IterableDatasetDict
|
|
44
|
+
):
|
|
45
|
+
return data
|
|
46
|
+
|
|
47
|
+
load_error = None
|
|
48
|
+
|
|
49
|
+
if (
|
|
50
|
+
isinstance(data, str | Path)
|
|
51
|
+
and (path := Path(data)).exists()
|
|
52
|
+
and ((path.is_file() and path.suffix == ".py") or path.is_dir())
|
|
53
|
+
):
|
|
54
|
+
# Handle python script or nested python script in a directory
|
|
55
|
+
try:
|
|
56
|
+
return load_dataset(str(data), **data_kwargs)
|
|
57
|
+
except (
|
|
58
|
+
FileNotFoundDatasetsError,
|
|
59
|
+
DatasetNotFoundError,
|
|
60
|
+
DataFilesNotFoundError,
|
|
61
|
+
) as err:
|
|
62
|
+
load_error = err
|
|
63
|
+
except Exception: # noqa: BLE001
|
|
64
|
+
# Try loading as a local dataset directory next
|
|
65
|
+
try:
|
|
66
|
+
return load_from_disk(str(data), **data_kwargs)
|
|
67
|
+
except (
|
|
68
|
+
FileNotFoundDatasetsError,
|
|
69
|
+
DatasetNotFoundError,
|
|
70
|
+
DataFilesNotFoundError,
|
|
71
|
+
) as err2:
|
|
72
|
+
load_error = err2
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
# Handle dataset identifier from the Hugging Face Hub
|
|
76
|
+
return load_dataset(str(data), **data_kwargs)
|
|
77
|
+
except (
|
|
78
|
+
FileNotFoundDatasetsError,
|
|
79
|
+
DatasetNotFoundError,
|
|
80
|
+
DataFilesNotFoundError,
|
|
81
|
+
) as err:
|
|
82
|
+
load_error = err
|
|
83
|
+
|
|
84
|
+
not_supported = DataNotSupportedError(
|
|
85
|
+
"Unsupported data for HuggingFaceDatasetDeserializer, "
|
|
86
|
+
"expected Dataset, IterableDataset, DatasetDict, IterableDatasetDict, "
|
|
87
|
+
"str or Path to a local dataset directory or a local .py dataset script, "
|
|
88
|
+
f"got {data} and HF load error: {load_error}"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
if load_error is not None:
|
|
92
|
+
raise not_supported from load_error
|
|
93
|
+
else:
|
|
94
|
+
raise not_supported
|