guidellm 0.4.0a21__py3-none-any.whl → 0.4.0a155__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guidellm might be problematic. Click here for more details.

Files changed (116) hide show
  1. guidellm/__init__.py +5 -2
  2. guidellm/__main__.py +451 -252
  3. guidellm/backends/__init__.py +33 -0
  4. guidellm/backends/backend.py +110 -0
  5. guidellm/backends/openai.py +355 -0
  6. guidellm/backends/response_handlers.py +455 -0
  7. guidellm/benchmark/__init__.py +53 -39
  8. guidellm/benchmark/benchmarker.py +148 -317
  9. guidellm/benchmark/entrypoints.py +466 -128
  10. guidellm/benchmark/output.py +517 -771
  11. guidellm/benchmark/profile.py +580 -280
  12. guidellm/benchmark/progress.py +568 -549
  13. guidellm/benchmark/scenarios/__init__.py +40 -0
  14. guidellm/benchmark/scenarios/chat.json +6 -0
  15. guidellm/benchmark/scenarios/rag.json +6 -0
  16. guidellm/benchmark/schemas.py +2085 -0
  17. guidellm/data/__init__.py +28 -4
  18. guidellm/data/collators.py +16 -0
  19. guidellm/data/deserializers/__init__.py +53 -0
  20. guidellm/data/deserializers/deserializer.py +109 -0
  21. guidellm/data/deserializers/file.py +222 -0
  22. guidellm/data/deserializers/huggingface.py +94 -0
  23. guidellm/data/deserializers/memory.py +192 -0
  24. guidellm/data/deserializers/synthetic.py +346 -0
  25. guidellm/data/loaders.py +145 -0
  26. guidellm/data/preprocessors/__init__.py +25 -0
  27. guidellm/data/preprocessors/formatters.py +412 -0
  28. guidellm/data/preprocessors/mappers.py +198 -0
  29. guidellm/data/preprocessors/preprocessor.py +29 -0
  30. guidellm/data/processor.py +30 -0
  31. guidellm/data/schemas.py +13 -0
  32. guidellm/data/utils/__init__.py +10 -0
  33. guidellm/data/utils/dataset.py +94 -0
  34. guidellm/data/utils/functions.py +18 -0
  35. guidellm/extras/__init__.py +4 -0
  36. guidellm/extras/audio.py +215 -0
  37. guidellm/extras/vision.py +242 -0
  38. guidellm/logger.py +2 -2
  39. guidellm/mock_server/__init__.py +8 -0
  40. guidellm/mock_server/config.py +84 -0
  41. guidellm/mock_server/handlers/__init__.py +17 -0
  42. guidellm/mock_server/handlers/chat_completions.py +280 -0
  43. guidellm/mock_server/handlers/completions.py +280 -0
  44. guidellm/mock_server/handlers/tokenizer.py +142 -0
  45. guidellm/mock_server/models.py +510 -0
  46. guidellm/mock_server/server.py +168 -0
  47. guidellm/mock_server/utils.py +302 -0
  48. guidellm/preprocess/dataset.py +23 -26
  49. guidellm/presentation/builder.py +2 -2
  50. guidellm/presentation/data_models.py +25 -21
  51. guidellm/presentation/injector.py +2 -3
  52. guidellm/scheduler/__init__.py +65 -26
  53. guidellm/scheduler/constraints.py +1035 -0
  54. guidellm/scheduler/environments.py +252 -0
  55. guidellm/scheduler/scheduler.py +140 -368
  56. guidellm/scheduler/schemas.py +272 -0
  57. guidellm/scheduler/strategies.py +519 -0
  58. guidellm/scheduler/worker.py +391 -420
  59. guidellm/scheduler/worker_group.py +707 -0
  60. guidellm/schemas/__init__.py +31 -0
  61. guidellm/schemas/info.py +159 -0
  62. guidellm/schemas/request.py +216 -0
  63. guidellm/schemas/response.py +119 -0
  64. guidellm/schemas/stats.py +228 -0
  65. guidellm/{config.py → settings.py} +32 -21
  66. guidellm/utils/__init__.py +95 -8
  67. guidellm/utils/auto_importer.py +98 -0
  68. guidellm/utils/cli.py +46 -2
  69. guidellm/utils/console.py +183 -0
  70. guidellm/utils/encoding.py +778 -0
  71. guidellm/utils/functions.py +134 -0
  72. guidellm/utils/hf_datasets.py +1 -2
  73. guidellm/utils/hf_transformers.py +4 -4
  74. guidellm/utils/imports.py +9 -0
  75. guidellm/utils/messaging.py +1118 -0
  76. guidellm/utils/mixins.py +115 -0
  77. guidellm/utils/pydantic_utils.py +411 -0
  78. guidellm/utils/random.py +3 -4
  79. guidellm/utils/registry.py +220 -0
  80. guidellm/utils/singleton.py +133 -0
  81. guidellm/{objects → utils}/statistics.py +341 -247
  82. guidellm/utils/synchronous.py +159 -0
  83. guidellm/utils/text.py +163 -50
  84. guidellm/utils/typing.py +41 -0
  85. guidellm/version.py +1 -1
  86. {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a155.dist-info}/METADATA +33 -10
  87. guidellm-0.4.0a155.dist-info/RECORD +96 -0
  88. guidellm/backend/__init__.py +0 -23
  89. guidellm/backend/backend.py +0 -259
  90. guidellm/backend/openai.py +0 -705
  91. guidellm/backend/response.py +0 -136
  92. guidellm/benchmark/aggregator.py +0 -760
  93. guidellm/benchmark/benchmark.py +0 -837
  94. guidellm/benchmark/scenario.py +0 -104
  95. guidellm/data/prideandprejudice.txt.gz +0 -0
  96. guidellm/dataset/__init__.py +0 -22
  97. guidellm/dataset/creator.py +0 -213
  98. guidellm/dataset/entrypoints.py +0 -42
  99. guidellm/dataset/file.py +0 -92
  100. guidellm/dataset/hf_datasets.py +0 -62
  101. guidellm/dataset/in_memory.py +0 -132
  102. guidellm/dataset/synthetic.py +0 -287
  103. guidellm/objects/__init__.py +0 -18
  104. guidellm/objects/pydantic.py +0 -89
  105. guidellm/request/__init__.py +0 -18
  106. guidellm/request/loader.py +0 -284
  107. guidellm/request/request.py +0 -79
  108. guidellm/request/types.py +0 -10
  109. guidellm/scheduler/queues.py +0 -25
  110. guidellm/scheduler/result.py +0 -155
  111. guidellm/scheduler/strategy.py +0 -495
  112. guidellm-0.4.0a21.dist-info/RECORD +0 -62
  113. {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a155.dist-info}/WHEEL +0 -0
  114. {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a155.dist-info}/entry_points.txt +0 -0
  115. {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a155.dist-info}/licenses/LICENSE +0 -0
  116. {guidellm-0.4.0a21.dist-info → guidellm-0.4.0a155.dist-info}/top_level.txt +0 -0
