guidellm 0.4.0a18__py3-none-any.whl → 0.4.0a155__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guidellm might be problematic. Click here for more details.

Files changed (116) hide show
  1. guidellm/__init__.py +5 -2
  2. guidellm/__main__.py +451 -252
  3. guidellm/backends/__init__.py +33 -0
  4. guidellm/backends/backend.py +110 -0
  5. guidellm/backends/openai.py +355 -0
  6. guidellm/backends/response_handlers.py +455 -0
  7. guidellm/benchmark/__init__.py +53 -39
  8. guidellm/benchmark/benchmarker.py +148 -317
  9. guidellm/benchmark/entrypoints.py +466 -128
  10. guidellm/benchmark/output.py +517 -771
  11. guidellm/benchmark/profile.py +580 -280
  12. guidellm/benchmark/progress.py +568 -549
  13. guidellm/benchmark/scenarios/__init__.py +40 -0
  14. guidellm/benchmark/scenarios/chat.json +6 -0
  15. guidellm/benchmark/scenarios/rag.json +6 -0
  16. guidellm/benchmark/schemas.py +2085 -0
  17. guidellm/data/__init__.py +28 -4
  18. guidellm/data/collators.py +16 -0
  19. guidellm/data/deserializers/__init__.py +53 -0
  20. guidellm/data/deserializers/deserializer.py +109 -0
  21. guidellm/data/deserializers/file.py +222 -0
  22. guidellm/data/deserializers/huggingface.py +94 -0
  23. guidellm/data/deserializers/memory.py +192 -0
  24. guidellm/data/deserializers/synthetic.py +346 -0
  25. guidellm/data/loaders.py +145 -0
  26. guidellm/data/preprocessors/__init__.py +25 -0
  27. guidellm/data/preprocessors/formatters.py +412 -0
  28. guidellm/data/preprocessors/mappers.py +198 -0
  29. guidellm/data/preprocessors/preprocessor.py +29 -0
  30. guidellm/data/processor.py +30 -0
  31. guidellm/data/schemas.py +13 -0
  32. guidellm/data/utils/__init__.py +10 -0
  33. guidellm/data/utils/dataset.py +94 -0
  34. guidellm/data/utils/functions.py +18 -0
  35. guidellm/extras/__init__.py +4 -0
  36. guidellm/extras/audio.py +215 -0
  37. guidellm/extras/vision.py +242 -0
  38. guidellm/logger.py +2 -2
  39. guidellm/mock_server/__init__.py +8 -0
  40. guidellm/mock_server/config.py +84 -0
  41. guidellm/mock_server/handlers/__init__.py +17 -0
  42. guidellm/mock_server/handlers/chat_completions.py +280 -0
  43. guidellm/mock_server/handlers/completions.py +280 -0
  44. guidellm/mock_server/handlers/tokenizer.py +142 -0
  45. guidellm/mock_server/models.py +510 -0
  46. guidellm/mock_server/server.py +168 -0
  47. guidellm/mock_server/utils.py +302 -0
  48. guidellm/preprocess/dataset.py +23 -26
  49. guidellm/presentation/builder.py +2 -2
  50. guidellm/presentation/data_models.py +25 -21
  51. guidellm/presentation/injector.py +2 -3
  52. guidellm/scheduler/__init__.py +65 -26
  53. guidellm/scheduler/constraints.py +1035 -0
  54. guidellm/scheduler/environments.py +252 -0
  55. guidellm/scheduler/scheduler.py +140 -368
  56. guidellm/scheduler/schemas.py +272 -0
  57. guidellm/scheduler/strategies.py +519 -0
  58. guidellm/scheduler/worker.py +391 -420
  59. guidellm/scheduler/worker_group.py +707 -0
  60. guidellm/schemas/__init__.py +31 -0
  61. guidellm/schemas/info.py +159 -0
  62. guidellm/schemas/request.py +216 -0
  63. guidellm/schemas/response.py +119 -0
  64. guidellm/schemas/stats.py +228 -0
  65. guidellm/{config.py → settings.py} +32 -21
  66. guidellm/utils/__init__.py +95 -8
  67. guidellm/utils/auto_importer.py +98 -0
  68. guidellm/utils/cli.py +46 -2
  69. guidellm/utils/console.py +183 -0
  70. guidellm/utils/encoding.py +778 -0
  71. guidellm/utils/functions.py +134 -0
  72. guidellm/utils/hf_datasets.py +1 -2
  73. guidellm/utils/hf_transformers.py +4 -4
  74. guidellm/utils/imports.py +9 -0
  75. guidellm/utils/messaging.py +1118 -0
  76. guidellm/utils/mixins.py +115 -0
  77. guidellm/utils/pydantic_utils.py +411 -0
  78. guidellm/utils/random.py +3 -4
  79. guidellm/utils/registry.py +220 -0
  80. guidellm/utils/singleton.py +133 -0
  81. guidellm/{objects → utils}/statistics.py +341 -247
  82. guidellm/utils/synchronous.py +159 -0
  83. guidellm/utils/text.py +163 -50
  84. guidellm/utils/typing.py +41 -0
  85. guidellm/version.py +1 -1
  86. {guidellm-0.4.0a18.dist-info → guidellm-0.4.0a155.dist-info}/METADATA +33 -10
  87. guidellm-0.4.0a155.dist-info/RECORD +96 -0
  88. guidellm/backend/__init__.py +0 -23
  89. guidellm/backend/backend.py +0 -259
  90. guidellm/backend/openai.py +0 -705
  91. guidellm/backend/response.py +0 -136
  92. guidellm/benchmark/aggregator.py +0 -760
  93. guidellm/benchmark/benchmark.py +0 -837
  94. guidellm/benchmark/scenario.py +0 -104
  95. guidellm/data/prideandprejudice.txt.gz +0 -0
  96. guidellm/dataset/__init__.py +0 -22
  97. guidellm/dataset/creator.py +0 -213
  98. guidellm/dataset/entrypoints.py +0 -42
  99. guidellm/dataset/file.py +0 -92
  100. guidellm/dataset/hf_datasets.py +0 -62
  101. guidellm/dataset/in_memory.py +0 -132
  102. guidellm/dataset/synthetic.py +0 -287
  103. guidellm/objects/__init__.py +0 -18
  104. guidellm/objects/pydantic.py +0 -89
  105. guidellm/request/__init__.py +0 -18
  106. guidellm/request/loader.py +0 -284
  107. guidellm/request/request.py +0 -79
  108. guidellm/request/types.py +0 -10
  109. guidellm/scheduler/queues.py +0 -25
  110. guidellm/scheduler/result.py +0 -155
  111. guidellm/scheduler/strategy.py +0 -495
  112. guidellm-0.4.0a18.dist-info/RECORD +0 -62
  113. {guidellm-0.4.0a18.dist-info → guidellm-0.4.0a155.dist-info}/WHEEL +0 -0
  114. {guidellm-0.4.0a18.dist-info → guidellm-0.4.0a155.dist-info}/entry_points.txt +0 -0
  115. {guidellm-0.4.0a18.dist-info → guidellm-0.4.0a155.dist-info}/licenses/LICENSE +0 -0
  116. {guidellm-0.4.0a18.dist-info → guidellm-0.4.0a155.dist-info}/top_level.txt +0 -0
