guidellm 0.3.1__py3-none-any.whl → 0.6.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. guidellm/__init__.py +5 -2
  2. guidellm/__main__.py +524 -255
  3. guidellm/backends/__init__.py +33 -0
  4. guidellm/backends/backend.py +109 -0
  5. guidellm/backends/openai.py +340 -0
  6. guidellm/backends/response_handlers.py +428 -0
  7. guidellm/benchmark/__init__.py +69 -39
  8. guidellm/benchmark/benchmarker.py +160 -316
  9. guidellm/benchmark/entrypoints.py +560 -127
  10. guidellm/benchmark/outputs/__init__.py +24 -0
  11. guidellm/benchmark/outputs/console.py +633 -0
  12. guidellm/benchmark/outputs/csv.py +721 -0
  13. guidellm/benchmark/outputs/html.py +473 -0
  14. guidellm/benchmark/outputs/output.py +169 -0
  15. guidellm/benchmark/outputs/serialized.py +69 -0
  16. guidellm/benchmark/profiles.py +718 -0
  17. guidellm/benchmark/progress.py +553 -556
  18. guidellm/benchmark/scenarios/__init__.py +40 -0
  19. guidellm/benchmark/scenarios/chat.json +6 -0
  20. guidellm/benchmark/scenarios/rag.json +6 -0
  21. guidellm/benchmark/schemas/__init__.py +66 -0
  22. guidellm/benchmark/schemas/base.py +402 -0
  23. guidellm/benchmark/schemas/generative/__init__.py +55 -0
  24. guidellm/benchmark/schemas/generative/accumulator.py +841 -0
  25. guidellm/benchmark/schemas/generative/benchmark.py +163 -0
  26. guidellm/benchmark/schemas/generative/entrypoints.py +381 -0
  27. guidellm/benchmark/schemas/generative/metrics.py +927 -0
  28. guidellm/benchmark/schemas/generative/report.py +158 -0
  29. guidellm/data/__init__.py +34 -4
  30. guidellm/data/builders.py +541 -0
  31. guidellm/data/collators.py +16 -0
  32. guidellm/data/config.py +120 -0
  33. guidellm/data/deserializers/__init__.py +49 -0
  34. guidellm/data/deserializers/deserializer.py +141 -0
  35. guidellm/data/deserializers/file.py +223 -0
  36. guidellm/data/deserializers/huggingface.py +94 -0
  37. guidellm/data/deserializers/memory.py +194 -0
  38. guidellm/data/deserializers/synthetic.py +246 -0
  39. guidellm/data/entrypoints.py +52 -0
  40. guidellm/data/loaders.py +190 -0
  41. guidellm/data/preprocessors/__init__.py +27 -0
  42. guidellm/data/preprocessors/formatters.py +410 -0
  43. guidellm/data/preprocessors/mappers.py +196 -0
  44. guidellm/data/preprocessors/preprocessor.py +30 -0
  45. guidellm/data/processor.py +29 -0
  46. guidellm/data/schemas.py +175 -0
  47. guidellm/data/utils/__init__.py +6 -0
  48. guidellm/data/utils/dataset.py +94 -0
  49. guidellm/extras/__init__.py +4 -0
  50. guidellm/extras/audio.py +220 -0
  51. guidellm/extras/vision.py +242 -0
  52. guidellm/logger.py +2 -2
  53. guidellm/mock_server/__init__.py +8 -0
  54. guidellm/mock_server/config.py +84 -0
  55. guidellm/mock_server/handlers/__init__.py +17 -0
  56. guidellm/mock_server/handlers/chat_completions.py +280 -0
  57. guidellm/mock_server/handlers/completions.py +280 -0
  58. guidellm/mock_server/handlers/tokenizer.py +142 -0
  59. guidellm/mock_server/models.py +510 -0
  60. guidellm/mock_server/server.py +238 -0
  61. guidellm/mock_server/utils.py +302 -0
  62. guidellm/scheduler/__init__.py +69 -26
  63. guidellm/scheduler/constraints/__init__.py +49 -0
  64. guidellm/scheduler/constraints/constraint.py +325 -0
  65. guidellm/scheduler/constraints/error.py +411 -0
  66. guidellm/scheduler/constraints/factory.py +182 -0
  67. guidellm/scheduler/constraints/request.py +312 -0
  68. guidellm/scheduler/constraints/saturation.py +722 -0
  69. guidellm/scheduler/environments.py +252 -0
  70. guidellm/scheduler/scheduler.py +137 -368
  71. guidellm/scheduler/schemas.py +358 -0
  72. guidellm/scheduler/strategies.py +617 -0
  73. guidellm/scheduler/worker.py +413 -419
  74. guidellm/scheduler/worker_group.py +712 -0
  75. guidellm/schemas/__init__.py +65 -0
  76. guidellm/schemas/base.py +417 -0
  77. guidellm/schemas/info.py +188 -0
  78. guidellm/schemas/request.py +235 -0
  79. guidellm/schemas/request_stats.py +349 -0
  80. guidellm/schemas/response.py +124 -0
  81. guidellm/schemas/statistics.py +1018 -0
  82. guidellm/{config.py → settings.py} +31 -24
  83. guidellm/utils/__init__.py +71 -8
  84. guidellm/utils/auto_importer.py +98 -0
  85. guidellm/utils/cli.py +132 -5
  86. guidellm/utils/console.py +566 -0
  87. guidellm/utils/encoding.py +778 -0
  88. guidellm/utils/functions.py +159 -0
  89. guidellm/utils/hf_datasets.py +1 -2
  90. guidellm/utils/hf_transformers.py +4 -4
  91. guidellm/utils/imports.py +9 -0
  92. guidellm/utils/messaging.py +1118 -0
  93. guidellm/utils/mixins.py +115 -0
  94. guidellm/utils/random.py +3 -4
  95. guidellm/utils/registry.py +220 -0
  96. guidellm/utils/singleton.py +133 -0
  97. guidellm/utils/synchronous.py +159 -0
  98. guidellm/utils/text.py +163 -50
  99. guidellm/utils/typing.py +41 -0
  100. guidellm/version.py +2 -2
  101. guidellm-0.6.0a5.dist-info/METADATA +364 -0
  102. guidellm-0.6.0a5.dist-info/RECORD +109 -0
  103. guidellm/backend/__init__.py +0 -23
  104. guidellm/backend/backend.py +0 -259
  105. guidellm/backend/openai.py +0 -708
  106. guidellm/backend/response.py +0 -136
  107. guidellm/benchmark/aggregator.py +0 -760
  108. guidellm/benchmark/benchmark.py +0 -837
  109. guidellm/benchmark/output.py +0 -997
  110. guidellm/benchmark/profile.py +0 -409
  111. guidellm/benchmark/scenario.py +0 -104
  112. guidellm/data/prideandprejudice.txt.gz +0 -0
  113. guidellm/dataset/__init__.py +0 -22
  114. guidellm/dataset/creator.py +0 -213
  115. guidellm/dataset/entrypoints.py +0 -42
  116. guidellm/dataset/file.py +0 -92
  117. guidellm/dataset/hf_datasets.py +0 -62
  118. guidellm/dataset/in_memory.py +0 -132
  119. guidellm/dataset/synthetic.py +0 -287
  120. guidellm/objects/__init__.py +0 -18
  121. guidellm/objects/pydantic.py +0 -89
  122. guidellm/objects/statistics.py +0 -953
  123. guidellm/preprocess/__init__.py +0 -3
  124. guidellm/preprocess/dataset.py +0 -374
  125. guidellm/presentation/__init__.py +0 -28
  126. guidellm/presentation/builder.py +0 -27
  127. guidellm/presentation/data_models.py +0 -232
  128. guidellm/presentation/injector.py +0 -66
  129. guidellm/request/__init__.py +0 -18
  130. guidellm/request/loader.py +0 -284
  131. guidellm/request/request.py +0 -79
  132. guidellm/request/types.py +0 -10
  133. guidellm/scheduler/queues.py +0 -25
  134. guidellm/scheduler/result.py +0 -155
  135. guidellm/scheduler/strategy.py +0 -495
  136. guidellm-0.3.1.dist-info/METADATA +0 -329
  137. guidellm-0.3.1.dist-info/RECORD +0 -62
  138. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/WHEEL +0 -0
  139. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/entry_points.txt +0 -0
  140. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/licenses/LICENSE +0 -0
  141. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,411 @@