@@ -1,287 +0,0 @@
1
- import json
2
- import random
3
- from collections.abc import Iterable, Iterator
4
- from itertools import cycle
5
- from pathlib import Path
6
- from typing import Any, Literal, Optional, Union
7
-
8
- import yaml
9
- from datasets import (
10
- Dataset,
11
- DatasetDict,
12
- IterableDataset,
13
- IterableDatasetDict,
14
- )
15
- from pydantic import BaseModel, Field
16
- from transformers import PreTrainedTokenizerBase # type: ignore[import]
17
-
18
- from guidellm.dataset.creator import ColumnInputTypes, DatasetCreator
19
- from guidellm.utils import EndlessTextCreator, IntegerRangeSampler, check_load_processor
20
-
21
- __all__ = [
22
- "SyntheticDatasetConfig",
23
- "SyntheticDatasetCreator",
24
- "SyntheticTextItemsGenerator",
25
- ]
26
-
27
-
28
- class SyntheticDatasetConfig(BaseModel):
29
- prefix_tokens: int = Field(
30
- description="The number of shared prefix tokens to prepend to each prompt.",
31
- ge=0,
32
- default=0,
33
- )
34
- prompt_tokens: int = Field(
35
- description="The average number of text tokens generated for prompts.",
36
- gt=0,
37
- )
38
- prompt_tokens_stdev: Optional[int] = Field(
39
- description="The standard deviation of the tokens generated for prompts.",
40
- gt=0,
41
- default=None,
42
- )
43
- prompt_tokens_min: Optional[int] = Field(
44
- description="The minimum number of text tokens generated for prompts.",
45
- gt=0,
46
- default=None,
47
- )
48
- prompt_tokens_max: Optional[int] = Field(
49
- description="The maximum number of text tokens generated for prompts.",
50
- gt=0,
51
- default=None,
52
- )
53
- output_tokens: int = Field(
54
- description="The average number of text tokens generated for outputs.",
55
- gt=0,
56
- )
57
- output_tokens_stdev: Optional[int] = Field(
58
- description="The standard deviation of the tokens generated for outputs.",
59
- gt=0,
60
- default=None,
61
- )
62
- output_tokens_min: Optional[int] = Field(
63
- description="The minimum number of text tokens generated for outputs.",
64
- gt=0,
65
- default=None,
66
- )
67
- output_tokens_max: Optional[int] = Field(
68
- description="The maximum number of text tokens generated for outputs.",
69
- gt=0,
70
- default=None,
71
- )
72
- samples: int = Field(
73
- description="The number of samples to generate for the dataset.",
74
- gt=0,
75
- default=1000,
76
- )
77
- source: str = Field(
78
- description="The source of the text data to be used for generation.",
79
- default="data:prideandprejudice.txt.gz",
80
- )
81
-
82
- @staticmethod
83
- def parse_str(data: Union[str, Path]) -> "SyntheticDatasetConfig":
84
- if (
85
- isinstance(data, Path)
86
- or data.strip().endswith(".config")
87
- or data.strip().endswith(".yaml")
88
- ):
89
- return SyntheticDatasetConfig.parse_config_file(data)
90
-
91
- if data.strip().startswith("{"):
92
- return SyntheticDatasetConfig.parse_json(data)
93
-
94
- if data.count("=") > 1:
95
- return SyntheticDatasetConfig.parse_key_value_pairs(data)
96
-
97
- raise ValueError(
98
- f"Unsupported data format. Expected JSON or key-value pairs, got {data}"
99
- )
100
-
101
- @staticmethod
102
- def parse_json(data: str) -> "SyntheticDatasetConfig":
103
- config_dict = json.loads(data.strip())
104
-
105
- return SyntheticDatasetConfig(**config_dict)
106
-
107
- @staticmethod
108
- def parse_key_value_pairs(data: str) -> "SyntheticDatasetConfig":
109
- config_dict = {}
110
- items = data.strip().split(",")
111
- for item in items:
112
- key, value = item.split("=")
113
- config_dict[key.strip()] = (
114
- int(value.strip()) if value.strip().isnumeric() else value.strip()
115
- )
116
-
117
- return SyntheticDatasetConfig(**config_dict) # type: ignore[arg-type]
118
-
119
- @staticmethod
120
- def parse_config_file(data: Union[str, Path]) -> "SyntheticDatasetConfig":
121
- with Path(data).open("r") as file:
122
- config_dict = yaml.safe_load(file)
123
-
124
- return SyntheticDatasetConfig(**config_dict)
125
-
126
-
127
- class SyntheticTextItemsGenerator(
128
- Iterable[
129
- dict[
130
- Literal["prompt", "prompt_tokens_count", "output_tokens_count"],
131
- Union[str, int],
132
- ]
133
- ]
134
- ):
135
- def __init__(
136
- self,
137
- config: SyntheticDatasetConfig,
138
- processor: PreTrainedTokenizerBase,
139
- random_seed: int,
140
- ):
141
- self.config = config
142
- self.processor = processor
143
- self.random_seed = random_seed
144
- self.text_creator = EndlessTextCreator(
145
- data=config.source,
146
- )
147
-
148
- def __iter__(
149
- self,
150
- ) -> Iterator[
151
- dict[
152
- Literal["prompt", "prompt_tokens_count", "output_tokens_count"],
153
- Union[str, int],
154
- ]
155
- ]:
156
- prompt_tokens_sampler = IntegerRangeSampler(
157
- average=self.config.prompt_tokens,
158
- variance=self.config.prompt_tokens_stdev,
159
- min_value=self.config.prompt_tokens_min,
160
- max_value=self.config.prompt_tokens_max,
161
- random_seed=self.random_seed,
162
- )
163
- output_tokens_sampler = IntegerRangeSampler(
164
- average=self.config.output_tokens,
165
- variance=self.config.output_tokens_stdev,
166
- min_value=self.config.output_tokens_min,
167
- max_value=self.config.output_tokens_max,
168
- random_seed=self.random_seed + 1, # ensure diff dist from prompts
169
- )
170
- # ensure diff distribution from output tokens
171
- rand = random.Random(self.random_seed + 2) # noqa: S311
172
- unique_prefix_iter = cycle(self.processor.get_vocab().values())
173
-
174
- prefix_index = rand.randint(0, len(self.text_creator.words))
175
- prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index)
176
-
177
- for _, prompt_tokens, output_tokens in zip(
178
- range(self.config.samples),
179
- prompt_tokens_sampler,
180
- output_tokens_sampler,
181
- ):
182
- start_index = rand.randint(0, len(self.text_creator.words))
183
- prompt_text = self.processor.decode(
184
- prefix_tokens
185
- + self._create_prompt(
186
- prompt_tokens, start_index, next(unique_prefix_iter)
187
- ),
188
- skip_special_tokens=True,
189
- )
190
- yield {
191
- "prompt": prompt_text,
192
- "prompt_tokens_count": self.config.prefix_tokens + prompt_tokens,
193
- "output_tokens_count": output_tokens,
194
- }
195
-
196
- def _create_prompt(
197
- self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None
198
- ) -> list[int]:
199
- if prompt_tokens <= 0:
200
- return []
201
-
202
- left = start_index
203
- right = start_index + 4 * prompt_tokens
204
- start_tokens = [unique_prefix] if unique_prefix else []
205
-
206
- while left < right:
207
- mid = (left + right) // 2
208
- test_prompt = self.text_creator.create_text(start_index, mid - start_index)
209
- test_tokens = start_tokens + self.processor.encode(test_prompt)
210
-
211
- if len(test_tokens) == prompt_tokens:
212
- return test_tokens
213
- elif len(test_tokens) < prompt_tokens:
214
- left = mid + 1
215
- else:
216
- right = mid
217
-
218
- final_text = self.text_creator.create_text(start_index, left - start_index)
219
- return start_tokens + self.processor.encode(final_text)
220
-
221
-
222
- class SyntheticDatasetCreator(DatasetCreator):
223
- @classmethod
224
- def is_supported(
225
- cls,
226
- data: Any,
227
- data_args: Optional[dict[str, Any]], # noqa: ARG003
228
- ) -> bool:
229
- if (
230
- isinstance(data, Path)
231
- and data.exists()
232
- and data.suffix in {".config", ".yaml"}
233
- ):
234
- return True
235
-
236
- if isinstance(data, str):
237
- data_str: str = data.strip()
238
- if (
239
- data_str.startswith("{")
240
- or data_str.count("=") > 1
241
- or data_str.endswith((".config", ".yaml"))
242
- ):
243
- return True
244
-
245
- return False
246
-
247
- @classmethod
248
- def handle_create(
249
- cls,
250
- data: Any,
251
- data_args: Optional[dict[str, Any]],
252
- processor: Optional[Union[str, Path, PreTrainedTokenizerBase]],
253
- processor_args: Optional[dict[str, Any]],
254
- random_seed: int,
255
- ) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]:
256
- processor = check_load_processor(
257
- processor,
258
- processor_args,
259
- error_msg=(
260
- "Processor/tokenizer required for synthetic dataset generation."
261
- ),
262
- )
263
-
264
- config = SyntheticDatasetConfig.parse_str(data)
265
- generator = SyntheticTextItemsGenerator(config, processor, random_seed)
266
- items = list(generator)
267
-
268
- return Dataset.from_list(items, **(data_args or {}))
269
-
270
- @classmethod
271
- def extract_args_column_mappings(
272
- cls,
273
- data_args: Optional[dict[str, Any]],
274
- ) -> dict[ColumnInputTypes, str]:
275
- data_args_columns = super().extract_args_column_mappings(data_args)
276
-
277
- if data_args_columns:
278
- raise ValueError(
279
- f"Column mappings are not supported for synthetic datasets. "
280
- f"Got {data_args_columns}"
281
- )
282
-
283
- return {
284
- "prompt_column": "prompt",
285
- "prompt_tokens_count_column": "prompt_tokens_count",
286
- "output_tokens_count_column": "output_tokens_count",
287
- }
@@ -1,18 +0,0 @@
1
- from .pydantic import StandardBaseModel, StatusBreakdown
2
- from .statistics import (
3
- DistributionSummary,
4
- Percentiles,
5
- RunningStats,
6
- StatusDistributionSummary,
7
- TimeRunningStats,
8
- )
9
-
10
- __all__ = [
11
- "DistributionSummary",
12
- "Percentiles",
13
- "RunningStats",
14
- "StandardBaseModel",
15
- "StatusBreakdown",
16
- "StatusDistributionSummary",
17
- "TimeRunningStats",
18
- ]
@@ -1,89 +0,0 @@
1
- import json
2
- from pathlib import Path
3
- from typing import Any, Generic, Optional, TypeVar
4
-
5
- import yaml
6
- from loguru import logger
7
- from pydantic import BaseModel, ConfigDict, Field
8
-
9
- __all__ = ["StandardBaseModel", "StatusBreakdown"]
10
-
11
- T = TypeVar("T", bound="StandardBaseModel")
12
-
13
-
14
- class StandardBaseModel(BaseModel):
15
- """
16
- A base class for Pydantic models throughout GuideLLM enabling standard
17
- configuration and logging.
18
- """
19
-
20
- model_config = ConfigDict(
21
- extra="ignore",
22
- use_enum_values=True,
23
- validate_assignment=True,
24
- from_attributes=True,
25
- )
26
-
27
- def __init__(self, /, **data: Any) -> None:
28
- super().__init__(**data)
29
- logger.debug(
30
- "Initialized new instance of {} with data: {}",
31
- self.__class__.__name__,
32
- data,
33
- )
34
-
35
- @classmethod
36
- def get_default(cls: type[T], field: str) -> Any:
37
- """Get default values for model fields"""
38
- return cls.model_fields[field].default
39
-
40
- @classmethod
41
- def from_file(cls: type[T], filename: Path, overrides: Optional[dict] = None) -> T:
42
- """
43
- Attempt to create a new instance of the model using
44
- data loaded from json or yaml file.
45
- """
46
- try:
47
- with filename.open() as f:
48
- if str(filename).endswith(".json"):
49
- data = json.load(f)
50
- else: # Assume everything else is yaml
51
- data = yaml.safe_load(f)
52
- except (json.JSONDecodeError, yaml.YAMLError) as e:
53
- logger.error(f"Failed to parse {filename} as type {cls.__name__}")
54
- raise ValueError(f"Error when parsing file: {filename}") from e
55
-
56
- data.update(overrides)
57
- return cls.model_validate(data)
58
-
59
-
60
- SuccessfulT = TypeVar("SuccessfulT")
61
- ErroredT = TypeVar("ErroredT")
62
- IncompleteT = TypeVar("IncompleteT")
63
- TotalT = TypeVar("TotalT")
64
-
65
-
66
- class StatusBreakdown(BaseModel, Generic[SuccessfulT, ErroredT, IncompleteT, TotalT]):
67
- """
68
- A base class for Pydantic models that are separated by statuses including
69
- successful, incomplete, and errored. It additionally enables the inclusion
70
- of total, which is intended as the combination of all statuses.
71
- Total may or may not be used depending on if it duplicates information.
72
- """
73
-
74
- successful: SuccessfulT = Field(
75
- description="The results with a successful status.",
76
- default=None, # type: ignore[assignment]
77
- )
78
- errored: ErroredT = Field(
79
- description="The results with an errored status.",
80
- default=None, # type: ignore[assignment]
81
- )
82
- incomplete: IncompleteT = Field(
83
- description="The results with an incomplete status.",
84
- default=None, # type: ignore[assignment]
85
- )
86
- total: TotalT = Field(
87
- description="The combination of all statuses.",
88
- default=None, # type: ignore[assignment]
89
- )
@@ -1,18 +0,0 @@
1
- from .loader import (
2
- GenerativeRequestLoader,
3
- GenerativeRequestLoaderDescription,
4
- RequestLoader,
5
- RequestLoaderDescription,
6
- )
7
- from .request import GenerationRequest
8
- from .types import RequestT, ResponseT
9
-
10
- __all__ = [
11
- "GenerationRequest",
12
- "GenerativeRequestLoader",
13
- "GenerativeRequestLoaderDescription",
14
- "RequestLoader",
15
- "RequestLoaderDescription",
16
- "RequestT",
17
- "ResponseT",
18
- ]
@@ -1,284 +0,0 @@
1
- from abc import abstractmethod
2
- from collections.abc import Iterable, Iterator
3
- from pathlib import Path
4
- from typing import (
5
- Any,
6
- Literal,
7
- Optional,
8
- Union,
9
- )
10
-
11
- from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
12
- from transformers import PreTrainedTokenizerBase # type: ignore[import]
13
-
14
- from guidellm.config import settings
15
- from guidellm.dataset import ColumnInputTypes, load_dataset
16
- from guidellm.objects import StandardBaseModel
17
- from guidellm.request.request import GenerationRequest
18
-
19
- __all__ = [
20
- "GenerativeRequestLoader",
21
- "GenerativeRequestLoaderDescription",
22
- "RequestLoader",
23
- "RequestLoaderDescription",
24
- ]
25
-
26
-
27
- class RequestLoaderDescription(StandardBaseModel):
28
- type_: Literal["request_loader"] = "request_loader"
29
-
30
-
31
- class RequestLoader(Iterable):
32
- @abstractmethod
33
- def __iter__(self) -> Iterator: ...
34
-
35
- @abstractmethod
36
- def __len__(self) -> int: ...
37
-
38
- @property
39
- @abstractmethod
40
- def description(self) -> RequestLoaderDescription: ...
41
-
42
-
43
- class GenerativeRequestLoaderDescription(RequestLoaderDescription):
44
- type_: Literal["generative_request_loader"] = "generative_request_loader" # type: ignore[assignment]
45
- data: str
46
- data_args: Optional[dict[str, Any]]
47
- processor: str
48
- processor_args: Optional[dict[str, Any]]
49
-
50
-
51
- class GenerativeRequestLoader(RequestLoader):
52
- DEFAULT_PROMPT_COLUMNS = [
53
- "prompt",
54
- "prompts",
55
- "instruction",
56
- "instructions",
57
- "question",
58
- "questions",
59
- "input",
60
- "inputs",
61
- "context",
62
- "content",
63
- "conversation",
64
- "conversations",
65
- "turn",
66
- "turns",
67
- "text",
68
- ]
69
-
70
- def __init__(
71
- self,
72
- data: Union[
73
- str,
74
- Path,
75
- Iterable[Union[str, dict[str, Any]]],
76
- Dataset,
77
- DatasetDict,
78
- IterableDataset,
79
- IterableDatasetDict,
80
- ],
81
- data_args: Optional[dict[str, Any]],
82
- processor: Optional[Union[str, Path, PreTrainedTokenizerBase]],
83
- processor_args: Optional[dict[str, Any]],
84
- shuffle: bool = True,
85
- iter_type: Literal["finite", "infinite"] = "finite",
86
- random_seed: int = 42,
87
- ):
88
- self.data = data
89
- self.data_args = data_args
90
- dataset, args_column_mappings = load_dataset(
91
- data,
92
- data_args,
93
- processor,
94
- processor_args,
95
- random_seed,
96
- )
97
- self.dataset = dataset
98
- self.processor = processor
99
- self.processor_args = processor_args
100
- self.shuffle = shuffle
101
- self.iter_type = iter_type
102
- self.random_seed = random_seed
103
-
104
- self.column_mappings = self._create_column_mappings(args_column_mappings)
105
- self.preserve_iter_state = iter_type == "infinite" # ensure no caching requests
106
- self._preserved_iter = None
107
-
108
- def __iter__(self) -> Iterator[GenerationRequest]:
109
- scope_create_count = 0
110
-
111
- while (dataset_iter := self._get_dataset_iter(scope_create_count)) is not None:
112
- scope_create_count += 1
113
-
114
- for item in dataset_iter:
115
- yield self._create_request(item)
116
-
117
- self._preserved_iter = None
118
-
119
- def __len__(self) -> int:
120
- if self.iter_type == "finite":
121
- return self.num_unique_items()
122
-
123
- raise ValueError(f"Unable to determine length of dataset: {self.data}")
124
-
125
- @property
126
- def description(self) -> GenerativeRequestLoaderDescription:
127
- return GenerativeRequestLoaderDescription(
128
- data=str(self.data),
129
- data_args=self.data_args,
130
- processor=str(self.processor),
131
- processor_args=self.processor_args,
132
- )
133
-
134
- def num_unique_items(self, raise_err: bool = True) -> int:
135
- try:
136
- return len(self.dataset)
137
- except Exception: # noqa: BLE001, S110
138
- pass
139
-
140
- dataset_size = self.dataset.info.dataset_size
141
- if dataset_size is not None:
142
- return dataset_size
143
-
144
- if raise_err:
145
- raise ValueError("Unable to determine number of items in the dataset")
146
-
147
- return -1
148
-
149
- def _create_column_mappings(
150
- self,
151
- args_column_mappings: dict[ColumnInputTypes, str],
152
- ) -> dict[ColumnInputTypes, str]:
153
- column_mappings: dict[ColumnInputTypes, str] = {}
154
-
155
- if "text_column" in args_column_mappings:
156
- column_mappings["prompt_column"] = args_column_mappings["text_column"]
157
- else:
158
- column_mappings["prompt_column"] = self._extract_text_column()
159
-
160
- if "prompt_tokens_count_column" in args_column_mappings:
161
- column_mappings["prompt_tokens_count_column"] = args_column_mappings[
162
- "prompt_tokens_count_column"
163
- ]
164
- elif prompt_tokens_count_column := self._extract_prompt_tokens_count_column():
165
- column_mappings["prompt_tokens_count_column"] = prompt_tokens_count_column
166
-
167
- if "output_tokens_count_column" in args_column_mappings:
168
- column_mappings["output_tokens_count_column"] = args_column_mappings[
169
- "output_tokens_count_column"
170
- ]
171
- elif output_tokens_count_column := self._extract_output_tokens_count_column():
172
- column_mappings["output_tokens_count_column"] = output_tokens_count_column
173
-
174
- return column_mappings
175
-
176
- def _extract_text_column(self) -> str:
177
- column_names = self._dataset_columns(
178
- err_msg=(
179
- "Unable to determine text column from dataset and it is required. "
180
- "To specify the text column, set the 'text_column' key in the "
181
- "'data_args' dictionary."
182
- )
183
- )
184
-
185
- if not column_names:
186
- raise ValueError(
187
- "Unable to determine text column from dataset and it is required. "
188
- "To specify the text column, set the 'text_column' key in the "
189
- "'data_args' dictionary."
190
- )
191
-
192
- if len(column_names) == 1:
193
- return column_names[0]
194
-
195
- for def_column in self.DEFAULT_PROMPT_COLUMNS:
196
- if def_column in column_names:
197
- return def_column
198
-
199
- raise ValueError(
200
- f"Unable to determine text column from dataset columns: {column_names}. "
201
- "To specify the text column, set the 'text_column' key in the "
202
- "'data_args' dictionary."
203
- )
204
-
205
- def _extract_prompt_tokens_count_column(self) -> Optional[str]:
206
- column_names = self._dataset_columns()
207
-
208
- if column_names and "prompt_tokens_count" in column_names:
209
- return "prompt_tokens_count"
210
-
211
- if column_names and "prompt_tokens" in column_names:
212
- return "prompt_tokens"
213
-
214
- return None
215
-
216
- def _extract_output_tokens_count_column(self) -> Optional[str]:
217
- column_names = self._dataset_columns()
218
-
219
- if column_names and "output_tokens_count" in column_names:
220
- return "output_tokens_count"
221
-
222
- if column_names and "output_tokens" in column_names:
223
- return "output_tokens"
224
-
225
- return None
226
-
227
- def _dataset_columns(self, err_msg: Optional[str] = None) -> Optional[list[str]]:
228
- try:
229
- column_names = self.dataset.column_names
230
-
231
- if not column_names and err_msg:
232
- raise ValueError(f"No column names found in dataset: {self.data}")
233
- except Exception as err:
234
- if err_msg:
235
- raise ValueError(err_msg) from err
236
-
237
- column_names = None
238
-
239
- return column_names
240
-
241
- def _get_dataset_iter(
242
- self, scope_create_count: int
243
- ) -> Optional[Iterator[dict[str, Any]]]:
244
- if scope_create_count > 0 and self.iter_type != "infinite":
245
- return None
246
-
247
- if self.preserve_iter_state and self._preserved_iter is not None:
248
- return self._preserved_iter
249
-
250
- dataset = (
251
- self.dataset
252
- if not self.shuffle
253
- else self.dataset.shuffle(seed=self.random_seed)
254
- )
255
-
256
- dataset_iter = iter(dataset)
257
-
258
- if self.preserve_iter_state:
259
- self._preserved_iter = dataset_iter
260
-
261
- return dataset_iter
262
-
263
- def _create_request(self, item: dict[str, Any]) -> GenerationRequest:
264
- prompt_tokens = (
265
- item[self.column_mappings["prompt_tokens_count_column"]]
266
- if "prompt_tokens_count_column" in self.column_mappings
267
- else None
268
- )
269
- output_tokens = (
270
- item[self.column_mappings["output_tokens_count_column"]]
271
- if "output_tokens_count_column" in self.column_mappings
272
- else None
273
- )
274
-
275
- return GenerationRequest(
276
- request_type=settings.preferred_route,
277
- content=item[self.column_mappings["prompt_column"]],
278
- stats=(
279
- {"prompt_tokens": prompt_tokens} if prompt_tokens is not None else {}
280
- ),
281
- constraints=(
282
- {"output_tokens": output_tokens} if output_tokens is not None else {}
283
- ),
284
- )