guidellm 0.3.1__py3-none-any.whl → 0.6.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. guidellm/__init__.py +5 -2
  2. guidellm/__main__.py +524 -255
  3. guidellm/backends/__init__.py +33 -0
  4. guidellm/backends/backend.py +109 -0
  5. guidellm/backends/openai.py +340 -0
  6. guidellm/backends/response_handlers.py +428 -0
  7. guidellm/benchmark/__init__.py +69 -39
  8. guidellm/benchmark/benchmarker.py +160 -316
  9. guidellm/benchmark/entrypoints.py +560 -127
  10. guidellm/benchmark/outputs/__init__.py +24 -0
  11. guidellm/benchmark/outputs/console.py +633 -0
  12. guidellm/benchmark/outputs/csv.py +721 -0
  13. guidellm/benchmark/outputs/html.py +473 -0
  14. guidellm/benchmark/outputs/output.py +169 -0
  15. guidellm/benchmark/outputs/serialized.py +69 -0
  16. guidellm/benchmark/profiles.py +718 -0
  17. guidellm/benchmark/progress.py +553 -556
  18. guidellm/benchmark/scenarios/__init__.py +40 -0
  19. guidellm/benchmark/scenarios/chat.json +6 -0
  20. guidellm/benchmark/scenarios/rag.json +6 -0
  21. guidellm/benchmark/schemas/__init__.py +66 -0
  22. guidellm/benchmark/schemas/base.py +402 -0
  23. guidellm/benchmark/schemas/generative/__init__.py +55 -0
  24. guidellm/benchmark/schemas/generative/accumulator.py +841 -0
  25. guidellm/benchmark/schemas/generative/benchmark.py +163 -0
  26. guidellm/benchmark/schemas/generative/entrypoints.py +381 -0
  27. guidellm/benchmark/schemas/generative/metrics.py +927 -0
  28. guidellm/benchmark/schemas/generative/report.py +158 -0
  29. guidellm/data/__init__.py +34 -4
  30. guidellm/data/builders.py +541 -0
  31. guidellm/data/collators.py +16 -0
  32. guidellm/data/config.py +120 -0
  33. guidellm/data/deserializers/__init__.py +49 -0
  34. guidellm/data/deserializers/deserializer.py +141 -0
  35. guidellm/data/deserializers/file.py +223 -0
  36. guidellm/data/deserializers/huggingface.py +94 -0
  37. guidellm/data/deserializers/memory.py +194 -0
  38. guidellm/data/deserializers/synthetic.py +246 -0
  39. guidellm/data/entrypoints.py +52 -0
  40. guidellm/data/loaders.py +190 -0
  41. guidellm/data/preprocessors/__init__.py +27 -0
  42. guidellm/data/preprocessors/formatters.py +410 -0
  43. guidellm/data/preprocessors/mappers.py +196 -0
  44. guidellm/data/preprocessors/preprocessor.py +30 -0
  45. guidellm/data/processor.py +29 -0
  46. guidellm/data/schemas.py +175 -0
  47. guidellm/data/utils/__init__.py +6 -0
  48. guidellm/data/utils/dataset.py +94 -0
  49. guidellm/extras/__init__.py +4 -0
  50. guidellm/extras/audio.py +220 -0
  51. guidellm/extras/vision.py +242 -0
  52. guidellm/logger.py +2 -2
  53. guidellm/mock_server/__init__.py +8 -0
  54. guidellm/mock_server/config.py +84 -0
  55. guidellm/mock_server/handlers/__init__.py +17 -0
  56. guidellm/mock_server/handlers/chat_completions.py +280 -0
  57. guidellm/mock_server/handlers/completions.py +280 -0
  58. guidellm/mock_server/handlers/tokenizer.py +142 -0
  59. guidellm/mock_server/models.py +510 -0
  60. guidellm/mock_server/server.py +238 -0
  61. guidellm/mock_server/utils.py +302 -0
  62. guidellm/scheduler/__init__.py +69 -26
  63. guidellm/scheduler/constraints/__init__.py +49 -0
  64. guidellm/scheduler/constraints/constraint.py +325 -0
  65. guidellm/scheduler/constraints/error.py +411 -0
  66. guidellm/scheduler/constraints/factory.py +182 -0
  67. guidellm/scheduler/constraints/request.py +312 -0
  68. guidellm/scheduler/constraints/saturation.py +722 -0
  69. guidellm/scheduler/environments.py +252 -0
  70. guidellm/scheduler/scheduler.py +137 -368
  71. guidellm/scheduler/schemas.py +358 -0
  72. guidellm/scheduler/strategies.py +617 -0
  73. guidellm/scheduler/worker.py +413 -419
  74. guidellm/scheduler/worker_group.py +712 -0
  75. guidellm/schemas/__init__.py +65 -0
  76. guidellm/schemas/base.py +417 -0
  77. guidellm/schemas/info.py +188 -0
  78. guidellm/schemas/request.py +235 -0
  79. guidellm/schemas/request_stats.py +349 -0
  80. guidellm/schemas/response.py +124 -0
  81. guidellm/schemas/statistics.py +1018 -0
  82. guidellm/{config.py → settings.py} +31 -24
  83. guidellm/utils/__init__.py +71 -8
  84. guidellm/utils/auto_importer.py +98 -0
  85. guidellm/utils/cli.py +132 -5
  86. guidellm/utils/console.py +566 -0
  87. guidellm/utils/encoding.py +778 -0
  88. guidellm/utils/functions.py +159 -0
  89. guidellm/utils/hf_datasets.py +1 -2
  90. guidellm/utils/hf_transformers.py +4 -4
  91. guidellm/utils/imports.py +9 -0
  92. guidellm/utils/messaging.py +1118 -0
  93. guidellm/utils/mixins.py +115 -0
  94. guidellm/utils/random.py +3 -4
  95. guidellm/utils/registry.py +220 -0
  96. guidellm/utils/singleton.py +133 -0
  97. guidellm/utils/synchronous.py +159 -0
  98. guidellm/utils/text.py +163 -50
  99. guidellm/utils/typing.py +41 -0
  100. guidellm/version.py +2 -2
  101. guidellm-0.6.0a5.dist-info/METADATA +364 -0
  102. guidellm-0.6.0a5.dist-info/RECORD +109 -0
  103. guidellm/backend/__init__.py +0 -23
  104. guidellm/backend/backend.py +0 -259
  105. guidellm/backend/openai.py +0 -708
  106. guidellm/backend/response.py +0 -136
  107. guidellm/benchmark/aggregator.py +0 -760
  108. guidellm/benchmark/benchmark.py +0 -837
  109. guidellm/benchmark/output.py +0 -997
  110. guidellm/benchmark/profile.py +0 -409
  111. guidellm/benchmark/scenario.py +0 -104
  112. guidellm/data/prideandprejudice.txt.gz +0 -0
  113. guidellm/dataset/__init__.py +0 -22
  114. guidellm/dataset/creator.py +0 -213
  115. guidellm/dataset/entrypoints.py +0 -42
  116. guidellm/dataset/file.py +0 -92
  117. guidellm/dataset/hf_datasets.py +0 -62
  118. guidellm/dataset/in_memory.py +0 -132
  119. guidellm/dataset/synthetic.py +0 -287
  120. guidellm/objects/__init__.py +0 -18
  121. guidellm/objects/pydantic.py +0 -89
  122. guidellm/objects/statistics.py +0 -953
  123. guidellm/preprocess/__init__.py +0 -3
  124. guidellm/preprocess/dataset.py +0 -374
  125. guidellm/presentation/__init__.py +0 -28
  126. guidellm/presentation/builder.py +0 -27
  127. guidellm/presentation/data_models.py +0 -232
  128. guidellm/presentation/injector.py +0 -66
  129. guidellm/request/__init__.py +0 -18
  130. guidellm/request/loader.py +0 -284
  131. guidellm/request/request.py +0 -79
  132. guidellm/request/types.py +0 -10
  133. guidellm/scheduler/queues.py +0 -25
  134. guidellm/scheduler/result.py +0 -155
  135. guidellm/scheduler/strategy.py +0 -495
  136. guidellm-0.3.1.dist-info/METADATA +0 -329
  137. guidellm-0.3.1.dist-info/RECORD +0 -62
  138. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/WHEEL +0 -0
  139. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/entry_points.txt +0 -0
  140. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/licenses/LICENSE +0 -0
  141. {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,175 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import ConfigDict, Field, model_validator
6
+
7
+ from guidellm.schemas import StandardBaseModel
8
+
9
+ __all__ = [
10
+ "DataConfig",
11
+ "DataNotSupportedError",
12
+ "GenerativeDatasetColumnType",
13
+ "SyntheticTextDatasetConfig",
14
+ "SyntheticTextPrefixBucketConfig",
15
+ ]
16
+
17
+
18
+ GenerativeDatasetColumnType = Literal[
19
+ "prompt_tokens_count_column",
20
+ "output_tokens_count_column",
21
+ "prefix_column",
22
+ "text_column",
23
+ "image_column",
24
+ "video_column",
25
+ "audio_column",
26
+ ]
27
+
28
+ class DataNotSupportedError(Exception):
29
+ """
30
+ Exception raised when the data format is not supported by deserializer or config.
31
+ """
32
+
33
+ class DataConfig(StandardBaseModel):
34
+ """
35
+ A generic parent class for various configs for the data package
36
+ that can be passed in as key-value pairs or JSON.
37
+ """
38
+
39
+ class PreprocessDatasetConfig(DataConfig):
40
+
41
+ prompt_tokens: int = Field(
42
+ description="The average number of text tokens retained or added to prompts.",
43
+ gt=0,
44
+ )
45
+ prompt_tokens_stdev: int | None = Field(
46
+ description="The standard deviation of the number of tokens retained in or "
47
+ "added to prompts.",
48
+ gt=0,
49
+ default=None,
50
+ )
51
+ prompt_tokens_min: int | None = Field(
52
+ description="The minimum number of text tokens retained or added to prompts.",
53
+ gt=0,
54
+ default=None,
55
+ )
56
+ prompt_tokens_max: int | None = Field(
57
+ description="The maximum number of text tokens retained or added to prompts.",
58
+ gt=0,
59
+ default=None,
60
+ )
61
+ output_tokens: int = Field(
62
+ description="The average number of text tokens retained or added to outputs.",
63
+ gt=0,
64
+ )
65
+ output_tokens_stdev: int | None = Field(
66
+ description="The standard deviation of the number of tokens retained or "
67
+ "added to outputs.",
68
+ gt=0,
69
+ default=None,
70
+ )
71
+ output_tokens_min: int | None = Field(
72
+ description="The minimum number of text tokens retained or added to outputs.",
73
+ gt=0,
74
+ default=None,
75
+ )
76
+ output_tokens_max: int | None = Field(
77
+ description="The maximum number of text tokens retained or added to outputs.",
78
+ gt=0,
79
+ default=None,
80
+ )
81
+ prefix_tokens_max: int | None = Field(
82
+ description="The maximum number of text tokens left in the prefixes.",
83
+ gt=0,
84
+ default=None,
85
+ )
86
+
87
+ class SyntheticTextPrefixBucketConfig(StandardBaseModel):
88
+ bucket_weight: int = Field(
89
+ description="Weight of this bucket in the overall distribution.",
90
+ gt=0,
91
+ default=100,
92
+ )
93
+ prefix_count: int = Field(
94
+ description="The number of unique prefixes to generate for this bucket.",
95
+ ge=1,
96
+ default=1,
97
+ )
98
+ prefix_tokens: int = Field(
99
+ description="The number of prefix tokens per-prompt for this bucket.",
100
+ ge=0,
101
+ default=0,
102
+ )
103
+
104
+
105
+ class SyntheticTextDatasetConfig(DataConfig):
106
+ prompt_tokens: int = Field(
107
+ description="The average number of text tokens generated for prompts.",
108
+ gt=0,
109
+ )
110
+ prompt_tokens_stdev: int | None = Field(
111
+ description="The standard deviation of the tokens generated for prompts.",
112
+ gt=0,
113
+ default=None,
114
+ )
115
+ prompt_tokens_min: int | None = Field(
116
+ description="The minimum number of text tokens generated for prompts.",
117
+ gt=0,
118
+ default=None,
119
+ )
120
+ prompt_tokens_max: int | None = Field(
121
+ description="The maximum number of text tokens generated for prompts.",
122
+ gt=0,
123
+ default=None,
124
+ )
125
+ output_tokens: int = Field(
126
+ description="The average number of text tokens generated for outputs.",
127
+ gt=0,
128
+ )
129
+ output_tokens_stdev: int | None = Field(
130
+ description="The standard deviation of the tokens generated for outputs.",
131
+ gt=0,
132
+ default=None,
133
+ )
134
+ output_tokens_min: int | None = Field(
135
+ description="The minimum number of text tokens generated for outputs.",
136
+ gt=0,
137
+ default=None,
138
+ )
139
+ output_tokens_max: int | None = Field(
140
+ description="The maximum number of text tokens generated for outputs.",
141
+ gt=0,
142
+ default=None,
143
+ )
144
+
145
+ model_config = ConfigDict(
146
+ extra="allow",
147
+ )
148
+
149
+ prefix_buckets: list[SyntheticTextPrefixBucketConfig] | None = Field(
150
+ description="Buckets for the prefix tokens distribution.",
151
+ default=None,
152
+ )
153
+
154
+
155
+ @model_validator(mode="after")
156
+ def check_prefix_options(self) -> SyntheticTextDatasetConfig:
157
+ if self.__pydantic_extra__ is not None:
158
+ prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined]
159
+ prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined]
160
+
161
+ if prefix_count is not None or prefix_tokens is not None:
162
+ if self.prefix_buckets:
163
+ raise ValueError(
164
+ "prefix_buckets is mutually exclusive"
165
+ " with prefix_count and prefix_tokens"
166
+ )
167
+
168
+ self.prefix_buckets = [
169
+ SyntheticTextPrefixBucketConfig(
170
+ prefix_count=prefix_count or 1,
171
+ prefix_tokens=prefix_tokens or 0,
172
+ )
173
+ ]
174
+
175
+ return self
@@ -0,0 +1,6 @@
1
+ from .dataset import DEFAULT_SPLITS, resolve_dataset_split
2
+
3
+ __all__ = [
4
+ "DEFAULT_SPLITS",
5
+ "resolve_dataset_split",
6
+ ]
@@ -0,0 +1,94 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
6
+
7
+ __all__ = ["DEFAULT_SPLITS", "resolve_dataset_split"]
8
+
9
+
10
+ DEFAULT_SPLITS: dict[Literal["train", "calib", "val", "test"], list[str]] = {
11
+ "train": [
12
+ "train",
13
+ "training",
14
+ "train_set",
15
+ "training_set",
16
+ "train_dataset",
17
+ "training_dataset",
18
+ "train_data",
19
+ "training_data",
20
+ "pretrain",
21
+ "pretrain_set",
22
+ "pretrain_dataset",
23
+ "pretrain_data",
24
+ "pretraining",
25
+ ],
26
+ "calib": [
27
+ "calibration",
28
+ "calib",
29
+ "cal",
30
+ "calibration_set",
31
+ "calib_set",
32
+ "cal_set",
33
+ "calibration_dataset",
34
+ "calib_dataset",
35
+ "cal_set",
36
+ "calibration_data",
37
+ "calib_data",
38
+ "cal_data",
39
+ ],
40
+ "val": [
41
+ "validation",
42
+ "val",
43
+ "valid",
44
+ "validation_set",
45
+ "val_set",
46
+ "validation_dataset",
47
+ "val_dataset",
48
+ "validation_data",
49
+ "val_data",
50
+ "dev",
51
+ "dev_set",
52
+ "dev_dataset",
53
+ "dev_data",
54
+ ],
55
+ "test": [
56
+ "test",
57
+ "testing",
58
+ "test_set",
59
+ "testing_set",
60
+ "test_dataset",
61
+ "testing_dataset",
62
+ "test_data",
63
+ "testing_data",
64
+ "eval",
65
+ "eval_set",
66
+ "eval_dataset",
67
+ "eval_data",
68
+ ],
69
+ }
70
+
71
+
72
+ def resolve_dataset_split(
73
+ dataset: Dataset | IterableDataset | DatasetDict | IterableDatasetDict,
74
+ split: str | None = None,
75
+ ) -> Dataset | IterableDataset:
76
+ if split is not None and isinstance(dataset, DatasetDict | IterableDatasetDict):
77
+ if split in dataset:
78
+ return dataset[split]
79
+
80
+ raise ValueError(f"Requested split '{split}' not found in dataset: {dataset}.")
81
+ elif split is not None:
82
+ raise ValueError(
83
+ f"Requested split '{split}' but dataset has no splits: {dataset}."
84
+ )
85
+
86
+ if isinstance(dataset, Dataset | IterableDataset):
87
+ return dataset
88
+
89
+ for _, default_splits in DEFAULT_SPLITS.items():
90
+ for default_split in default_splits:
91
+ if default_split in dataset:
92
+ return dataset[default_split]
93
+
94
+ return dataset[list(dataset.keys())[0]]
@@ -0,0 +1,4 @@
1
+ """
2
+ Code that depends on optional dependencies.
3
+ Each submodule should be deferred imported.
4
+ """
@@ -0,0 +1,220 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ from pathlib import Path
5
+ from typing import Any, Literal
6
+
7
+ import httpx
8
+ import numpy as np
9
+ import torch
10
+
11
+ try:
12
+ from torchcodec import AudioSamples
13
+ from torchcodec.decoders import AudioDecoder
14
+ from torchcodec.encoders import AudioEncoder
15
+ except ImportError as e:
16
+ raise ImportError("Please install guidellm[audio] to use audio features") from e
17
+
18
+ __all__ = [
19
+ "encode_audio",
20
+ "is_url",
21
+ ]
22
+
23
+
24
+ def is_url(text: Any) -> bool:
25
+ return isinstance(text, str) and text.startswith(("http://", "https://"))
26
+
27
+
28
+ def encode_audio(
29
+ audio: AudioDecoder
30
+ | bytes
31
+ | str
32
+ | Path
33
+ | np.ndarray
34
+ | torch.Tensor
35
+ | dict[str, Any],
36
+ b64encode: bool = False,
37
+ sample_rate: int | None = None,
38
+ file_name: str = "audio.wav",
39
+ encode_sample_rate: int = 16000,
40
+ max_duration: float | None = None,
41
+ mono: bool = True,
42
+ audio_format: str = "mp3",
43
+ bitrate: str = "64k",
44
+ ) -> dict[
45
+ Literal[
46
+ "type",
47
+ "audio",
48
+ "format",
49
+ "mimetype",
50
+ "audio_samples",
51
+ "audio_seconds",
52
+ "audio_bytes",
53
+ "file_name",
54
+ ],
55
+ str | int | float | bytes | None,
56
+ ]:
57
+ """Decode audio (if necessary) and re-encode to specified format."""
58
+ samples = _decode_audio(audio, sample_rate=sample_rate, max_duration=max_duration)
59
+
60
+ bitrate_val = (
61
+ int(bitrate.rstrip("k")) * 1000 if bitrate.endswith("k") else int(bitrate)
62
+ )
63
+ format_val = audio_format.lower()
64
+
65
+ encoded_audio = _encode_audio(
66
+ samples=samples,
67
+ resample_rate=encode_sample_rate,
68
+ bitrate=bitrate_val,
69
+ audio_format=format_val,
70
+ mono=mono,
71
+ )
72
+
73
+ return {
74
+ "type": "audio_base64" if b64encode else "audio_file",
75
+ "audio": (
76
+ base64.b64encode(encoded_audio).decode("utf-8")
77
+ if b64encode
78
+ else encoded_audio
79
+ ),
80
+ "file_name": get_file_name(audio)
81
+ if isinstance(audio, str | Path)
82
+ else file_name,
83
+ "format": audio_format,
84
+ "mimetype": f"audio/{format_val}",
85
+ "audio_samples": samples.sample_rate,
86
+ "audio_seconds": samples.duration_seconds,
87
+ "audio_bytes": len(encoded_audio),
88
+ }
89
+
90
+
91
+ def _decode_audio( # noqa: C901, PLR0912
92
+ audio: AudioDecoder
93
+ | bytes
94
+ | str
95
+ | Path
96
+ | np.ndarray
97
+ | torch.Tensor
98
+ | dict[str, Any],
99
+ sample_rate: int | None = None,
100
+ max_duration: float | None = None,
101
+ ) -> AudioSamples:
102
+ """Decode audio from various input types into AudioSamples."""
103
+ # If input is a dict, unwrap it into a function call
104
+ if isinstance(audio, dict):
105
+ sample_rate = audio.get("sample_rate", audio.get("sampling_rate", sample_rate))
106
+ if "data" not in audio and "url" not in audio:
107
+ raise ValueError(
108
+ f"Audio dict must contain either 'data' or 'url' keys, got {audio}"
109
+ )
110
+ audio_data = audio["data"] if "data" in audio else audio.get("url")
111
+ if audio_data is None:
112
+ raise ValueError(
113
+ f"Audio dict must contain either 'data' or 'url' keys, got {audio}"
114
+ )
115
+ return _decode_audio(
116
+ audio=audio_data,
117
+ sample_rate=sample_rate,
118
+ max_duration=max_duration,
119
+ )
120
+
121
+ # Convert numpy array to torch tensor and re-call
122
+ if isinstance(audio, np.ndarray):
123
+ return _decode_audio(
124
+ audio=torch.from_numpy(audio),
125
+ sample_rate=sample_rate,
126
+ max_duration=max_duration,
127
+ )
128
+
129
+ samples: AudioSamples
130
+
131
+ data: torch.Tensor | bytes
132
+ # HF datasets return AudioDecoder for audio column
133
+ if isinstance(audio, AudioDecoder):
134
+ samples = audio.get_samples_played_in_range(stop_seconds=max_duration)
135
+ elif isinstance(audio, torch.Tensor):
136
+ # If float stream assume decoded audio
137
+ if torch.is_floating_point(audio):
138
+ if sample_rate is None:
139
+ raise ValueError("Sample rate must be set for decoded audio")
140
+
141
+ full_duration = audio.shape[1] / sample_rate
142
+ # If max_duration is set, trim the audio to that duration
143
+ if max_duration is not None:
144
+ num_samples = int(max_duration * sample_rate)
145
+ duration = min(max_duration, full_duration)
146
+ data = audio[:, :num_samples]
147
+ else:
148
+ duration = full_duration
149
+ data = audio
150
+
151
+ samples = AudioSamples(
152
+ data=data,
153
+ pts_seconds=0.0,
154
+ duration_seconds=duration,
155
+ sample_rate=sample_rate,
156
+ )
157
+ # If bytes tensor assume encoded audio
158
+ elif audio.dtype == torch.uint8:
159
+ decoder = AudioDecoder(
160
+ source=audio,
161
+ sample_rate=sample_rate,
162
+ )
163
+ samples = decoder.get_samples_played_in_range(stop_seconds=max_duration)
164
+
165
+ else:
166
+ raise ValueError(f"Unsupported audio type: {type(audio)}")
167
+
168
+ # If bytes, assume encoded audio
169
+ elif isinstance(audio, bytes):
170
+ decoder = AudioDecoder(
171
+ source=audio,
172
+ sample_rate=sample_rate,
173
+ )
174
+ samples = decoder.get_samples_played_in_range(stop_seconds=max_duration)
175
+
176
+ # If str or Path, assume file path or URL to encoded audio
177
+ elif isinstance(audio, str | Path):
178
+ if isinstance(audio, str) and is_url(audio):
179
+ response = httpx.get(audio)
180
+ response.raise_for_status()
181
+ data = response.content
182
+ else:
183
+ if not Path(audio).exists():
184
+ raise ValueError(f"Audio file does not exist: {audio}")
185
+ data = Path(audio).read_bytes()
186
+ decoder = AudioDecoder(
187
+ source=data,
188
+ )
189
+ samples = decoder.get_samples_played_in_range(stop_seconds=max_duration)
190
+ else:
191
+ raise ValueError(f"Unsupported audio type: {type(audio)}")
192
+
193
+ return samples
194
+
195
+
196
+ def _encode_audio(
197
+ samples: AudioSamples,
198
+ resample_rate: int | None = None,
199
+ bitrate: int = 64000,
200
+ audio_format: str = "mp3",
201
+ mono: bool = True,
202
+ ) -> bytes:
203
+ encoder = AudioEncoder(
204
+ samples=samples.data,
205
+ sample_rate=samples.sample_rate,
206
+ )
207
+
208
+ audio_tensor = encoder.to_tensor(
209
+ format=audio_format,
210
+ bit_rate=bitrate if audio_format == "mp3" else None,
211
+ num_channels=1 if mono else None,
212
+ sample_rate=resample_rate,
213
+ )
214
+
215
+ return audio_tensor.numpy().tobytes()
216
+
217
+
218
+ def get_file_name(path: Path | str) -> str:
219
+ """Get file name from path."""
220
+ return Path(path).name