caul 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
caul-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,13 @@
1
+ Metadata-Version: 2.3
2
+ Name: caul
3
+ Version: 0.1.0
4
+ Summary: Python implementation of an ASR service
5
+ Author: Lion Summerbell
6
+ Author-email: Lion Summerbell <lsummerbell@icij.org>
7
+ Requires-Dist: nemo-toolkit[asr]
8
+ Requires-Dist: torch>=2.10.0
9
+ Requires-Dist: numpy
10
+ Requires-Dist: torchaudio>=2.10.0
11
+ Requires-Dist: torchcodec>=0.9.1
12
+ Requires-Dist: librosa>=0.11.0
13
+ Requires-Python: >=3.10.0, <3.14.0
@@ -0,0 +1,48 @@
1
+ [project]
2
+ name = "caul"
3
+ version = "0.1.0"
4
+ description = "Python implementation of an ASR service"
5
+ requires-python = ">=3.10.0, <3.14.0"
6
+ dependencies = [
7
+ "nemo-toolkit[asr]",
8
+ "torch>=2.10.0",
9
+ "numpy",
10
+ "torchaudio>=2.10.0",
11
+ "torchcodec>=0.9.1",
12
+ "librosa>=0.11.0",
13
+ ]
14
+ authors = [
15
+ {name = "Lion Summerbell", email = "lsummerbell@icij.org"}
16
+ ]
17
+ [build-system]
18
+ requires = ["uv_build >= 0.9.22, <0.10.0"]
19
+ build-backend = "uv_build"
20
+
21
+ [dependency-groups]
22
+ dev = [
23
+ "black>=25.12.0",
24
+ "pre-commit>=4.5.1",
25
+ "pylint>=4.0.4",
26
+ "pytest>=9.0.2",
27
+ ]
28
+
29
+ [tool.pylint."MESSAGES CONTROL"]
30
+ disable = [
31
+ "missing-module-docstring"
32
+ ]
33
+
34
+ [tool.pylint.main]
35
+ fail-under = 8.89
36
+
37
+ [tool.pylint.refactoring]
38
+ max-nested-blocks = 5
39
+ never-returning-functions = ["sys.exit", "argparse.parse_error"]
40
+
41
+ [tool.pylint.similarities]
42
+ # Docstrings are removed from the similarity computation
43
+ ignore-docstrings = true
44
+ # Comments are removed from the similarity computation
45
+ ignore-comments = true
46
+
47
+ [tool.black]
48
+ target-version = ["py310"]
File without changes
@@ -0,0 +1,7 @@
1
+ from caul.configs.parakeet import ParakeetConfig
2
+ from caul.constant import PARAKEET
3
+
4
+
5
+ MODEL_FAMILY_CONFIG_MAP = {
6
+ PARAKEET: ParakeetConfig(),
7
+ }
@@ -0,0 +1,23 @@
1
+ from dataclasses import dataclass
2
+ from typing import TYPE_CHECKING
3
+
4
+ import torch
5
+
6
+ from caul.constant import DEVICE_CPU
7
+
8
+ if TYPE_CHECKING:
9
+ from caul.model_handlers.asr_model_handler import ASRModelHandler
10
+
11
+
12
+ @dataclass
13
+ class ASRConfig:
14
+ """Base config class"""
15
+
16
+ model_name: str
17
+ model_handler: "ASRModelHandler"
18
+ device: str | torch.device = DEVICE_CPU
19
+
20
+ def handler_from_config(self) -> "ASRModelHandler":
21
+ return (
22
+ self.model_handler(config=self) if self.model_handler is not None else None
23
+ )
@@ -0,0 +1,19 @@
1
+ from dataclasses import dataclass
2
+ from typing import TYPE_CHECKING
3
+
4
+ from caul.configs.asr import ASRConfig
5
+ from caul.constant import EXPECTED_SAMPLE_RATE, PARAKEET_MODEL_REF
6
+ from caul.model_handlers.parakeet import ParakeetModelHandler
7
+
8
+ if TYPE_CHECKING:
9
+ from caul.model_handlers.asr_model_handler import ASRModelHandler
10
+
11
+
12
+ @dataclass
13
+ class ParakeetConfig(ASRConfig):
14
+
15
+ model_name: str = PARAKEET_MODEL_REF
16
+ model_handler: "ASRModelHandler" = ParakeetModelHandler
17
+ save_to_filesystem: bool = True
18
+ return_tensors: bool = True
19
+ sample_rate: int = EXPECTED_SAMPLE_RATE
@@ -0,0 +1,31 @@
1
+ # General
2
+
3
+ DEVICE_CPU = "cpu"
4
+
5
+ DEVICE_GPU = "gpu"
6
+
7
+ DEVICE_MPS = "mps"
8
+
9
+ EXPECTED_SAMPLE_RATE = 16000
10
+
11
+ EXPECTED_FORMAT = "wav"
12
+
13
+ EXPECTED_SAMPLE_MINUTE = EXPECTED_SAMPLE_RATE * 60
14
+
15
+ # Parakeet
16
+
17
+ PARAKEET_MODEL_REF = "nvidia/parakeet-tdt-0.6b-v3"
18
+
19
+ PARAKEET_INFERENCE_MAX_DURATION_MIN = (
20
+ 20 # actually 24, but we want to give ourselves some room
21
+ )
22
+
23
+ PARAKEET_INFERENCE_MAX_DURATION_KHZ = (
24
+ PARAKEET_INFERENCE_MAX_DURATION_MIN * EXPECTED_SAMPLE_MINUTE
25
+ )
26
+
27
+ PARAKEET = "parakeet"
28
+
29
+ # Whisper
30
+
31
+ WHISPER_CPP = "whisper-cpp"
@@ -0,0 +1,6 @@
1
+ class MissingModelSpecificationException(Exception):
2
+ """Raise if referencing a missing model"""
3
+
4
+
5
+ class UnsupportedModelException(Exception):
6
+ """Raise if an unsupported model type is passed"""
@@ -0,0 +1,33 @@
1
+ from tempfile import mkstemp
2
+
3
+ import torch
4
+ import torchaudio
5
+
6
+ from caul.constant import EXPECTED_SAMPLE_RATE, EXPECTED_FORMAT
7
+
8
+
9
+ def save_tensor(audio_tensor: torch.Tensor) -> str:
10
+ """Filesystem routine for audio tensor; defaults to wav
11
+
12
+ :param audio_tensor: input tensor
13
+ :return: string file uri
14
+ """
15
+ # TODO: Change paths to run_id + tensor uuid + pagination
16
+ # Allow for remote paths
17
+
18
+ _, file_path = mkstemp()
19
+
20
+ # torchcodec requires this
21
+ file_path = f"{file_path}.wav"
22
+
23
+ # Channel required as first dim
24
+ audio_tensor = audio_tensor.unsqueeze(0)
25
+
26
+ torchaudio.save(
27
+ file_path,
28
+ audio_tensor,
29
+ sample_rate=EXPECTED_SAMPLE_RATE,
30
+ format=EXPECTED_FORMAT,
31
+ )
32
+
33
+ return file_path
@@ -0,0 +1,163 @@
1
+ import logging
2
+
3
+ from dataclasses import dataclass, field
4
+
5
+ import torch
6
+
7
+ import numpy as np
8
+
9
+ from caul.configs.asr import ASRConfig
10
+ from caul.exception import (
11
+ MissingModelSpecificationException,
12
+ UnsupportedModelException,
13
+ )
14
+ from caul.configs import MODEL_FAMILY_CONFIG_MAP
15
+ from caul.tasks.inference.asr_inference import (
16
+ ASRModelHandlerResult,
17
+ )
18
+ from caul.model_handlers.asr_model_handler import ASRModelHandler, ASRModelHandlerResult
19
+ from caul.utils import dict_key_fuzzy_match
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class ASRHandler:
25
+ """ASRHandler class"""
26
+
27
+ # pylint: disable=R0913,R0917
28
+
29
+ def __init__(
30
+ self,
31
+ models: (
32
+ list[str | ASRModelHandler | ASRConfig] | str | ASRModelHandler | ASRConfig
33
+ ),
34
+ device: torch.device | str = None,
35
+ language_map: dict[str, int] = None,
36
+ ):
37
+ """Primary application handler class. Handles transcription agnostically.
38
+
39
+ :param models: Model_handler(s) or string reference(s)
40
+ :param device: cuda/cpu/mps
41
+ :param language_map: Map from ISO-639-3 language code to index of inference_handler
42
+ """
43
+ self.device = device
44
+
45
+ if language_map is None:
46
+ language_map = {}
47
+
48
+ self.language_map = language_map
49
+
50
+ self.model_handlers = []
51
+
52
+ if isinstance(models, list) and len(models) == 0:
53
+ raise MissingModelSpecificationException(
54
+ "At least one model name or model handler must be provided"
55
+ )
56
+
57
+ if not isinstance(models, list):
58
+ models = [models]
59
+
60
+ for model in models:
61
+ if isinstance(model, str):
62
+ supported_model_config = dict_key_fuzzy_match(
63
+ MODEL_FAMILY_CONFIG_MAP, model
64
+ )
65
+
66
+ if supported_model_config is None:
67
+ raise UnsupportedModelException(f"Unsupported model '{model}'")
68
+
69
+ # Set device after instantiation
70
+ supported_model_handler = supported_model_config.handler_from_config()
71
+ supported_model_handler.set_device(self.device)
72
+
73
+ self.model_handlers.append(supported_model_handler)
74
+ elif isinstance(model, ASRModelHandler):
75
+ self.model_handlers.append(model)
76
+ else:
77
+ raise UnsupportedModelException(f"Unsupported model type '{model}'")
78
+
79
+ def __repr__(self):
80
+ return f"<ASRHandler " f"models: {self.model_handlers} "
81
+
82
+ def startup(self):
83
+ """Run all model handler startup procedures"""
84
+ for model_handler in self.model_handlers:
85
+ model_handler.startup()
86
+
87
+ def shutdown(self):
88
+ """Garbage collect model handlers"""
89
+ self.model_handlers = []
90
+
91
+ def get_handler_by_language(self, language: str) -> ASRModelHandler:
92
+ """Get model_handler from language map or return first reference if language is not mapped
93
+
94
+ :param language: ISO-639-3 language code
95
+ :return: ASRModelHandler
96
+ """
97
+ reference_idx = self.language_map.get(
98
+ language, 0
99
+ ) # default to primary inference_handler when no language given
100
+
101
+ if len(self.model_handlers) <= reference_idx:
102
+ raise UnsupportedModelException(
103
+ "Language is mapped to a model index which does not exist"
104
+ )
105
+
106
+ return self.model_handlers[reference_idx]
107
+
108
+ def transcribe(
109
+ self,
110
+ inputs: list[np.ndarray | torch.Tensor | str] | np.ndarray | torch.Tensor | str,
111
+ languages: list[str] = None,
112
+ ) -> list[ASRModelHandlerResult]:
113
+ """Transcribe audio tensors or strings. Returns a tuple of (transcription, score). A list
114
+ of languages of len(inputs) may be passed to direct inputs to certain inference_handlers.
115
+
116
+ :param inputs: List of np.ndarray or torch.Tensor or str, or a singleton of same types
117
+ :param languages: List of ISO-639-3 language codes
118
+ :return: HandlerResult
119
+ """
120
+ if len(self.model_handlers) == 0:
121
+ raise MissingModelSpecificationException(
122
+ "At least one model name or model handler must be provided"
123
+ )
124
+
125
+ if not isinstance(inputs, list):
126
+ inputs = [inputs]
127
+
128
+ audios_by_language = {}
129
+ model_handler_results_by_language = {}
130
+ batch_language_ordering = []
131
+ model_handler_results = []
132
+
133
+ if languages is None:
134
+ # Default to first model handler
135
+ return self.model_handlers[0].process(inputs)
136
+
137
+ # Sort by language where present, preserving original order for returning result
138
+ for idx, aud in enumerate(inputs):
139
+ language = languages[idx]
140
+
141
+ if language not in audios_by_language:
142
+ audios_by_language[language] = []
143
+
144
+ batch_language_ordering.append(language)
145
+ audios_by_language[language].append(aud)
146
+
147
+ # Run inference_handler on language batch
148
+ for language, audio_list in audios_by_language.items():
149
+ model_handler = self.get_handler_by_language(language)
150
+ model_handler_results_by_language[language] = model_handler.process(
151
+ audio_list
152
+ )
153
+
154
+ # For use with .pop()
155
+ batch_language_ordering.reverse()
156
+
157
+ # Reassemble and postprocess
158
+ for language in batch_language_ordering:
159
+ model_handler_result = model_handler_results_by_language[language].pop()
160
+
161
+ model_handler_results.append(model_handler_result)
162
+
163
+ return model_handler_results
File without changes
@@ -0,0 +1,39 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from caul.configs.asr import ASRConfig
7
+ from caul.model_handlers.helpers import ASRModelHandlerResult
8
+ from caul.tasks.asr_task import ASRTask
9
+
10
+
11
+ class ASRModelHandler(ABC):
12
+ """ASR model handler abstract"""
13
+
14
+ # pylint: disable=R0903
15
+
16
+ def __init__(self, config: "ASRConfig", *args, **kwargs):
17
+ self.config = config
18
+ self.tasks: list[ASRTask] = []
19
+
20
+ @abstractmethod
21
+ def startup(self):
22
+ """Generic method to load ASR resources"""
23
+
24
+ @abstractmethod
25
+ def shutdown(self):
26
+ """Generic method to unload ASR resources"""
27
+
28
+ def process(
29
+ self,
30
+ inputs: list[np.ndarray | torch.Tensor | str] | np.ndarray | torch.Tensor | str,
31
+ ) -> list[ASRModelHandlerResult]:
32
+ """Generic sequential processing method for ASR model handlers"""
33
+
34
+ output = inputs
35
+
36
+ for task in self.tasks:
37
+ output = task.process(output)
38
+
39
+ return output
@@ -0,0 +1,67 @@
1
+ from dataclasses import dataclass
2
+
3
+ from nemo.collections.asr.parts.utils import Hypothesis
4
+
5
+
6
+ @dataclass
7
+ class ASRModelHandlerResult:
8
+ """Base result class for ASR models"""
9
+
10
+ input_ordering: int = -1
11
+ transcription: list[tuple] = None
12
+ score: float = None
13
+
14
+
15
+ @dataclass
16
+ class ParakeetModelHandlerResult(ASRModelHandlerResult):
17
+ """Result handler for ParakeetInferenceHandler objects"""
18
+
19
+ def parse_parakeet_hypothesis(
20
+ self, hypothesis: Hypothesis
21
+ ) -> ASRModelHandlerResult:
22
+ """Parse a hypothesis returned by a Parakeet RNN model
23
+
24
+ :param hypothesis: Parakeet hypothesis
25
+ :return: copy of self
26
+ """
27
+ self.transcription = (
28
+ [
29
+ (s["start"], s["end"], s["segment"])
30
+ for s in hypothesis.timestamp.get("segment")
31
+ ]
32
+ if hypothesis.timestamp.get("segment") is not None
33
+ else [(0.0, 0.0, hypothesis.text)]
34
+ )
35
+ self.score = round(hypothesis.score, 2)
36
+
37
+ return self
38
+
39
+ def concat(self, model_result: ASRModelHandlerResult) -> ASRModelHandlerResult:
40
+ """Left fold with ParakeetModelHandlerResult object
41
+
42
+ :param model_result: ParakeetModelHandlerResult
43
+ :return: copy of self
44
+ """
45
+ if model_result is None:
46
+ return self
47
+
48
+ if self.transcription is None:
49
+ self.transcription = []
50
+
51
+ self.transcription += model_result.transcription
52
+
53
+ # We have to weight by total segment len
54
+ transcription_duration = self.transcription[-1][1]
55
+ model_result_duration = model_result.transcription[-1][1]
56
+ total_duration = transcription_duration + model_result_duration
57
+
58
+ self.score = round(
59
+ (
60
+ self.score * transcription_duration
61
+ + model_result.score * model_result_duration
62
+ )
63
+ / total_duration,
64
+ 2,
65
+ )
66
+
67
+ return self
@@ -0,0 +1,73 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import torch
4
+
5
+ from caul.constant import DEVICE_CPU, PARAKEET_MODEL_REF
6
+ from caul.model_handlers.asr_model_handler import ASRModelHandler
7
+ from caul.tasks.inference.parakeet_inference import ParakeetInferenceHandler
8
+ from caul.tasks.postprocessing.parakeet_postprocessor import ParakeetPostprocessor
9
+ from caul.tasks.preprocessing.parakeet_preprocessor import ParakeetPreprocessor
10
+
11
+ if TYPE_CHECKING:
12
+ from caul.configs import ParakeetConfig
13
+
14
+
15
+ class ParakeetModelHandler(ASRModelHandler):
16
+ """Model handler for Parakeet family"""
17
+
18
+ def __init__(
19
+ self,
20
+ config: "ParakeetConfig" = None,
21
+ model_name: str = PARAKEET_MODEL_REF,
22
+ device: str | torch.device = DEVICE_CPU,
23
+ ):
24
+ super().__init__(config=config)
25
+
26
+ if config is not None and config.model_name is not None:
27
+ model_name = config.model_name
28
+
29
+ self.model_name = model_name
30
+
31
+ if config is not None and config.device is not None:
32
+ device = config.device
33
+
34
+ if isinstance(device, str):
35
+ device = torch.device(device)
36
+
37
+ self.device = device
38
+
39
+ self.preprocessor = ParakeetPreprocessor(
40
+ save_to_filesystem=config.save_to_filesystem,
41
+ return_tensors=config.return_tensors,
42
+ )
43
+ self.inference_handler = ParakeetInferenceHandler(
44
+ model_name=config.model_name, device=config.device
45
+ )
46
+ self.postprocessor = ParakeetPostprocessor()
47
+
48
+ self.tasks = [self.preprocessor, self.inference_handler, self.postprocessor]
49
+
50
+ def set_device(self, device: str | torch.device = DEVICE_CPU):
51
+ """Set/change device here and on inference_handler
52
+
53
+ :param device: device to use
54
+ """
55
+ if isinstance(device, str):
56
+ device = torch.device(device)
57
+
58
+ self.device = device
59
+
60
+ self.inference_handler.set_device(device)
61
+
62
+ return self
63
+
64
+ def startup(self):
65
+ """Load model"""
66
+ self.inference_handler.load()
67
+
68
+ def shutdown(self):
69
+ """Shut down"""
70
+ self.preprocessor = None
71
+ self.inference_handler = None
72
+ self.postprocessor = None
73
+ self.tasks = []
File without changes
@@ -0,0 +1,12 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+
5
+ class ASRTask(ABC):
6
+ """Generic ASR task"""
7
+
8
+ # pylint: disable=R0903
9
+
10
+ @abstractmethod
11
+ def process(self, inputs: Any, *args, **kwargs) -> list:
12
+ """Generic processing task"""
File without changes
@@ -0,0 +1,24 @@
1
+ from abc import abstractmethod
2
+
3
+ from caul.model_handlers.asr_model_handler import ASRModelHandlerResult
4
+ from caul.tasks.asr_task import ASRTask
5
+
6
+
7
+ class ASRInferenceHandler(ASRTask):
8
+ """Abstract for ASR inference"""
9
+
10
+ @abstractmethod
11
+ def process(self, inputs: list, *args, **kwargs) -> list[ASRModelHandlerResult]:
12
+ """
13
+
14
+ :param inputs: List of inference inputs
15
+ :return: ASRModelHandlerResult
16
+ """
17
+
18
+ @abstractmethod
19
+ def load(self):
20
+ """Load model"""
21
+
22
+ @abstractmethod
23
+ def unload(self):
24
+ """Unload model"""
@@ -0,0 +1,83 @@
1
+ import torch
2
+
3
+ import nemo.collections.asr as nemo_asr
4
+
5
+ from caul.constant import DEVICE_CPU
6
+ from caul.model_handlers.helpers import ParakeetModelHandlerResult
7
+ from caul.tasks.inference.asr_inference import (
8
+ ASRInferenceHandler,
9
+ )
10
+ from caul.tasks.preprocessing.helpers import PreprocessedInput
11
+
12
+
13
+ class ParakeetInferenceHandler(ASRInferenceHandler):
14
+ """Inference handler for NVIDIA's Parakeet family of ASR models. Supports up to 24 minutes of
15
+ audio (batched or unbatched) in a single pass. Assumes that audio inputs (wav files or tensors)
16
+ are single-channel with a sample rate of 16000—this last is very important for segmenting.
17
+ """
18
+
19
+ def __init__(self, model_name: str, device: str | torch.device = DEVICE_CPU):
20
+ self.model_name = model_name
21
+
22
+ if isinstance(device, str):
23
+ device = torch.device(device)
24
+
25
+ self.device = device
26
+ self.model = None
27
+
28
+ def load(self):
29
+ """Load model; default to CPU where no device is present"""
30
+ device = self.device
31
+
32
+ if device is None:
33
+ device = DEVICE_CPU
34
+
35
+ self.model = nemo_asr.models.ASRModel.from_pretrained(
36
+ self.model_name, map_location=torch.device(device)
37
+ ).eval()
38
+
39
+ def unload(self):
40
+ """Unload model"""
41
+ self.model = None
42
+
43
+ def set_device(self, device: str | torch.device = DEVICE_CPU):
44
+ """Set/change device"""
45
+ if isinstance(device, str):
46
+ device = torch.device(device)
47
+
48
+ self.device = device
49
+
50
+ return self
51
+
52
+ def process(
53
+ self,
54
+ inputs: list[list[PreprocessedInput]] | list[PreprocessedInput],
55
+ timestamps: bool = True,
56
+ ) -> list[ParakeetModelHandlerResult]:
57
+ """Transcribe a batch of audio tensors or file names of max duration <= 20 minutes
58
+
59
+ :param inputs: List of np.ndarray or torch.Tensor or str, or singleton of same types
60
+ :param timestamps: Whether to include timestamps with transcriptions
61
+ :return: List of results
62
+ """
63
+ if len(inputs) == 0:
64
+ return []
65
+
66
+ if isinstance(inputs[0], PreprocessedInput):
67
+ inputs = [inputs]
68
+
69
+ transcriptions = []
70
+
71
+ for input_batch in inputs:
72
+ hypotheses = self.model.transcribe(
73
+ [i.tensor.to(self.device) for i in input_batch], timestamps=timestamps
74
+ )
75
+ # Get timestamped segments if available, otherwise default to whole text
76
+ for idx, hyp in enumerate(hypotheses):
77
+ input_ordering_idx = input_batch[idx].metadata.input_ordering
78
+ model_result = ParakeetModelHandlerResult(
79
+ input_ordering=input_ordering_idx
80
+ ).parse_parakeet_hypothesis(hyp)
81
+ transcriptions.append(model_result)
82
+
83
+ return transcriptions
@@ -0,0 +1,20 @@
1
+ from caul.tasks.inference.asr_inference import (
2
+ ASRInferenceHandler,
3
+ ASRModelHandlerResult,
4
+ )
5
+
6
+
7
+ class WhisperCPPInferenceHandler(ASRInferenceHandler):
8
+ """Handler for WhisperCPP; wrapper round subprocess calls"""
9
+
10
+ # pylint: disable=R0903
11
+
12
+ def process(
13
+ self,
14
+ inputs: list[str],
15
+ ) -> list[ASRModelHandlerResult]:
16
+ """List of np.ndarray or torch.Tensor or str, or a singleton of same types
17
+
18
+ :param inputs: List of np.ndarray or torch.Tensor or str, or a singleton of same types
19
+ :return:
20
+ """
File without changes
@@ -0,0 +1,55 @@
1
+ from functools import reduce
2
+ from itertools import groupby
3
+
4
+ from caul.tasks.asr_task import ASRTask
5
+ from caul.tasks.inference.parakeet_inference import ParakeetModelHandlerResult
6
+
7
+
8
+ class ParakeetPostprocessor(ASRTask):
9
+ """Postprocessing logic for ParakeetInferenceHandler output"""
10
+
11
+ def process(
12
+ self, inputs: list[ParakeetModelHandlerResult]
13
+ ) -> list[ParakeetModelHandlerResult]:
14
+ """Process indexed ParakeetInferenceHandler results and return them in their original
15
+ ordering
16
+
17
+ :param inputs: List of parakeet model results
18
+ :return: list of parakeet model results in input ordering
19
+ """
20
+
21
+ return self.map_results_to_inputs(inputs)
22
+
23
+ @staticmethod
24
+ def map_results_to_inputs(
25
+ batched_results: list[ParakeetModelHandlerResult],
26
+ ) -> list[ParakeetModelHandlerResult]:
27
+ """Remap unordered and segmented tensors to original inputs for return
28
+
29
+ :param batched_results: list of unordered ParakeetModelHandlerResult, still
30
+ segmented
31
+ :return: list[ParakeetModelHandlerResult]
32
+ """
33
+ unbatched_results = []
34
+
35
+ # Sort in order before batching
36
+ batched_results = sorted(batched_results, key=lambda r: r.input_ordering)
37
+
38
+ # Concat segmented tensors
39
+ results_grouped_by_index = groupby(
40
+ batched_results, key=lambda r: r.input_ordering
41
+ )
42
+
43
+ for _, group_results in results_grouped_by_index:
44
+ group_results = list(group_results)
45
+
46
+ merged_results = (
47
+ reduce(lambda l, r: l.concat(r), group_results)
48
+ if len(group_results) > 1
49
+ else group_results[0]
50
+ )
51
+ unbatched_results.append(merged_results)
52
+
53
+ # TODO: Drop index from result
54
+
55
+ return unbatched_results
File without changes
@@ -0,0 +1,32 @@
1
+ import datetime
2
+ import uuid
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+
10
+ @dataclass
11
+ class InputMetadata:
12
+ """Preprocessed input metadata"""
13
+
14
+ input_ordering: int
15
+ duration: int
16
+ start_time: int = 0
17
+ end_time: int = 0
18
+ preprocessed_at: str = (
19
+ datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0).isoformat()
20
+ )
21
+ uuid: str = uuid.uuid4().hex
22
+ input_format: str = None
23
+ input_file_path: str = None
24
+ preprocessed_file_path: str = None
25
+
26
+
27
+ @dataclass
28
+ class PreprocessedInput:
29
+ """Preprocessed input wrapper"""
30
+
31
+ metadata: InputMetadata
32
+ tensor: Optional[torch.Tensor | list] = None
@@ -0,0 +1,258 @@
1
+ import librosa
2
+ import torch
3
+
4
+ import numpy as np
5
+ import torchaudio
6
+
7
+ from caul.constant import (
8
+ PARAKEET_INFERENCE_MAX_DURATION_KHZ,
9
+ EXPECTED_SAMPLE_MINUTE,
10
+ EXPECTED_SAMPLE_RATE,
11
+ PARAKEET_INFERENCE_MAX_DURATION_MIN,
12
+ )
13
+ from caul.filesystem import save_tensor
14
+ from caul.tasks.asr_task import ASRTask
15
+ from caul.tasks.preprocessing.helpers import PreprocessedInput, InputMetadata
16
+
17
+
18
+ class ParakeetPreprocessor(ASRTask):
19
+ """Preprocessing logic for ParakeetInferenceHandler inputs"""
20
+
21
+ def __init__(
22
+ self,
23
+ save_to_filesystem: bool = True,
24
+ return_tensors: bool = True,
25
+ sample_rate: int = EXPECTED_SAMPLE_RATE,
26
+ ):
27
+ super().__init__()
28
+
29
+ self.save_to_filesystem = save_to_filesystem
30
+ self.return_tensors = return_tensors
31
+ self.sample_rate = sample_rate
32
+
33
+ def process(
34
+ self,
35
+ inputs: list[np.ndarray | torch.Tensor | str] | np.ndarray | torch.Tensor | str,
36
+ input_sample_rates: list[int] | int = None,
37
+ ) -> list[list[PreprocessedInput]]:
38
+ """Segment and batch audio inputs
39
+
40
+ :param inputs: List of np.ndarray or torch.Tensor or str, or singleton of same types
41
+ :param input_sample_rates: sample rate(s) of audio inputs
42
+ :return: batches of indexed preprocessed audio tensors (input_idx, preprocessed_input)
43
+ """
44
+ if not isinstance(inputs, list):
45
+ inputs = [inputs]
46
+
47
+ preprocessed_inputs = self.preprocess_inputs(inputs, input_sample_rates)
48
+ batches = self.batch_audio_tensors(preprocessed_inputs)
49
+
50
+ return batches
51
+
52
+ def preprocess_inputs(
53
+ self,
54
+ inputs: list[np.ndarray | torch.Tensor | str],
55
+ input_sample_rates: list[int] = None,
56
+ ) -> list[PreprocessedInput]:
57
+ """Accepts audio inputs as a list of file paths, np.ndarray, or torch.Tensor, converting to
58
+ torch.Tensor, normalizing, segmenting inputs longer than 20 minutes (just under Parakeet's
59
+ max) first by silences or with overlaps where not available, and batching segments
60
+
61
+ :param inputs: List of np.ndarray or torch.Tensor or str, or a singleton of same types
62
+ :param input_sample_rates: sample rate(s) of audio inputs
63
+ :return: List of processed inputs
64
+ """
65
+ preprocessed_inputs = []
66
+
67
+ # Load arrays and divide into max_length segments
68
+ for input_idx, audio_input in enumerate(inputs):
69
+ input_file_path = None
70
+ new_file_path = None
71
+ input_format = None
72
+
73
+ # Load audio files as arrays
74
+ if isinstance(audio_input, str):
75
+ input_file_path = audio_input
76
+ input_format = (
77
+ input_file_path.split(".")[-1]
78
+ if len(input_file_path.split(".")) > 1
79
+ else None
80
+ )
81
+ audio_input, sample_rate = torchaudio.load(audio_input)
82
+
83
+ if isinstance(audio_input, np.ndarray):
84
+ audio_input = torch.Tensor(audio_input)
85
+
86
+ # Normalize
87
+ if input_sample_rates is not None and len(input_sample_rates) > input_idx:
88
+ sample_rate = input_sample_rates[input_idx]
89
+ else:
90
+ sample_rate = self.sample_rate
91
+
92
+ audio_input = self.normalize(audio_input, sample_rate)
93
+
94
+ # Segment where necessary
95
+ duration_khz = audio_input.shape[-1]
96
+ tensor_segments = [audio_input]
97
+
98
+ if duration_khz > PARAKEET_INFERENCE_MAX_DURATION_KHZ:
99
+ tensor_segments = self.segment_audio_tensor(audio_input)
100
+
101
+ for tensor_segment in tensor_segments:
102
+ # Create temporary filesystem reference if applicable
103
+ if self.save_to_filesystem:
104
+ new_file_path = save_tensor(tensor_segment)
105
+
106
+ if not self.return_tensors:
107
+ tensor_segment = None
108
+
109
+ # Create preprocessed input
110
+ metadata = InputMetadata(
111
+ input_ordering=input_idx,
112
+ duration=duration_khz / EXPECTED_SAMPLE_MINUTE,
113
+ input_format=input_format,
114
+ input_file_path=input_file_path,
115
+ preprocessed_file_path=new_file_path,
116
+ )
117
+
118
+ preprocessed_input = PreprocessedInput(
119
+ tensor=tensor_segment,
120
+ metadata=metadata,
121
+ )
122
+
123
+ preprocessed_inputs.append(preprocessed_input)
124
+
125
+ return preprocessed_inputs
126
+
127
+ def normalize(self, audio_tensor: torch.Tensor, sample_rate: int) -> torch.Tensor:
128
+ """Normalize audio_tensor (single channel, sample rate = 16000)
129
+
130
+ :param audio_tensor: input tensor
131
+ :param sample_rate: input sample rate
132
+ :return: normalized tensor
133
+ """
134
+ if sample_rate != self.sample_rate:
135
+ audio_tensor = self.resample_waveform(audio_tensor, sample_rate)
136
+
137
+ # Stereo dims (channels, aud_length); need mono (aud_length)
138
+ if len(audio_tensor.shape) > 1:
139
+ audio_tensor = audio_tensor.squeeze(0)
140
+
141
+ return audio_tensor
142
+
143
+ @staticmethod
144
+ def segment_audio_tensor(
145
+ audio_tensor: torch.Tensor,
146
+ frame_len: int = 2048,
147
+ silence_thresh_db: int = 35,
148
+ hop_len: int = 512,
149
+ kept_silence_len_secs: int = 0.15,
150
+ min_silence_len_secs: int = 0.5,
151
+ max_segment_len_secs: int = EXPECTED_SAMPLE_MINUTE
152
+ * PARAKEET_INFERENCE_MAX_DURATION_MIN,
153
+ ) -> list[torch.Tensor]:
154
+ """Splits on silences with librosa, falling back to overlaps where min segments
155
+ are not sufficient to safely divide audio.
156
+
157
+ :param audio_tensor: input tensor
158
+ :param frame_len: number of samples per analysis frame
159
+ :param silence_thresh_db: max decibel value
160
+ :param hop_len: number of samples between analysis frames
161
+ :param kept_silence_len_secs: number of seconds to keep silence
162
+ :param min_silence_len_secs: minimum seconds to keep silence
163
+ :param max_segment_len_secs: maximum seconds to keep silence
164
+ :return: list of tensor segments
165
+ """
166
+ # TODO: Implement fallback to overlaps
167
+ tensor_segments = []
168
+
169
+ # Intervals between silences
170
+ nonsilent_intervals = librosa.effects.split(
171
+ audio_tensor.numpy(),
172
+ top_db=silence_thresh_db,
173
+ frame_length=frame_len,
174
+ hop_length=hop_len,
175
+ )
176
+
177
+ merged = []
178
+ min_silence_sample_len = int(min_silence_len_secs * EXPECTED_SAMPLE_MINUTE)
179
+ kept_silence_sample_len = int(kept_silence_len_secs * EXPECTED_SAMPLE_MINUTE)
180
+ max_segment_sample_len = int(max_segment_len_secs * EXPECTED_SAMPLE_MINUTE)
181
+
182
+ # Merge intervals separated by short silences
183
+ for start, end in nonsilent_intervals:
184
+ if len(merged) == 0:
185
+ merged.append((start, end))
186
+ else:
187
+ _, prev_end = merged[-1]
188
+ if start - prev_end < min_silence_sample_len:
189
+ merged[-1][1] = end
190
+ else:
191
+ merged.append((start, end))
192
+
193
+ # Segment controlling max length
194
+ for start, end in merged:
195
+ start = max(0, start - kept_silence_sample_len)
196
+ end = min(audio_tensor.shape[-1], end + kept_silence_sample_len)
197
+
198
+ while end - start > max_segment_sample_len:
199
+ segment_end = start + max_segment_sample_len
200
+ tensor_segment = audio_tensor[start:segment_end]
201
+
202
+ tensor_segments.append(tensor_segment)
203
+
204
+ return tensor_segments
205
+
206
+ @staticmethod
207
+ def batch_audio_tensors( # pylint: disable=R0914
208
+ preprocessed_inputs: list[PreprocessedInput],
209
+ ) -> list[list[PreprocessedInput]]:
210
+ """Batch audio tensors by duration, 20 minutes max per batch, optimizing for tightly packed
211
+ batches.
212
+
213
+ :param preprocessed_inputs: list of PreprocessedInput
214
+ :return: list of list[PreprocessedInput]
215
+ """
216
+
217
+ # Sort by duration
218
+ preprocessed_inputs = sorted(
219
+ preprocessed_inputs, key=lambda p: p.metadata.duration, reverse=True
220
+ )
221
+
222
+ # Now this becomes a bin-packing minimization problem. We'll use a variant of best-fit
223
+ # decreasing.
224
+
225
+ bins = [[]]
226
+ bins_len = [0]
227
+
228
+ # With each pass, choose a bin by maximizing remaining space
229
+ for preprocessed_input in preprocessed_inputs:
230
+ bin_len_diffs = []
231
+
232
+ for bin_len in bins_len:
233
+ bin_len_diffs.append(PARAKEET_INFERENCE_MAX_DURATION_MIN - bin_len)
234
+
235
+ if max(bin_len_diffs) <= preprocessed_input.metadata.duration:
236
+ bins.append([])
237
+ bins_len.append(0)
238
+ bin_len_diffs.append(PARAKEET_INFERENCE_MAX_DURATION_MIN)
239
+
240
+ max_diff_idx = np.argmax(bin_len_diffs)
241
+
242
+ bins[max_diff_idx].append(preprocessed_input)
243
+
244
+ bins_len[max_diff_idx] += preprocessed_input.metadata.duration
245
+
246
+ return bins
247
+
248
+ def resample_waveform(
249
+ self, waveform: torch.Tensor, sample_rate: int
250
+ ) -> torch.Tensor:
251
+ """Resample when sample rate is not 16000
252
+
253
+ :param waveform: torch.Tensor
254
+ :param sample_rate: int
255
+ :return: resampled torch.Tensor
256
+ """
257
+ transform = torchaudio.transforms.Resample(sample_rate, self.sample_rate)
258
+ return transform(waveform)
@@ -0,0 +1,24 @@
1
+ from typing import Any
2
+
3
+
4
+ def dict_key_fuzzy_match(dict_obj: dict, search_key: str) -> Any | None:
5
+ """Match a dict key name fuzzily (returning first match)
6
+
7
+ :param dict_obj: dictionary to search in
8
+ :param search_key: key name
9
+ :return: key value (if exists)
10
+ """
11
+
12
+ if search_key in dict_obj:
13
+ return dict_obj[search_key]
14
+
15
+ fuzzy_matches = [
16
+ dict_value
17
+ for dict_key, dict_value in dict_obj.items()
18
+ if dict_key in search_key or search_key in dict_key
19
+ ]
20
+
21
+ if len(fuzzy_matches) > 0:
22
+ return fuzzy_matches[0]
23
+
24
+ return None