guidellm 0.1.0__py3-none-any.whl → 0.2.0rc20250418__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guidellm might be problematic. Click here for more details.

Files changed (69) hide show
  1. guidellm/__init__.py +38 -6
  2. guidellm/__main__.py +294 -0
  3. guidellm/backend/__init__.py +19 -6
  4. guidellm/backend/backend.py +238 -0
  5. guidellm/backend/openai.py +532 -122
  6. guidellm/backend/response.py +132 -0
  7. guidellm/benchmark/__init__.py +73 -0
  8. guidellm/benchmark/aggregator.py +760 -0
  9. guidellm/benchmark/benchmark.py +838 -0
  10. guidellm/benchmark/benchmarker.py +334 -0
  11. guidellm/benchmark/entrypoints.py +141 -0
  12. guidellm/benchmark/output.py +946 -0
  13. guidellm/benchmark/profile.py +409 -0
  14. guidellm/benchmark/progress.py +720 -0
  15. guidellm/config.py +34 -56
  16. guidellm/data/__init__.py +4 -0
  17. guidellm/data/prideandprejudice.txt.gz +0 -0
  18. guidellm/dataset/__init__.py +22 -0
  19. guidellm/dataset/creator.py +213 -0
  20. guidellm/dataset/entrypoints.py +42 -0
  21. guidellm/dataset/file.py +90 -0
  22. guidellm/dataset/hf_datasets.py +62 -0
  23. guidellm/dataset/in_memory.py +132 -0
  24. guidellm/dataset/synthetic.py +262 -0
  25. guidellm/objects/__init__.py +18 -0
  26. guidellm/objects/pydantic.py +60 -0
  27. guidellm/objects/statistics.py +947 -0
  28. guidellm/request/__init__.py +12 -10
  29. guidellm/request/loader.py +281 -0
  30. guidellm/request/request.py +79 -0
  31. guidellm/scheduler/__init__.py +51 -3
  32. guidellm/scheduler/result.py +137 -0
  33. guidellm/scheduler/scheduler.py +382 -0
  34. guidellm/scheduler/strategy.py +493 -0
  35. guidellm/scheduler/types.py +7 -0
  36. guidellm/scheduler/worker.py +511 -0
  37. guidellm/utils/__init__.py +16 -29
  38. guidellm/utils/colors.py +8 -0
  39. guidellm/utils/hf_transformers.py +35 -0
  40. guidellm/utils/random.py +43 -0
  41. guidellm/utils/text.py +118 -357
  42. {guidellm-0.1.0.dist-info → guidellm-0.2.0rc20250418.dist-info}/METADATA +96 -79
  43. guidellm-0.2.0rc20250418.dist-info/RECORD +48 -0
  44. {guidellm-0.1.0.dist-info → guidellm-0.2.0rc20250418.dist-info}/WHEEL +1 -1
  45. guidellm-0.2.0rc20250418.dist-info/entry_points.txt +2 -0
  46. guidellm/backend/base.py +0 -320
  47. guidellm/core/__init__.py +0 -24
  48. guidellm/core/distribution.py +0 -190
  49. guidellm/core/report.py +0 -321
  50. guidellm/core/request.py +0 -44
  51. guidellm/core/result.py +0 -545
  52. guidellm/core/serializable.py +0 -169
  53. guidellm/executor/__init__.py +0 -10
  54. guidellm/executor/base.py +0 -213
  55. guidellm/executor/profile_generator.py +0 -343
  56. guidellm/main.py +0 -336
  57. guidellm/request/base.py +0 -194
  58. guidellm/request/emulated.py +0 -391
  59. guidellm/request/file.py +0 -76
  60. guidellm/request/transformers.py +0 -100
  61. guidellm/scheduler/base.py +0 -374
  62. guidellm/scheduler/load_generator.py +0 -196
  63. guidellm/utils/injector.py +0 -70
  64. guidellm/utils/progress.py +0 -196
  65. guidellm/utils/transformers.py +0 -151
  66. guidellm-0.1.0.dist-info/RECORD +0 -35
  67. guidellm-0.1.0.dist-info/entry_points.txt +0 -3
  68. {guidellm-0.1.0.dist-info → guidellm-0.2.0rc20250418.dist-info/licenses}/LICENSE +0 -0
  69. {guidellm-0.1.0.dist-info → guidellm-0.2.0rc20250418.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,511 @@
1
+ import asyncio
2
+ import math
3
+ import multiprocessing
4
+ import multiprocessing.queues
5
+ import time
6
+ from abc import ABC, abstractmethod
7
+ from collections.abc import AsyncGenerator
8
+ from dataclasses import dataclass
9
+ from typing import (
10
+ Any,
11
+ Generic,
12
+ Literal,
13
+ Optional,
14
+ Union,
15
+ )
16
+
17
+ from loguru import logger
18
+ from pydantic import Field
19
+
20
+ from guidellm.backend import (
21
+ Backend,
22
+ BackendType,
23
+ RequestArgs,
24
+ ResponseSummary,
25
+ StreamingTextResponse,
26
+ )
27
+ from guidellm.objects import StandardBaseModel
28
+ from guidellm.request import GenerationRequest
29
+ from guidellm.scheduler.result import SchedulerRequestInfo
30
+ from guidellm.scheduler.types import RequestT, ResponseT
31
+
32
+ __all__ = [
33
+ "WorkerProcessRequest",
34
+ "WorkerProcessResult",
35
+ "ResolveStatus",
36
+ "WorkerDescription",
37
+ "RequestsWorker",
38
+ "GenerativeRequestsWorkerDescription",
39
+ "GenerativeRequestsWorker",
40
+ ]
41
+
42
+
43
+ @dataclass
44
+ class WorkerProcessRequest(Generic[RequestT]):
45
+ request: RequestT
46
+ start_time: float
47
+ timeout_time: float
48
+ queued_time: float
49
+
50
+
51
+ @dataclass
52
+ class WorkerProcessResult(Generic[RequestT, ResponseT]):
53
+ type_: Literal["request_scheduled", "request_start", "request_complete"]
54
+ request: RequestT
55
+ response: Optional[ResponseT]
56
+ info: SchedulerRequestInfo
57
+
58
+
59
+ @dataclass
60
+ class ResolveStatus:
61
+ requested: bool
62
+ completed: bool
63
+ errored: bool
64
+ canceled: bool
65
+
66
+ request_start: float
67
+ request_end: float
68
+
69
+
70
+ class WorkerDescription(StandardBaseModel):
71
+ type_: Literal["worker"] = "worker"
72
+
73
+
74
+ class RequestsWorker(ABC, Generic[RequestT, ResponseT]):
75
+ """
76
+ An abstract base class for a worker that processes requests.
77
+ This class defines the interface for a worker that can resolve requests
78
+ asynchronously or synchronously within the Scheduler class.
79
+ Subclasses must implement the `resolve` method,
80
+ which takes a request directly given from the load generator,
81
+ along with the desired start_time for the request and a timeout_time.
82
+ The `resolve` method should return the response from the backend.
83
+ """
84
+
85
+ @property
86
+ @abstractmethod
87
+ def description(self) -> WorkerDescription:
88
+ """
89
+ An abstract property that must be implemented by subclasses.
90
+ This property should return a Serializable class representing the information
91
+ about the worker instance.
92
+ """
93
+ ...
94
+
95
+ @abstractmethod
96
+ async def prepare_multiprocessing(self):
97
+ """
98
+ An abstract method that must be implemented by subclasses.
99
+ This is useful for workers that have instance state that can not
100
+ be shared across processes and should be cleared out and re-initialized
101
+ for each new process.
102
+ """
103
+ ...
104
+
105
+ @abstractmethod
106
+ async def resolve(
107
+ self,
108
+ request: RequestT,
109
+ timeout_time: float,
110
+ ) -> tuple[ResolveStatus, ResponseT]:
111
+ """
112
+ An abstract method that must be implemented by subclasses.
113
+ This method should handle the resolution of a request through asyncio,
114
+ including any necessary backend processing and response handling.
115
+
116
+ :param request: The request to be resolved generated by the load generator.
117
+ :param timeout_time: The timeout time for the request, if there is no timeout
118
+ given, then this will be math.inf.
119
+ :return: The response from the worker.
120
+ """
121
+ ...
122
+
123
+ async def get_request(
124
+ self, requests_queue: multiprocessing.Queue
125
+ ) -> Optional[WorkerProcessRequest[RequestT]]:
126
+ return await asyncio.to_thread(requests_queue.get) # type: ignore[attr-defined]
127
+
128
+ async def send_result(
129
+ self,
130
+ results_queue: multiprocessing.Queue,
131
+ result: WorkerProcessResult[RequestT, ResponseT],
132
+ ):
133
+ await asyncio.to_thread(results_queue.put, result) # type: ignore[attr-defined]
134
+
135
+ async def resolve_scheduler_request(
136
+ self,
137
+ request: Any,
138
+ queued_time: float,
139
+ dequeued_time: float,
140
+ start_time: float,
141
+ timeout_time: float,
142
+ results_queue: multiprocessing.Queue,
143
+ process_id: int,
144
+ ):
145
+ info = SchedulerRequestInfo(
146
+ targeted_start_time=start_time,
147
+ queued_time=queued_time,
148
+ dequeued_time=dequeued_time,
149
+ scheduled_time=time.time(),
150
+ process_id=process_id,
151
+ )
152
+ result: WorkerProcessResult[RequestT, ResponseT] = WorkerProcessResult(
153
+ type_="request_scheduled",
154
+ request=request,
155
+ response=None,
156
+ info=info,
157
+ )
158
+ asyncio.create_task(self.send_result(results_queue, result))
159
+
160
+ if (wait_time := start_time - time.time()) > 0:
161
+ await asyncio.sleep(wait_time)
162
+
163
+ info.worker_start = time.time()
164
+ result = WorkerProcessResult(
165
+ type_="request_start",
166
+ request=request,
167
+ response=None,
168
+ info=info,
169
+ )
170
+ asyncio.create_task(self.send_result(results_queue, result))
171
+
172
+ status, response = await self.resolve(request, timeout_time)
173
+ info.worker_end = time.time()
174
+ info.requested = status.requested
175
+ info.completed = status.completed
176
+ info.errored = status.errored
177
+ info.canceled = status.canceled
178
+ info.request_start = status.request_start
179
+ info.request_end = status.request_end
180
+ result = WorkerProcessResult(
181
+ type_="request_complete",
182
+ request=request,
183
+ response=response,
184
+ info=info,
185
+ )
186
+ asyncio.create_task(self.send_result(results_queue, result))
187
+
188
+ def process_loop_synchronous(
189
+ self,
190
+ requests_queue: multiprocessing.Queue,
191
+ results_queue: multiprocessing.Queue,
192
+ process_id: int,
193
+ ):
194
+ async def _process_runner():
195
+ while (
196
+ process_request := await self.get_request(requests_queue)
197
+ ) is not None:
198
+ dequeued_time = time.time()
199
+
200
+ await self.resolve_scheduler_request(
201
+ request=process_request.request,
202
+ queued_time=process_request.queued_time,
203
+ dequeued_time=dequeued_time,
204
+ start_time=process_request.start_time,
205
+ timeout_time=process_request.timeout_time,
206
+ results_queue=results_queue,
207
+ process_id=process_id,
208
+ )
209
+
210
+ try:
211
+ asyncio.run(_process_runner())
212
+ except Exception as exc: # noqa: BLE001
213
+ logger.error(
214
+ f"Error in worker process {process_id}: {exc}",
215
+ exc_info=True,
216
+ stack_info=True,
217
+ )
218
+
219
+ def process_loop_asynchronous(
220
+ self,
221
+ requests_queue: multiprocessing.Queue,
222
+ results_queue: multiprocessing.Queue,
223
+ max_concurrency: int,
224
+ process_id: int,
225
+ ):
226
+ async def _process_runner():
227
+ pending = asyncio.Semaphore(max_concurrency)
228
+
229
+ if pending.locked():
230
+ raise ValueError("Async worker called with max_concurrency < 1")
231
+
232
+ while (
233
+ process_request := await self.get_request(requests_queue)
234
+ ) is not None:
235
+ dequeued_time = time.time()
236
+
237
+ await pending.acquire()
238
+
239
+ def _task_done(_: asyncio.Task):
240
+ nonlocal pending
241
+ pending.release()
242
+
243
+ task = asyncio.create_task(
244
+ self.resolve_scheduler_request(
245
+ request=process_request.request,
246
+ queued_time=process_request.queued_time,
247
+ dequeued_time=dequeued_time,
248
+ start_time=process_request.start_time,
249
+ timeout_time=process_request.timeout_time,
250
+ results_queue=results_queue,
251
+ process_id=process_id,
252
+ )
253
+ )
254
+ task.add_done_callback(_task_done)
255
+ await asyncio.sleep(0) # enable start task immediately
256
+
257
+ try:
258
+ asyncio.run(_process_runner())
259
+ except Exception as exc: # noqa: BLE001
260
+ logger.error(
261
+ f"Error in worker process {process_id}: {exc}",
262
+ exc_info=True,
263
+ stack_info=True,
264
+ )
265
+
266
+
267
+ class GenerativeRequestsWorkerDescription(WorkerDescription):
268
+ type_: Literal["generative_requests_worker"] = "generative_requests_worker" # type: ignore[assignment]
269
+ backend_type: BackendType
270
+ backend_target: str
271
+ backend_model: str
272
+ backend_info: dict[str, Any] = Field(
273
+ default_factory=dict,
274
+ )
275
+
276
+
277
+ class GenerativeRequestsWorker(RequestsWorker[GenerationRequest, ResponseSummary]):
278
+ """
279
+ A class that handles the execution of requests using a backend.
280
+ This class is responsible for sending requests to the backend,
281
+ handling responses, and managing errors.
282
+
283
+ :param backend: The backend to use for handling requests.
284
+ This should be an instance of Backend such as an OpenAIHTTPBackend.
285
+ """
286
+
287
+ def __init__(self, backend: Backend):
288
+ self.backend = backend
289
+
290
+ @property
291
+ def description(self) -> GenerativeRequestsWorkerDescription:
292
+ """
293
+ Get the description of the worker.
294
+ :return: The description of the worker.
295
+ """
296
+ return GenerativeRequestsWorkerDescription(
297
+ backend_type=self.backend.type_,
298
+ backend_target=self.backend.target,
299
+ backend_model=self.backend.model or "None",
300
+ backend_info=self.backend.info,
301
+ )
302
+
303
+ async def prepare_multiprocessing(self):
304
+ """
305
+ Prepare the worker for multiprocessing.
306
+ This is useful for workers that have instance state that can not
307
+ be shared across processes and should be cleared out and re-initialized
308
+ for each new process.
309
+ """
310
+ await self.backend.prepare_multiprocessing()
311
+
312
+ def process_loop_synchronous(
313
+ self,
314
+ requests_queue: multiprocessing.Queue,
315
+ results_queue: multiprocessing.Queue,
316
+ process_id: int,
317
+ ):
318
+ asyncio.run(self.backend.validate())
319
+ super().process_loop_synchronous(
320
+ requests_queue=requests_queue,
321
+ results_queue=results_queue,
322
+ process_id=process_id,
323
+ )
324
+
325
+ def process_loop_asynchronous(
326
+ self,
327
+ requests_queue: multiprocessing.Queue,
328
+ results_queue: multiprocessing.Queue,
329
+ max_concurrency: int,
330
+ process_id: int,
331
+ ):
332
+ asyncio.run(self.backend.validate())
333
+ super().process_loop_asynchronous(
334
+ requests_queue=requests_queue,
335
+ results_queue=results_queue,
336
+ max_concurrency=max_concurrency,
337
+ process_id=process_id,
338
+ )
339
+
340
+ async def resolve(
341
+ self,
342
+ request: GenerationRequest,
343
+ timeout_time: float,
344
+ ) -> tuple[ResolveStatus, ResponseSummary]:
345
+ """
346
+ Resolve a request by sending it to the backend and handling the response.
347
+ This method sends the request to the backend, waits for a response,
348
+ and handles any errors that may occur during the process.
349
+
350
+ :param request: The request to resolve.
351
+ :param timeout_time: The time to wait for a response before timing out.
352
+ If timeout_time is math.inf, the request will not timeout.
353
+ :return: A ResponseSummary object containing the response from the backend.
354
+ If an error occurs, the ResponseSummary will contain the error message.
355
+ """
356
+ resolve_start_time = time.time()
357
+ response = None
358
+ error: Optional[str] = None
359
+ status = ResolveStatus(
360
+ requested=False,
361
+ completed=False,
362
+ errored=False,
363
+ canceled=False,
364
+ request_start=-1,
365
+ request_end=-1,
366
+ )
367
+
368
+ try:
369
+ if timeout_time < time.time():
370
+ raise asyncio.TimeoutError(
371
+ "The timeout time has already passed."
372
+ ) # exit early
373
+
374
+ status.requested = True
375
+ request_func, request_kwargs = self._create_request_func_kwargs(request)
376
+
377
+ async def _runner():
378
+ # wrap function so we can enforce timeout and
379
+ # still return the latest state from the backend
380
+ async for resp in request_func(**request_kwargs): # type: ignore[operator]
381
+ nonlocal response
382
+ response = resp
383
+
384
+ await asyncio.wait_for(
385
+ _runner(),
386
+ timeout=timeout_time - time.time() if timeout_time < math.inf else None,
387
+ )
388
+
389
+ if not response:
390
+ raise ValueError(
391
+ f"No response received for request: {request} "
392
+ f"and backend: {self.backend}"
393
+ )
394
+ if not isinstance(response, ResponseSummary):
395
+ raise ValueError(
396
+ f"Received no ResponseSummary for request: {request} "
397
+ f"and backend: {self.backend}, received: {response}"
398
+ )
399
+
400
+ status.completed = True
401
+ except asyncio.TimeoutError:
402
+ error = "TimeoutError: The request timed out before completing."
403
+ status.errored = True
404
+ status.canceled = True
405
+ except Exception as exc: # noqa: BLE001
406
+ error = str(exc)
407
+ status.errored = True
408
+
409
+ return self._handle_response(
410
+ status=status,
411
+ request=request,
412
+ response=response,
413
+ error=error,
414
+ resolve_start_time=resolve_start_time,
415
+ )
416
+
417
+ def _create_request_func_kwargs(
418
+ self,
419
+ request: GenerationRequest,
420
+ ) -> tuple[
421
+ AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None],
422
+ dict[str, Any],
423
+ ]:
424
+ request_func: AsyncGenerator[
425
+ Union[StreamingTextResponse, ResponseSummary], None
426
+ ]
427
+ request_kwargs: dict[str, Any]
428
+
429
+ if request.request_type == "text_completions":
430
+ request_func = self.backend.text_completions # type: ignore[assignment]
431
+ request_kwargs = {
432
+ "prompt": request.content,
433
+ "request_id": request.request_id,
434
+ "prompt_token_count": request.stats.get("prompt_tokens", None),
435
+ "output_token_count": request.constraints.get("output_tokens", None),
436
+ **request.params,
437
+ }
438
+ elif request.request_type == "chat_completions":
439
+ request_func = self.backend.chat_completions # type: ignore[assignment]
440
+ request_kwargs = {
441
+ "content": request.content,
442
+ "request_id": request.request_id,
443
+ "prompt_token_count": request.stats.get("prompt_tokens", None),
444
+ "output_token_count": request.constraints.get("output_tokens", None),
445
+ **request.params,
446
+ }
447
+ else:
448
+ raise ValueError(
449
+ f"Invalid request type: {request.request_type} for {request}"
450
+ )
451
+
452
+ return request_func, request_kwargs
453
+
454
+ def _handle_response(
455
+ self,
456
+ status: ResolveStatus,
457
+ request: GenerationRequest,
458
+ response: Any,
459
+ error: Optional[str],
460
+ resolve_start_time: float,
461
+ ) -> tuple[ResolveStatus, ResponseSummary]:
462
+ if response is None or not isinstance(
463
+ response, (ResponseSummary, StreamingTextResponse)
464
+ ):
465
+ # nothing received or invalid response, fill in defaults for error
466
+ if response:
467
+ error = str(
468
+ ValueError(
469
+ f"Invalid response: {type(response)} for request: {request}; "
470
+ )
471
+ ) + (error or "")
472
+
473
+ response = ResponseSummary(
474
+ value="",
475
+ request_args=RequestArgs(
476
+ target=self.backend.target,
477
+ headers={},
478
+ payload={},
479
+ ),
480
+ start_time=resolve_start_time,
481
+ end_time=status.request_end,
482
+ first_iter_time=None,
483
+ last_iter_time=None,
484
+ request_id=request.request_id,
485
+ error=error or "Unknown error",
486
+ )
487
+ elif isinstance(response, StreamingTextResponse):
488
+ response = ResponseSummary(
489
+ value=response.value,
490
+ request_args=RequestArgs(
491
+ target=self.backend.target,
492
+ headers={},
493
+ payload={},
494
+ ),
495
+ start_time=response.start_time,
496
+ end_time=time.time(),
497
+ first_iter_time=response.first_iter_time,
498
+ last_iter_time=response.time if response.iter_count > 0 else None,
499
+ request_prompt_tokens=request.stats.get("prompt_tokens", None),
500
+ request_output_tokens=request.constraints.get("output_tokens", None),
501
+ response_prompt_tokens=None,
502
+ response_output_tokens=response.iter_count,
503
+ request_id=request.request_id,
504
+ error=error or "Unknown error",
505
+ )
506
+
507
+ response.error = error
508
+ status.request_start = response.start_time
509
+ status.request_end = response.end_time
510
+
511
+ return status, response
@@ -1,40 +1,27 @@
1
- from .injector import create_report, inject_data
2
- from .progress import BenchmarkReportProgress
1
+ from .colors import Colors
2
+ from .hf_transformers import (
3
+ check_load_processor,
4
+ )
5
+ from .random import IntegerRangeSampler
3
6
  from .text import (
7
+ EndlessTextCreator,
4
8
  clean_text,
5
9
  filter_text,
6
- is_path,
7
- is_path_like,
8
- is_url,
10
+ is_puncutation,
9
11
  load_text,
10
- load_text_lines,
11
- parse_text_objects,
12
- split_lines_by_punctuation,
13
12
  split_text,
14
- )
15
- from .transformers import (
16
- load_transformers_dataset,
17
- resolve_transformers_dataset,
18
- resolve_transformers_dataset_column,
19
- resolve_transformers_dataset_split,
13
+ split_text_list_by_length,
20
14
  )
