simulstream 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. docs/source/conf.py +47 -0
  2. simulstream/__init__.py +15 -0
  3. simulstream/client/__init__.py +0 -0
  4. simulstream/client/wav_reader_client.py +228 -0
  5. simulstream/config.py +31 -0
  6. simulstream/inference.py +170 -0
  7. simulstream/metrics/__init__.py +0 -0
  8. simulstream/metrics/detokenizers.py +71 -0
  9. simulstream/metrics/logger.py +32 -0
  10. simulstream/metrics/readers.py +348 -0
  11. simulstream/metrics/score_latency.py +130 -0
  12. simulstream/metrics/score_quality.py +169 -0
  13. simulstream/metrics/scorers/__init__.py +0 -0
  14. simulstream/metrics/scorers/latency/__init__.py +115 -0
  15. simulstream/metrics/scorers/latency/mwersegmenter.py +136 -0
  16. simulstream/metrics/scorers/latency/stream_laal.py +119 -0
  17. simulstream/metrics/scorers/quality/__init__.py +132 -0
  18. simulstream/metrics/scorers/quality/comet.py +57 -0
  19. simulstream/metrics/scorers/quality/mwersegmenter.py +93 -0
  20. simulstream/metrics/scorers/quality/sacrebleu.py +59 -0
  21. simulstream/metrics/stats.py +184 -0
  22. simulstream/server/__init__.py +0 -0
  23. simulstream/server/http_server.py +95 -0
  24. simulstream/server/message_processor.py +156 -0
  25. simulstream/server/speech_processors/__init__.py +173 -0
  26. simulstream/server/speech_processors/base.py +135 -0
  27. simulstream/server/speech_processors/base_streamatt.py +320 -0
  28. simulstream/server/speech_processors/canary_sliding_window_retranslation.py +73 -0
  29. simulstream/server/speech_processors/hf_sliding_window_retranslation.py +87 -0
  30. simulstream/server/speech_processors/incremental_output.py +85 -0
  31. simulstream/server/speech_processors/seamless_sliding_window_retranslation.py +84 -0
  32. simulstream/server/speech_processors/seamless_streamatt.py +268 -0
  33. simulstream/server/speech_processors/simuleval_wrapper.py +165 -0
  34. simulstream/server/speech_processors/sliding_window_retranslation.py +135 -0
  35. simulstream/server/speech_processors/vad_wrapper.py +180 -0
  36. simulstream/server/websocket_server.py +236 -0
  37. simulstream-0.1.0.dist-info/METADATA +465 -0
  38. simulstream-0.1.0.dist-info/RECORD +48 -0
  39. simulstream-0.1.0.dist-info/WHEEL +5 -0
  40. simulstream-0.1.0.dist-info/entry_points.txt +8 -0
  41. simulstream-0.1.0.dist-info/licenses/LICENSE +201 -0
  42. simulstream-0.1.0.dist-info/top_level.txt +3 -0
  43. uts/__init__.py +0 -0
  44. uts/metrics/__init__.py +0 -0
  45. uts/metrics/log_reader.py +50 -0
  46. uts/speech_processors/__init__.py +0 -0
  47. uts/speech_processors/test_simuleval_wrapper.py +88 -0
  48. uts/utils.py +5 -0
