guidellm 0.3.1__py3-none-any.whl → 0.6.0a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guidellm/__init__.py +5 -2
- guidellm/__main__.py +524 -255
- guidellm/backends/__init__.py +33 -0
- guidellm/backends/backend.py +109 -0
- guidellm/backends/openai.py +340 -0
- guidellm/backends/response_handlers.py +428 -0
- guidellm/benchmark/__init__.py +69 -39
- guidellm/benchmark/benchmarker.py +160 -316
- guidellm/benchmark/entrypoints.py +560 -127
- guidellm/benchmark/outputs/__init__.py +24 -0
- guidellm/benchmark/outputs/console.py +633 -0
- guidellm/benchmark/outputs/csv.py +721 -0
- guidellm/benchmark/outputs/html.py +473 -0
- guidellm/benchmark/outputs/output.py +169 -0
- guidellm/benchmark/outputs/serialized.py +69 -0
- guidellm/benchmark/profiles.py +718 -0
- guidellm/benchmark/progress.py +553 -556
- guidellm/benchmark/scenarios/__init__.py +40 -0
- guidellm/benchmark/scenarios/chat.json +6 -0
- guidellm/benchmark/scenarios/rag.json +6 -0
- guidellm/benchmark/schemas/__init__.py +66 -0
- guidellm/benchmark/schemas/base.py +402 -0
- guidellm/benchmark/schemas/generative/__init__.py +55 -0
- guidellm/benchmark/schemas/generative/accumulator.py +841 -0
- guidellm/benchmark/schemas/generative/benchmark.py +163 -0
- guidellm/benchmark/schemas/generative/entrypoints.py +381 -0
- guidellm/benchmark/schemas/generative/metrics.py +927 -0
- guidellm/benchmark/schemas/generative/report.py +158 -0
- guidellm/data/__init__.py +34 -4
- guidellm/data/builders.py +541 -0
- guidellm/data/collators.py +16 -0
- guidellm/data/config.py +120 -0
- guidellm/data/deserializers/__init__.py +49 -0
- guidellm/data/deserializers/deserializer.py +141 -0
- guidellm/data/deserializers/file.py +223 -0
- guidellm/data/deserializers/huggingface.py +94 -0
- guidellm/data/deserializers/memory.py +194 -0
- guidellm/data/deserializers/synthetic.py +246 -0
- guidellm/data/entrypoints.py +52 -0
- guidellm/data/loaders.py +190 -0
- guidellm/data/preprocessors/__init__.py +27 -0
- guidellm/data/preprocessors/formatters.py +410 -0
- guidellm/data/preprocessors/mappers.py +196 -0
- guidellm/data/preprocessors/preprocessor.py +30 -0
- guidellm/data/processor.py +29 -0
- guidellm/data/schemas.py +175 -0
- guidellm/data/utils/__init__.py +6 -0
- guidellm/data/utils/dataset.py +94 -0
- guidellm/extras/__init__.py +4 -0
- guidellm/extras/audio.py +220 -0
- guidellm/extras/vision.py +242 -0
- guidellm/logger.py +2 -2
- guidellm/mock_server/__init__.py +8 -0
- guidellm/mock_server/config.py +84 -0
- guidellm/mock_server/handlers/__init__.py +17 -0
- guidellm/mock_server/handlers/chat_completions.py +280 -0
- guidellm/mock_server/handlers/completions.py +280 -0
- guidellm/mock_server/handlers/tokenizer.py +142 -0
- guidellm/mock_server/models.py +510 -0
- guidellm/mock_server/server.py +238 -0
- guidellm/mock_server/utils.py +302 -0
- guidellm/scheduler/__init__.py +69 -26
- guidellm/scheduler/constraints/__init__.py +49 -0
- guidellm/scheduler/constraints/constraint.py +325 -0
- guidellm/scheduler/constraints/error.py +411 -0
- guidellm/scheduler/constraints/factory.py +182 -0
- guidellm/scheduler/constraints/request.py +312 -0
- guidellm/scheduler/constraints/saturation.py +722 -0
- guidellm/scheduler/environments.py +252 -0
- guidellm/scheduler/scheduler.py +137 -368
- guidellm/scheduler/schemas.py +358 -0
- guidellm/scheduler/strategies.py +617 -0
- guidellm/scheduler/worker.py +413 -419
- guidellm/scheduler/worker_group.py +712 -0
- guidellm/schemas/__init__.py +65 -0
- guidellm/schemas/base.py +417 -0
- guidellm/schemas/info.py +188 -0
- guidellm/schemas/request.py +235 -0
- guidellm/schemas/request_stats.py +349 -0
- guidellm/schemas/response.py +124 -0
- guidellm/schemas/statistics.py +1018 -0
- guidellm/{config.py → settings.py} +31 -24
- guidellm/utils/__init__.py +71 -8
- guidellm/utils/auto_importer.py +98 -0
- guidellm/utils/cli.py +132 -5
- guidellm/utils/console.py +566 -0
- guidellm/utils/encoding.py +778 -0
- guidellm/utils/functions.py +159 -0
- guidellm/utils/hf_datasets.py +1 -2
- guidellm/utils/hf_transformers.py +4 -4
- guidellm/utils/imports.py +9 -0
- guidellm/utils/messaging.py +1118 -0
- guidellm/utils/mixins.py +115 -0
- guidellm/utils/random.py +3 -4
- guidellm/utils/registry.py +220 -0
- guidellm/utils/singleton.py +133 -0
- guidellm/utils/synchronous.py +159 -0
- guidellm/utils/text.py +163 -50
- guidellm/utils/typing.py +41 -0
- guidellm/version.py +2 -2
- guidellm-0.6.0a5.dist-info/METADATA +364 -0
- guidellm-0.6.0a5.dist-info/RECORD +109 -0
- guidellm/backend/__init__.py +0 -23
- guidellm/backend/backend.py +0 -259
- guidellm/backend/openai.py +0 -708
- guidellm/backend/response.py +0 -136
- guidellm/benchmark/aggregator.py +0 -760
- guidellm/benchmark/benchmark.py +0 -837
- guidellm/benchmark/output.py +0 -997
- guidellm/benchmark/profile.py +0 -409
- guidellm/benchmark/scenario.py +0 -104
- guidellm/data/prideandprejudice.txt.gz +0 -0
- guidellm/dataset/__init__.py +0 -22
- guidellm/dataset/creator.py +0 -213
- guidellm/dataset/entrypoints.py +0 -42
- guidellm/dataset/file.py +0 -92
- guidellm/dataset/hf_datasets.py +0 -62
- guidellm/dataset/in_memory.py +0 -132
- guidellm/dataset/synthetic.py +0 -287
- guidellm/objects/__init__.py +0 -18
- guidellm/objects/pydantic.py +0 -89
- guidellm/objects/statistics.py +0 -953
- guidellm/preprocess/__init__.py +0 -3
- guidellm/preprocess/dataset.py +0 -374
- guidellm/presentation/__init__.py +0 -28
- guidellm/presentation/builder.py +0 -27
- guidellm/presentation/data_models.py +0 -232
- guidellm/presentation/injector.py +0 -66
- guidellm/request/__init__.py +0 -18
- guidellm/request/loader.py +0 -284
- guidellm/request/request.py +0 -79
- guidellm/request/types.py +0 -10
- guidellm/scheduler/queues.py +0 -25
- guidellm/scheduler/result.py +0 -155
- guidellm/scheduler/strategy.py +0 -495
- guidellm-0.3.1.dist-info/METADATA +0 -329
- guidellm-0.3.1.dist-info/RECORD +0 -62
- {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/WHEEL +0 -0
- {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/entry_points.txt +0 -0
- {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/licenses/LICENSE +0 -0
- {guidellm-0.3.1.dist-info → guidellm-0.6.0a5.dist-info}/top_level.txt +0 -0
guidellm/data/schemas.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field, model_validator
|
|
6
|
+
|
|
7
|
+
from guidellm.schemas import StandardBaseModel
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"DataConfig",
|
|
11
|
+
"DataNotSupportedError",
|
|
12
|
+
"GenerativeDatasetColumnType",
|
|
13
|
+
"SyntheticTextDatasetConfig",
|
|
14
|
+
"SyntheticTextPrefixBucketConfig",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
GenerativeDatasetColumnType = Literal[
|
|
19
|
+
"prompt_tokens_count_column",
|
|
20
|
+
"output_tokens_count_column",
|
|
21
|
+
"prefix_column",
|
|
22
|
+
"text_column",
|
|
23
|
+
"image_column",
|
|
24
|
+
"video_column",
|
|
25
|
+
"audio_column",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
class DataNotSupportedError(Exception):
|
|
29
|
+
"""
|
|
30
|
+
Exception raised when the data format is not supported by deserializer or config.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
class DataConfig(StandardBaseModel):
|
|
34
|
+
"""
|
|
35
|
+
A generic parent class for various configs for the data package
|
|
36
|
+
that can be passed in as key-value pairs or JSON.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
class PreprocessDatasetConfig(DataConfig):
|
|
40
|
+
|
|
41
|
+
prompt_tokens: int = Field(
|
|
42
|
+
description="The average number of text tokens retained or added to prompts.",
|
|
43
|
+
gt=0,
|
|
44
|
+
)
|
|
45
|
+
prompt_tokens_stdev: int | None = Field(
|
|
46
|
+
description="The standard deviation of the number of tokens retained in or "
|
|
47
|
+
"added to prompts.",
|
|
48
|
+
gt=0,
|
|
49
|
+
default=None,
|
|
50
|
+
)
|
|
51
|
+
prompt_tokens_min: int | None = Field(
|
|
52
|
+
description="The minimum number of text tokens retained or added to prompts.",
|
|
53
|
+
gt=0,
|
|
54
|
+
default=None,
|
|
55
|
+
)
|
|
56
|
+
prompt_tokens_max: int | None = Field(
|
|
57
|
+
description="The maximum number of text tokens retained or added to prompts.",
|
|
58
|
+
gt=0,
|
|
59
|
+
default=None,
|
|
60
|
+
)
|
|
61
|
+
output_tokens: int = Field(
|
|
62
|
+
description="The average number of text tokens retained or added to outputs.",
|
|
63
|
+
gt=0,
|
|
64
|
+
)
|
|
65
|
+
output_tokens_stdev: int | None = Field(
|
|
66
|
+
description="The standard deviation of the number of tokens retained or "
|
|
67
|
+
"added to outputs.",
|
|
68
|
+
gt=0,
|
|
69
|
+
default=None,
|
|
70
|
+
)
|
|
71
|
+
output_tokens_min: int | None = Field(
|
|
72
|
+
description="The minimum number of text tokens retained or added to outputs.",
|
|
73
|
+
gt=0,
|
|
74
|
+
default=None,
|
|
75
|
+
)
|
|
76
|
+
output_tokens_max: int | None = Field(
|
|
77
|
+
description="The maximum number of text tokens retained or added to outputs.",
|
|
78
|
+
gt=0,
|
|
79
|
+
default=None,
|
|
80
|
+
)
|
|
81
|
+
prefix_tokens_max: int | None = Field(
|
|
82
|
+
description="The maximum number of text tokens left in the prefixes.",
|
|
83
|
+
gt=0,
|
|
84
|
+
default=None,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
class SyntheticTextPrefixBucketConfig(StandardBaseModel):
|
|
88
|
+
bucket_weight: int = Field(
|
|
89
|
+
description="Weight of this bucket in the overall distribution.",
|
|
90
|
+
gt=0,
|
|
91
|
+
default=100,
|
|
92
|
+
)
|
|
93
|
+
prefix_count: int = Field(
|
|
94
|
+
description="The number of unique prefixes to generate for this bucket.",
|
|
95
|
+
ge=1,
|
|
96
|
+
default=1,
|
|
97
|
+
)
|
|
98
|
+
prefix_tokens: int = Field(
|
|
99
|
+
description="The number of prefix tokens per-prompt for this bucket.",
|
|
100
|
+
ge=0,
|
|
101
|
+
default=0,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class SyntheticTextDatasetConfig(DataConfig):
|
|
106
|
+
prompt_tokens: int = Field(
|
|
107
|
+
description="The average number of text tokens generated for prompts.",
|
|
108
|
+
gt=0,
|
|
109
|
+
)
|
|
110
|
+
prompt_tokens_stdev: int | None = Field(
|
|
111
|
+
description="The standard deviation of the tokens generated for prompts.",
|
|
112
|
+
gt=0,
|
|
113
|
+
default=None,
|
|
114
|
+
)
|
|
115
|
+
prompt_tokens_min: int | None = Field(
|
|
116
|
+
description="The minimum number of text tokens generated for prompts.",
|
|
117
|
+
gt=0,
|
|
118
|
+
default=None,
|
|
119
|
+
)
|
|
120
|
+
prompt_tokens_max: int | None = Field(
|
|
121
|
+
description="The maximum number of text tokens generated for prompts.",
|
|
122
|
+
gt=0,
|
|
123
|
+
default=None,
|
|
124
|
+
)
|
|
125
|
+
output_tokens: int = Field(
|
|
126
|
+
description="The average number of text tokens generated for outputs.",
|
|
127
|
+
gt=0,
|
|
128
|
+
)
|
|
129
|
+
output_tokens_stdev: int | None = Field(
|
|
130
|
+
description="The standard deviation of the tokens generated for outputs.",
|
|
131
|
+
gt=0,
|
|
132
|
+
default=None,
|
|
133
|
+
)
|
|
134
|
+
output_tokens_min: int | None = Field(
|
|
135
|
+
description="The minimum number of text tokens generated for outputs.",
|
|
136
|
+
gt=0,
|
|
137
|
+
default=None,
|
|
138
|
+
)
|
|
139
|
+
output_tokens_max: int | None = Field(
|
|
140
|
+
description="The maximum number of text tokens generated for outputs.",
|
|
141
|
+
gt=0,
|
|
142
|
+
default=None,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
model_config = ConfigDict(
|
|
146
|
+
extra="allow",
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
prefix_buckets: list[SyntheticTextPrefixBucketConfig] | None = Field(
|
|
150
|
+
description="Buckets for the prefix tokens distribution.",
|
|
151
|
+
default=None,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
@model_validator(mode="after")
|
|
156
|
+
def check_prefix_options(self) -> SyntheticTextDatasetConfig:
|
|
157
|
+
if self.__pydantic_extra__ is not None:
|
|
158
|
+
prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined]
|
|
159
|
+
prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined]
|
|
160
|
+
|
|
161
|
+
if prefix_count is not None or prefix_tokens is not None:
|
|
162
|
+
if self.prefix_buckets:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
"prefix_buckets is mutually exclusive"
|
|
165
|
+
" with prefix_count and prefix_tokens"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
self.prefix_buckets = [
|
|
169
|
+
SyntheticTextPrefixBucketConfig(
|
|
170
|
+
prefix_count=prefix_count or 1,
|
|
171
|
+
prefix_tokens=prefix_tokens or 0,
|
|
172
|
+
)
|
|
173
|
+
]
|
|
174
|
+
|
|
175
|
+
return self
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
|
6
|
+
|
|
7
|
+
__all__ = ["DEFAULT_SPLITS", "resolve_dataset_split"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
DEFAULT_SPLITS: dict[Literal["train", "calib", "val", "test"], list[str]] = {
|
|
11
|
+
"train": [
|
|
12
|
+
"train",
|
|
13
|
+
"training",
|
|
14
|
+
"train_set",
|
|
15
|
+
"training_set",
|
|
16
|
+
"train_dataset",
|
|
17
|
+
"training_dataset",
|
|
18
|
+
"train_data",
|
|
19
|
+
"training_data",
|
|
20
|
+
"pretrain",
|
|
21
|
+
"pretrain_set",
|
|
22
|
+
"pretrain_dataset",
|
|
23
|
+
"pretrain_data",
|
|
24
|
+
"pretraining",
|
|
25
|
+
],
|
|
26
|
+
"calib": [
|
|
27
|
+
"calibration",
|
|
28
|
+
"calib",
|
|
29
|
+
"cal",
|
|
30
|
+
"calibration_set",
|
|
31
|
+
"calib_set",
|
|
32
|
+
"cal_set",
|
|
33
|
+
"calibration_dataset",
|
|
34
|
+
"calib_dataset",
|
|
35
|
+
"cal_set",
|
|
36
|
+
"calibration_data",
|
|
37
|
+
"calib_data",
|
|
38
|
+
"cal_data",
|
|
39
|
+
],
|
|
40
|
+
"val": [
|
|
41
|
+
"validation",
|
|
42
|
+
"val",
|
|
43
|
+
"valid",
|
|
44
|
+
"validation_set",
|
|
45
|
+
"val_set",
|
|
46
|
+
"validation_dataset",
|
|
47
|
+
"val_dataset",
|
|
48
|
+
"validation_data",
|
|
49
|
+
"val_data",
|
|
50
|
+
"dev",
|
|
51
|
+
"dev_set",
|
|
52
|
+
"dev_dataset",
|
|
53
|
+
"dev_data",
|
|
54
|
+
],
|
|
55
|
+
"test": [
|
|
56
|
+
"test",
|
|
57
|
+
"testing",
|
|
58
|
+
"test_set",
|
|
59
|
+
"testing_set",
|
|
60
|
+
"test_dataset",
|
|
61
|
+
"testing_dataset",
|
|
62
|
+
"test_data",
|
|
63
|
+
"testing_data",
|
|
64
|
+
"eval",
|
|
65
|
+
"eval_set",
|
|
66
|
+
"eval_dataset",
|
|
67
|
+
"eval_data",
|
|
68
|
+
],
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def resolve_dataset_split(
|
|
73
|
+
dataset: Dataset | IterableDataset | DatasetDict | IterableDatasetDict,
|
|
74
|
+
split: str | None = None,
|
|
75
|
+
) -> Dataset | IterableDataset:
|
|
76
|
+
if split is not None and isinstance(dataset, DatasetDict | IterableDatasetDict):
|
|
77
|
+
if split in dataset:
|
|
78
|
+
return dataset[split]
|
|
79
|
+
|
|
80
|
+
raise ValueError(f"Requested split '{split}' not found in dataset: {dataset}.")
|
|
81
|
+
elif split is not None:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"Requested split '{split}' but dataset has no splits: {dataset}."
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if isinstance(dataset, Dataset | IterableDataset):
|
|
87
|
+
return dataset
|
|
88
|
+
|
|
89
|
+
for _, default_splits in DEFAULT_SPLITS.items():
|
|
90
|
+
for default_split in default_splits:
|
|
91
|
+
if default_split in dataset:
|
|
92
|
+
return dataset[default_split]
|
|
93
|
+
|
|
94
|
+
return dataset[list(dataset.keys())[0]]
|
guidellm/extras/audio.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Literal
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from torchcodec import AudioSamples
|
|
13
|
+
from torchcodec.decoders import AudioDecoder
|
|
14
|
+
from torchcodec.encoders import AudioEncoder
|
|
15
|
+
except ImportError as e:
|
|
16
|
+
raise ImportError("Please install guidellm[audio] to use audio features") from e
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"encode_audio",
|
|
20
|
+
"is_url",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def is_url(text: Any) -> bool:
|
|
25
|
+
return isinstance(text, str) and text.startswith(("http://", "https://"))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def encode_audio(
|
|
29
|
+
audio: AudioDecoder
|
|
30
|
+
| bytes
|
|
31
|
+
| str
|
|
32
|
+
| Path
|
|
33
|
+
| np.ndarray
|
|
34
|
+
| torch.Tensor
|
|
35
|
+
| dict[str, Any],
|
|
36
|
+
b64encode: bool = False,
|
|
37
|
+
sample_rate: int | None = None,
|
|
38
|
+
file_name: str = "audio.wav",
|
|
39
|
+
encode_sample_rate: int = 16000,
|
|
40
|
+
max_duration: float | None = None,
|
|
41
|
+
mono: bool = True,
|
|
42
|
+
audio_format: str = "mp3",
|
|
43
|
+
bitrate: str = "64k",
|
|
44
|
+
) -> dict[
|
|
45
|
+
Literal[
|
|
46
|
+
"type",
|
|
47
|
+
"audio",
|
|
48
|
+
"format",
|
|
49
|
+
"mimetype",
|
|
50
|
+
"audio_samples",
|
|
51
|
+
"audio_seconds",
|
|
52
|
+
"audio_bytes",
|
|
53
|
+
"file_name",
|
|
54
|
+
],
|
|
55
|
+
str | int | float | bytes | None,
|
|
56
|
+
]:
|
|
57
|
+
"""Decode audio (if necessary) and re-encode to specified format."""
|
|
58
|
+
samples = _decode_audio(audio, sample_rate=sample_rate, max_duration=max_duration)
|
|
59
|
+
|
|
60
|
+
bitrate_val = (
|
|
61
|
+
int(bitrate.rstrip("k")) * 1000 if bitrate.endswith("k") else int(bitrate)
|
|
62
|
+
)
|
|
63
|
+
format_val = audio_format.lower()
|
|
64
|
+
|
|
65
|
+
encoded_audio = _encode_audio(
|
|
66
|
+
samples=samples,
|
|
67
|
+
resample_rate=encode_sample_rate,
|
|
68
|
+
bitrate=bitrate_val,
|
|
69
|
+
audio_format=format_val,
|
|
70
|
+
mono=mono,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
return {
|
|
74
|
+
"type": "audio_base64" if b64encode else "audio_file",
|
|
75
|
+
"audio": (
|
|
76
|
+
base64.b64encode(encoded_audio).decode("utf-8")
|
|
77
|
+
if b64encode
|
|
78
|
+
else encoded_audio
|
|
79
|
+
),
|
|
80
|
+
"file_name": get_file_name(audio)
|
|
81
|
+
if isinstance(audio, str | Path)
|
|
82
|
+
else file_name,
|
|
83
|
+
"format": audio_format,
|
|
84
|
+
"mimetype": f"audio/{format_val}",
|
|
85
|
+
"audio_samples": samples.sample_rate,
|
|
86
|
+
"audio_seconds": samples.duration_seconds,
|
|
87
|
+
"audio_bytes": len(encoded_audio),
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _decode_audio( # noqa: C901, PLR0912
|
|
92
|
+
audio: AudioDecoder
|
|
93
|
+
| bytes
|
|
94
|
+
| str
|
|
95
|
+
| Path
|
|
96
|
+
| np.ndarray
|
|
97
|
+
| torch.Tensor
|
|
98
|
+
| dict[str, Any],
|
|
99
|
+
sample_rate: int | None = None,
|
|
100
|
+
max_duration: float | None = None,
|
|
101
|
+
) -> AudioSamples:
|
|
102
|
+
"""Decode audio from various input types into AudioSamples."""
|
|
103
|
+
# If input is a dict, unwrap it into a function call
|
|
104
|
+
if isinstance(audio, dict):
|
|
105
|
+
sample_rate = audio.get("sample_rate", audio.get("sampling_rate", sample_rate))
|
|
106
|
+
if "data" not in audio and "url" not in audio:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"Audio dict must contain either 'data' or 'url' keys, got {audio}"
|
|
109
|
+
)
|
|
110
|
+
audio_data = audio["data"] if "data" in audio else audio.get("url")
|
|
111
|
+
if audio_data is None:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"Audio dict must contain either 'data' or 'url' keys, got {audio}"
|
|
114
|
+
)
|
|
115
|
+
return _decode_audio(
|
|
116
|
+
audio=audio_data,
|
|
117
|
+
sample_rate=sample_rate,
|
|
118
|
+
max_duration=max_duration,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Convert numpy array to torch tensor and re-call
|
|
122
|
+
if isinstance(audio, np.ndarray):
|
|
123
|
+
return _decode_audio(
|
|
124
|
+
audio=torch.from_numpy(audio),
|
|
125
|
+
sample_rate=sample_rate,
|
|
126
|
+
max_duration=max_duration,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
samples: AudioSamples
|
|
130
|
+
|
|
131
|
+
data: torch.Tensor | bytes
|
|
132
|
+
# HF datasets return AudioDecoder for audio column
|
|
133
|
+
if isinstance(audio, AudioDecoder):
|
|
134
|
+
samples = audio.get_samples_played_in_range(stop_seconds=max_duration)
|
|
135
|
+
elif isinstance(audio, torch.Tensor):
|
|
136
|
+
# If float stream assume decoded audio
|
|
137
|
+
if torch.is_floating_point(audio):
|
|
138
|
+
if sample_rate is None:
|
|
139
|
+
raise ValueError("Sample rate must be set for decoded audio")
|
|
140
|
+
|
|
141
|
+
full_duration = audio.shape[1] / sample_rate
|
|
142
|
+
# If max_duration is set, trim the audio to that duration
|
|
143
|
+
if max_duration is not None:
|
|
144
|
+
num_samples = int(max_duration * sample_rate)
|
|
145
|
+
duration = min(max_duration, full_duration)
|
|
146
|
+
data = audio[:, :num_samples]
|
|
147
|
+
else:
|
|
148
|
+
duration = full_duration
|
|
149
|
+
data = audio
|
|
150
|
+
|
|
151
|
+
samples = AudioSamples(
|
|
152
|
+
data=data,
|
|
153
|
+
pts_seconds=0.0,
|
|
154
|
+
duration_seconds=duration,
|
|
155
|
+
sample_rate=sample_rate,
|
|
156
|
+
)
|
|
157
|
+
# If bytes tensor assume encoded audio
|
|
158
|
+
elif audio.dtype == torch.uint8:
|
|
159
|
+
decoder = AudioDecoder(
|
|
160
|
+
source=audio,
|
|
161
|
+
sample_rate=sample_rate,
|
|
162
|
+
)
|
|
163
|
+
samples = decoder.get_samples_played_in_range(stop_seconds=max_duration)
|
|
164
|
+
|
|
165
|
+
else:
|
|
166
|
+
raise ValueError(f"Unsupported audio type: {type(audio)}")
|
|
167
|
+
|
|
168
|
+
# If bytes, assume encoded audio
|
|
169
|
+
elif isinstance(audio, bytes):
|
|
170
|
+
decoder = AudioDecoder(
|
|
171
|
+
source=audio,
|
|
172
|
+
sample_rate=sample_rate,
|
|
173
|
+
)
|
|
174
|
+
samples = decoder.get_samples_played_in_range(stop_seconds=max_duration)
|
|
175
|
+
|
|
176
|
+
# If str or Path, assume file path or URL to encoded audio
|
|
177
|
+
elif isinstance(audio, str | Path):
|
|
178
|
+
if isinstance(audio, str) and is_url(audio):
|
|
179
|
+
response = httpx.get(audio)
|
|
180
|
+
response.raise_for_status()
|
|
181
|
+
data = response.content
|
|
182
|
+
else:
|
|
183
|
+
if not Path(audio).exists():
|
|
184
|
+
raise ValueError(f"Audio file does not exist: {audio}")
|
|
185
|
+
data = Path(audio).read_bytes()
|
|
186
|
+
decoder = AudioDecoder(
|
|
187
|
+
source=data,
|
|
188
|
+
)
|
|
189
|
+
samples = decoder.get_samples_played_in_range(stop_seconds=max_duration)
|
|
190
|
+
else:
|
|
191
|
+
raise ValueError(f"Unsupported audio type: {type(audio)}")
|
|
192
|
+
|
|
193
|
+
return samples
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _encode_audio(
|
|
197
|
+
samples: AudioSamples,
|
|
198
|
+
resample_rate: int | None = None,
|
|
199
|
+
bitrate: int = 64000,
|
|
200
|
+
audio_format: str = "mp3",
|
|
201
|
+
mono: bool = True,
|
|
202
|
+
) -> bytes:
|
|
203
|
+
encoder = AudioEncoder(
|
|
204
|
+
samples=samples.data,
|
|
205
|
+
sample_rate=samples.sample_rate,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
audio_tensor = encoder.to_tensor(
|
|
209
|
+
format=audio_format,
|
|
210
|
+
bit_rate=bitrate if audio_format == "mp3" else None,
|
|
211
|
+
num_channels=1 if mono else None,
|
|
212
|
+
sample_rate=resample_rate,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
return audio_tensor.numpy().tobytes()
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def get_file_name(path: Path | str) -> str:
|
|
219
|
+
"""Get file name from path."""
|
|
220
|
+
return Path(path).name
|