21
15
 
22
16
  __all__ = [
23
- "BenchmarkReportProgress",
24
- "clean_text",
25
- "create_report",
17
+ "IntegerRangeSampler",
18
+ "Colors",
19
+ "check_load_processor",
26
20
  "filter_text",
27
- "inject_data",
28
- "is_path",
29
- "is_path_like",
30
- "is_url",
31
- "load_text",
32
- "load_text_lines",
33
- "load_transformers_dataset",
34
- "parse_text_objects",
35
- "resolve_transformers_dataset",
36
- "resolve_transformers_dataset_column",
37
- "resolve_transformers_dataset_split",
38
- "split_lines_by_punctuation",
21
+ "clean_text",
39
22
  "split_text",
23
+ "load_text",
24
+ "is_puncutation",
25
+ "EndlessTextCreator",
26
+ "split_text_list_by_length",
40
27
  ]
@@ -0,0 +1,8 @@
1
+ __all__ = ["Colors"]
2
+
3
+
4
+ class Colors:
5
+ INFO: str = "light_steel_blue"
6
+ PROGRESS: str = "dark_slate_gray1"
7
+ SUCCESS: str = "chartreuse1"
8
+ ERROR: str = "orange_red1"
@@ -0,0 +1,35 @@
1
+ from pathlib import Path
2
+ from typing import Any, Optional, Union
3
+
4
+ from transformers import AutoTokenizer, PreTrainedTokenizerBase # type: ignore[import]
5
+
6
+ __all__ = [
7
+ "check_load_processor",
8
+ ]
9
+
10
+
11
+ def check_load_processor(
12
+ processor: Optional[Union[str, Path, PreTrainedTokenizerBase]],
13
+ processor_args: Optional[dict[str, Any]],
14
+ error_msg: str,
15
+ ) -> PreTrainedTokenizerBase:
16
+ if processor is None:
17
+ raise ValueError(f"Processor/Tokenizer is required for {error_msg}.")
18
+
19
+ try:
20
+ if isinstance(processor, (str, Path)):
21
+ loaded = AutoTokenizer.from_pretrained(
22
+ processor,
23
+ **(processor_args or {}),
24
+ )
25
+ else:
26
+ loaded = processor
27
+ except Exception as err:
28
+ raise ValueError(
29
+ f"Failed to load processor/Tokenizer for {error_msg}."
30
+ ) from err
31
+
32
+ if not isinstance(loaded, PreTrainedTokenizerBase):
33
+ raise ValueError(f"Invalid processor/Tokenizer for {error_msg}.")
34
+
35
+ return loaded
@@ -0,0 +1,43 @@
1
+ import random
2
+ from collections.abc import Iterator
3
+ from typing import Optional
4
+
5
+ __all__ = ["IntegerRangeSampler"]
6
+
7
+
8
+ class IntegerRangeSampler:
9
+ def __init__(
10
+ self,
11
+ average: int,
12
+ variance: Optional[int],
13
+ min_value: Optional[int],
14
+ max_value: Optional[int],
15
+ random_seed: int,
16
+ ):
17
+ self.average = average
18
+ self.variance = variance
19
+ self.min_value = min_value
20
+ self.max_value = max_value
21
+ self.seed = random_seed
22
+ self.rng = random.Random(random_seed) # noqa: S311
23
+
24
+ def __iter__(self) -> Iterator[int]:
25
+ calc_min = self.min_value
26
+ if calc_min is None:
27
+ calc_min = max(
28
+ 1, self.average - 5 * self.variance if self.variance else self.average
29
+ )
30
+ calc_max = self.max_value
31
+ if calc_max is None:
32
+ calc_max = (
33
+ self.average + 5 * self.variance if self.variance else self.average
34
+ )
35
+
36
+ while True:
37
+ if calc_min == calc_max:
38
+ yield calc_min
39
+ elif not self.variance:
40
+ yield self.rng.randint(calc_min, calc_max + 1)
41
+ else:
42
+ rand = self.rng.gauss(self.average, self.variance)
43
+ yield round(max(calc_min, min(calc_max, rand)))