guidellm 0.3.1__py3-none-any.whl → 0.6.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. guidellm/__init__.py +5 -2
  2. guidellm/__main__.py +524 -255
  3. guidellm/backends/__init__.py +33 -0
  4. guidellm/backends/backend.py +109 -0
  5. guidellm/backends/openai.py +340 -0
  6. guidellm/backends/response_handlers.py +428 -0
  7. guidellm/benchmark/__init__.py +69 -39
  8. guidellm/benchmark/benchmarker.py +160 -316
  9. guidellm/benchmark/entrypoints.py +560 -127
  10. guidellm/benchmark/outputs/__init__.py +24 -0
  11. guidellm/benchmark/outputs/console.py +633 -0
  12. guidellm/benchmark/outputs/csv.py +721 -0
  13. guidellm/benchmark/outputs/html.py +473 -0
  14. guidellm/benchmark/outputs/output.py +169 -0
  15. guidellm/benchmark/outputs/serialized.py +69 -0
  16. guidellm/benchmark/profiles.py +718 -0
  17. guidellm/benchmark/progress.py +553 -556
  18. guidellm/benchmark/scenarios/__init__.py +40 -0
  19. guidellm/benchmark/scenarios/chat.json +6 -0
  20. guidellm/benchmark/scenarios/rag.json +6 -0
  21. guidellm/benchmark/schemas/__init__.py +66 -0
  22. guidellm/benchmark/schemas/base.py +402 -0
  23. guidellm/benchmark/schemas/generative/__init__.py +55 -0
  24. guidellm/benchmark/schemas/generative/accumulator.py +841 -0
  25. guidellm/benchmark/schemas/generative/benchmark.py +163 -0
  26. guidellm/benchmark/schemas/generative/entrypoints.py +381 -0
  27. guidellm/benchmark/schemas/generative/metrics.py +927 -0
  28. guidellm/benchmark/schemas/generative/report.py +158 -0
  29. guidellm/data/__init__.py +34 -4
  30. guidellm/data/builders.py +541 -0
  31. guidellm/data/collators.py +16 -0
  32. guidellm/data/config.py +120 -0
  33. guidellm/data/deserializers/__init__.py +49 -0
  34. guidellm/data/deserializers/deserializer.py +141 -0
  35. guidellm/data/deserializers/file.py +223 -0
  36. guidellm/data/deserializers/huggingface.py +94 -0
  37. guidellm/data/deserializers/memory.py +194 -0
  38. guidellm/data/deserializers/synthetic.py +246 -0
  39. guidellm/data/entrypoints.py +52 -0
  40. guidellm/data/loaders.py +190 -0
  41. guidellm/data/preprocessors/__init__.py +27 -0
  42. guidellm/data/preprocessors/formatters.py +410 -0
  43. guidellm/data/preprocessors/mappers.py +196 -0
  44. guidellm/data/preprocessors/preprocessor.py +30 -0
  45. guidellm/data/processor.py +29 -0
  46. guidellm/data/schemas.py +175 -0
  47. guidellm/data/utils/__init__.py +6 -0
  48. guidellm/data/utils/dataset.py +94 -0
  49. guidellm/extras/__init__.py +4 -0
  50. guidellm/extras/audio.py +220 -0
  51. guidellm/extras/vision.py +242 -0
  52. guidellm/logger.py +2 -2
  53. guidellm/mock_server/__init__.py +8 -0
  54. guidellm/mock_server/config.py +84 -0
  55. guidellm/mock_server/handlers/__init__.py +17 -0
  56. guidellm/mock_server/handlers/chat_completions.py +280 -0
  57. guidellm/mock_server/handlers/completions.py +280 -0
  58. guidellm/mock_server/handlers/tokenizer.py +142 -0
  59. guidellm/mock_server/models.py +510 -0
  60. guidellm/mock_server/server.py +238 -0
  61. guidellm/mock_server/utils.py +302 -0
  62. guidellm/scheduler/__init__.py +69 -26
  63. guidellm/scheduler/constraints/__init__.py +49 -0
  64. guidellm/scheduler/constraints/constraint.py +325 -0
  65. guidellm/scheduler/constraints/error.py +411 -0
  66. guidellm/scheduler/constraints/factory.py +182 -0
  67. guidellm/scheduler/constraints/request.py +312 -0
  68. guidellm/scheduler/constraints/saturation.py +722 -0
  69. guidellm/scheduler/environments.py +252 -0
  70. guidellm/scheduler/scheduler.py +137 -368
  71. guidellm/scheduler/schemas.py +358 -0
  72. guidellm/scheduler/strategies.py +617 -0
  73. guidellm/scheduler/worker.py +413 -419
  74. guidellm/scheduler/worker_group.py +712 -0
  75. guidellm/schemas/__init__.py +65 -0
  76. guidellm/schemas/base.py +417 -0
  77. guidellm/schemas/info.py +188 -0
  78. guidellm/schemas/request.py +235 -0
  79. guidellm/schemas/request_stats.py +349 -0
  80. guidellm/schemas/response.py +124 -0
  81. guidellm/schemas/statistics.py +1018 -0
  82. guidellm/{config.py → settings.py} +31 -24
  83. guidellm/utils/__init__.py +71 -8
  84. guidellm/utils/auto_importer.py +98 -0
  85. guidellm/utils/cli.py +132 -5
  86. guidellm/utils/console.py +566 -0
  87. guidellm/utils/encoding.py +778 -0
  88. guidellm/utils/functions.py +159 -0
  89. guidellm/utils/hf_datasets.py +1 -2
  90. guidellm/utils/hf_transformers.py +4 -4
  91. guidellm/utils/imports.py +9 -0
  92. guidellm/utils/messaging.py +1118 -0
  93. guidellm/utils/mixins.py +115 -0
  94. guidellm/utils/random.py +3 -4
  95. guidellm/utils/registry.py +220 -0
  96. guidellm/utils/singleton.py +133 -0
  97. guidellm/utils/synchronous.py +159 -0
  98. guidellm/utils/text.py +163 -50
  99. guidellm/utils/typing.py +41 -0
  100. guidellm/version.py +2 -2
  101. guidellm-0.6.0a5.dist-info/METADATA +364 -0
  102. guidellm-0.6.0a5.dist-info/RECORD +109 -0
  103. guidellm/backend/__init__.py +0 -23
  104. guidellm/backend/backend.py +0 -259
  105. guidellm/backend/openai.py +0 -708
  106. guidellm/backend/response.py +0 -136
  107. guidellm/benchmark/aggregator.py +0 -760
  108. guidellm/benchmark/benchmark.py +0 -837
  109. guidellm/benchmark/output.py +0 -997
  110. guidellm/benchmark/profile.py +0 -409
  111. guidellm/benchmark/scenario.py +0 -104
  112. guidellm/data/prideandprejudice.txt.gz +0 -0
  113. guidellm/dataset/__init__.py +0 -22
  114. guidellm/dataset/creator.py +0 -213
  115. guidellm/dataset/entrypoints.py +0 -42
  116. guidellm/dataset/file.py +0 -92
  117. guidellm/dataset/hf_datasets.py +0 -62
  118. guidellm/dataset/in_memory.py +0 -132
  119. guidellm/dataset/synthetic.py +0 -287
  120. guidellm/objects/__init__.py +0 -18
  121. guidellm/objects/pydantic.py +0 -89
  122. guidellm/objects/statistics.py +0 -953
  123. guidellm/preprocess/__init__.py +0 -3
  124. guidellm/preprocess/dataset.py +0 -374
  125. guidellm/presentation/__init__.py +0 -28
  126. guidellm/presentation/builder.py +0 -27
  127. guidellm/presentation/data_models.py +0 -232
  128. guidellm/presentation/injector.py +0 -66
  129. guidellm/request/__init__.py +0 -18
  130. guidellm/request/loader.py +0 -284
  131. guidellm/request/request.py +0 -79
  132. guidellm/request/types.py +0 -10
  133. guidellm/scheduler/queues.py +0 -25
  134. guidellm/scheduler/result.py +0 -155
  135. guidellm/scheduler/strategy.py +0 -495
  136. guidellm-0.3.1.dist-info/METADATA +0 -329
  137. guidellm-0.3.1.dist-info/RECORD +0 -62
  138. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/WHEEL +0 -0
  139. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/entry_points.txt +0 -0
  140. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/licenses/LICENSE +0 -0
  141. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,246 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from collections.abc import Callable, Iterator
