simulstream 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +47 -0
- simulstream/__init__.py +15 -0
- simulstream/client/__init__.py +0 -0
- simulstream/client/wav_reader_client.py +228 -0
- simulstream/config.py +31 -0
- simulstream/inference.py +170 -0
- simulstream/metrics/__init__.py +0 -0
- simulstream/metrics/detokenizers.py +71 -0
- simulstream/metrics/logger.py +32 -0
- simulstream/metrics/readers.py +348 -0
- simulstream/metrics/score_latency.py +130 -0
- simulstream/metrics/score_quality.py +169 -0
- simulstream/metrics/scorers/__init__.py +0 -0
- simulstream/metrics/scorers/latency/__init__.py +115 -0
- simulstream/metrics/scorers/latency/mwersegmenter.py +136 -0
- simulstream/metrics/scorers/latency/stream_laal.py +119 -0
- simulstream/metrics/scorers/quality/__init__.py +132 -0
- simulstream/metrics/scorers/quality/comet.py +57 -0
- simulstream/metrics/scorers/quality/mwersegmenter.py +93 -0
- simulstream/metrics/scorers/quality/sacrebleu.py +59 -0
- simulstream/metrics/stats.py +184 -0
- simulstream/server/__init__.py +0 -0
- simulstream/server/http_server.py +95 -0
- simulstream/server/message_processor.py +156 -0
- simulstream/server/speech_processors/__init__.py +173 -0
- simulstream/server/speech_processors/base.py +135 -0
- simulstream/server/speech_processors/base_streamatt.py +320 -0
- simulstream/server/speech_processors/canary_sliding_window_retranslation.py +73 -0
- simulstream/server/speech_processors/hf_sliding_window_retranslation.py +87 -0
- simulstream/server/speech_processors/incremental_output.py +85 -0
- simulstream/server/speech_processors/seamless_sliding_window_retranslation.py +84 -0
- simulstream/server/speech_processors/seamless_streamatt.py +268 -0
- simulstream/server/speech_processors/simuleval_wrapper.py +165 -0
- simulstream/server/speech_processors/sliding_window_retranslation.py +135 -0
- simulstream/server/speech_processors/vad_wrapper.py +180 -0
- simulstream/server/websocket_server.py +236 -0
- simulstream-0.1.0.dist-info/METADATA +465 -0
- simulstream-0.1.0.dist-info/RECORD +48 -0
- simulstream-0.1.0.dist-info/WHEEL +5 -0
- simulstream-0.1.0.dist-info/entry_points.txt +8 -0
- simulstream-0.1.0.dist-info/licenses/LICENSE +201 -0
- simulstream-0.1.0.dist-info/top_level.txt +3 -0
- uts/__init__.py +0 -0
- uts/metrics/__init__.py +0 -0
- uts/metrics/log_reader.py +50 -0
- uts/speech_processors/__init__.py +0 -0
- uts/speech_processors/test_simuleval_wrapper.py +88 -0
- uts/utils.py +5 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Copyright 2025 FBK
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License
|
|
14
|
+
|
|
15
|
+
from abc import abstractmethod
|
|
16
|
+
from types import SimpleNamespace
|
|
17
|
+
from typing import List, Union, Dict
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from simulstream.server.speech_processors import SpeechProcessor
|
|
23
|
+
from simulstream.server.speech_processors.incremental_output import IncrementalOutput
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class BaseSpeechProcessor(SpeechProcessor):
|
|
27
|
+
"""
|
|
28
|
+
A partial implementation of :class:`SpeechProcessor` that provides
|
|
29
|
+
common logic for handling incremental speech-to-text processing.
|
|
30
|
+
|
|
31
|
+
This class defines the high-level workflow of processing an incoming
|
|
32
|
+
audio chunk (preprocessing, generation, updating history, building
|
|
33
|
+
outputs), while leaving the model-specific details to subclasses.
|
|
34
|
+
|
|
35
|
+
Subclasses must implement the abstract helper methods to define how
|
|
36
|
+
audio is preprocessed, tokens are generated, and histories are updated.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, config: SimpleNamespace):
|
|
40
|
+
super().__init__(config)
|
|
41
|
+
self.tgt_lang_tag = None
|
|
42
|
+
self.src_lang_tag = None
|
|
43
|
+
self.audio_history = None
|
|
44
|
+
self.text_history = None
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def _preprocess(self, waveform: np.float32) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
|
|
48
|
+
"""
|
|
49
|
+
Convert a raw audio waveform into model-ready features.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
waveform (np.ndarray): A 1D NumPy array of type ``float32``
|
|
53
|
+
containing normalized audio samples in the range ``[-1.0, 1.0]``. This is the new
|
|
54
|
+
incoming chunk of audio.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Union[Dict[str, torch.Tensor], torch.Tensor]:
|
|
58
|
+
A tensor or a dictionary of tensors ready for model inference.
|
|
59
|
+
"""
|
|
60
|
+
...
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def _update_speech_history(
|
|
64
|
+
self,
|
|
65
|
+
new_speech: torch.Tensor,
|
|
66
|
+
generated_tokens: List[str],
|
|
67
|
+
new_output: IncrementalOutput) -> None:
|
|
68
|
+
"""
|
|
69
|
+
Update the internal audio history.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
new_speech (torch.Tensor): The newly preprocessed speech features.
|
|
73
|
+
generated_tokens (List[str]): Tokens generated from the new input.
|
|
74
|
+
new_output (IncrementalOutput): Incremental output object containing
|
|
75
|
+
new and deleted tokens/strings.
|
|
76
|
+
"""
|
|
77
|
+
...
|
|
78
|
+
|
|
79
|
+
@abstractmethod
|
|
80
|
+
def _update_text_history(
|
|
81
|
+
self,
|
|
82
|
+
new_speech: torch.Tensor,
|
|
83
|
+
generated_tokens: List[str],
|
|
84
|
+
new_output: IncrementalOutput) -> None:
|
|
85
|
+
"""
|
|
86
|
+
Update the internal text history with new generated tokens.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
new_speech (torch.Tensor): The newly preprocessed speech features.
|
|
90
|
+
generated_tokens (List[str]): Tokens generated from the new input.
|
|
91
|
+
new_output (IncrementalOutput): Incremental output object containing
|
|
92
|
+
new and deleted tokens/strings.
|
|
93
|
+
"""
|
|
94
|
+
...
|
|
95
|
+
|
|
96
|
+
@abstractmethod
|
|
97
|
+
def _generate(self, speech: Union[Dict[str, torch.Tensor], torch.Tensor]) -> List[str]:
|
|
98
|
+
"""
|
|
99
|
+
Generate tokens from the given speech features.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
speech (Union[Dict[str, torch.Tensor], torch.Tensor]):
|
|
103
|
+
Model-ready speech features as produced by :meth:`_preprocess`.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
List[str]: A list of generated tokens.
|
|
107
|
+
"""
|
|
108
|
+
...
|
|
109
|
+
|
|
110
|
+
@abstractmethod
|
|
111
|
+
def _build_incremental_outputs(self, generated_tokens: List[str]) -> IncrementalOutput:
|
|
112
|
+
"""
|
|
113
|
+
Build an :class:`IncrementalOutput` object from generated tokens.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
generated_tokens (List[str]): Tokens generated by :meth:`_generate`.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
IncrementalOutput: The structured incremental output.
|
|
120
|
+
"""
|
|
121
|
+
...
|
|
122
|
+
|
|
123
|
+
def process_chunk(self, waveform: np.float32) -> IncrementalOutput:
|
|
124
|
+
speech = self._preprocess(waveform)
|
|
125
|
+
generated_tokens = self._generate(speech)
|
|
126
|
+
new_output = self._build_incremental_outputs(generated_tokens)
|
|
127
|
+
self._update_speech_history(speech, generated_tokens, new_output)
|
|
128
|
+
self._update_text_history(speech, generated_tokens, new_output)
|
|
129
|
+
return new_output
|
|
130
|
+
|
|
131
|
+
def clear(self) -> None:
|
|
132
|
+
self.text_history = None
|
|
133
|
+
self.audio_history = None
|
|
134
|
+
self.src_lang_tag = None
|
|
135
|
+
self.tgt_lang_tag = None
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
# Copyright 2025 FBK
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import logging
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from types import SimpleNamespace
|
|
20
|
+
from abc import abstractmethod
|
|
21
|
+
from typing import List, Tuple
|
|
22
|
+
|
|
23
|
+
from simulstream.server.speech_processors import class_load
|
|
24
|
+
from simulstream.server.speech_processors.base import BaseSpeechProcessor
|
|
25
|
+
from simulstream.server.speech_processors.incremental_output import IncrementalOutput
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
BOW_PREFIX = "\u2581"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class BaseStreamAtt(BaseSpeechProcessor):
|
|
35
|
+
"""
|
|
36
|
+
A partial implementation of :class:`BaseSpeechProcessor` that provides common logic for the
|
|
37
|
+
StreamAtt policy, introduced in:
|
|
38
|
+
|
|
39
|
+
S. Papi, et al. 2024. *"StreamAtt: Direct Streaming Speech-to-Text Translation with
|
|
40
|
+
Attention-based Audio History Selection"* (https://aclanthology.org/2024.acl-long.202/)
|
|
41
|
+
|
|
42
|
+
The approach relies on selecting the audio history based on the cross-attention mechanism.
|
|
43
|
+
Specifically, the history for the next decoding step is defined as follows:
|
|
44
|
+
- First, the new textual history is selected by the **text_history_method**, which is in
|
|
45
|
+
charge of selecting the tokens to retain;
|
|
46
|
+
- Second, the new audio history is selected according to cross-attention scores between the
|
|
47
|
+
audio features and the retained textual history by discarding past features that are not
|
|
48
|
+
attended by any tokens of the textual history.
|
|
49
|
+
|
|
50
|
+
The derived class should implement the following methods:
|
|
51
|
+
- **audio_max_len**: Returns the maximum audio feature length.
|
|
52
|
+
- **load_model**: Loads the model to device.
|
|
53
|
+
- **_preprocess**: Preprocess the audio features before feeding them into the model.
|
|
54
|
+
- **_generate**: Generate that also returns cross attention scores.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
config (SimpleNamespace): Configuration object. The following attributes are expected:
|
|
58
|
+
- **text_history (str)**: config (SimpleNamespace) with the following attribute:
|
|
59
|
+
- **type (str)**: Name of the class to use to determine the text history to keep as
|
|
60
|
+
context for next predictions.
|
|
61
|
+
- **audio_subsampling_factor (int)**: Subsampling factor of the model, if any.
|
|
62
|
+
Defaults to 1.
|
|
63
|
+
- **text_history_max_len (int)**: The maximum length of the textual history after which
|
|
64
|
+
the current content is cut. Defaults to 128.
|
|
65
|
+
- **cross_attention_layer (int)**: Layer from which to extract the cross-attention from.
|
|
66
|
+
- **cutoff_frame_num (int)**: Number of last frames that cannot be attended by tokens
|
|
67
|
+
in the AlignAtt policy.
|
|
68
|
+
- **word_level_postprocess (bool)**: Whether to postprocess the generated tokens to keep
|
|
69
|
+
only complete words in the emitted hypothesis. To be disabled when operating with
|
|
70
|
+
character-level languages. Defaults to True.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(self, config: SimpleNamespace):
|
|
74
|
+
super().__init__(config)
|
|
75
|
+
self.config = config
|
|
76
|
+
text_history_config = self.config.text_history
|
|
77
|
+
text_history_cls = class_load(text_history_config.type)
|
|
78
|
+
self.text_history_method = text_history_cls(text_history_config)
|
|
79
|
+
self.audio_subsampling_factor = getattr(self.config, "audio_subsampling_factor", 1)
|
|
80
|
+
self.text_history_max_len = getattr(self.config, "text_history_max_len", 128)
|
|
81
|
+
self.cross_attn_layer = getattr(self.config, "cross_attention_layer", 3)
|
|
82
|
+
self.cutoff_frame_num = getattr(self.config, "cutoff_frame_num", 2)
|
|
83
|
+
self.word_level_postprocess = getattr(self.config, "word_level_postprocess", True)
|
|
84
|
+
self.unselected_tokens = []
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
@abstractmethod
|
|
88
|
+
def audio_max_len(self) -> float:
|
|
89
|
+
"""
|
|
90
|
+
Return the maximum allowed length of the audio features, beyond which the audio is cut off.
|
|
91
|
+
"""
|
|
92
|
+
...
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
def _generate(self, speech: torch.Tensor) -> Tuple[List[str], torch.Tensor]:
|
|
96
|
+
"""
|
|
97
|
+
Generate tokens from the given speech features together with the cross-attention scores.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
speech (torch.Tensor): Model-ready speech features as produced by :meth:`_preprocess`.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Tuple[List[str], torch.Tensor]:
|
|
104
|
+
List[str]: A list of generated tokens.
|
|
105
|
+
torch.Tensor: Cross-attention scores with dimension (generated_tokens,
|
|
106
|
+
input_length).
|
|
107
|
+
"""
|
|
108
|
+
...
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def normalize_attn(attn):
|
|
112
|
+
"""
|
|
113
|
+
Normalize the cross-attention scores along the frame dimension to avoid attention sinks.
|
|
114
|
+
"""
|
|
115
|
+
std = attn.std(axis=0)
|
|
116
|
+
std[std == 0.] = 1.0
|
|
117
|
+
mean = attn.mean(axis=0)
|
|
118
|
+
return (attn - mean) / std
|
|
119
|
+
|
|
120
|
+
def _update_text_history(self, new_output: List[str]) -> int:
|
|
121
|
+
if self.text_history:
|
|
122
|
+
self.text_history += new_output
|
|
123
|
+
else:
|
|
124
|
+
self.text_history = new_output
|
|
125
|
+
new_history = self.text_history_method.select_text_history(self.text_history)
|
|
126
|
+
discarded_text = len(self.text_history) - len(new_history)
|
|
127
|
+
self.text_history = new_history
|
|
128
|
+
|
|
129
|
+
# Ensure not exceeding max text history length
|
|
130
|
+
if self.text_history and len(self.text_history) > self.text_history_max_len:
|
|
131
|
+
if logger.isEnabledFor(logging.WARNING):
|
|
132
|
+
logger.warning(
|
|
133
|
+
f"The textual history has hit the maximum predefined length of "
|
|
134
|
+
f"{self.text_history_max_len}")
|
|
135
|
+
self.text_history = self.text_history[-self.text_history_max_len:]
|
|
136
|
+
return discarded_text
|
|
137
|
+
|
|
138
|
+
def _cut_audio_exceeding_maxlen(self):
|
|
139
|
+
# Ensure not exceeding max audio history length
|
|
140
|
+
if len(self.audio_history) > self.audio_max_len:
|
|
141
|
+
if logger.isEnabledFor(logging.WARNING):
|
|
142
|
+
logger.warning(
|
|
143
|
+
f"The audio history has hit the maximum predefined length of "
|
|
144
|
+
f"{self.audio_max_len}")
|
|
145
|
+
self.audio_history = self.audio_history[-self.audio_max_len:]
|
|
146
|
+
|
|
147
|
+
def _update_speech_history(self, discarded_text: int, cross_attn: torch.Tensor) -> None:
|
|
148
|
+
# If no history is discarded, no need for attention-based audio trimming
|
|
149
|
+
if discarded_text == 0:
|
|
150
|
+
# Check audio history not exceeding maximum allowed length
|
|
151
|
+
self._cut_audio_exceeding_maxlen()
|
|
152
|
+
return
|
|
153
|
+
|
|
154
|
+
# Trim the cross-attention by excluding the discarded new generated tokens and the
|
|
155
|
+
# discarded textual history. Output shape: (text_history_len, n_audio_features)
|
|
156
|
+
cross_attn = cross_attn[discarded_text:discarded_text + len(self.text_history), :]
|
|
157
|
+
|
|
158
|
+
# Compute the frame to which each token of the textual history mostly attends to
|
|
159
|
+
most_attended_idxs = torch.argmax(cross_attn.float(), dim=1)
|
|
160
|
+
|
|
161
|
+
# Find the first feature that is attended
|
|
162
|
+
if most_attended_idxs.shape[0] > 1:
|
|
163
|
+
# Multiple tokens: sort and get the earliest attended frame
|
|
164
|
+
sorted_idxs = torch.sort(most_attended_idxs)[0]
|
|
165
|
+
earliest_attended_idx = sorted_idxs[0]
|
|
166
|
+
else:
|
|
167
|
+
# Only one token: use the unique most attended frame
|
|
168
|
+
earliest_attended_idx = most_attended_idxs[0]
|
|
169
|
+
|
|
170
|
+
# Multiply by the subsampling factor to recover the original number of frames
|
|
171
|
+
frames_to_cut = earliest_attended_idx * self.audio_subsampling_factor
|
|
172
|
+
|
|
173
|
+
# Cut the unattended audio features
|
|
174
|
+
self.audio_history = self.audio_history[frames_to_cut:]
|
|
175
|
+
|
|
176
|
+
# Check audio history not exceeding maximum allowed length
|
|
177
|
+
self._cut_audio_exceeding_maxlen()
|
|
178
|
+
|
|
179
|
+
@staticmethod
|
|
180
|
+
def _strip_incomplete_words(tokens: List[str]) -> List[str]:
|
|
181
|
+
"""
|
|
182
|
+
Remove last incomplete word(s) from the new hypothesis.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
tokens (List[str]): selected tokens, possibly containing partial words to be removed.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
List[str]: A list of generated tokens from which partial words are removed.
|
|
189
|
+
"""
|
|
190
|
+
tokens_to_write = []
|
|
191
|
+
# iterate from the end and count how many trailing tokens to drop
|
|
192
|
+
num_tokens_incomplete = 0
|
|
193
|
+
for tok in reversed(tokens):
|
|
194
|
+
num_tokens_incomplete += 1
|
|
195
|
+
if tok.startswith(BOW_PREFIX):
|
|
196
|
+
# slice off the trailing incomplete tokens
|
|
197
|
+
tokens_to_write = tokens[:-num_tokens_incomplete]
|
|
198
|
+
break
|
|
199
|
+
return tokens_to_write
|
|
200
|
+
|
|
201
|
+
def alignatt_policy(self, generated_tokens, cross_attn) -> List[str]:
|
|
202
|
+
"""
|
|
203
|
+
Apply the AlignAtt policy by cutting off tokens whose attention falls
|
|
204
|
+
beyond the allowed frame range.
|
|
205
|
+
The AlignAtt policy was introduced in:
|
|
206
|
+
S. Papi, et al. 2023. *"AlignAtt: Using Attention-based Audio-Translation
|
|
207
|
+
Alignments as a Guide for Simultaneous Speech Translation"*
|
|
208
|
+
(https://www.isca-archive.org/interspeech_2023/papi23_interspeech.html)
|
|
209
|
+
"""
|
|
210
|
+
# Select attention scores corresponding to the new generated tokens
|
|
211
|
+
cross_attn = cross_attn[-len(generated_tokens):, :]
|
|
212
|
+
selected_tokens = generated_tokens
|
|
213
|
+
|
|
214
|
+
# Find the frame to which each token mostly attends to
|
|
215
|
+
most_attended_frames = torch.argmax(cross_attn, dim=1)
|
|
216
|
+
cutoff = cross_attn.size(1) - self.cutoff_frame_num
|
|
217
|
+
|
|
218
|
+
# Find the first token that attends beyond the cutoff frame
|
|
219
|
+
invalid_tok_ids = torch.where(most_attended_frames >= cutoff)[0]
|
|
220
|
+
|
|
221
|
+
# Truncate tokens up to the first invalid alignment (if any)
|
|
222
|
+
if len(invalid_tok_ids) > 0:
|
|
223
|
+
selected_tokens = selected_tokens[:invalid_tok_ids[0]]
|
|
224
|
+
|
|
225
|
+
if self.word_level_postprocess:
|
|
226
|
+
selected_tokens = self._strip_incomplete_words(selected_tokens)
|
|
227
|
+
|
|
228
|
+
# Store unselected tokens, to be used in the case of end of stream
|
|
229
|
+
self.unselected_tokens = generated_tokens[len(selected_tokens):]
|
|
230
|
+
|
|
231
|
+
return selected_tokens
|
|
232
|
+
|
|
233
|
+
def _build_incremental_outputs(self, generated_tokens: List[str]) -> IncrementalOutput:
|
|
234
|
+
return IncrementalOutput(
|
|
235
|
+
new_tokens=generated_tokens,
|
|
236
|
+
new_string=self.tokens_to_string(generated_tokens),
|
|
237
|
+
deleted_tokens=[],
|
|
238
|
+
deleted_string="",
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
def process_chunk(self, waveform: np.float32) -> IncrementalOutput:
|
|
242
|
+
speech = self._preprocess(waveform)
|
|
243
|
+
# Generate new hypothesis with its corresponding cross-attention scores (no prefix)
|
|
244
|
+
generated_tokens, cross_attn = self._generate(speech)
|
|
245
|
+
# Select the part of the new hypothesis to be emitted, and trim cross-attention accordingly
|
|
246
|
+
selected_output = self.alignatt_policy(generated_tokens, cross_attn)
|
|
247
|
+
incremental_output = self._build_incremental_outputs(selected_output)
|
|
248
|
+
# Discard textual history, if needed
|
|
249
|
+
discarded_text = self._update_text_history(selected_output)
|
|
250
|
+
# Trim audio corresponding to the discarded textual history
|
|
251
|
+
self._update_speech_history(discarded_text, cross_attn)
|
|
252
|
+
return incremental_output
|
|
253
|
+
|
|
254
|
+
def end_of_stream(self) -> IncrementalOutput:
|
|
255
|
+
last_output = self._build_incremental_outputs(self.unselected_tokens)
|
|
256
|
+
self.unselected_tokens = []
|
|
257
|
+
return last_output
|
|
258
|
+
|
|
259
|
+
def clear(self) -> None:
|
|
260
|
+
super().clear()
|
|
261
|
+
self.text_history = None
|
|
262
|
+
self.audio_history = None
|
|
263
|
+
self.unselected_tokens = []
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class FixedWordsTextHistory:
|
|
267
|
+
"""
|
|
268
|
+
Fixed Words textual history selection method that retains a pre-defined
|
|
269
|
+
number of words in the history (*history_words*).
|
|
270
|
+
|
|
271
|
+
The current implementation supports only SentencePiece.
|
|
272
|
+
"""
|
|
273
|
+
def __init__(self, config: SimpleNamespace):
|
|
274
|
+
self.history_words = getattr(config, "history_words", 20)
|
|
275
|
+
self.config = config
|
|
276
|
+
|
|
277
|
+
def select_text_history(self, text_history: List[str]):
|
|
278
|
+
words_to_keep = self.history_words
|
|
279
|
+
new_history = []
|
|
280
|
+
for token in reversed(text_history):
|
|
281
|
+
new_history.append(token)
|
|
282
|
+
# Check if 'BOW_PREFIX' (space in SentencePiece) is contained in the token,
|
|
283
|
+
# meaning that we reached the beginning of the word that should be counted
|
|
284
|
+
if BOW_PREFIX in token:
|
|
285
|
+
words_to_keep -= 1
|
|
286
|
+
# When all the words to keep are consumed, the accumulation is stopped
|
|
287
|
+
# and the prefix is returned
|
|
288
|
+
if words_to_keep == 0:
|
|
289
|
+
break
|
|
290
|
+
# Reverse the list
|
|
291
|
+
return new_history[::-1]
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class PunctuationTextHistory:
|
|
295
|
+
"""
|
|
296
|
+
Punctuation textual history selection method that retains the sentence
|
|
297
|
+
before the last strong punctuation character.
|
|
298
|
+
|
|
299
|
+
The current implementation supports only SentencePiece.
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
STRONG_PUNCTUATION = [".", "!", "?", ":", ";"]
|
|
303
|
+
|
|
304
|
+
def __init__(self, config: SimpleNamespace):
|
|
305
|
+
self.config = config
|
|
306
|
+
|
|
307
|
+
def select_text_history(self, text_history):
|
|
308
|
+
new_history = []
|
|
309
|
+
for token in reversed(text_history):
|
|
310
|
+
prefix_token = token
|
|
311
|
+
contains_punctuation = False
|
|
312
|
+
for punct in self.STRONG_PUNCTUATION:
|
|
313
|
+
if punct in prefix_token:
|
|
314
|
+
contains_punctuation = True
|
|
315
|
+
break
|
|
316
|
+
if contains_punctuation:
|
|
317
|
+
break
|
|
318
|
+
new_history.append(token)
|
|
319
|
+
# Reverse the list
|
|
320
|
+
return new_history[::-1]
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Copyright 2025 FBK
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License
|
|
14
|
+
|
|
15
|
+
from types import SimpleNamespace
|
|
16
|
+
from typing import List
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torch
|
|
20
|
+
from nemo.collections.asr.models import ASRModel
|
|
21
|
+
|
|
22
|
+
from simulstream.server.speech_processors import SAMPLE_RATE
|
|
23
|
+
from simulstream.server.speech_processors.sliding_window_retranslation import \
|
|
24
|
+
SlidingWindowRetranslator
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class CanarySlidingWindowRetranslator(SlidingWindowRetranslator):
|
|
28
|
+
"""
|
|
29
|
+
Perform Sliding Window Retranslation with Canary.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def load_model(cls, config: SimpleNamespace):
|
|
34
|
+
if not hasattr(cls, "model") or cls.model is None:
|
|
35
|
+
cls.model = ASRModel.from_pretrained(model_name=config.model_name)
|
|
36
|
+
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
37
|
+
assert cls.model.preprocessor._sample_rate == SAMPLE_RATE
|
|
38
|
+
cls.model.to(cls.device)
|
|
39
|
+
|
|
40
|
+
def _generate(self, speech: torch.Tensor) -> List[str]:
|
|
41
|
+
output = self.model.transcribe(
|
|
42
|
+
speech, source_lang=self.src_lang_tag, target_lang=self.tgt_lang_tag)
|
|
43
|
+
return self.model.tokenizer.ids_to_tokens(output[0].y_sequence)
|
|
44
|
+
|
|
45
|
+
def tokens_to_string(self, tokens: List[str]) -> str:
|
|
46
|
+
# avoid that the initial space, if it is there, get removed in the detokenization
|
|
47
|
+
check_for_init_space = self.text_history is not None and len(self.text_history) > 0
|
|
48
|
+
if check_for_init_space:
|
|
49
|
+
tokens = [' '] + tokens
|
|
50
|
+
text = self.model.tokenizer.tokens_to_text(tokens)
|
|
51
|
+
if check_for_init_space:
|
|
52
|
+
text = text[1:]
|
|
53
|
+
return text
|
|
54
|
+
|
|
55
|
+
def _preprocess(self, waveform: np.float32) -> torch.Tensor:
|
|
56
|
+
"""
|
|
57
|
+
Extracts the filter-bank features from the input waveform and appends them to the audio
|
|
58
|
+
history. Returns the concatenated audio history and new frames, taking the last
|
|
59
|
+
`self.window_len` frames, and returns it after storing it in the audio history.
|
|
60
|
+
"""
|
|
61
|
+
if self.audio_history is not None:
|
|
62
|
+
waveform = np.concatenate((self.audio_history, waveform))
|
|
63
|
+
new_speech_len = len(waveform)
|
|
64
|
+
if new_speech_len > self.window_len:
|
|
65
|
+
waveform = waveform[-self.window_len:]
|
|
66
|
+
self.audio_history = waveform
|
|
67
|
+
return torch.tensor(waveform).to(self.device)
|
|
68
|
+
|
|
69
|
+
def set_target_language(self, language: str) -> None:
|
|
70
|
+
self.tgt_lang_tag = language
|
|
71
|
+
|
|
72
|
+
def set_source_language(self, language: str) -> None:
|
|
73
|
+
self.src_lang_tag = language
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# Copyright 2025 FBK
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License
|
|
14
|
+
|
|
15
|
+
from types import SimpleNamespace
|
|
16
|
+
from typing import List
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torch
|
|
20
|
+
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
|
21
|
+
|
|
22
|
+
from simulstream.server.speech_processors import SAMPLE_RATE
|
|
23
|
+
from simulstream.server.speech_processors.sliding_window_retranslation import \
|
|
24
|
+
SlidingWindowRetranslator
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class HFSlidingWindowRetranslator(SlidingWindowRetranslator):
|
|
28
|
+
"""
|
|
29
|
+
Perform Sliding Window Retranslation with a Huggingface speech-to-text model.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def load_model(cls, config: SimpleNamespace):
|
|
34
|
+
if not hasattr(cls, "model") or cls.model is None:
|
|
35
|
+
lang_tags = None
|
|
36
|
+
if hasattr(config, "supported_langs") and config.supported_langs is not None:
|
|
37
|
+
lang_tags = [
|
|
38
|
+
config.lang_tag_template.format(lang) for lang in config.supported_langs]
|
|
39
|
+
cls.processor = AutoProcessor.from_pretrained(
|
|
40
|
+
config.hf_model_name,
|
|
41
|
+
additional_special_tokens=lang_tags)
|
|
42
|
+
cls.model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
|
43
|
+
config.hf_model_name, trust_remote_code=True)
|
|
44
|
+
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
45
|
+
cls.model.to(cls.device)
|
|
46
|
+
|
|
47
|
+
def _generate(self, speech: torch.Tensor) -> List[str]:
|
|
48
|
+
speech_seconds = speech.shape[1] / 100 # 1 frame every 10 ms
|
|
49
|
+
extra_kwargs = {
|
|
50
|
+
"max_new_tokens": int(max(self.max_tokens_per_second * speech_seconds, 10))}
|
|
51
|
+
if self.tgt_lang_tag is not None:
|
|
52
|
+
extra_kwargs["forced_bos_token_id"] = self.tgt_lang_tag
|
|
53
|
+
generated_ids = self.model.generate(speech, **extra_kwargs)[0]
|
|
54
|
+
return self.processor.tokenizer.convert_ids_to_tokens(
|
|
55
|
+
generated_ids, skip_special_tokens=True)
|
|
56
|
+
|
|
57
|
+
def tokens_to_string(self, tokens: List[str]) -> str:
|
|
58
|
+
# avoid that the initial space, if it is there, get removed in the detokenization
|
|
59
|
+
if self.text_history is not None and len(self.text_history) > 0:
|
|
60
|
+
tokens = [''] + tokens
|
|
61
|
+
return self.processor.tokenizer.convert_tokens_to_string(tokens)
|
|
62
|
+
|
|
63
|
+
def _preprocess(self, waveform: np.float32) -> torch.Tensor:
|
|
64
|
+
"""
|
|
65
|
+
Extracts the filter-bank features from the input waveform and appends them to the audio
|
|
66
|
+
history. Returns the concatenated audio history and new frames, taking the last
|
|
67
|
+
`self.window_len` frames, and returns it after storing it in the audio history.
|
|
68
|
+
"""
|
|
69
|
+
if self.audio_history is not None:
|
|
70
|
+
waveform = np.concatenate((self.audio_history, waveform))
|
|
71
|
+
new_speech_len = len(waveform)
|
|
72
|
+
if new_speech_len > self.window_len:
|
|
73
|
+
waveform = waveform[-self.window_len:]
|
|
74
|
+
self.audio_history = waveform
|
|
75
|
+
new_speech = self.processor(
|
|
76
|
+
waveform,
|
|
77
|
+
sampling_rate=SAMPLE_RATE,
|
|
78
|
+
return_tensors="pt")["input_features"]
|
|
79
|
+
return new_speech.to(self.device)
|
|
80
|
+
|
|
81
|
+
def set_target_language(self, language: str) -> None:
|
|
82
|
+
lang_tag_id = self.processor.tokenizer.convert_tokens_to_ids(
|
|
83
|
+
self.config.lang_tag_template.format(language))
|
|
84
|
+
self.tgt_lang_tag = torch.tensor(lang_tag_id, dtype=torch.int, device=self.device)
|
|
85
|
+
|
|
86
|
+
def set_source_language(self, language: str) -> None:
|
|
87
|
+
pass
|