1
+ """
2
+ Error-based constraint implementations.
3
+
4
+ Provides constraint types for limiting benchmark execution based on error rates
5
+ and error counts. These constraints monitor request error status to determine
6
+ when to stop benchmark execution due to excessive errors.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import time
12
+ from typing import Any, Literal, cast
13
+
14
+ from pydantic import Field, field_validator
15
+
16
+ from guidellm.scheduler.schemas import (
17
+ SchedulerProgress,
18
+ SchedulerState,
19
+ SchedulerUpdateAction,
20
+ )
21
+ from guidellm.schemas import RequestInfo
22
+ from guidellm.settings import settings
23
+
24
+ from .constraint import Constraint, PydanticConstraintInitializer
25
+ from .factory import ConstraintsInitializerFactory
26
+
27
+ __all__ = [
28
+ "MaxErrorRateConstraint",
29
+ "MaxErrorsConstraint",
30
+ "MaxGlobalErrorRateConstraint",
31
+ ]
32
+
33
+
34
+ @ConstraintsInitializerFactory.register(
35
+ ["max_errors", "max_err", "max_error", "max_errs"]
36
+ )
37
+ class MaxErrorsConstraint(PydanticConstraintInitializer):
38
+ """
39
+ Constraint that limits execution based on absolute error count.
40
+
41
+ Stops both request queuing and all request processing when the total number
42
+ of errored requests reaches the maximum threshold. Uses global error tracking
43
+ across all requests for immediate constraint evaluation.
44
+ """
45
+
46
+ type_: Literal["max_errors"] = "max_errors" # type: ignore[assignment]
47
+ max_errors: int | float | list[int | float] = Field(
48
+ description="Maximum number of errors allowed before triggering constraint",
49
+ )
50
+ current_index: int = Field(default=-1, description="Current index in error list")
51
+
52
+ @classmethod
53
+ def validated_kwargs(
54
+ cls, max_errors: int | float | list[int | float] | None = None, **kwargs
55
+ ) -> dict[str, Any]:
56
+ """
57
+ Validate and process arguments for MaxErrorsConstraint creation.
58
+
59
+ :param max_errors: Maximum number of errors to allow
60
+ :param kwargs: Supports max_errors, max_err, max_error, max_errs,
61
+ and optional type_
62
+ :return: Validated dictionary with max_errors and type_ fields
63
+ """
64
+ aliases = ["max_errors", "max_err", "max_error", "max_errs"]
65
+ for alias in aliases:
66
+ if max_errors is None:
67
+ max_errors = kwargs.get(alias)
68
+
69
+ return {
70
+ "max_errors": max_errors,
71
+ "current_index": kwargs.get("current_index", -1),
72
+ }
73
+
74
+ def create_constraint(self, **_kwargs) -> Constraint:
75
+ """
76
+ Return self as the constraint instance.
77
+
78
+ :param kwargs: Additional keyword arguments (unused)
79
+ :return: Self instance as the constraint
80
+ """
81
+ self.current_index += 1
82
+
83
+ return cast("Constraint", self.model_copy())
84
+
85
+ def __call__(
86
+ self, state: SchedulerState, request_info: RequestInfo
87
+ ) -> SchedulerUpdateAction:
88
+ """
89
+ Evaluate constraint against current error count.
90
+
91
+ :param state: Current scheduler state with error counts
92
+ :param request_info: Individual request information (unused)
93
+ :return: Action indicating whether to continue or stop operations
94
+ """
95
+ _ = request_info # Unused parameters
96
+ current_index = max(0, self.current_index)
97
+ max_errors = (
98
+ self.max_errors
99
+ if isinstance(self.max_errors, int | float)
100
+ else self.max_errors[min(current_index, len(self.max_errors) - 1)]
101
+ )
102
+ errors_exceeded = state.errored_requests >= max_errors
103
+ stop_time = (
104
+ None if not errors_exceeded else request_info.completed_at or time.time()
105
+ )
106
+
107
+ return SchedulerUpdateAction(
108
+ request_queuing="stop" if errors_exceeded else "continue",
109
+ request_processing="stop_all" if errors_exceeded else "continue",
110
+ metadata={
111
+ "max_errors": max_errors,
112
+ "errors_exceeded": errors_exceeded,
113
+ "current_errors": state.errored_requests,
114
+ "stop_time": stop_time,
115
+ },
116
+ progress=SchedulerProgress(stop_time=stop_time),
117
+ )
118
+
119
+ @field_validator("max_errors")
120
+ @classmethod
121
+ def _validate_max_errors(
122
+ cls, value: int | float | list[int | float]
123
+ ) -> int | float | list[int | float]:
124
+ if not isinstance(value, list):
125
+ value = [value]
126
+ for val in value:
127
+ if not val:
128
+ raise ValueError(
129
+ "max_errors must be set and truthful, "
130
+ f"received {value} ({val} failed)"
131
+ )
132
+ if not isinstance(val, int | float) or val <= 0:
133
+ raise ValueError(
134
+ f"max_errors must be a positive num,received {value} ({val} failed)"
135
+ )
136
+
137
+ return value[0] if isinstance(value, list) and len(value) == 1 else value
138
+
139
+
140
+ @ConstraintsInitializerFactory.register(
141
+ ["max_error_rate", "max_err_rate", "max_errors_rate"]
142
+ )
143
+ class MaxErrorRateConstraint(PydanticConstraintInitializer):
144
+ """
145
+ Constraint that limits execution based on sliding window error rate.
146
+
147
+ Tracks error status of recent requests in a sliding window and stops all
148
+ processing when the error rate exceeds the threshold. Only applies the
149
+ constraint after processing enough requests to fill the minimum window size
150
+ for statistical significance.
151
+ """
152
+
153
+ type_: Literal["max_error_rate"] = "max_error_rate" # type: ignore[assignment]
154
+ max_error_rate: int | float | list[int | float] = Field(
155
+ description="Maximum error rate allowed (0.0, 1.0)"
156
+ )
157
+ window_size: int | float = Field(
158
+ default=30,
159
+ gt=0,
160
+ description="Size of sliding window for calculating error rate",
161
+ )
162
+ error_window: list[bool] = Field(
163
+ default_factory=list,
164
+ description="Sliding window tracking error status of recent requests",
165
+ )
166
+ current_index: int = Field(
167
+ default=-1, description="Current index in the error window"
168
+ )
169
+
170
+ @classmethod
171
+ def validated_kwargs(
172
+ cls, max_error_rate: int | float | list[int | float], **kwargs
173
+ ) -> dict[str, Any]:
174
+ """
175
+ Validate and process arguments for MaxErrorRateConstraint creation.
176
+
177
+ :param max_error_rate: Maximum error rate to allow
178
+ :param kwargs: Supports max_error_rate, max_err_rate, max_errors_rate,
179
+ optional window_size, and optional type_
180
+ :return: Validated dictionary with max_error_rate, window_size,
181
+ and type_ fields
182
+ """
183
+ aliases = ["max_error_rate", "max_err_rate", "max_errors_rate"]
184
+ for alias in aliases:
185
+ if max_error_rate is None:
186
+ max_error_rate = kwargs.get(alias)
187
+
188
+ return {
189
+ "max_error_rate": max_error_rate,
190
+ "window_size": kwargs.get(
191
+ "window_size", settings.constraint_error_window_size
192
+ ),
193
+ "error_window": kwargs.get("error_window", []),
194
+ "current_index": kwargs.get("current_index", -1),
195
+ }
196
+
197
+ def create_constraint(self, **_kwargs) -> Constraint:
198
+ """
199
+ Create a new instance of MaxErrorRateConstraint (due to stateful window).
200
+
201
+ :param kwargs: Additional keyword arguments (unused)
202
+ :return: New instance of the constraint
203
+ """
204
+ self.current_index += 1
205
+
206
+ return cast("Constraint", self.model_copy())
207
+
208
+ def __call__(
209
+ self, state: SchedulerState, request_info: RequestInfo
210
+ ) -> SchedulerUpdateAction:
211
+ """
212
+ Evaluate constraint against sliding window error rate.
213
+
214
+ :param state: Current scheduler state with request counts
215
+ :param request_info: Individual request with completion status
216
+ :return: Action indicating whether to continue or stop operations
217
+ """
218
+ current_index = max(0, self.current_index)
219
+ max_error_rate = (
220
+ self.max_error_rate
221
+ if isinstance(self.max_error_rate, int | float)
222
+ else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)]
223
+ )
224
+
225
+ if request_info.status in ["completed", "errored", "cancelled"]:
226
+ self.error_window.append(request_info.status == "errored")
227
+ if len(self.error_window) > self.window_size:
228
+ self.error_window.pop(0)
229
+
230
+ error_count = sum(self.error_window)
231
+ window_requests = len(self.error_window)
232
+ error_rate = (
233
+ error_count / float(window_requests) if window_requests > 0 else 0.0
234
+ )
235
+ exceeded_min_processed = state.processed_requests >= self.window_size
236
+ exceeded_error_rate = error_rate >= max_error_rate
237
+ exceeded = exceeded_min_processed and exceeded_error_rate
238
+ stop_time = None if not exceeded else request_info.completed_at or time.time()
239
+
240
+ return SchedulerUpdateAction(
241
+ request_queuing="stop" if exceeded else "continue",
242
+ request_processing="stop_all" if exceeded else "continue",
243
+ metadata={
244
+ "max_error_rate": max_error_rate,
245
+ "window_size": self.window_size,
246
+ "error_count": error_count,
247
+ "processed_count": state.processed_requests,
248
+ "current_window_size": len(self.error_window),
249
+ "current_error_rate": error_rate,
250
+ "exceeded_min_processed": exceeded_min_processed,
251
+ "exceeded_error_rate": exceeded_error_rate,
252
+ "exceeded": exceeded,
253
+ "stop_time": stop_time,
254
+ },
255
+ )
256
+
257
+ @field_validator("max_error_rate")
258
+ @classmethod
259
+ def _validate_max_error_rate(
260
+ cls, value: int | float | list[int | float]
261
+ ) -> int | float | list[int | float]:
262
+ if not isinstance(value, list):
263
+ value = [value]
264
+ for val in value:
265
+ if not val:
266
+ raise ValueError(
267
+ "max_error_rate must be set and truthful, "
268
+ f"received {value} ({val} failed)"
269
+ )
270
+ if not isinstance(val, int | float) or val <= 0 or val >= 1:
271
+ raise ValueError(
272
+ "max_error_rate must be a number between 0 and 1,"
273
+ f"received {value} ({val} failed)"
274
+ )
275
+
276
+ return value[0] if isinstance(value, list) and len(value) == 1 else value
277
+
278
+
279
+ @ConstraintsInitializerFactory.register(
280
+ ["max_global_error_rate", "max_global_err_rate", "max_global_errors_rate"]
281
+ )
282
+ class MaxGlobalErrorRateConstraint(PydanticConstraintInitializer):
283
+ """
284
+ Constraint that limits execution based on global error rate.
285
+
286
+ Calculates error rate across all processed requests and stops all processing
287
+ when the rate exceeds the threshold. Only applies the constraint after
288
+ processing the minimum number of requests to ensure statistical significance
289
+ for global error rate calculations.
290
+ """
291
+
292
+ type_: Literal["max_global_error_rate"] = "max_global_error_rate" # type: ignore[assignment]
293
+ max_error_rate: int | float = Field(
294
+ description="Maximum error rate allowed (0.0 to 1.0)"
295
+ )
296
+ min_processed: int | float | None = Field(
297
+ default=30,
298
+ gt=0,
299
+ description="Minimum requests processed before applying error rate constraint",
300
+ )
301
+ current_index: int = Field(
302
+ default=-1, description="Current index for list-based max_error_rate values"
303
+ )
304
+
305
+ @classmethod
306
+ def validated_kwargs(
307
+ cls, max_error_rate: int | float | list[int | float], **kwargs
308
+ ) -> dict[str, Any]:
309
+ """
310
+ Validate and process arguments for MaxGlobalErrorRateConstraint creation.
311
+
312
+ :param max_error_rate: Maximum error rate to allow
313
+ :param kwargs: Supports max_global_error_rate, max_global_err_rate,
314
+ max_global_errors_rate, optional min_processed, and optional type_
315
+ :return: Validated dictionary with max_error_rate, min_processed,
316
+ and type_ fields
317
+ """
318
+ for alias in [
319
+ "max_global_error_rate",
320
+ "max_global_err_rate",
321
+ "max_global_errors_rate",
322
+ ]:
323
+ if max_error_rate is None:
324
+ max_error_rate = kwargs.get(alias)
325
+
326
+ return {
327
+ "max_error_rate": max_error_rate,
328
+ "min_processed": kwargs.get(
329
+ "min_processed", settings.constraint_error_min_processed
330
+ ),
331
+ "current_index": kwargs.get("current_index", -1),
332
+ }
333
+
334
+ def create_constraint(self, **_kwargs) -> Constraint:
335
+ """
336
+ Return self as the constraint instance.
337
+
338
+ :param kwargs: Additional keyword arguments (unused)
339
+ :return: Self instance as the constraint
340
+ """
341
+ self.current_index += 1
342
+
343
+ return cast("Constraint", self.model_copy())
344
+
345
+ def __call__(
346
+ self, state: SchedulerState, request_info: RequestInfo
347
+ ) -> SchedulerUpdateAction:
348
+ """
349
+ Evaluate constraint against global error rate.
350
+
351
+ :param state: Current scheduler state with global request and error counts
352
+ :param request_info: Individual request information (unused)
353
+ :return: Action indicating whether to continue or stop operations
354
+ """
355
+ _ = request_info # Unused parameters
356
+ current_index = max(0, self.current_index)
357
+ max_error_rate = (
358
+ self.max_error_rate
359
+ if isinstance(self.max_error_rate, int | float)
360
+ else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)]
361
+ )
362
+
363
+ exceeded_min_processed = (
364
+ self.min_processed is None or state.processed_requests >= self.min_processed
365
+ )
366
+ error_rate = (
367
+ state.errored_requests / float(state.processed_requests)
368
+ if state.processed_requests > 0
369
+ else 0.0
370
+ )
371
+ exceeded_error_rate = error_rate >= max_error_rate
372
+ exceeded = exceeded_min_processed and exceeded_error_rate
373
+ stop_time = None if not exceeded else request_info.completed_at or time.time()
374
+
375
+ return SchedulerUpdateAction(
376
+ request_queuing="stop" if exceeded else "continue",
377
+ request_processing="stop_all" if exceeded else "continue",
378
+ metadata={
379
+ "max_error_rate": max_error_rate,
380
+ "min_processed": self.min_processed,
381
+ "processed_requests": state.processed_requests,
382
+ "errored_requests": state.errored_requests,
383
+ "error_rate": error_rate,
384
+ "exceeded_min_processed": exceeded_min_processed,
385
+ "exceeded_error_rate": exceeded_error_rate,
386
+ "exceeded": exceeded,
387
+ "stop_time": stop_time,
388
+ },
389
+ progress=SchedulerProgress(stop_time=stop_time),
390
+ )
391
+
392
+ @field_validator("max_error_rate")
393
+ @classmethod
394
+ def _validate_max_error_rate(
395
+ cls, value: int | float | list[int | float]
396
+ ) -> int | float | list[int | float]:
397
+ if not isinstance(value, list):
398
+ value = [value]
399
+ for val in value:
400
+ if not val:
401
+ raise ValueError(
402
+ "max_error_rate must be set and truthful, "
403
+ f"received {value} ({val} failed)"
404
+ )
405
+ if not isinstance(val, int | float) or val <= 0 or val >= 1:
406
+ raise ValueError(
407
+ "max_error_rate must be a number between 0 and 1,"
408
+ f"received {value} ({val} failed)"
409
+ )
410
+
411
+ return value[0] if isinstance(value, list) and len(value) == 1 else value
@@ -0,0 +1,182 @@
1
+ """
2
+ Factory for creating and managing constraint initializers.
3
+
4
+ Provides centralized access to registered constraint types with support for
5
+ creating constraints from configuration dictionaries, simple values, or
6
+ pre-configured instances.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Any
12
+
13
+ from guidellm.scheduler.constraints.constraint import (
14
+ Constraint,
15
+ ConstraintInitializer,
16
+ SerializableConstraintInitializer,
17
+ UnserializableConstraintInitializer,
18
+ )
19
+ from guidellm.utils import InfoMixin, RegistryMixin
20
+
21
+ __all__ = ["ConstraintsInitializerFactory"]
22
+
23
+
24
+ class ConstraintsInitializerFactory(RegistryMixin[ConstraintInitializer]):
25
+ """
26
+ Registry factory for creating and managing constraint initializers.
27
+
28
+ Provides centralized access to registered constraint types with support for
29
+ creating constraints from configuration dictionaries, simple values, or
30
+ pre-configured instances. Handles constraint resolution and type validation
31
+ for the scheduler constraint system.
32
+
33
+ Example:
34
+ ::
35
+ from guidellm.scheduler import ConstraintsInitializerFactory
36
+
37
+ # Register new constraint type
38
+ @ConstraintsInitializerFactory.register("new_constraint")
39
+ class NewConstraint:
40
+ def create_constraint(self, **kwargs) -> Constraint:
41
+ return lambda state, request: SchedulerUpdateAction()
42
+
43
+ # Create and use constraint
44
+ constraint = ConstraintsInitializerFactory.create_constraint("new_constraint")
45
+ """
46
+
47
+ @classmethod
48
+ def create(cls, key: str, *args, **kwargs) -> ConstraintInitializer:
49
+ """
50
+ Create a constraint initializer for the specified key.
51
+
52
+ :param key: Registered constraint initializer key
53
+ :param args: Positional arguments for initializer creation
54
+ :param kwargs: Keyword arguments for initializer creation
55
+ :return: Configured constraint initializer instance
56
+ :raises ValueError: If the key is not registered in the factory
57
+ """
58
+ if cls.registry is None or key not in cls.registry:
59
+ raise ValueError(f"Unknown constraint initializer key: {key}")
60
+
61
+ initializer_class = cls.registry[key]
62
+
63
+ return (
64
+ initializer_class(*args, **kwargs) # type: ignore[operator]
65
+ if not isinstance(initializer_class, type)
66
+ or not issubclass(initializer_class, SerializableConstraintInitializer)
67
+ else initializer_class(
68
+ **initializer_class.validated_kwargs(*args, **kwargs) # type: ignore[misc]
69
+ )
70
+ )
71
+
72
+ @classmethod
73
+ def serialize(cls, initializer: ConstraintInitializer) -> dict[str, Any]:
74
+ """
75
+ Serialize constraint initializer to dictionary format.
76
+
77
+ :param initializer: Constraint initializer to serialize
78
+ :return: Dictionary representation or unserializable placeholder
79
+ """
80
+ if isinstance(initializer, SerializableConstraintInitializer):
81
+ return initializer.model_dump()
82
+ else:
83
+ unserializable = UnserializableConstraintInitializer(
84
+ orig_info=InfoMixin.extract_from_obj(initializer)
85
+ )
86
+ return unserializable.model_dump()
87
+
88
+ @classmethod
89
+ def deserialize(
90
+ cls, initializer_dict: dict[str, Any]
91
+ ) -> SerializableConstraintInitializer | UnserializableConstraintInitializer:
92
+ """
93
+ Deserialize constraint initializer from dictionary format.
94
+
95
+ :param initializer_dict: Dictionary representation of constraint initializer
96
+ :return: Reconstructed constraint initializer instance
97
+ :raises ValueError: If constraint type is unknown or cannot be deserialized
98
+ """
99
+ if initializer_dict.get("type_") == "unserializable":
100
+ return UnserializableConstraintInitializer.model_validate(initializer_dict)
101
+
102
+ if (
103
+ cls.registry is not None
104
+ and initializer_dict.get("type_")
105
+ and initializer_dict["type_"] in cls.registry
106
+ ):
107
+ initializer_class = cls.registry[initializer_dict["type_"]]
108
+ if hasattr(initializer_class, "model_validate"):
109
+ return initializer_class.model_validate(initializer_dict) # type: ignore[return-value]
110
+ else:
111
+ return initializer_class(**initializer_dict) # type: ignore[return-value,operator]
112
+
113
+ raise ValueError(
114
+ f"Cannot deserialize unknown constraint initializer: "
115
+ f"{initializer_dict.get('type_', 'unknown')}"
116
+ )
117
+
118
+ @classmethod
119
+ def create_constraint(cls, key: str, *args, **kwargs) -> Constraint:
120
+ """
121
+ Create a constraint instance for the specified key.
122
+
123
+ :param key: Registered constraint initializer key
124
+ :param args: Positional arguments for constraint creation
125
+ :param kwargs: Keyword arguments for constraint creation
126
+ :return: Configured constraint function ready for evaluation
127
+ :raises ValueError: If the key is not registered in the factory
128
+ """
129
+ return cls.create(key, *args, **kwargs).create_constraint()
130
+
131
+ @classmethod
132
+ def resolve(
133
+ cls,
134
+ initializers: dict[
135
+ str,
136
+ Any | dict[str, Any] | Constraint | ConstraintInitializer,
137
+ ],
138
+ ) -> dict[str, Constraint]:
139
+ """
140
+ Resolve mixed constraint specifications to callable constraints.
141
+
142
+ :param initializers: Dictionary mapping constraint keys to specifications
143
+ :return: Dictionary mapping constraint keys to callable functions
144
+ :raises ValueError: If any key is not registered in the factory
145
+ """
146
+ constraints = {}
147
+
148
+ for key, val in initializers.items():
149
+ if isinstance(val, Constraint):
150
+ constraints[key] = val
151
+ elif isinstance(val, ConstraintInitializer):
152
+ constraints[key] = val.create_constraint()
153
+ elif isinstance(val, dict):
154
+ constraints[key] = cls.create_constraint(key, **val)
155
+ else:
156
+ constraints[key] = cls.create_constraint(key, val)
157
+
158
+ return constraints
159
+
160
+ @classmethod
161
+ def resolve_constraints(
162
+ cls,
163
+ constraints: dict[str, Any | dict[str, Any] | Constraint],
164
+ ) -> dict[str, Constraint]:
165
+ """
166
+ Resolve constraints from mixed constraint specifications.
167
+
168
+ :param constraints: Dictionary mapping constraint keys to specifications
169
+ :return: Dictionary mapping constraint keys to callable functions
170
+ :raises ValueError: If any constraint key is not registered
171
+ """
172
+ resolved_constraints = {}
173
+
174
+ for key, val in constraints.items():
175
+ if isinstance(val, Constraint):
176
+ resolved_constraints[key] = val
177
+ elif isinstance(val, dict):
178
+ resolved_constraints[key] = cls.create_constraint(key, **val)
179
+ else:
180
+ resolved_constraints[key] = cls.create_constraint(key, val)
181
+
182
+ return resolved_constraints