@@ -0,0 +1,135 @@
1
+ # Copyright 2025 FBK
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License
14
+
15
+ from abc import abstractmethod
16
+ from types import SimpleNamespace
17
+ from typing import List, Union, Dict
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from simulstream.server.speech_processors import SpeechProcessor
23
+ from simulstream.server.speech_processors.incremental_output import IncrementalOutput
24
+
25
+
26
+ class BaseSpeechProcessor(SpeechProcessor):
27
+ """
28
+ A partial implementation of :class:`SpeechProcessor` that provides
29
+ common logic for handling incremental speech-to-text processing.
30
+
31
+ This class defines the high-level workflow of processing an incoming
32
+ audio chunk (preprocessing, generation, updating history, building
33
+ outputs), while leaving the model-specific details to subclasses.
34
+
35
+ Subclasses must implement the abstract helper methods to define how
36
+ audio is preprocessed, tokens are generated, and histories are updated.
37
+ """
38
+
39
+ def __init__(self, config: SimpleNamespace):
40
+ super().__init__(config)
41
+ self.tgt_lang_tag = None
42
+ self.src_lang_tag = None
43
+ self.audio_history = None
44
+ self.text_history = None
45
+
46
+ @abstractmethod
47
+ def _preprocess(self, waveform: np.float32) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
48
+ """
49
+ Convert a raw audio waveform into model-ready features.
50
+
51
+ Args:
52
+ waveform (np.ndarray): A 1D NumPy array of type ``float32``
53
+ containing normalized audio samples in the range ``[-1.0, 1.0]``. This is the new
54
+ incoming chunk of audio.
55
+
56
+ Returns:
57
+ Union[Dict[str, torch.Tensor], torch.Tensor]:
58
+ A tensor or a dictionary of tensors ready for model inference.
59
+ """
60
+ ...
61
+
62
+ @abstractmethod
63
+ def _update_speech_history(
64
+ self,
65
+ new_speech: torch.Tensor,
66
+ generated_tokens: List[str],
67
+ new_output: IncrementalOutput) -> None:
68
+ """
69
+ Update the internal audio history.
70
+
71
+ Args:
72
+ new_speech (torch.Tensor): The newly preprocessed speech features.
73
+ generated_tokens (List[str]): Tokens generated from the new input.
74
+ new_output (IncrementalOutput): Incremental output object containing
75
+ new and deleted tokens/strings.
76
+ """
77
+ ...
78
+
79
+ @abstractmethod
80
+ def _update_text_history(
81
+ self,
82
+ new_speech: torch.Tensor,
83
+ generated_tokens: List[str],
84
+ new_output: IncrementalOutput) -> None:
85
+ """
86
+ Update the internal text history with new generated tokens.
87
+
88
+ Args:
89
+ new_speech (torch.Tensor): The newly preprocessed speech features.
90
+ generated_tokens (List[str]): Tokens generated from the new input.
91
+ new_output (IncrementalOutput): Incremental output object containing
92
+ new and deleted tokens/strings.
93
+ """
94
+ ...
95
+
96
+ @abstractmethod
97
+ def _generate(self, speech: Union[Dict[str, torch.Tensor], torch.Tensor]) -> List[str]:
98
+ """
99
+ Generate tokens from the given speech features.
100
+
101
+ Args:
102
+ speech (Union[Dict[str, torch.Tensor], torch.Tensor]):
103
+ Model-ready speech features as produced by :meth:`_preprocess`.
104
+
105
+ Returns:
106
+ List[str]: A list of generated tokens.
107
+ """
108
+ ...
109
+
110
+ @abstractmethod
111
+ def _build_incremental_outputs(self, generated_tokens: List[str]) -> IncrementalOutput:
112
+ """
113
+ Build an :class:`IncrementalOutput` object from generated tokens.
114
+
115
+ Args:
116
+ generated_tokens (List[str]): Tokens generated by :meth:`_generate`.
117
+
118
+ Returns:
119
+ IncrementalOutput: The structured incremental output.
120
+ """
121
+ ...
122
+
123
+ def process_chunk(self, waveform: np.float32) -> IncrementalOutput:
124
+ speech = self._preprocess(waveform)
125
+ generated_tokens = self._generate(speech)
126
+ new_output = self._build_incremental_outputs(generated_tokens)
127
+ self._update_speech_history(speech, generated_tokens, new_output)
128
+ self._update_text_history(speech, generated_tokens, new_output)
129
+ return new_output
130
+
131
+ def clear(self) -> None:
132
+ self.text_history = None
133
+ self.audio_history = None
134
+ self.src_lang_tag = None
135
+ self.tgt_lang_tag = None
@@ -0,0 +1,320 @@
1
+ # Copyright 2025 FBK
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License
14
+
15
+ import torch
16
+ import logging
17
+ import numpy as np
18
+
19
+ from types import SimpleNamespace
20
+ from abc import abstractmethod
21
+ from typing import List, Tuple
22
+
23
+ from simulstream.server.speech_processors import class_load
24
+ from simulstream.server.speech_processors.base import BaseSpeechProcessor
25
+ from simulstream.server.speech_processors.incremental_output import IncrementalOutput
26
+
27
+
28
+ BOW_PREFIX = "\u2581"
29
+
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class BaseStreamAtt(BaseSpeechProcessor):
35
+ """
36
+ A partial implementation of :class:`BaseSpeechProcessor` that provides common logic for the
37
+ StreamAtt policy, introduced in:
38
+
39
+ S. Papi, et al. 2024. *"StreamAtt: Direct Streaming Speech-to-Text Translation with
40
+ Attention-based Audio History Selection"* (https://aclanthology.org/2024.acl-long.202/)
41
+
42
+ The approach relies on selecting the audio history based on the cross-attention mechanism.
43
+ Specifically, the history for the next decoding step is defined as follows:
44
+ - First, the new textual history is selected by the **text_history_method**, which is in
45
+ charge of selecting the tokens to retain;
46
+ - Second, the new audio history is selected according to cross-attention scores between the
47
+ audio features and the retained textual history by discarding past features that are not
48
+ attended by any tokens of the textual history.
49
+
50
+ The derived class should implement the following methods:
51
+ - **audio_max_len**: Returns the maximum audio feature length.
52
+ - **load_model**: Loads the model to device.
53
+ - **_preprocess**: Preprocess the audio features before feeding them into the model.
54
+ - **_generate**: Generate that also returns cross attention scores.
55
+
56
+ Args:
57
+ config (SimpleNamespace): Configuration object. The following attributes are expected:
58
+ - **text_history (str)**: config (SimpleNamespace) with the following attribute:
59
+ - **type (str)**: Name of the class to use to determine the text history to keep as
60
+ context for next predictions.
61
+ - **audio_subsampling_factor (int)**: Subsampling factor of the model, if any.
62
+ Defaults to 1.
63
+ - **text_history_max_len (int)**: The maximum length of the textual history after which
64
+ the current content is cut. Defaults to 128.
65
+ - **cross_attention_layer (int)**: Layer from which to extract the cross-attention from.
66
+ - **cutoff_frame_num (int)**: Number of last frames that cannot be attended by tokens
67
+ in the AlignAtt policy.
68
+ - **word_level_postprocess (bool)**: Whether to postprocess the generated tokens to keep
69
+ only complete words in the emitted hypothesis. To be disabled when operating with
70
+ character-level languages. Defaults to True.
71
+ """
72
+
73
+ def __init__(self, config: SimpleNamespace):
74
+ super().__init__(config)
75
+ self.config = config
76
+ text_history_config = self.config.text_history
77
+ text_history_cls = class_load(text_history_config.type)
78
+ self.text_history_method = text_history_cls(text_history_config)
79
+ self.audio_subsampling_factor = getattr(self.config, "audio_subsampling_factor", 1)
80
+ self.text_history_max_len = getattr(self.config, "text_history_max_len", 128)
81
+ self.cross_attn_layer = getattr(self.config, "cross_attention_layer", 3)
82
+ self.cutoff_frame_num = getattr(self.config, "cutoff_frame_num", 2)
83
+ self.word_level_postprocess = getattr(self.config, "word_level_postprocess", True)
84
+ self.unselected_tokens = []
85
+
86
+ @property
87
+ @abstractmethod
88
+ def audio_max_len(self) -> float:
89
+ """
90
+ Return the maximum allowed length of the audio features, beyond which the audio is cut off.
91
+ """
92
+ ...
93
+
94
+ @abstractmethod
95
+ def _generate(self, speech: torch.Tensor) -> Tuple[List[str], torch.Tensor]:
96
+ """
97
+ Generate tokens from the given speech features together with the cross-attention scores.
98
+
99
+ Args:
100
+ speech (torch.Tensor): Model-ready speech features as produced by :meth:`_preprocess`.
101
+
102
+ Returns:
103
+ Tuple[List[str], torch.Tensor]:
104
+ List[str]: A list of generated tokens.
105
+ torch.Tensor: Cross-attention scores with dimension (generated_tokens,
106
+ input_length).
107
+ """
108
+ ...
109
+
110
+ @staticmethod
111
+ def normalize_attn(attn):
112
+ """
113
+ Normalize the cross-attention scores along the frame dimension to avoid attention sinks.
114
+ """
115
+ std = attn.std(axis=0)
116
+ std[std == 0.] = 1.0
117
+ mean = attn.mean(axis=0)
118
+ return (attn - mean) / std
119
+
120
+ def _update_text_history(self, new_output: List[str]) -> int:
121
+ if self.text_history:
122
+ self.text_history += new_output
123
+ else:
124
+ self.text_history = new_output
125
+ new_history = self.text_history_method.select_text_history(self.text_history)
126
+ discarded_text = len(self.text_history) - len(new_history)
127
+ self.text_history = new_history
128
+
129
+ # Ensure not exceeding max text history length
130
+ if self.text_history and len(self.text_history) > self.text_history_max_len:
131
+ if logger.isEnabledFor(logging.WARNING):
132
+ logger.warning(
133
+ f"The textual history has hit the maximum predefined length of "
134
+ f"{self.text_history_max_len}")
135
+ self.text_history = self.text_history[-self.text_history_max_len:]
136
+ return discarded_text
137
+
138
+ def _cut_audio_exceeding_maxlen(self):
139
+ # Ensure not exceeding max audio history length
140
+ if len(self.audio_history) > self.audio_max_len:
141
+ if logger.isEnabledFor(logging.WARNING):
142
+ logger.warning(
143
+ f"The audio history has hit the maximum predefined length of "
144
+ f"{self.audio_max_len}")
145
+ self.audio_history = self.audio_history[-self.audio_max_len:]
146
+
147
+ def _update_speech_history(self, discarded_text: int, cross_attn: torch.Tensor) -> None:
148
+ # If no history is discarded, no need for attention-based audio trimming
149
+ if discarded_text == 0:
150
+ # Check audio history not exceeding maximum allowed length
151
+ self._cut_audio_exceeding_maxlen()
152
+ return
153
+
154
+ # Trim the cross-attention by excluding the discarded new generated tokens and the
155
+ # discarded textual history. Output shape: (text_history_len, n_audio_features)
156
+ cross_attn = cross_attn[discarded_text:discarded_text + len(self.text_history), :]
157
+
158
+ # Compute the frame to which each token of the textual history mostly attends to
159
+ most_attended_idxs = torch.argmax(cross_attn.float(), dim=1)
160
+
161
+ # Find the first feature that is attended
162
+ if most_attended_idxs.shape[0] > 1:
163
+ # Multiple tokens: sort and get the earliest attended frame
164
+ sorted_idxs = torch.sort(most_attended_idxs)[0]
165
+ earliest_attended_idx = sorted_idxs[0]
166
+ else:
167
+ # Only one token: use the unique most attended frame
168
+ earliest_attended_idx = most_attended_idxs[0]
169
+
170
+ # Multiply by the subsampling factor to recover the original number of frames
171
+ frames_to_cut = earliest_attended_idx * self.audio_subsampling_factor
172
+
173
+ # Cut the unattended audio features
174
+ self.audio_history = self.audio_history[frames_to_cut:]
175
+
176
+ # Check audio history not exceeding maximum allowed length
177
+ self._cut_audio_exceeding_maxlen()
178
+
179
+ @staticmethod
180
+ def _strip_incomplete_words(tokens: List[str]) -> List[str]:
181
+ """
182
+ Remove last incomplete word(s) from the new hypothesis.
183
+
184
+ Args:
185
+ tokens (List[str]): selected tokens, possibly containing partial words to be removed.
186
+
187
+ Returns:
188
+ List[str]: A list of generated tokens from which partial words are removed.
189
+ """
190
+ tokens_to_write = []
191
+ # iterate from the end and count how many trailing tokens to drop
192
+ num_tokens_incomplete = 0
193
+ for tok in reversed(tokens):
194
+ num_tokens_incomplete += 1
195
+ if tok.startswith(BOW_PREFIX):
196
+ # slice off the trailing incomplete tokens
197
+ tokens_to_write = tokens[:-num_tokens_incomplete]
198
+ break
199
+ return tokens_to_write
200
+
201
+ def alignatt_policy(self, generated_tokens, cross_attn) -> List[str]:
202
+ """
203
+ Apply the AlignAtt policy by cutting off tokens whose attention falls
204
+ beyond the allowed frame range.
205
+ The AlignAtt policy was introduced in:
206
+ S. Papi, et al. 2023. *"AlignAtt: Using Attention-based Audio-Translation
207
+ Alignments as a Guide for Simultaneous Speech Translation"*
208
+ (https://www.isca-archive.org/interspeech_2023/papi23_interspeech.html)
209
+ """
210
+ # Select attention scores corresponding to the new generated tokens
211
+ cross_attn = cross_attn[-len(generated_tokens):, :]
212
+ selected_tokens = generated_tokens
213
+
214
+ # Find the frame to which each token mostly attends to
215
+ most_attended_frames = torch.argmax(cross_attn, dim=1)
216
+ cutoff = cross_attn.size(1) - self.cutoff_frame_num
217
+
218
+ # Find the first token that attends beyond the cutoff frame
219
+ invalid_tok_ids = torch.where(most_attended_frames >= cutoff)[0]
220
+
221
+ # Truncate tokens up to the first invalid alignment (if any)
222
+ if len(invalid_tok_ids) > 0:
223
+ selected_tokens = selected_tokens[:invalid_tok_ids[0]]
224
+
225
+ if self.word_level_postprocess:
226
+ selected_tokens = self._strip_incomplete_words(selected_tokens)
227
+
228
+ # Store unselected tokens, to be used in the case of end of stream
229
+ self.unselected_tokens = generated_tokens[len(selected_tokens):]
230
+
231
+ return selected_tokens
232
+
233
+ def _build_incremental_outputs(self, generated_tokens: List[str]) -> IncrementalOutput:
234
+ return IncrementalOutput(
235
+ new_tokens=generated_tokens,
236
+ new_string=self.tokens_to_string(generated_tokens),
237
+ deleted_tokens=[],
238
+ deleted_string="",
239
+ )
240
+
241
+ def process_chunk(self, waveform: np.float32) -> IncrementalOutput:
242
+ speech = self._preprocess(waveform)
243
+ # Generate new hypothesis with its corresponding cross-attention scores (no prefix)
244
+ generated_tokens, cross_attn = self._generate(speech)
245
+ # Select the part of the new hypothesis to be emitted, and trim cross-attention accordingly
246
+ selected_output = self.alignatt_policy(generated_tokens, cross_attn)
247
+ incremental_output = self._build_incremental_outputs(selected_output)
248
+ # Discard textual history, if needed
249
+ discarded_text = self._update_text_history(selected_output)
250
+ # Trim audio corresponding to the discarded textual history
251
+ self._update_speech_history(discarded_text, cross_attn)
252
+ return incremental_output
253
+
254
+ def end_of_stream(self) -> IncrementalOutput:
255
+ last_output = self._build_incremental_outputs(self.unselected_tokens)
256
+ self.unselected_tokens = []
257
+ return last_output
258
+
259
+ def clear(self) -> None:
260
+ super().clear()
261
+ self.text_history = None
262
+ self.audio_history = None
263
+ self.unselected_tokens = []
264
+
265
+
266
+ class FixedWordsTextHistory:
267
+ """
268
+ Fixed Words textual history selection method that retains a pre-defined
269
+ number of words in the history (*history_words*).
270
+
271
+ The current implementation supports only SentencePiece.
272
+ """
273
+ def __init__(self, config: SimpleNamespace):
274
+ self.history_words = getattr(config, "history_words", 20)
275
+ self.config = config
276
+
277
+ def select_text_history(self, text_history: List[str]):
278
+ words_to_keep = self.history_words
279
+ new_history = []
280
+ for token in reversed(text_history):
281
+ new_history.append(token)
282
+ # Check if 'BOW_PREFIX' (space in SentencePiece) is contained in the token,
283
+ # meaning that we reached the beginning of the word that should be counted
284
+ if BOW_PREFIX in token:
285
+ words_to_keep -= 1
286
+ # When all the words to keep are consumed, the accumulation is stopped
287
+ # and the prefix is returned
288
+ if words_to_keep == 0:
289
+ break
290
+ # Reverse the list
291
+ return new_history[::-1]
292
+
293
+
294
+ class PunctuationTextHistory:
295
+ """
296
+ Punctuation textual history selection method that retains the sentence
297
+ before the last strong punctuation character.
298
+
299
+ The current implementation supports only SentencePiece.
300
+ """
301
+
302
+ STRONG_PUNCTUATION = [".", "!", "?", ":", ";"]
303
+
304
+ def __init__(self, config: SimpleNamespace):
305
+ self.config = config
306
+
307
+ def select_text_history(self, text_history):
308
+ new_history = []
309
+ for token in reversed(text_history):
310
+ prefix_token = token
311
+ contains_punctuation = False
312
+ for punct in self.STRONG_PUNCTUATION:
313
+ if punct in prefix_token:
314
+ contains_punctuation = True
315
+ break
316
+ if contains_punctuation:
317
+ break
318
+ new_history.append(token)
319
+ # Reverse the list
320
+ return new_history[::-1]
@@ -0,0 +1,73 @@
1
+ # Copyright 2025 FBK
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License
14
+
15
+ from types import SimpleNamespace
16
+ from typing import List
17
+
18
+ import numpy as np
19
+ import torch
20
+ from nemo.collections.asr.models import ASRModel
21
+
22
+ from simulstream.server.speech_processors import SAMPLE_RATE
23
+ from simulstream.server.speech_processors.sliding_window_retranslation import \
24
+ SlidingWindowRetranslator
25
+
26
+
27
+ class CanarySlidingWindowRetranslator(SlidingWindowRetranslator):
28
+ """
29
+ Perform Sliding Window Retranslation with Canary.
30
+ """
31
+
32
+ @classmethod
33
+ def load_model(cls, config: SimpleNamespace):
34
+ if not hasattr(cls, "model") or cls.model is None:
35
+ cls.model = ASRModel.from_pretrained(model_name=config.model_name)
36
+ cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ assert cls.model.preprocessor._sample_rate == SAMPLE_RATE
38
+ cls.model.to(cls.device)
39
+
40
+ def _generate(self, speech: torch.Tensor) -> List[str]:
41
+ output = self.model.transcribe(
42
+ speech, source_lang=self.src_lang_tag, target_lang=self.tgt_lang_tag)
43
+ return self.model.tokenizer.ids_to_tokens(output[0].y_sequence)
44
+
45
+ def tokens_to_string(self, tokens: List[str]) -> str:
46
+ # avoid that the initial space, if it is there, get removed in the detokenization
47
+ check_for_init_space = self.text_history is not None and len(self.text_history) > 0
48
+ if check_for_init_space:
49
+ tokens = [' '] + tokens
50
+ text = self.model.tokenizer.tokens_to_text(tokens)
51
+ if check_for_init_space:
52
+ text = text[1:]
53
+ return text
54
+
55
+ def _preprocess(self, waveform: np.float32) -> torch.Tensor:
56
+ """
57
+ Extracts the filter-bank features from the input waveform and appends them to the audio
58
+ history. Returns the concatenated audio history and new frames, taking the last
59
+ `self.window_len` frames, and returns it after storing it in the audio history.
60
+ """
61
+ if self.audio_history is not None:
62
+ waveform = np.concatenate((self.audio_history, waveform))
63
+ new_speech_len = len(waveform)
64
+ if new_speech_len > self.window_len:
65
+ waveform = waveform[-self.window_len:]
66
+ self.audio_history = waveform
67
+ return torch.tensor(waveform).to(self.device)
68
+
69
+ def set_target_language(self, language: str) -> None:
70
+ self.tgt_lang_tag = language
71
+
72
+ def set_source_language(self, language: str) -> None:
73
+ self.src_lang_tag = language
@@ -0,0 +1,87 @@
1
+ # Copyright 2025 FBK
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License
14
+
15
+ from types import SimpleNamespace
16
+ from typing import List
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
21
+
22
+ from simulstream.server.speech_processors import SAMPLE_RATE
23
+ from simulstream.server.speech_processors.sliding_window_retranslation import \
24
+ SlidingWindowRetranslator
25
+
26
+
27
+ class HFSlidingWindowRetranslator(SlidingWindowRetranslator):
28
+ """
29
+ Perform Sliding Window Retranslation with a Huggingface speech-to-text model.
30
+ """
31
+
32
+ @classmethod
33
+ def load_model(cls, config: SimpleNamespace):
34
+ if not hasattr(cls, "model") or cls.model is None:
35
+ lang_tags = None
36
+ if hasattr(config, "supported_langs") and config.supported_langs is not None:
37
+ lang_tags = [
38
+ config.lang_tag_template.format(lang) for lang in config.supported_langs]
39
+ cls.processor = AutoProcessor.from_pretrained(
40
+ config.hf_model_name,
41
+ additional_special_tokens=lang_tags)
42
+ cls.model = AutoModelForSpeechSeq2Seq.from_pretrained(
43
+ config.hf_model_name, trust_remote_code=True)
44
+ cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ cls.model.to(cls.device)
46
+
47
+ def _generate(self, speech: torch.Tensor) -> List[str]:
48
+ speech_seconds = speech.shape[1] / 100 # 1 frame every 10 ms
49
+ extra_kwargs = {
50
+ "max_new_tokens": int(max(self.max_tokens_per_second * speech_seconds, 10))}
51
+ if self.tgt_lang_tag is not None:
52
+ extra_kwargs["forced_bos_token_id"] = self.tgt_lang_tag
53
+ generated_ids = self.model.generate(speech, **extra_kwargs)[0]
54
+ return self.processor.tokenizer.convert_ids_to_tokens(
55
+ generated_ids, skip_special_tokens=True)
56
+
57
+ def tokens_to_string(self, tokens: List[str]) -> str:
58
+ # avoid that the initial space, if it is there, get removed in the detokenization
59
+ if self.text_history is not None and len(self.text_history) > 0:
60
+ tokens = [''] + tokens
61
+ return self.processor.tokenizer.convert_tokens_to_string(tokens)
62
+
63
+ def _preprocess(self, waveform: np.float32) -> torch.Tensor:
64
+ """
65
+ Extracts the filter-bank features from the input waveform and appends them to the audio
66
+ history. Returns the concatenated audio history and new frames, taking the last
67
+ `self.window_len` frames, and returns it after storing it in the audio history.
68
+ """
69
+ if self.audio_history is not None:
70
+ waveform = np.concatenate((self.audio_history, waveform))
71
+ new_speech_len = len(waveform)
72
+ if new_speech_len > self.window_len:
73
+ waveform = waveform[-self.window_len:]
74
+ self.audio_history = waveform
75
+ new_speech = self.processor(
76
+ waveform,
77
+ sampling_rate=SAMPLE_RATE,
78
+ return_tensors="pt")["input_features"]
79
+ return new_speech.to(self.device)
80
+
81
+ def set_target_language(self, language: str) -> None:
82
+ lang_tag_id = self.processor.tokenizer.convert_tokens_to_ids(
83
+ self.config.lang_tag_template.format(language))
84
+ self.tgt_lang_tag = torch.tensor(lang_tag_id, dtype=torch.int, device=self.device)
85
+
86
+ def set_source_language(self, language: str) -> None:
87
+ pass