caul 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caul-0.1.0/PKG-INFO +13 -0
- caul-0.1.0/pyproject.toml +48 -0
- caul-0.1.0/src/caul/__init__.py +0 -0
- caul-0.1.0/src/caul/configs/__init__.py +7 -0
- caul-0.1.0/src/caul/configs/asr.py +23 -0
- caul-0.1.0/src/caul/configs/parakeet.py +19 -0
- caul-0.1.0/src/caul/constant.py +31 -0
- caul-0.1.0/src/caul/exception.py +6 -0
- caul-0.1.0/src/caul/filesystem.py +33 -0
- caul-0.1.0/src/caul/handler.py +163 -0
- caul-0.1.0/src/caul/model_handlers/__init__.py +0 -0
- caul-0.1.0/src/caul/model_handlers/asr_model_handler.py +39 -0
- caul-0.1.0/src/caul/model_handlers/helpers.py +67 -0
- caul-0.1.0/src/caul/model_handlers/parakeet.py +73 -0
- caul-0.1.0/src/caul/tasks/__init__.py +0 -0
- caul-0.1.0/src/caul/tasks/asr_task.py +12 -0
- caul-0.1.0/src/caul/tasks/inference/__init__.py +0 -0
- caul-0.1.0/src/caul/tasks/inference/asr_inference.py +24 -0
- caul-0.1.0/src/caul/tasks/inference/parakeet_inference.py +83 -0
- caul-0.1.0/src/caul/tasks/inference/whisper_cpp_inference.py +20 -0
- caul-0.1.0/src/caul/tasks/postprocessing/__init__.py +0 -0
- caul-0.1.0/src/caul/tasks/postprocessing/parakeet_postprocessor.py +55 -0
- caul-0.1.0/src/caul/tasks/preprocessing/__init__.py +0 -0
- caul-0.1.0/src/caul/tasks/preprocessing/helpers.py +32 -0
- caul-0.1.0/src/caul/tasks/preprocessing/parakeet_preprocessor.py +258 -0
- caul-0.1.0/src/caul/utils.py +24 -0
caul-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: caul
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Python implementation of an ASR service
|
|
5
|
+
Author: Lion Summerbell
|
|
6
|
+
Author-email: Lion Summerbell <lsummerbell@icij.org>
|
|
7
|
+
Requires-Dist: nemo-toolkit[asr]
|
|
8
|
+
Requires-Dist: torch>=2.10.0
|
|
9
|
+
Requires-Dist: numpy
|
|
10
|
+
Requires-Dist: torchaudio>=2.10.0
|
|
11
|
+
Requires-Dist: torchcodec>=0.9.1
|
|
12
|
+
Requires-Dist: librosa>=0.11.0
|
|
13
|
+
Requires-Python: >=3.10.0, <3.14.0
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "caul"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Python implementation of an ASR service"
|
|
5
|
+
requires-python = ">=3.10.0, <3.14.0"
|
|
6
|
+
dependencies = [
|
|
7
|
+
"nemo-toolkit[asr]",
|
|
8
|
+
"torch>=2.10.0",
|
|
9
|
+
"numpy",
|
|
10
|
+
"torchaudio>=2.10.0",
|
|
11
|
+
"torchcodec>=0.9.1",
|
|
12
|
+
"librosa>=0.11.0",
|
|
13
|
+
]
|
|
14
|
+
authors = [
|
|
15
|
+
{name = "Lion Summerbell", email = "lsummerbell@icij.org"}
|
|
16
|
+
]
|
|
17
|
+
[build-system]
|
|
18
|
+
requires = ["uv_build >= 0.9.22, <0.10.0"]
|
|
19
|
+
build-backend = "uv_build"
|
|
20
|
+
|
|
21
|
+
[dependency-groups]
|
|
22
|
+
dev = [
|
|
23
|
+
"black>=25.12.0",
|
|
24
|
+
"pre-commit>=4.5.1",
|
|
25
|
+
"pylint>=4.0.4",
|
|
26
|
+
"pytest>=9.0.2",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
[tool.pylint."MESSAGES CONTROL"]
|
|
30
|
+
disable = [
|
|
31
|
+
"missing-module-docstring"
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
[tool.pylint.main]
|
|
35
|
+
fail-under = 8.89
|
|
36
|
+
|
|
37
|
+
[tool.pylint.refactoring]
|
|
38
|
+
max-nested-blocks = 5
|
|
39
|
+
never-returning-functions = ["sys.exit", "argparse.parse_error"]
|
|
40
|
+
|
|
41
|
+
[tool.pylint.similarities]
|
|
42
|
+
# Docstrings are removed from the similarity computation
|
|
43
|
+
ignore-docstrings = true
|
|
44
|
+
# Comments are removed from the similarity computation
|
|
45
|
+
ignore-comments = true
|
|
46
|
+
|
|
47
|
+
[tool.black]
|
|
48
|
+
target-version = ["py310"]
|
|
File without changes
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from caul.constant import DEVICE_CPU
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from caul.model_handlers.asr_model_handler import ASRModelHandler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class ASRConfig:
|
|
14
|
+
"""Base config class"""
|
|
15
|
+
|
|
16
|
+
model_name: str
|
|
17
|
+
model_handler: "ASRModelHandler"
|
|
18
|
+
device: str | torch.device = DEVICE_CPU
|
|
19
|
+
|
|
20
|
+
def handler_from_config(self) -> "ASRModelHandler":
|
|
21
|
+
return (
|
|
22
|
+
self.model_handler(config=self) if self.model_handler is not None else None
|
|
23
|
+
)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from caul.configs.asr import ASRConfig
|
|
5
|
+
from caul.constant import EXPECTED_SAMPLE_RATE, PARAKEET_MODEL_REF
|
|
6
|
+
from caul.model_handlers.parakeet import ParakeetModelHandler
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from caul.model_handlers.asr_model_handler import ASRModelHandler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class ParakeetConfig(ASRConfig):
|
|
14
|
+
|
|
15
|
+
model_name: str = PARAKEET_MODEL_REF
|
|
16
|
+
model_handler: "ASRModelHandler" = ParakeetModelHandler
|
|
17
|
+
save_to_filesystem: bool = True
|
|
18
|
+
return_tensors: bool = True
|
|
19
|
+
sample_rate: int = EXPECTED_SAMPLE_RATE
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# General
|
|
2
|
+
|
|
3
|
+
DEVICE_CPU = "cpu"
|
|
4
|
+
|
|
5
|
+
DEVICE_GPU = "gpu"
|
|
6
|
+
|
|
7
|
+
DEVICE_MPS = "mps"
|
|
8
|
+
|
|
9
|
+
EXPECTED_SAMPLE_RATE = 16000
|
|
10
|
+
|
|
11
|
+
EXPECTED_FORMAT = "wav"
|
|
12
|
+
|
|
13
|
+
EXPECTED_SAMPLE_MINUTE = EXPECTED_SAMPLE_RATE * 60
|
|
14
|
+
|
|
15
|
+
# Parakeet
|
|
16
|
+
|
|
17
|
+
PARAKEET_MODEL_REF = "nvidia/parakeet-tdt-0.6b-v3"
|
|
18
|
+
|
|
19
|
+
PARAKEET_INFERENCE_MAX_DURATION_MIN = (
|
|
20
|
+
20 # actually 24, but we want to give ourselves some room
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
PARAKEET_INFERENCE_MAX_DURATION_KHZ = (
|
|
24
|
+
PARAKEET_INFERENCE_MAX_DURATION_MIN * EXPECTED_SAMPLE_MINUTE
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
PARAKEET = "parakeet"
|
|
28
|
+
|
|
29
|
+
# Whisper
|
|
30
|
+
|
|
31
|
+
WHISPER_CPP = "whisper-cpp"
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from tempfile import mkstemp
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torchaudio
|
|
5
|
+
|
|
6
|
+
from caul.constant import EXPECTED_SAMPLE_RATE, EXPECTED_FORMAT
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def save_tensor(audio_tensor: torch.Tensor) -> str:
|
|
10
|
+
"""Filesystem routine for audio tensor; defaults to wav
|
|
11
|
+
|
|
12
|
+
:param audio_tensor: input tensor
|
|
13
|
+
:return: string file uri
|
|
14
|
+
"""
|
|
15
|
+
# TODO: Change paths to run_id + tensor uuid + pagination
|
|
16
|
+
# Allow for remote paths
|
|
17
|
+
|
|
18
|
+
_, file_path = mkstemp()
|
|
19
|
+
|
|
20
|
+
# torchcodec requires this
|
|
21
|
+
file_path = f"{file_path}.wav"
|
|
22
|
+
|
|
23
|
+
# Channel required as first dim
|
|
24
|
+
audio_tensor = audio_tensor.unsqueeze(0)
|
|
25
|
+
|
|
26
|
+
torchaudio.save(
|
|
27
|
+
file_path,
|
|
28
|
+
audio_tensor,
|
|
29
|
+
sample_rate=EXPECTED_SAMPLE_RATE,
|
|
30
|
+
format=EXPECTED_FORMAT,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
return file_path
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from caul.configs.asr import ASRConfig
|
|
10
|
+
from caul.exception import (
|
|
11
|
+
MissingModelSpecificationException,
|
|
12
|
+
UnsupportedModelException,
|
|
13
|
+
)
|
|
14
|
+
from caul.configs import MODEL_FAMILY_CONFIG_MAP
|
|
15
|
+
from caul.tasks.inference.asr_inference import (
|
|
16
|
+
ASRModelHandlerResult,
|
|
17
|
+
)
|
|
18
|
+
from caul.model_handlers.asr_model_handler import ASRModelHandler, ASRModelHandlerResult
|
|
19
|
+
from caul.utils import dict_key_fuzzy_match
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ASRHandler:
|
|
25
|
+
"""ASRHandler class"""
|
|
26
|
+
|
|
27
|
+
# pylint: disable=R0913,R0917
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
models: (
|
|
32
|
+
list[str | ASRModelHandler | ASRConfig] | str | ASRModelHandler | ASRConfig
|
|
33
|
+
),
|
|
34
|
+
device: torch.device | str = None,
|
|
35
|
+
language_map: dict[str, int] = None,
|
|
36
|
+
):
|
|
37
|
+
"""Primary application handler class. Handles transcription agnostically.
|
|
38
|
+
|
|
39
|
+
:param models: Model_handler(s) or string reference(s)
|
|
40
|
+
:param device: cuda/cpu/mps
|
|
41
|
+
:param language_map: Map from ISO-639-3 language code to index of inference_handler
|
|
42
|
+
"""
|
|
43
|
+
self.device = device
|
|
44
|
+
|
|
45
|
+
if language_map is None:
|
|
46
|
+
language_map = {}
|
|
47
|
+
|
|
48
|
+
self.language_map = language_map
|
|
49
|
+
|
|
50
|
+
self.model_handlers = []
|
|
51
|
+
|
|
52
|
+
if isinstance(models, list) and len(models) == 0:
|
|
53
|
+
raise MissingModelSpecificationException(
|
|
54
|
+
"At least one model name or model handler must be provided"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
if not isinstance(models, list):
|
|
58
|
+
models = [models]
|
|
59
|
+
|
|
60
|
+
for model in models:
|
|
61
|
+
if isinstance(model, str):
|
|
62
|
+
supported_model_config = dict_key_fuzzy_match(
|
|
63
|
+
MODEL_FAMILY_CONFIG_MAP, model
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if supported_model_config is None:
|
|
67
|
+
raise UnsupportedModelException(f"Unsupported model '{model}'")
|
|
68
|
+
|
|
69
|
+
# Set device after instantiation
|
|
70
|
+
supported_model_handler = supported_model_config.handler_from_config()
|
|
71
|
+
supported_model_handler.set_device(self.device)
|
|
72
|
+
|
|
73
|
+
self.model_handlers.append(supported_model_handler)
|
|
74
|
+
elif isinstance(model, ASRModelHandler):
|
|
75
|
+
self.model_handlers.append(model)
|
|
76
|
+
else:
|
|
77
|
+
raise UnsupportedModelException(f"Unsupported model type '{model}'")
|
|
78
|
+
|
|
79
|
+
def __repr__(self):
|
|
80
|
+
return f"<ASRHandler " f"models: {self.model_handlers} "
|
|
81
|
+
|
|
82
|
+
def startup(self):
|
|
83
|
+
"""Run all model handler startup procedures"""
|
|
84
|
+
for model_handler in self.model_handlers:
|
|
85
|
+
model_handler.startup()
|
|
86
|
+
|
|
87
|
+
def shutdown(self):
|
|
88
|
+
"""Garbage collect model handlers"""
|
|
89
|
+
self.model_handlers = []
|
|
90
|
+
|
|
91
|
+
def get_handler_by_language(self, language: str) -> ASRModelHandler:
|
|
92
|
+
"""Get model_handler from language map or return first reference if language is not mapped
|
|
93
|
+
|
|
94
|
+
:param language: ISO-639-3 language code
|
|
95
|
+
:return: ASRModelHandler
|
|
96
|
+
"""
|
|
97
|
+
reference_idx = self.language_map.get(
|
|
98
|
+
language, 0
|
|
99
|
+
) # default to primary inference_handler when no language given
|
|
100
|
+
|
|
101
|
+
if len(self.model_handlers) <= reference_idx:
|
|
102
|
+
raise UnsupportedModelException(
|
|
103
|
+
"Language is mapped to a model index which does not exist"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return self.model_handlers[reference_idx]
|
|
107
|
+
|
|
108
|
+
def transcribe(
|
|
109
|
+
self,
|
|
110
|
+
inputs: list[np.ndarray | torch.Tensor | str] | np.ndarray | torch.Tensor | str,
|
|
111
|
+
languages: list[str] = None,
|
|
112
|
+
) -> list[ASRModelHandlerResult]:
|
|
113
|
+
"""Transcribe audio tensors or strings. Returns a tuple of (transcription, score). A list
|
|
114
|
+
of languages of len(inputs) may be passed to direct inputs to certain inference_handlers.
|
|
115
|
+
|
|
116
|
+
:param inputs: List of np.ndarray or torch.Tensor or str, or a singleton of same types
|
|
117
|
+
:param languages: List of ISO-639-3 language codes
|
|
118
|
+
:return: HandlerResult
|
|
119
|
+
"""
|
|
120
|
+
if len(self.model_handlers) == 0:
|
|
121
|
+
raise MissingModelSpecificationException(
|
|
122
|
+
"At least one model name or model handler must be provided"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if not isinstance(inputs, list):
|
|
126
|
+
inputs = [inputs]
|
|
127
|
+
|
|
128
|
+
audios_by_language = {}
|
|
129
|
+
model_handler_results_by_language = {}
|
|
130
|
+
batch_language_ordering = []
|
|
131
|
+
model_handler_results = []
|
|
132
|
+
|
|
133
|
+
if languages is None:
|
|
134
|
+
# Default to first model handler
|
|
135
|
+
return self.model_handlers[0].process(inputs)
|
|
136
|
+
|
|
137
|
+
# Sort by language where present, preserving original order for returning result
|
|
138
|
+
for idx, aud in enumerate(inputs):
|
|
139
|
+
language = languages[idx]
|
|
140
|
+
|
|
141
|
+
if language not in audios_by_language:
|
|
142
|
+
audios_by_language[language] = []
|
|
143
|
+
|
|
144
|
+
batch_language_ordering.append(language)
|
|
145
|
+
audios_by_language[language].append(aud)
|
|
146
|
+
|
|
147
|
+
# Run inference_handler on language batch
|
|
148
|
+
for language, audio_list in audios_by_language.items():
|
|
149
|
+
model_handler = self.get_handler_by_language(language)
|
|
150
|
+
model_handler_results_by_language[language] = model_handler.process(
|
|
151
|
+
audio_list
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# For use with .pop()
|
|
155
|
+
batch_language_ordering.reverse()
|
|
156
|
+
|
|
157
|
+
# Reassemble and postprocess
|
|
158
|
+
for language in batch_language_ordering:
|
|
159
|
+
model_handler_result = model_handler_results_by_language[language].pop()
|
|
160
|
+
|
|
161
|
+
model_handler_results.append(model_handler_result)
|
|
162
|
+
|
|
163
|
+
return model_handler_results
|
|
File without changes
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from caul.configs.asr import ASRConfig
|
|
7
|
+
from caul.model_handlers.helpers import ASRModelHandlerResult
|
|
8
|
+
from caul.tasks.asr_task import ASRTask
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ASRModelHandler(ABC):
|
|
12
|
+
"""ASR model handler abstract"""
|
|
13
|
+
|
|
14
|
+
# pylint: disable=R0903
|
|
15
|
+
|
|
16
|
+
def __init__(self, config: "ASRConfig", *args, **kwargs):
|
|
17
|
+
self.config = config
|
|
18
|
+
self.tasks: list[ASRTask] = []
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def startup(self):
|
|
22
|
+
"""Generic method to load ASR resources"""
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def shutdown(self):
|
|
26
|
+
"""Generic method to unload ASR resources"""
|
|
27
|
+
|
|
28
|
+
def process(
|
|
29
|
+
self,
|
|
30
|
+
inputs: list[np.ndarray | torch.Tensor | str] | np.ndarray | torch.Tensor | str,
|
|
31
|
+
) -> list[ASRModelHandlerResult]:
|
|
32
|
+
"""Generic sequential processing method for ASR model handlers"""
|
|
33
|
+
|
|
34
|
+
output = inputs
|
|
35
|
+
|
|
36
|
+
for task in self.tasks:
|
|
37
|
+
output = task.process(output)
|
|
38
|
+
|
|
39
|
+
return output
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from nemo.collections.asr.parts.utils import Hypothesis
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class ASRModelHandlerResult:
|
|
8
|
+
"""Base result class for ASR models"""
|
|
9
|
+
|
|
10
|
+
input_ordering: int = -1
|
|
11
|
+
transcription: list[tuple] = None
|
|
12
|
+
score: float = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class ParakeetModelHandlerResult(ASRModelHandlerResult):
|
|
17
|
+
"""Result handler for ParakeetInferenceHandler objects"""
|
|
18
|
+
|
|
19
|
+
def parse_parakeet_hypothesis(
|
|
20
|
+
self, hypothesis: Hypothesis
|
|
21
|
+
) -> ASRModelHandlerResult:
|
|
22
|
+
"""Parse a hypothesis returned by a Parakeet RNN model
|
|
23
|
+
|
|
24
|
+
:param hypothesis: Parakeet hypothesis
|
|
25
|
+
:return: copy of self
|
|
26
|
+
"""
|
|
27
|
+
self.transcription = (
|
|
28
|
+
[
|
|
29
|
+
(s["start"], s["end"], s["segment"])
|
|
30
|
+
for s in hypothesis.timestamp.get("segment")
|
|
31
|
+
]
|
|
32
|
+
if hypothesis.timestamp.get("segment") is not None
|
|
33
|
+
else [(0.0, 0.0, hypothesis.text)]
|
|
34
|
+
)
|
|
35
|
+
self.score = round(hypothesis.score, 2)
|
|
36
|
+
|
|
37
|
+
return self
|
|
38
|
+
|
|
39
|
+
def concat(self, model_result: ASRModelHandlerResult) -> ASRModelHandlerResult:
|
|
40
|
+
"""Left fold with ParakeetModelHandlerResult object
|
|
41
|
+
|
|
42
|
+
:param model_result: ParakeetModelHandlerResult
|
|
43
|
+
:return: copy of self
|
|
44
|
+
"""
|
|
45
|
+
if model_result is None:
|
|
46
|
+
return self
|
|
47
|
+
|
|
48
|
+
if self.transcription is None:
|
|
49
|
+
self.transcription = []
|
|
50
|
+
|
|
51
|
+
self.transcription += model_result.transcription
|
|
52
|
+
|
|
53
|
+
# We have to weight by total segment len
|
|
54
|
+
transcription_duration = self.transcription[-1][1]
|
|
55
|
+
model_result_duration = model_result.transcription[-1][1]
|
|
56
|
+
total_duration = transcription_duration + model_result_duration
|
|
57
|
+
|
|
58
|
+
self.score = round(
|
|
59
|
+
(
|
|
60
|
+
self.score * transcription_duration
|
|
61
|
+
+ model_result.score * model_result_duration
|
|
62
|
+
)
|
|
63
|
+
/ total_duration,
|
|
64
|
+
2,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return self
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from caul.constant import DEVICE_CPU, PARAKEET_MODEL_REF
|
|
6
|
+
from caul.model_handlers.asr_model_handler import ASRModelHandler
|
|
7
|
+
from caul.tasks.inference.parakeet_inference import ParakeetInferenceHandler
|
|
8
|
+
from caul.tasks.postprocessing.parakeet_postprocessor import ParakeetPostprocessor
|
|
9
|
+
from caul.tasks.preprocessing.parakeet_preprocessor import ParakeetPreprocessor
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from caul.configs import ParakeetConfig
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ParakeetModelHandler(ASRModelHandler):
|
|
16
|
+
"""Model handler for Parakeet family"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
config: "ParakeetConfig" = None,
|
|
21
|
+
model_name: str = PARAKEET_MODEL_REF,
|
|
22
|
+
device: str | torch.device = DEVICE_CPU,
|
|
23
|
+
):
|
|
24
|
+
super().__init__(config=config)
|
|
25
|
+
|
|
26
|
+
if config is not None and config.model_name is not None:
|
|
27
|
+
model_name = config.model_name
|
|
28
|
+
|
|
29
|
+
self.model_name = model_name
|
|
30
|
+
|
|
31
|
+
if config is not None and config.device is not None:
|
|
32
|
+
device = config.device
|
|
33
|
+
|
|
34
|
+
if isinstance(device, str):
|
|
35
|
+
device = torch.device(device)
|
|
36
|
+
|
|
37
|
+
self.device = device
|
|
38
|
+
|
|
39
|
+
self.preprocessor = ParakeetPreprocessor(
|
|
40
|
+
save_to_filesystem=config.save_to_filesystem,
|
|
41
|
+
return_tensors=config.return_tensors,
|
|
42
|
+
)
|
|
43
|
+
self.inference_handler = ParakeetInferenceHandler(
|
|
44
|
+
model_name=config.model_name, device=config.device
|
|
45
|
+
)
|
|
46
|
+
self.postprocessor = ParakeetPostprocessor()
|
|
47
|
+
|
|
48
|
+
self.tasks = [self.preprocessor, self.inference_handler, self.postprocessor]
|
|
49
|
+
|
|
50
|
+
def set_device(self, device: str | torch.device = DEVICE_CPU):
|
|
51
|
+
"""Set/change device here and on inference_handler
|
|
52
|
+
|
|
53
|
+
:param device: device to use
|
|
54
|
+
"""
|
|
55
|
+
if isinstance(device, str):
|
|
56
|
+
device = torch.device(device)
|
|
57
|
+
|
|
58
|
+
self.device = device
|
|
59
|
+
|
|
60
|
+
self.inference_handler.set_device(device)
|
|
61
|
+
|
|
62
|
+
return self
|
|
63
|
+
|
|
64
|
+
def startup(self):
|
|
65
|
+
"""Load model"""
|
|
66
|
+
self.inference_handler.load()
|
|
67
|
+
|
|
68
|
+
def shutdown(self):
|
|
69
|
+
"""Shut down"""
|
|
70
|
+
self.preprocessor = None
|
|
71
|
+
self.inference_handler = None
|
|
72
|
+
self.postprocessor = None
|
|
73
|
+
self.tasks = []
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
|
|
3
|
+
from caul.model_handlers.asr_model_handler import ASRModelHandlerResult
|
|
4
|
+
from caul.tasks.asr_task import ASRTask
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ASRInferenceHandler(ASRTask):
|
|
8
|
+
"""Abstract for ASR inference"""
|
|
9
|
+
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def process(self, inputs: list, *args, **kwargs) -> list[ASRModelHandlerResult]:
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
:param inputs: List of inference inputs
|
|
15
|
+
:return: ASRModelHandlerResult
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def load(self):
|
|
20
|
+
"""Load model"""
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def unload(self):
|
|
24
|
+
"""Unload model"""
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
import nemo.collections.asr as nemo_asr
|
|
4
|
+
|
|
5
|
+
from caul.constant import DEVICE_CPU
|
|
6
|
+
from caul.model_handlers.helpers import ParakeetModelHandlerResult
|
|
7
|
+
from caul.tasks.inference.asr_inference import (
|
|
8
|
+
ASRInferenceHandler,
|
|
9
|
+
)
|
|
10
|
+
from caul.tasks.preprocessing.helpers import PreprocessedInput
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ParakeetInferenceHandler(ASRInferenceHandler):
|
|
14
|
+
"""Inference handler for NVIDIA's Parakeet family of ASR models. Supports up to 24 minutes of
|
|
15
|
+
audio (batched or unbatched) in a single pass. Assumes that audio inputs (wav files or tensors)
|
|
16
|
+
are single-channel with a sample rate of 16000—this last is very important for segmenting.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, model_name: str, device: str | torch.device = DEVICE_CPU):
|
|
20
|
+
self.model_name = model_name
|
|
21
|
+
|
|
22
|
+
if isinstance(device, str):
|
|
23
|
+
device = torch.device(device)
|
|
24
|
+
|
|
25
|
+
self.device = device
|
|
26
|
+
self.model = None
|
|
27
|
+
|
|
28
|
+
def load(self):
|
|
29
|
+
"""Load model; default to CPU where no device is present"""
|
|
30
|
+
device = self.device
|
|
31
|
+
|
|
32
|
+
if device is None:
|
|
33
|
+
device = DEVICE_CPU
|
|
34
|
+
|
|
35
|
+
self.model = nemo_asr.models.ASRModel.from_pretrained(
|
|
36
|
+
self.model_name, map_location=torch.device(device)
|
|
37
|
+
).eval()
|
|
38
|
+
|
|
39
|
+
def unload(self):
|
|
40
|
+
"""Unload model"""
|
|
41
|
+
self.model = None
|
|
42
|
+
|
|
43
|
+
def set_device(self, device: str | torch.device = DEVICE_CPU):
|
|
44
|
+
"""Set/change device"""
|
|
45
|
+
if isinstance(device, str):
|
|
46
|
+
device = torch.device(device)
|
|
47
|
+
|
|
48
|
+
self.device = device
|
|
49
|
+
|
|
50
|
+
return self
|
|
51
|
+
|
|
52
|
+
def process(
|
|
53
|
+
self,
|
|
54
|
+
inputs: list[list[PreprocessedInput]] | list[PreprocessedInput],
|
|
55
|
+
timestamps: bool = True,
|
|
56
|
+
) -> list[ParakeetModelHandlerResult]:
|
|
57
|
+
"""Transcribe a batch of audio tensors or file names of max duration <= 20 minutes
|
|
58
|
+
|
|
59
|
+
:param inputs: List of np.ndarray or torch.Tensor or str, or singleton of same types
|
|
60
|
+
:param timestamps: Whether to include timestamps with transcriptions
|
|
61
|
+
:return: List of results
|
|
62
|
+
"""
|
|
63
|
+
if len(inputs) == 0:
|
|
64
|
+
return []
|
|
65
|
+
|
|
66
|
+
if isinstance(inputs[0], PreprocessedInput):
|
|
67
|
+
inputs = [inputs]
|
|
68
|
+
|
|
69
|
+
transcriptions = []
|
|
70
|
+
|
|
71
|
+
for input_batch in inputs:
|
|
72
|
+
hypotheses = self.model.transcribe(
|
|
73
|
+
[i.tensor.to(self.device) for i in input_batch], timestamps=timestamps
|
|
74
|
+
)
|
|
75
|
+
# Get timestamped segments if available, otherwise default to whole text
|
|
76
|
+
for idx, hyp in enumerate(hypotheses):
|
|
77
|
+
input_ordering_idx = input_batch[idx].metadata.input_ordering
|
|
78
|
+
model_result = ParakeetModelHandlerResult(
|
|
79
|
+
input_ordering=input_ordering_idx
|
|
80
|
+
).parse_parakeet_hypothesis(hyp)
|
|
81
|
+
transcriptions.append(model_result)
|
|
82
|
+
|
|
83
|
+
return transcriptions
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from caul.tasks.inference.asr_inference import (
|
|
2
|
+
ASRInferenceHandler,
|
|
3
|
+
ASRModelHandlerResult,
|
|
4
|
+
)
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class WhisperCPPInferenceHandler(ASRInferenceHandler):
|
|
8
|
+
"""Handler for WhisperCPP; wrapper round subprocess calls"""
|
|
9
|
+
|
|
10
|
+
# pylint: disable=R0903
|
|
11
|
+
|
|
12
|
+
def process(
|
|
13
|
+
self,
|
|
14
|
+
inputs: list[str],
|
|
15
|
+
) -> list[ASRModelHandlerResult]:
|
|
16
|
+
"""List of np.ndarray or torch.Tensor or str, or a singleton of same types
|
|
17
|
+
|
|
18
|
+
:param inputs: List of np.ndarray or torch.Tensor or str, or a singleton of same types
|
|
19
|
+
:return:
|
|
20
|
+
"""
|
|
File without changes
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from functools import reduce
|
|
2
|
+
from itertools import groupby
|
|
3
|
+
|
|
4
|
+
from caul.tasks.asr_task import ASRTask
|
|
5
|
+
from caul.tasks.inference.parakeet_inference import ParakeetModelHandlerResult
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ParakeetPostprocessor(ASRTask):
|
|
9
|
+
"""Postprocessing logic for ParakeetInferenceHandler output"""
|
|
10
|
+
|
|
11
|
+
def process(
|
|
12
|
+
self, inputs: list[ParakeetModelHandlerResult]
|
|
13
|
+
) -> list[ParakeetModelHandlerResult]:
|
|
14
|
+
"""Process indexed ParakeetInferenceHandler results and return them in their original
|
|
15
|
+
ordering
|
|
16
|
+
|
|
17
|
+
:param inputs: List of parakeet model results
|
|
18
|
+
:return: list of parakeet model results in input ordering
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
return self.map_results_to_inputs(inputs)
|
|
22
|
+
|
|
23
|
+
@staticmethod
|
|
24
|
+
def map_results_to_inputs(
|
|
25
|
+
batched_results: list[ParakeetModelHandlerResult],
|
|
26
|
+
) -> list[ParakeetModelHandlerResult]:
|
|
27
|
+
"""Remap unordered and segmented tensors to original inputs for return
|
|
28
|
+
|
|
29
|
+
:param batched_results: list of unordered ParakeetModelHandlerResult, still
|
|
30
|
+
segmented
|
|
31
|
+
:return: list[ParakeetModelHandlerResult]
|
|
32
|
+
"""
|
|
33
|
+
unbatched_results = []
|
|
34
|
+
|
|
35
|
+
# Sort in order before batching
|
|
36
|
+
batched_results = sorted(batched_results, key=lambda r: r.input_ordering)
|
|
37
|
+
|
|
38
|
+
# Concat segmented tensors
|
|
39
|
+
results_grouped_by_index = groupby(
|
|
40
|
+
batched_results, key=lambda r: r.input_ordering
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
for _, group_results in results_grouped_by_index:
|
|
44
|
+
group_results = list(group_results)
|
|
45
|
+
|
|
46
|
+
merged_results = (
|
|
47
|
+
reduce(lambda l, r: l.concat(r), group_results)
|
|
48
|
+
if len(group_results) > 1
|
|
49
|
+
else group_results[0]
|
|
50
|
+
)
|
|
51
|
+
unbatched_results.append(merged_results)
|
|
52
|
+
|
|
53
|
+
# TODO: Drop index from result
|
|
54
|
+
|
|
55
|
+
return unbatched_results
|
|
File without changes
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import uuid
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class InputMetadata:
|
|
12
|
+
"""Preprocessed input metadata"""
|
|
13
|
+
|
|
14
|
+
input_ordering: int
|
|
15
|
+
duration: int
|
|
16
|
+
start_time: int = 0
|
|
17
|
+
end_time: int = 0
|
|
18
|
+
preprocessed_at: str = (
|
|
19
|
+
datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0).isoformat()
|
|
20
|
+
)
|
|
21
|
+
uuid: str = uuid.uuid4().hex
|
|
22
|
+
input_format: str = None
|
|
23
|
+
input_file_path: str = None
|
|
24
|
+
preprocessed_file_path: str = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class PreprocessedInput:
|
|
29
|
+
"""Preprocessed input wrapper"""
|
|
30
|
+
|
|
31
|
+
metadata: InputMetadata
|
|
32
|
+
tensor: Optional[torch.Tensor | list] = None
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
import librosa
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torchaudio
|
|
6
|
+
|
|
7
|
+
from caul.constant import (
|
|
8
|
+
PARAKEET_INFERENCE_MAX_DURATION_KHZ,
|
|
9
|
+
EXPECTED_SAMPLE_MINUTE,
|
|
10
|
+
EXPECTED_SAMPLE_RATE,
|
|
11
|
+
PARAKEET_INFERENCE_MAX_DURATION_MIN,
|
|
12
|
+
)
|
|
13
|
+
from caul.filesystem import save_tensor
|
|
14
|
+
from caul.tasks.asr_task import ASRTask
|
|
15
|
+
from caul.tasks.preprocessing.helpers import PreprocessedInput, InputMetadata
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ParakeetPreprocessor(ASRTask):
|
|
19
|
+
"""Preprocessing logic for ParakeetInferenceHandler inputs"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
save_to_filesystem: bool = True,
|
|
24
|
+
return_tensors: bool = True,
|
|
25
|
+
sample_rate: int = EXPECTED_SAMPLE_RATE,
|
|
26
|
+
):
|
|
27
|
+
super().__init__()
|
|
28
|
+
|
|
29
|
+
self.save_to_filesystem = save_to_filesystem
|
|
30
|
+
self.return_tensors = return_tensors
|
|
31
|
+
self.sample_rate = sample_rate
|
|
32
|
+
|
|
33
|
+
def process(
|
|
34
|
+
self,
|
|
35
|
+
inputs: list[np.ndarray | torch.Tensor | str] | np.ndarray | torch.Tensor | str,
|
|
36
|
+
input_sample_rates: list[int] | int = None,
|
|
37
|
+
) -> list[list[PreprocessedInput]]:
|
|
38
|
+
"""Segment and batch audio inputs
|
|
39
|
+
|
|
40
|
+
:param inputs: List of np.ndarray or torch.Tensor or str, or singleton of same types
|
|
41
|
+
:param input_sample_rates: sample rate(s) of audio inputs
|
|
42
|
+
:return: batches of indexed preprocessed audio tensors (input_idx, preprocessed_input)
|
|
43
|
+
"""
|
|
44
|
+
if not isinstance(inputs, list):
|
|
45
|
+
inputs = [inputs]
|
|
46
|
+
|
|
47
|
+
preprocessed_inputs = self.preprocess_inputs(inputs, input_sample_rates)
|
|
48
|
+
batches = self.batch_audio_tensors(preprocessed_inputs)
|
|
49
|
+
|
|
50
|
+
return batches
|
|
51
|
+
|
|
52
|
+
def preprocess_inputs(
|
|
53
|
+
self,
|
|
54
|
+
inputs: list[np.ndarray | torch.Tensor | str],
|
|
55
|
+
input_sample_rates: list[int] = None,
|
|
56
|
+
) -> list[PreprocessedInput]:
|
|
57
|
+
"""Accepts audio inputs as a list of file paths, np.ndarray, or torch.Tensor, converting to
|
|
58
|
+
torch.Tensor, normalizing, segmenting inputs longer than 20 minutes (just under Parakeet's
|
|
59
|
+
max) first by silences or with overlaps where not available, and batching segments
|
|
60
|
+
|
|
61
|
+
:param inputs: List of np.ndarray or torch.Tensor or str, or a singleton of same types
|
|
62
|
+
:param input_sample_rates: sample rate(s) of audio inputs
|
|
63
|
+
:return: List of processed inputs
|
|
64
|
+
"""
|
|
65
|
+
preprocessed_inputs = []
|
|
66
|
+
|
|
67
|
+
# Load arrays and divide into max_length segments
|
|
68
|
+
for input_idx, audio_input in enumerate(inputs):
|
|
69
|
+
input_file_path = None
|
|
70
|
+
new_file_path = None
|
|
71
|
+
input_format = None
|
|
72
|
+
|
|
73
|
+
# Load audio files as arrays
|
|
74
|
+
if isinstance(audio_input, str):
|
|
75
|
+
input_file_path = audio_input
|
|
76
|
+
input_format = (
|
|
77
|
+
input_file_path.split(".")[-1]
|
|
78
|
+
if len(input_file_path.split(".")) > 1
|
|
79
|
+
else None
|
|
80
|
+
)
|
|
81
|
+
audio_input, sample_rate = torchaudio.load(audio_input)
|
|
82
|
+
|
|
83
|
+
if isinstance(audio_input, np.ndarray):
|
|
84
|
+
audio_input = torch.Tensor(audio_input)
|
|
85
|
+
|
|
86
|
+
# Normalize
|
|
87
|
+
if input_sample_rates is not None and len(input_sample_rates) > input_idx:
|
|
88
|
+
sample_rate = input_sample_rates[input_idx]
|
|
89
|
+
else:
|
|
90
|
+
sample_rate = self.sample_rate
|
|
91
|
+
|
|
92
|
+
audio_input = self.normalize(audio_input, sample_rate)
|
|
93
|
+
|
|
94
|
+
# Segment where necessary
|
|
95
|
+
duration_khz = audio_input.shape[-1]
|
|
96
|
+
tensor_segments = [audio_input]
|
|
97
|
+
|
|
98
|
+
if duration_khz > PARAKEET_INFERENCE_MAX_DURATION_KHZ:
|
|
99
|
+
tensor_segments = self.segment_audio_tensor(audio_input)
|
|
100
|
+
|
|
101
|
+
for tensor_segment in tensor_segments:
|
|
102
|
+
# Create temporary filesystem reference if applicable
|
|
103
|
+
if self.save_to_filesystem:
|
|
104
|
+
new_file_path = save_tensor(tensor_segment)
|
|
105
|
+
|
|
106
|
+
if not self.return_tensors:
|
|
107
|
+
tensor_segment = None
|
|
108
|
+
|
|
109
|
+
# Create preprocessed input
|
|
110
|
+
metadata = InputMetadata(
|
|
111
|
+
input_ordering=input_idx,
|
|
112
|
+
duration=duration_khz / EXPECTED_SAMPLE_MINUTE,
|
|
113
|
+
input_format=input_format,
|
|
114
|
+
input_file_path=input_file_path,
|
|
115
|
+
preprocessed_file_path=new_file_path,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
preprocessed_input = PreprocessedInput(
|
|
119
|
+
tensor=tensor_segment,
|
|
120
|
+
metadata=metadata,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
preprocessed_inputs.append(preprocessed_input)
|
|
124
|
+
|
|
125
|
+
return preprocessed_inputs
|
|
126
|
+
|
|
127
|
+
def normalize(self, audio_tensor: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
|
128
|
+
"""Normalize audio_tensor (single channel, sample rate = 16000)
|
|
129
|
+
|
|
130
|
+
:param audio_tensor: input tensor
|
|
131
|
+
:param sample_rate: input sample rate
|
|
132
|
+
:return: normalized tensor
|
|
133
|
+
"""
|
|
134
|
+
if sample_rate != self.sample_rate:
|
|
135
|
+
audio_tensor = self.resample_waveform(audio_tensor, sample_rate)
|
|
136
|
+
|
|
137
|
+
# Stereo dims (channels, aud_length); need mono (aud_length)
|
|
138
|
+
if len(audio_tensor.shape) > 1:
|
|
139
|
+
audio_tensor = audio_tensor.squeeze(0)
|
|
140
|
+
|
|
141
|
+
return audio_tensor
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def segment_audio_tensor(
|
|
145
|
+
audio_tensor: torch.Tensor,
|
|
146
|
+
frame_len: int = 2048,
|
|
147
|
+
silence_thresh_db: int = 35,
|
|
148
|
+
hop_len: int = 512,
|
|
149
|
+
kept_silence_len_secs: int = 0.15,
|
|
150
|
+
min_silence_len_secs: int = 0.5,
|
|
151
|
+
max_segment_len_secs: int = EXPECTED_SAMPLE_MINUTE
|
|
152
|
+
* PARAKEET_INFERENCE_MAX_DURATION_MIN,
|
|
153
|
+
) -> list[torch.Tensor]:
|
|
154
|
+
"""Splits on silences with librosa, falling back to overlaps where min segments
|
|
155
|
+
are not sufficient to safely divide audio.
|
|
156
|
+
|
|
157
|
+
:param audio_tensor: input tensor
|
|
158
|
+
:param frame_len: number of samples per analysis frame
|
|
159
|
+
:param silence_thresh_db: max decibel value
|
|
160
|
+
:param hop_len: number of samples between analysis frames
|
|
161
|
+
:param kept_silence_len_secs: number of seconds to keep silence
|
|
162
|
+
:param min_silence_len_secs: minimum seconds to keep silence
|
|
163
|
+
:param max_segment_len_secs: maximum seconds to keep silence
|
|
164
|
+
:return: list of tensor segments
|
|
165
|
+
"""
|
|
166
|
+
# TODO: Implement fallback to overlaps
|
|
167
|
+
tensor_segments = []
|
|
168
|
+
|
|
169
|
+
# Intervals between silences
|
|
170
|
+
nonsilent_intervals = librosa.effects.split(
|
|
171
|
+
audio_tensor.numpy(),
|
|
172
|
+
top_db=silence_thresh_db,
|
|
173
|
+
frame_length=frame_len,
|
|
174
|
+
hop_length=hop_len,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
merged = []
|
|
178
|
+
min_silence_sample_len = int(min_silence_len_secs * EXPECTED_SAMPLE_MINUTE)
|
|
179
|
+
kept_silence_sample_len = int(kept_silence_len_secs * EXPECTED_SAMPLE_MINUTE)
|
|
180
|
+
max_segment_sample_len = int(max_segment_len_secs * EXPECTED_SAMPLE_MINUTE)
|
|
181
|
+
|
|
182
|
+
# Merge intervals separated by short silences
|
|
183
|
+
for start, end in nonsilent_intervals:
|
|
184
|
+
if len(merged) == 0:
|
|
185
|
+
merged.append((start, end))
|
|
186
|
+
else:
|
|
187
|
+
_, prev_end = merged[-1]
|
|
188
|
+
if start - prev_end < min_silence_sample_len:
|
|
189
|
+
merged[-1][1] = end
|
|
190
|
+
else:
|
|
191
|
+
merged.append((start, end))
|
|
192
|
+
|
|
193
|
+
# Segment controlling max length
|
|
194
|
+
for start, end in merged:
|
|
195
|
+
start = max(0, start - kept_silence_sample_len)
|
|
196
|
+
end = min(audio_tensor.shape[-1], end + kept_silence_sample_len)
|
|
197
|
+
|
|
198
|
+
while end - start > max_segment_sample_len:
|
|
199
|
+
segment_end = start + max_segment_sample_len
|
|
200
|
+
tensor_segment = audio_tensor[start:segment_end]
|
|
201
|
+
|
|
202
|
+
tensor_segments.append(tensor_segment)
|
|
203
|
+
|
|
204
|
+
return tensor_segments
|
|
205
|
+
|
|
206
|
+
@staticmethod
|
|
207
|
+
def batch_audio_tensors( # pylint: disable=R0914
|
|
208
|
+
preprocessed_inputs: list[PreprocessedInput],
|
|
209
|
+
) -> list[list[PreprocessedInput]]:
|
|
210
|
+
"""Batch audio tensors by duration, 20 minutes max per batch, optimizing for tightly packed
|
|
211
|
+
batches.
|
|
212
|
+
|
|
213
|
+
:param preprocessed_inputs: list of PreprocessedInput
|
|
214
|
+
:return: list of list[PreprocessedInput]
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
# Sort by duration
|
|
218
|
+
preprocessed_inputs = sorted(
|
|
219
|
+
preprocessed_inputs, key=lambda p: p.metadata.duration, reverse=True
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Now this becomes a bin-packing minimization problem. We'll use a variant of best-fit
|
|
223
|
+
# decreasing.
|
|
224
|
+
|
|
225
|
+
bins = [[]]
|
|
226
|
+
bins_len = [0]
|
|
227
|
+
|
|
228
|
+
# With each pass, choose a bin by maximizing remaining space
|
|
229
|
+
for preprocessed_input in preprocessed_inputs:
|
|
230
|
+
bin_len_diffs = []
|
|
231
|
+
|
|
232
|
+
for bin_len in bins_len:
|
|
233
|
+
bin_len_diffs.append(PARAKEET_INFERENCE_MAX_DURATION_MIN - bin_len)
|
|
234
|
+
|
|
235
|
+
if max(bin_len_diffs) <= preprocessed_input.metadata.duration:
|
|
236
|
+
bins.append([])
|
|
237
|
+
bins_len.append(0)
|
|
238
|
+
bin_len_diffs.append(PARAKEET_INFERENCE_MAX_DURATION_MIN)
|
|
239
|
+
|
|
240
|
+
max_diff_idx = np.argmax(bin_len_diffs)
|
|
241
|
+
|
|
242
|
+
bins[max_diff_idx].append(preprocessed_input)
|
|
243
|
+
|
|
244
|
+
bins_len[max_diff_idx] += preprocessed_input.metadata.duration
|
|
245
|
+
|
|
246
|
+
return bins
|
|
247
|
+
|
|
248
|
+
def resample_waveform(
|
|
249
|
+
self, waveform: torch.Tensor, sample_rate: int
|
|
250
|
+
) -> torch.Tensor:
|
|
251
|
+
"""Resample when sample rate is not 16000
|
|
252
|
+
|
|
253
|
+
:param waveform: torch.Tensor
|
|
254
|
+
:param sample_rate: int
|
|
255
|
+
:return: resampled torch.Tensor
|
|
256
|
+
"""
|
|
257
|
+
transform = torchaudio.transforms.Resample(sample_rate, self.sample_rate)
|
|
258
|
+
return transform(waveform)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def dict_key_fuzzy_match(dict_obj: dict, search_key: str) -> Any | None:
|
|
5
|
+
"""Match a dict key name fuzzily (returning first match)
|
|
6
|
+
|
|
7
|
+
:param dict_obj: dictionary to search in
|
|
8
|
+
:param search_key: key name
|
|
9
|
+
:return: key value (if exists)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
if search_key in dict_obj:
|
|
13
|
+
return dict_obj[search_key]
|
|
14
|
+
|
|
15
|
+
fuzzy_matches = [
|
|
16
|
+
dict_value
|
|
17
|
+
for dict_key, dict_value in dict_obj.items()
|
|
18
|
+
if dict_key in search_key or search_key in dict_key
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
if len(fuzzy_matches) > 0:
|
|
22
|
+
return fuzzy_matches[0]
|
|
23
|
+
|
|
24
|
+
return None
|