5
+ from random import Random
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ from datasets import DatasetInfo, Features, IterableDataset, Value
10
+ from datasets.iterable_dataset import _BaseExamplesIterable
11
+ from faker import Faker
12
+ from transformers import PreTrainedTokenizerBase
13
+
14
+ from guidellm.data.config import load_config
15
+ from guidellm.data.deserializers.deserializer import (
16
+ DataNotSupportedError,
17
+ DatasetDeserializer,
18
+ DatasetDeserializerFactory,
19
+ )
20
+ from guidellm.data.schemas import SyntheticTextDatasetConfig
21
+ from guidellm.utils import IntegerRangeSampler
22
+
23
+ __all__ = [
24
+ "SyntheticTextDataset",
25
+ "SyntheticTextDatasetDeserializer",
26
+ ]
27
+
28
+
29
+ class _SyntheticTextExamplesIterable(_BaseExamplesIterable):
30
+ """Custom examples iterable for synthetic text generation."""
31
+
32
+ def __init__(
33
+ self,
34
+ config: SyntheticTextDatasetConfig,
35
+ processor: PreTrainedTokenizerBase,
36
+ random_seed: int,
37
+ ):
38
+ super().__init__()
39
+ self.config = config
40
+ self.processor = processor
41
+ self.random_seed = random_seed
42
+ self.iteration_count = 0
43
+
44
+ def __iter__(self) -> Iterator[tuple[int, dict[str, Any]]]:
45
+ iter_random_seed = self.random_seed + self.iteration_count
46
+ self.iteration_count += 1
47
+
48
+ faker = Faker()
49
+ faker.seed_instance(iter_random_seed)
50
+ prompt_tokens_sampler = iter(
51
+ IntegerRangeSampler(
52
+ average=self.config.prompt_tokens,
53
+ variance=self.config.prompt_tokens_stdev,
54
+ min_value=self.config.prompt_tokens_min,
55
+ max_value=self.config.prompt_tokens_max,
56
+ random_seed=iter_random_seed,
57
+ )
58
+ )
59
+ output_tokens_sampler = iter(
60
+ IntegerRangeSampler(
61
+ average=self.config.output_tokens,
62
+ variance=self.config.output_tokens_stdev,
63
+ min_value=self.config.output_tokens_min,
64
+ max_value=self.config.output_tokens_max,
65
+ random_seed=iter_random_seed + 1, # ensure diff dist from prompts
66
+ )
67
+ )
68
+
69
+ # Create a shared prefix if specified
70
+ rand = Random(iter_random_seed + 3)
71
+ prefix_iter = self._create_prefix_iter(faker, rand)
72
+ samples_count = 0
73
+
74
+ while True:
75
+ prompt_tokens_count = next(prompt_tokens_sampler)
76
+ output_tokens_count = next(output_tokens_sampler)
77
+
78
+ yield (
79
+ samples_count,
80
+ {
81
+ "prefix": next(prefix_iter),
82
+ "prompt": self._create_prompt(
83
+ prompt_tokens_count,
84
+ faker,
85
+ f"{self.iteration_count} {samples_count} ",
86
+ ),
87
+ "prompt_tokens_count": prompt_tokens_count,
88
+ "output_tokens_count": output_tokens_count,
89
+ },
90
+ )
91
+ samples_count += 1
92
+
93
+ @property
94
+ def is_typed(self) -> bool:
95
+ return True
96
+
97
+ @property
98
+ def features(self) -> Features:
99
+ return Features(
100
+ {
101
+ "prefix": Value("string"),
102
+ "prompt": Value("string"),
103
+ "prompt_tokens_count": Value("int32"),
104
+ "output_tokens_count": Value("int32"),
105
+ }
106
+ )
107
+
108
+ @property
109
+ def num_shards(self) -> int:
110
+ return 1
111
+
112
+ def shuffle_data_sources(
113
+ self,
114
+ generator: np.random.Generator, # noqa: ARG002
115
+ ) -> _SyntheticTextExamplesIterable:
116
+ """Return self since synthetic data doesn't have fixed sources to shuffle."""
117
+ return self
118
+
119
+ def shard_data_sources(
120
+ self,
121
+ num_shards: int, # noqa: ARG002
122
+ index: int, # noqa: ARG002
123
+ contiguous: bool = True, # noqa: ARG002
124
+ ) -> _SyntheticTextExamplesIterable:
125
+ """Return self since synthetic data generation is infinite and stateless."""
126
+ return self
127
+
128
+ def load_state_dict(self, state_dict: dict) -> None:
129
+ """Load the state from a state dict."""
130
+ self.iteration_count = state_dict.get("iteration_count", 0)
131
+
132
+ def _init_state_dict(self) -> dict:
133
+ """Initialize the state dict for the iterable."""
134
+ self._state_dict = {"iteration_count": self.iteration_count}
135
+ return self._state_dict
136
+
137
+ def _create_prompt(
138
+ self, prompt_tokens_count: int, faker: Faker, unique: str = ""
139
+ ) -> str:
140
+ prompt_token_ids: list[int] = []
141
+ avg_chars_per_token = 5
142
+ margin_of_safety = 1.5
143
+ attempts = 0
144
+
145
+ while len(prompt_token_ids) < prompt_tokens_count:
146
+ attempts += 1
147
+ num_chars = int(
148
+ prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts
149
+ )
150
+ text = unique + faker.text(max_nb_chars=num_chars)
151
+ prompt_token_ids = self.processor.encode(text)
152
+
153
+ return self.processor.decode(
154
+ prompt_token_ids[:prompt_tokens_count], skip_special_tokens=True
155
+ )
156
+
157
+ def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]:
158
+ if not self.config.prefix_buckets:
159
+ while True:
160
+ yield ""
161
+
162
+ # Increase weights to ensure an integer number of samples per per-prefix
163
+ least_common_prefix_count = math.lcm(
164
+ *(bucket.prefix_count for bucket in self.config.prefix_buckets)
165
+ )
166
+ unnorm_weights = [
167
+ least_common_prefix_count * bucket.bucket_weight // bucket.prefix_count
168
+ for bucket in self.config.prefix_buckets
169
+ ]
170
+ # Use GCD to reduce the weights to smallest integer ratio
171
+ common_divisor = math.gcd(*unnorm_weights)
172
+
173
+ # Create prefix list maintaining the correct distribution
174
+ prefixes = []
175
+ for bucket, weight in zip(
176
+ self.config.prefix_buckets, unnorm_weights, strict=False
177
+ ):
178
+ bucket_prefixes = [
179
+ self._create_prompt(bucket.prefix_tokens, faker)
180
+ for _ in range(bucket.prefix_count)
181
+ ]
182
+ sample_count = weight // common_divisor
183
+ prefixes.extend(bucket_prefixes * sample_count)
184
+
185
+ while True:
186
+ yield rand.choice(prefixes)
187
+
188
+
189
+ class SyntheticTextDataset(IterableDataset):
190
+ def __init__(
191
+ self,
192
+ config: SyntheticTextDatasetConfig,
193
+ processor: PreTrainedTokenizerBase,
194
+ random_seed: int = 42,
195
+ ):
196
+ self.config = config
197
+ self.processor = processor
198
+ self.random_seed = random_seed
199
+
200
+ # Create the examples iterable
201
+ ex_iterable = _SyntheticTextExamplesIterable(
202
+ config=config,
203
+ processor=processor,
204
+ random_seed=random_seed,
205
+ )
206
+
207
+ # Initialize parent with proper ex_iterable
208
+ super().__init__(
209
+ ex_iterable=ex_iterable,
210
+ info=DatasetInfo(
211
+ description="Synthetic text dataset generator",
212
+ features=ex_iterable.features,
213
+ ),
214
+ )
215
+
216
+ def set_epoch(self, epoch: int):
217
+ """Set the epoch for the dataset iteration."""
218
+ if isinstance(self._ex_iterable, _SyntheticTextExamplesIterable):
219
+ self._ex_iterable.iteration_count = epoch
220
+
221
+
222
+ @DatasetDeserializerFactory.register("synthetic_text")
223
+ class SyntheticTextDatasetDeserializer(DatasetDeserializer):
224
+ def __call__(
225
+ self,
226
+ data: Any,
227
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
228
+ random_seed: int,
229
+ **data_kwargs: dict[str, Any],
230
+ ) -> IterableDataset:
231
+ # Config file and string pathways; deserialize and call self again
232
+ if (config := load_config(data, SyntheticTextDatasetConfig)) is not None:
233
+ return self(config, processor_factory, random_seed, **data_kwargs)
234
+
235
+ if not isinstance(data, SyntheticTextDatasetConfig):
236
+ raise DataNotSupportedError(
237
+ "Unsupported data for SyntheticTextDatasetDeserializer, "
238
+ "expected SyntheticTextDatasetConfig, str or Path to a config file, "
239
+ f"got {data}"
240
+ )
241
+
242
+ return SyntheticTextDataset(
243
+ config=data,
244
+ processor=processor_factory(),
245
+ random_seed=random_seed,
246
+ )
@@ -0,0 +1,52 @@
1
+ from pathlib import Path
2
+ from typing import Any
3
+
4
+ from transformers import PreTrainedTokenizerBase
5
+
6
+ from guidellm.data import builders
7
+ from guidellm.data.builders import ShortPromptStrategy
8
+
9
+
10
+ def process_dataset(
11
+ data: str | Path,
12
+ output_path: str | Path,
13
+ processor: str | Path | PreTrainedTokenizerBase,
14
+ config: str | Path,
15
+ processor_args: dict[str, Any] | None = None,
16
+ data_args: dict[str, Any] | None = None,
17
+ data_column_mapper: dict[str, str] | None = None,
18
+ short_prompt_strategy: ShortPromptStrategy = ShortPromptStrategy.IGNORE,
19
+ pad_char: str | None = None,
20
+ concat_delimiter: str | None = None,
21
+ include_prefix_in_token_count: bool = False,
22
+ push_to_hub: bool = False,
23
+ hub_dataset_id: str | None = None,
24
+ random_seed: int = 42,
25
+ ) -> None:
26
+ """
27
+ Main method to process and save a dataset with sampled prompt/output token counts.
28
+
29
+ :param data: Path or identifier for dataset input.
30
+ :param output_path: File path to save the processed dataset.
31
+ :param processor: Tokenizer object or its config.
32
+ :param config: PreprocessDatasetConfig string or file path.
33
+ :param processor_args: Optional processor arguments.
34
+ :param data_args: Optional data loading arguments.
35
+ :param data_column_mapper: Optional column mapping dictionary.
36
+ :param short_prompt_strategy: Strategy for handling short prompts.
37
+ :param pad_char: Character used when padding short prompts.
38
+ :param concat_delimiter: Delimiter for concatenation strategy.
39
+ :param include_prefix_in_token_count:
40
+ Whether to include prefix in prompt token count, simplifying the token counts.
41
+ When True, prefix trimming is disabled and the prefix is kept as-is. The prefix
42
+ token count is subtracted from the prompt token budget instead.
43
+ :param push_to_hub: Whether to push to Hugging Face Hub.
44
+ :param hub_dataset_id: Dataset ID on Hugging Face Hub.
45
+ :param random_seed: Seed for random sampling.
46
+ :raises ValueError: If the output path is invalid or pushing conditions unmet.
47
+ """
48
+ builders.process_dataset(
49
+ data, output_path, processor, config, processor_args, data_args,
50
+ data_column_mapper, short_prompt_strategy, pad_char, concat_delimiter,
51
+ include_prefix_in_token_count, push_to_hub, hub_dataset_id, random_seed,
52
+ )
@@ -0,0 +1,190 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ from collections.abc import Callable, Iterator
5
+ from typing import Any, Literal, TypeVar
6
+
7
+ import torch
8
+ from torch.utils.data import Sampler
9
+ from torch.utils.data.dataloader import DataLoader as PyTorchDataLoader
10
+ from torch.utils.data.dataset import IterableDataset as TorchIterableDataset
11
+ from transformers import PreTrainedTokenizerBase
12
+
13
+ from guidellm.data.deserializers import DatasetDeserializerFactory
14
+ from guidellm.data.preprocessors import DataDependentPreprocessor, DatasetPreprocessor
15
+ from guidellm.logger import logger
16
+ from guidellm.utils import InfoMixin
17
+
18
+ __all__ = ["DataLoader", "DatasetsIterator"]
19
+
20
+
21
+ DataT = TypeVar("DataT")
22
+
23
+
24
+ class DatasetsIterator(TorchIterableDataset[DataT]):
25
+ def __init__(
26
+ self,
27
+ data: list[Any],
28
+ data_args: list[dict[str, Any]] | None,
29
+ data_samples: int,
30
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
31
+ preprocessors: list[DatasetPreprocessor | DataDependentPreprocessor],
32
+ random_seed: int,
33
+ ):
34
+ if not data or not isinstance(data, list):
35
+ raise ValueError(f"Data must be a non-empty list, got {data}.")
36
+
37
+ if not data_args:
38
+ data_args = [{} for _ in data]
39
+
40
+ if len(data) != len(data_args):
41
+ raise ValueError(
42
+ f"Length of data ({len(data)}) must match length of data_args "
43
+ f"({len(data_args)})."
44
+ )
45
+
46
+ self.datasets = []
47
+ for datum, data_kwargs in zip(data, data_args, strict=False):
48
+ self.datasets.append(
49
+ DatasetDeserializerFactory.deserialize(
50
+ data=datum,
51
+ processor_factory=processor_factory,
52
+ random_seed=random_seed,
53
+ **data_kwargs,
54
+ )
55
+ )
56
+ self.preprocessors = preprocessors
57
+ for preprocessor in self.preprocessors:
58
+ if isinstance(preprocessor, DataDependentPreprocessor):
59
+ preprocessor.setup_data(
60
+ datasets=self.datasets,
61
+ data_args=data_args,
62
+ )
63
+ self.precache: list[Any] | None = (
64
+ list(self.generator(data_samples)) if data_samples else None
65
+ )
66
+ self.epoch = 0
67
+
68
+ def __iter__(self) -> Iterator[DataT]:
69
+ worker_info = torch.utils.data.get_worker_info()
70
+ worker_modulus = worker_info.num_workers if worker_info is not None else 1
71
+ worker_index = worker_info.id if worker_info is not None else 0
72
+
73
+ if self.precache:
74
+ for index, item in enumerate(self.precache):
75
+ if (index + worker_index) % worker_modulus == 0:
76
+ yield item
77
+ else:
78
+ yield from self.generator(
79
+ modulus=worker_modulus, offset=worker_index, epoch=self.epoch
80
+ )
81
+
82
+ def set_epoch(self, epoch: int):
83
+ self.epoch = epoch
84
+
85
+ def generator(
86
+ self,
87
+ max_items: int | None = None,
88
+ modulus: int | None = None,
89
+ offset: int | None = None,
90
+ epoch: int = 0,
91
+ ) -> Iterator[DataT]:
92
+ gen_count = 0
93
+
94
+ with contextlib.suppress(StopIteration):
95
+ dataset_iters = []
96
+ for dataset in self.datasets:
97
+ if hasattr(dataset, "set_epoch"):
98
+ with contextlib.suppress(Exception):
99
+ dataset.set_epoch(epoch)
100
+ dataset_iters.append(iter(dataset))
101
+
102
+ while max_items is None or gen_count < max_items:
103
+ try:
104
+ row: dict[str, Any] = {
105
+ "items": [next(dataset_iter) for dataset_iter in dataset_iters]
106
+ }
107
+ gen_count += 1
108
+
109
+ if (
110
+ modulus is not None
111
+ and offset is not None
112
+ and (gen_count % modulus) != offset
113
+ ):
114
+ continue
115
+
116
+ for preprocessor in self.preprocessors:
117
+ # This can assign a GenerationRequest, which would then be
118
+ # passed into the preprocessor, which is a type violation.
119
+ # This should be fixed at some point.
120
+ row = preprocessor(row) # type: ignore[assignment]
121
+ yield row # type: ignore[misc]
122
+ except StopIteration:
123
+ raise # Stop iteration when any dataset is exhausted
124
+ except Exception as err: # noqa: BLE001 # Exception logged
125
+ logger.error(f"Skipping data row due to error: {err}")
126
+ gen_count -= 1
127
+
128
+ if max_items is not None and gen_count < max_items:
129
+ raise ValueError(
130
+ f"Requested {max_items} samples, but only {gen_count} "
131
+ "available from the provided datasets."
132
+ )
133
+
134
+
135
+ class DataLoader(PyTorchDataLoader[DataT], InfoMixin):
136
+ def __init__(
137
+ self,
138
+ data: list[Any],
139
+ data_args: list[dict[str, Any]] | None,
140
+ data_samples: int,
141
+ processor_factory: Callable[[], PreTrainedTokenizerBase],
142
+ preprocessors: list[DatasetPreprocessor | DataDependentPreprocessor],
143
+ collator: Callable,
144
+ sampler: Sampler[int] | Literal["shuffle"] | None = None,
145
+ num_workers: int | None = 1,
146
+ random_seed: int = 42,
147
+ **kwargs: Any,
148
+ ):
149
+ iterator: DatasetsIterator[DataT] = DatasetsIterator(
150
+ data=data,
151
+ data_args=data_args,
152
+ data_samples=data_samples,
153
+ processor_factory=processor_factory,
154
+ preprocessors=preprocessors,
155
+ random_seed=random_seed,
156
+ )
157
+ self._info: dict[str, Any] = {
158
+ "data": str(data),
159
+ "data_args": str(data_args),
160
+ "data_samples": data_samples,
161
+ "preprocessors": [
162
+ preprocessor.__class__.__name__ for preprocessor in preprocessors
163
+ ],
164
+ "collator": collator.__class__.__name__,
165
+ "sampler": str(sampler),
166
+ "num_workers": num_workers,
167
+ "random_seed": random_seed,
168
+ }
169
+ self.epoch = 0
170
+
171
+ super().__init__(
172
+ dataset=iterator,
173
+ batch_size=1,
174
+ shuffle=sampler == "shuffle",
175
+ sampler=sampler if sampler != "shuffle" else None,
176
+ collate_fn=collator,
177
+ num_workers=num_workers or 0,
178
+ **kwargs,
179
+ )
180
+
181
+ def __iter__(self):
182
+ if isinstance(self.dataset, DatasetsIterator):
183
+ self.dataset.set_epoch(self.epoch)
184
+ self.epoch += 1
185
+
186
+ return super().__iter__()
187
+
188
+ @property
189
+ def info(self) -> dict[str, Any]:
190
+ return self._info
@@ -0,0 +1,27 @@
1
+ from .formatters import (
2
+ GenerativeAudioTranscriptionRequestFormatter,
3
+ GenerativeAudioTranslationRequestFormatter,
4
+ GenerativeChatCompletionsRequestFormatter,
5
+ GenerativeTextCompletionsRequestFormatter,
6
+ RequestFormatter,
7
+ )
8
+ from .mappers import GenerativeColumnMapper
9
+ from .preprocessor import (
10
+ DataDependentPreprocessor,
11
+ DatasetPreprocessor,
12
+ PreprocessorRegistry,
13
+ )
14
+
15
+ __all__ = [
16
+ "ColumnMapper",
17
+ "ColumnMapperRegistry",
18
+ "DataDependentPreprocessor",
19
+ "DatasetPreprocessor",
20
+ "GenerativeAudioTranscriptionRequestFormatter",
21
+ "GenerativeAudioTranslationRequestFormatter",
22
+ "GenerativeChatCompletionsRequestFormatter",
23
+ "GenerativeColumnMapper",
24
+ "GenerativeTextCompletionsRequestFormatter",
25
+ "PreprocessorRegistry",
26
+ "RequestFormatter",
27
+ ]