guidellm/data/__init__.py CHANGED
@@ -1,4 +1,28 @@
1
- """
2
- Required for python < 3.12
3
- https://docs.python.org/3/library/importlib.resources.html#importlib.resources.files
4
- """
1
+ from .collators import GenerativeRequestCollator
2
+ from .deserializers import (
3
+ DataNotSupportedError,
4
+ DatasetDeserializer,
5
+ DatasetDeserializerFactory,
6
+ )
7
+ from .loaders import DataLoader, DatasetsIterator
8
+ from .preprocessors import (
9
+ DataDependentPreprocessor,
10
+ DatasetPreprocessor,
11
+ PreprocessorRegistry,
12
+ )
13
+ from .processor import ProcessorFactory
14
+ from .schemas import GenerativeDatasetColumnType
15
+
16
+ __all__ = [
17
+ "DataDependentPreprocessor",
18
+ "DataLoader",
19
+ "DataNotSupportedError",
20
+ "DatasetDeserializer",
21
+ "DatasetDeserializerFactory",
22
+ "DatasetPreprocessor",
23
+ "DatasetsIterator",
24
+ "GenerativeDatasetColumnType",
25
+ "GenerativeRequestCollator",
26
+ "PreprocessorRegistry",
27
+ "ProcessorFactory",
28
+ ]
@@ -0,0 +1,16 @@
1
+ from __future__ import annotations
2
+
3
+ from guidellm.schemas import GenerationRequest
4
+
5
+ __all__ = ["GenerativeRequestCollator"]
6
+
7
+
8
+ class GenerativeRequestCollator:
9
+ def __call__(self, batch: list) -> GenerationRequest:
10
+ if len(batch) != 1:
11
+ raise NotImplementedError(
12
+ f"Batch size greater than 1 is not currently supported. "
13
+ f"Got batch size: {len(batch)}"
14
+ )
15
+
16
+ return batch[0]
@@ -0,0 +1,53 @@
1
+ from .deserializer import (
2
+ DataNotSupportedError,
3
+ DatasetDeserializer,
4
+ DatasetDeserializerFactory,
5
+ )
6
+ from .file import (
7
+ ArrowFileDatasetDeserializer,
8
+ CSVFileDatasetDeserializer,
9
+ DBFileDatasetDeserializer,
10
+ HDF5FileDatasetDeserializer,
11
+ JSONFileDatasetDeserializer,
12
+ ParquetFileDatasetDeserializer,
13
+ TarFileDatasetDeserializer,
14
+ TextFileDatasetDeserializer,
15
+ )
16
+ from .huggingface import HuggingFaceDatasetDeserializer
17
+ from .memory import (
18
+ InMemoryCsvDatasetDeserializer,
19
+ InMemoryDictDatasetDeserializer,
20
+ InMemoryDictListDatasetDeserializer,
21
+ InMemoryItemListDatasetDeserializer,
22
+ InMemoryJsonStrDatasetDeserializer,
23
+ )
24
+ from .synthetic import (
25
+ SyntheticTextDatasetConfig,
26
+ SyntheticTextDatasetDeserializer,
27
+ SyntheticTextGenerator,
28
+ SyntheticTextPrefixBucketConfig,
29
+ )
30
+
31
+ __all__ = [
32
+ "ArrowFileDatasetDeserializer",
33
+ "CSVFileDatasetDeserializer",
34
+ "DBFileDatasetDeserializer",
35
+ "DataNotSupportedError",
36
+ "DatasetDeserializer",
37
+ "DatasetDeserializerFactory",
38
+ "HDF5FileDatasetDeserializer",
39
+ "HuggingFaceDatasetDeserializer",
40
+ "InMemoryCsvDatasetDeserializer",
41
+ "InMemoryDictDatasetDeserializer",
42
+ "InMemoryDictListDatasetDeserializer",
43
+ "InMemoryItemListDatasetDeserializer",
44
+ "InMemoryJsonStrDatasetDeserializer",
45
+ "JSONFileDatasetDeserializer",
46
+ "ParquetFileDatasetDeserializer",
47
+ "SyntheticTextDatasetConfig",
48
+ "SyntheticTextDatasetDeserializer",
49
+ "SyntheticTextGenerator",
50
+ "SyntheticTextPrefixBucketConfig",
51
+ "TarFileDatasetDeserializer",
52
+ "TextFileDatasetDeserializer",
53
+ ]
@@ -0,0 +1,109 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ from collections.abc import Callable
5
+ from typing import Any, Protocol, Union, runtime_checkable
6
+
7
+ from datasets import Dataset, IterableDataset
8
+ from transformers import PreTrainedTokenizerBase
9
+
10
+ from guidellm.data.utils import resolve_dataset_split
11
+ from guidellm.utils import RegistryMixin
12
+
13
+ __all__ = [
14
+ "DataNotSupportedError",
15
+ "DatasetDeserializer",
16
+ "DatasetDeserializerFactory",
17
+ ]
18
+
19
+
20
+ class DataNotSupportedError(Exception):
21
+ """Exception raised when data format is not supported by deserializer."""
22
+
23
+
24
+ @runtime_checkable
25
+ class DatasetDeserializer(Protocol):
26
+ def __call__(
27
+ self,
28
+ data: Any,
29
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
30
+ random_seed: int,
31
+ **data_kwargs: dict[str, Any],
32
+ ) -> dict[str, list]: ...
33
+
34
+
35
+ class DatasetDeserializerFactory(
36
+ RegistryMixin[Union["type[DatasetDeserializer]", DatasetDeserializer]],
37
+ ):
38
+ @classmethod
39
+ def deserialize(
40
+ cls,
41
+ data: Any,
42
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
43
+ random_seed: int = 42,
44
+ type_: str | None = None,
45
+ resolve_split: bool = True,
46
+ select_columns: list[str] | None = None,
47
+ remove_columns: list[str] | None = None,
48
+ **data_kwargs: dict[str, Any],
49
+ ) -> Dataset | IterableDataset:
50
+ dataset = None
51
+
52
+ if type_ is None:
53
+ errors = []
54
+ # Note: There is no priority order for the deserializers, so all deserializers
55
+ # must be mutually exclusive to ensure deterministic behavior.
56
+ for name, deserializer in cls.registry.items():
57
+ deserializer_fn: DatasetDeserializer = (
58
+ deserializer() if isinstance(deserializer, type) else deserializer
59
+ )
60
+
61
+ try:
62
+ with contextlib.suppress(DataNotSupportedError):
63
+ dataset = deserializer_fn(
64
+ data=data,
65
+ processor_factory=processor_factory,
66
+ random_seed=random_seed,
67
+ **data_kwargs,
68
+ )
69
+ except Exception as e:
70
+ errors.append(e)
71
+
72
+ if dataset is not None:
73
+ break # Found one that works. Continuing could overwrite it.
74
+
75
+ if dataset is None and len(errors) > 0:
76
+ raise DataNotSupportedError(f"data deserialization failed; {len(errors)} errors occurred while "
77
+ f"attempting to deserialize data {data}: {errors}")
78
+
79
+ elif deserializer := cls.get_registered_object(type_) is not None:
80
+ deserializer_fn: DatasetDeserializer = (
81
+ deserializer() if isinstance(deserializer, type) else deserializer
82
+ )
83
+
84
+ dataset = deserializer_fn(
85
+ data=data,
86
+ processor_factory=processor_factory,
87
+ random_seed=random_seed,
88
+ **data_kwargs,
89
+ )
90
+
91
+ if dataset is None:
92
+ raise DataNotSupportedError(
93
+ f"No suitable deserializer found for data {data} "
94
+ f"with kwargs {data_kwargs} and deserializer type {type_}."
95
+ )
96
+
97
+ if resolve_split:
98
+ dataset = resolve_dataset_split(dataset)
99
+
100
+ if select_columns is not None or remove_columns is not None:
101
+ column_names = dataset.column_names or list(next(iter(dataset)).keys())
102
+ if select_columns is not None:
103
+ remove_columns = [
104
+ col for col in column_names if col not in select_columns
105
+ ]
106
+
107
+ dataset = dataset.remove_columns(remove_columns)
108
+
109
+ return dataset
@@ -0,0 +1,222 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import pandas as pd
8
+ from datasets import Dataset, load_dataset
9
+ from transformers import PreTrainedTokenizerBase
10
+
11
+ from guidellm.data.deserializers.deserializer import (
12
+ DataNotSupportedError,
13
+ DatasetDeserializer,
14
+ DatasetDeserializerFactory,
15
+ )
16
+
17
+ __all__ = [
18
+ "ArrowFileDatasetDeserializer",
19
+ "CSVFileDatasetDeserializer",
20
+ "DBFileDatasetDeserializer",
21
+ "HDF5FileDatasetDeserializer",
22
+ "JSONFileDatasetDeserializer",
23
+ "ParquetFileDatasetDeserializer",
24
+ "TarFileDatasetDeserializer",
25
+ "TextFileDatasetDeserializer",
26
+ ]
27
+
28
+
29
+ @DatasetDeserializerFactory.register("text_file")
30
+ class TextFileDatasetDeserializer(DatasetDeserializer):
31
+ def __call__(
32
+ self,
33
+ data: Any,
34
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
35
+ random_seed: int,
36
+ **data_kwargs: dict[str, Any],
37
+ ) -> dict[str, list]:
38
+ _ = (processor_factory, random_seed) # Ignore unused args format errors
39
+
40
+ if (
41
+ not isinstance(data, (str, Path))
42
+ or not (path := Path(data)).exists()
43
+ or not path.is_file()
44
+ or path.suffix.lower() not in {".txt", ".text"}
45
+ ):
46
+ raise DataNotSupportedError(
47
+ "Unsupported data for TextFileDatasetDeserializer, "
48
+ f"expected str or Path to a local .txt or .text file, got {data}"
49
+ )
50
+
51
+ with path.open() as file:
52
+ lines = file.readlines()
53
+
54
+ return Dataset.from_dict({"text": lines}, **data_kwargs)
55
+
56
+
57
+ @DatasetDeserializerFactory.register("csv_file")
58
+ class CSVFileDatasetDeserializer(DatasetDeserializer):
59
+ def __call__(
60
+ self,
61
+ data: Any,
62
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
63
+ random_seed: int,
64
+ **data_kwargs: dict[str, Any],
65
+ ) -> dict[str, list]:
66
+ _ = (processor_factory, random_seed)
67
+ if (
68
+ not isinstance(data, (str, Path))
69
+ or not (path := Path(data)).exists()
70
+ or not path.is_file()
71
+ or path.suffix.lower() != ".csv"
72
+ ):
73
+ raise DataNotSupportedError(
74
+ "Unsupported data for CSVFileDatasetDeserializer, "
75
+ f"expected str or Path to a local .csv file, got {data}"
76
+ )
77
+
78
+ return load_dataset("csv", data_files=str(path), **data_kwargs)
79
+
80
+
81
+ @DatasetDeserializerFactory.register("json_file")
82
+ class JSONFileDatasetDeserializer(DatasetDeserializer):
83
+ def __call__(
84
+ self,
85
+ data: Any,
86
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
87
+ random_seed: int,
88
+ **data_kwargs: dict[str, Any],
89
+ ) -> dict[str, list]:
90
+ _ = (processor_factory, random_seed)
91
+ if (
92
+ not isinstance(data, (str, Path))
93
+ or not (path := Path(data)).exists()
94
+ or not path.is_file()
95
+ or path.suffix.lower() not in {".json", ".jsonl"}
96
+ ):
97
+ raise DataNotSupportedError(
98
+ f"Unsupported data for JSONFileDatasetDeserializer, "
99
+ f"expected str or Path to a local .json or .jsonl file, got {data}"
100
+ )
101
+
102
+ return load_dataset("json", data_files=str(path), **data_kwargs)
103
+
104
+
105
+ @DatasetDeserializerFactory.register("parquet_file")
106
+ class ParquetFileDatasetDeserializer(DatasetDeserializer):
107
+ def __call__(
108
+ self,
109
+ data: Any,
110
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
111
+ random_seed: int,
112
+ **data_kwargs: dict[str, Any],
113
+ ) -> dict[str, list]:
114
+ _ = (processor_factory, random_seed)
115
+ if (
116
+ not isinstance(data, (str, Path))
117
+ or not (path := Path(data)).exists()
118
+ or not path.is_file()
119
+ or path.suffix.lower() != ".parquet"
120
+ ):
121
+ raise DataNotSupportedError(
122
+ f"Unsupported data for ParquetFileDatasetDeserializer, "
123
+ f"expected str or Path to a local .parquet file, got {data}"
124
+ )
125
+
126
+ return load_dataset("parquet", data_files=str(path), **data_kwargs)
127
+
128
+
129
+ @DatasetDeserializerFactory.register("arrow_file")
130
+ class ArrowFileDatasetDeserializer(DatasetDeserializer):
131
+ def __call__(
132
+ self,
133
+ data: Any,
134
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
135
+ random_seed: int,
136
+ **data_kwargs: dict[str, Any],
137
+ ) -> dict[str, list]:
138
+ _ = (processor_factory, random_seed)
139
+ if (
140
+ not isinstance(data, (str, Path))
141
+ or not (path := Path(data)).exists()
142
+ or not path.is_file()
143
+ or path.suffix.lower() != ".arrow"
144
+ ):
145
+ raise DataNotSupportedError(
146
+ f"Unsupported data for ArrowFileDatasetDeserializer, "
147
+ f"expected str or Path to a local .arrow file, got {data}"
148
+ )
149
+
150
+ return load_dataset("arrow", data_files=str(path), **data_kwargs)
151
+
152
+
153
+ @DatasetDeserializerFactory.register("hdf5_file")
154
+ class HDF5FileDatasetDeserializer(DatasetDeserializer):
155
+ def __call__(
156
+ self,
157
+ data: Any,
158
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
159
+ random_seed: int,
160
+ **data_kwargs: dict[str, Any],
161
+ ) -> dict[str, list]:
162
+ _ = (processor_factory, random_seed)
163
+ if (
164
+ not isinstance(data, (str, Path))
165
+ or not (path := Path(data)).exists()
166
+ or not path.is_file()
167
+ or path.suffix.lower() not in {".hdf5", ".h5"}
168
+ ):
169
+ raise DataNotSupportedError(
170
+ f"Unsupported data for HDF5FileDatasetDeserializer, "
171
+ f"expected str or Path to a local .hdf5 or .h5 file, got {data}"
172
+ )
173
+
174
+ return Dataset.from_pandas(pd.read_hdf(str(path)), **data_kwargs)
175
+
176
+
177
+ @DatasetDeserializerFactory.register("db_file")
178
+ class DBFileDatasetDeserializer(DatasetDeserializer):
179
+ def __call__(
180
+ self,
181
+ data: Any,
182
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
183
+ random_seed: int,
184
+ **data_kwargs: dict[str, Any],
185
+ ) -> dict[str, list]:
186
+ _ = (processor_factory, random_seed)
187
+ if (
188
+ not isinstance(data, (str, Path))
189
+ or not (path := Path(data)).exists()
190
+ or not path.is_file()
191
+ or path.suffix.lower() != ".db"
192
+ ):
193
+ raise DataNotSupportedError(
194
+ f"Unsupported data for DBFileDatasetDeserializer, "
195
+ f"expected str or Path to a local .db file, got {data}"
196
+ )
197
+
198
+ return Dataset.from_sql(con=str(path), **data_kwargs)
199
+
200
+
201
+ @DatasetDeserializerFactory.register("tar_file")
202
+ class TarFileDatasetDeserializer(DatasetDeserializer):
203
+ def __call__(
204
+ self,
205
+ data: Any,
206
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
207
+ random_seed: int,
208
+ **data_kwargs: dict[str, Any],
209
+ ) -> dict[str, list]:
210
+ _ = (processor_factory, random_seed)
211
+ if (
212
+ not isinstance(data, (str, Path))
213
+ or not (path := Path(data)).exists()
214
+ or not path.is_file()
215
+ or path.suffix.lower() != ".tar"
216
+ ):
217
+ raise DataNotSupportedError(
218
+ f"Unsupported data for TarFileDatasetDeserializer, "
219
+ f"expected str or Path to a local .tar file, got {data}"
220
+ )
221
+
222
+ return load_dataset("webdataset", data_files=str(path), **data_kwargs)
@@ -0,0 +1,94 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ from datasets import (
8
+ Dataset,
9
+ DatasetDict,
10
+ IterableDataset,
11
+ IterableDatasetDict,
12
+ load_dataset,
13
+ load_from_disk,
14
+ )
15
+ from datasets.exceptions import (
16
+ DataFilesNotFoundError,
17
+ DatasetNotFoundError,
18
+ FileNotFoundDatasetsError,
19
+ )
20
+ from transformers import PreTrainedTokenizerBase
21
+
22
+ from guidellm.data.deserializers.deserializer import (
23
+ DataNotSupportedError,
24
+ DatasetDeserializer,
25
+ DatasetDeserializerFactory,
26
+ )
27
+
28
+ __all__ = ["HuggingFaceDatasetDeserializer"]
29
+
30
+
31
+ @DatasetDeserializerFactory.register("huggingface")
32
+ class HuggingFaceDatasetDeserializer(DatasetDeserializer):
33
+ def __call__(
34
+ self,
35
+ data: Any,
36
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
37
+ random_seed: int,
38
+ **data_kwargs: dict[str, Any],
39
+ ) -> dict[str, list]:
40
+ _ = (processor_factory, random_seed)
41
+
42
+ if isinstance(
43
+ data, Dataset | IterableDataset | DatasetDict | IterableDatasetDict
44
+ ):
45
+ return data
46
+
47
+ load_error = None
48
+
49
+ if (
50
+ isinstance(data, str | Path)
51
+ and (path := Path(data)).exists()
52
+ and ((path.is_file() and path.suffix == ".py") or path.is_dir())
53
+ ):
54
+ # Handle python script or nested python script in a directory
55
+ try:
56
+ return load_dataset(str(data), **data_kwargs)
57
+ except (
58
+ FileNotFoundDatasetsError,
59
+ DatasetNotFoundError,
60
+ DataFilesNotFoundError,
61
+ ) as err:
62
+ load_error = err
63
+ except Exception: # noqa: BLE001
64
+ # Try loading as a local dataset directory next
65
+ try:
66
+ return load_from_disk(str(data), **data_kwargs)
67
+ except (
68
+ FileNotFoundDatasetsError,
69
+ DatasetNotFoundError,
70
+ DataFilesNotFoundError,
71
+ ) as err2:
72
+ load_error = err2
73
+
74
+ try:
75
+ # Handle dataset identifier from the Hugging Face Hub
76
+ return load_dataset(str(data), **data_kwargs)
77
+ except (
78
+ FileNotFoundDatasetsError,
79
+ DatasetNotFoundError,
80
+ DataFilesNotFoundError,
81
+ ) as err:
82
+ load_error = err
83
+
84
+ not_supported = DataNotSupportedError(
85
+ "Unsupported data for HuggingFaceDatasetDeserializer, "
86
+ "expected Dataset, IterableDataset, DatasetDict, IterableDatasetDict, "
87
+ "str or Path to a local dataset directory or a local .py dataset script, "
88
+ f"got {data} and HF load error: {load_error}"
89
+ )
90
+
91
+ if load_error is not None:
92
+ raise not_supported from load_error
93
+ else:
94
+ raise not_supported