simulstream 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. docs/source/conf.py +47 -0
  2. simulstream/__init__.py +15 -0
  3. simulstream/client/__init__.py +0 -0
  4. simulstream/client/wav_reader_client.py +228 -0
  5. simulstream/config.py +31 -0
  6. simulstream/inference.py +170 -0
  7. simulstream/metrics/__init__.py +0 -0
  8. simulstream/metrics/detokenizers.py +71 -0
  9. simulstream/metrics/logger.py +32 -0
  10. simulstream/metrics/readers.py +348 -0
  11. simulstream/metrics/score_latency.py +130 -0
  12. simulstream/metrics/score_quality.py +169 -0
  13. simulstream/metrics/scorers/__init__.py +0 -0
  14. simulstream/metrics/scorers/latency/__init__.py +115 -0
  15. simulstream/metrics/scorers/latency/mwersegmenter.py +136 -0
  16. simulstream/metrics/scorers/latency/stream_laal.py +119 -0
  17. simulstream/metrics/scorers/quality/__init__.py +132 -0
  18. simulstream/metrics/scorers/quality/comet.py +57 -0
  19. simulstream/metrics/scorers/quality/mwersegmenter.py +93 -0
  20. simulstream/metrics/scorers/quality/sacrebleu.py +59 -0
  21. simulstream/metrics/stats.py +184 -0
  22. simulstream/server/__init__.py +0 -0
  23. simulstream/server/http_server.py +95 -0
  24. simulstream/server/message_processor.py +156 -0
  25. simulstream/server/speech_processors/__init__.py +173 -0
  26. simulstream/server/speech_processors/base.py +135 -0
  27. simulstream/server/speech_processors/base_streamatt.py +320 -0
  28. simulstream/server/speech_processors/canary_sliding_window_retranslation.py +73 -0
  29. simulstream/server/speech_processors/hf_sliding_window_retranslation.py +87 -0
  30. simulstream/server/speech_processors/incremental_output.py +85 -0
  31. simulstream/server/speech_processors/seamless_sliding_window_retranslation.py +84 -0
  32. simulstream/server/speech_processors/seamless_streamatt.py +268 -0
  33. simulstream/server/speech_processors/simuleval_wrapper.py +165 -0
  34. simulstream/server/speech_processors/sliding_window_retranslation.py +135 -0
  35. simulstream/server/speech_processors/vad_wrapper.py +180 -0
  36. simulstream/server/websocket_server.py +236 -0
  37. simulstream-0.1.0.dist-info/METADATA +465 -0
  38. simulstream-0.1.0.dist-info/RECORD +48 -0
  39. simulstream-0.1.0.dist-info/WHEEL +5 -0
  40. simulstream-0.1.0.dist-info/entry_points.txt +8 -0
  41. simulstream-0.1.0.dist-info/licenses/LICENSE +201 -0
  42. simulstream-0.1.0.dist-info/top_level.txt +3 -0
  43. uts/__init__.py +0 -0
  44. uts/metrics/__init__.py +0 -0
  45. uts/metrics/log_reader.py +50 -0
  46. uts/speech_processors/__init__.py +0 -0
  47. uts/speech_processors/test_simuleval_wrapper.py +88 -0
  48. uts/utils.py +5 -0
@@ -0,0 +1,85 @@
1
+ # Copyright 2025 FBK
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License
14
+
15
+ import json
16
+ from dataclasses import dataclass
17
+ from typing import List, Callable
18
+
19
+
20
+ @dataclass
21
+ class IncrementalOutput:
22
+ """
23
+ Represents the incremental output of a speech processor for a single
24
+ processed chunk of audio.
25
+
26
+ Attributes:
27
+ new_tokens (List[str]): List of newly generated tokens in this chunk.
28
+ new_string (str): Concatenated string representation of the new tokens.
29
+ deleted_tokens (List[str]): List of tokens that were deleted/overwritten.
30
+ deleted_string (str): Concatenated string representation of the deleted tokens.
31
+ """
32
+ new_tokens: List[str]
33
+ new_string: str
34
+ deleted_tokens: List[str]
35
+ deleted_string: str
36
+
37
+ def strings_to_json(self) -> str:
38
+ """
39
+ Serialize the incremental output to a JSON string.
40
+
41
+ Returns:
42
+ str: A JSON string containing the newly generated and the deleted text.
43
+ """
44
+ return json.dumps({"new": self.new_string, "deleted": self.deleted_string})
45
+
46
+
47
+ def merge_incremental_outputs(
48
+ outputs: List[IncrementalOutput],
49
+ tokens_to_string: Callable[[List[str]], str]) -> IncrementalOutput:
50
+ """
51
+ Merge the incremental outputs passed as input into a single incremental output.
52
+ The outputs must be sorted in cronological order.
53
+
54
+ Args:
55
+ outputs (List[IncrementalOutput]): List of incremental outputs to be merged.
56
+ tokens_to_string (Callable[[List[str]], str]): A function that takes a list of tokens and
57
+ returns a string that contains the detokenized text.
58
+ """
59
+ if len(outputs) == 1:
60
+ return outputs[0]
61
+ if len(outputs) == 0:
62
+ return IncrementalOutput([], "", [], "")
63
+
64
+ current_output_tokens = outputs[0].new_tokens
65
+ current_output_deleted_tokens = outputs[0].deleted_tokens
66
+ for output in outputs[1:]:
67
+ num_deleted_tokens = len(output.deleted_tokens)
68
+ if num_deleted_tokens > 0:
69
+ if num_deleted_tokens < len(current_output_tokens):
70
+ assert output.deleted_tokens == current_output_tokens[-num_deleted_tokens:]
71
+ current_output_tokens = current_output_tokens[:-num_deleted_tokens]
72
+ else:
73
+ # we are deleting more than it was generated so far, so extra deleted tokens
74
+ # should be included
75
+ extra_deleted_tokens = output.deleted_tokens[:-len(current_output_tokens)]
76
+ current_output_deleted_tokens = \
77
+ extra_deleted_tokens + current_output_deleted_tokens
78
+ current_output_tokens = []
79
+ current_output_tokens += output.new_tokens
80
+
81
+ return IncrementalOutput(
82
+ current_output_tokens,
83
+ tokens_to_string(current_output_tokens),
84
+ current_output_deleted_tokens,
85
+ tokens_to_string(current_output_deleted_tokens))
@@ -0,0 +1,84 @@
1
+ # Copyright 2025 FBK
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License
14
+
15
+ from types import SimpleNamespace
16
+ from typing import List
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import AutoProcessor, SeamlessM4TModel, SeamlessM4Tv2Model
21
+
22
+ from simulstream.server.speech_processors import SAMPLE_RATE
23
+ from simulstream.server.speech_processors.sliding_window_retranslation import \
24
+ SlidingWindowRetranslator
25
+
26
+
27
+ class SeamlessSlidingWindowRetranslator(SlidingWindowRetranslator):
28
+ """
29
+ Perform Sliding Window Retranslation with a SeamlessM4T speech-to-text model.
30
+ """
31
+
32
+ @classmethod
33
+ def load_model(cls, config: SimpleNamespace):
34
+ if not hasattr(cls, "model") or cls.model is None:
35
+ cls.processor = AutoProcessor.from_pretrained(config.hf_model_name)
36
+ seamless_version = getattr(config, "seamless_version", 1)
37
+ if seamless_version == 2:
38
+ cls.model = SeamlessM4Tv2Model.from_pretrained(config.hf_model_name)
39
+ else:
40
+ cls.model = SeamlessM4TModel.from_pretrained(config.hf_model_name)
41
+ cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ cls.model.to(cls.device)
43
+
44
+ def _generate(self, speech: torch.Tensor) -> List[str]:
45
+ speech_seconds = speech.shape[1] / 50 # 1 frame every 20 ms
46
+ extra_kwargs = {
47
+ "max_new_tokens": int(max(self.max_tokens_per_second * speech_seconds, 10)),
48
+ "generate_speech": False
49
+ }
50
+ if self.tgt_lang_tag is not None:
51
+ extra_kwargs["tgt_lang"] = self.tgt_lang_tag
52
+ generated_ids = self.model.generate(input_features=speech, **extra_kwargs)[0]
53
+ return self.processor.tokenizer.convert_ids_to_tokens(
54
+ generated_ids.squeeze(), skip_special_tokens=True)
55
+
56
+ def tokens_to_string(self, tokens: List[str]) -> str:
57
+ # avoid that the initial space, if it is there, get removed in the detokenization
58
+ if self.text_history is not None and len(self.text_history) > 0:
59
+ tokens = [''] + tokens
60
+ return self.processor.tokenizer.convert_tokens_to_string(tokens)
61
+
62
+ def _preprocess(self, waveform: np.float32) -> torch.Tensor:
63
+ """
64
+ Extracts the filter-bank features from the input waveform and appends them to the audio
65
+ history. Returns the concatenated audio history and new frames, taking the last
66
+ `self.window_len` frames, and returns it after storing it in the audio history.
67
+ """
68
+ if self.audio_history is not None:
69
+ waveform = np.concatenate((self.audio_history, waveform))
70
+ new_speech_len = len(waveform)
71
+ if new_speech_len > self.window_len:
72
+ waveform = waveform[-self.window_len:]
73
+ self.audio_history = waveform
74
+ new_speech = self.processor(
75
+ audios=waveform,
76
+ sampling_rate=SAMPLE_RATE,
77
+ return_tensors="pt")
78
+ return new_speech["input_features"].to(self.device)
79
+
80
+ def set_target_language(self, language: str) -> None:
81
+ self.tgt_lang_tag = language
82
+
83
+ def set_source_language(self, language: str) -> None:
84
+ pass
@@ -0,0 +1,268 @@
1
+ # Copyright 2025 FBK
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License
14
+
15
+ from types import SimpleNamespace
16
+ from typing import List, Tuple
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import AutoProcessor, SeamlessM4TModel, SeamlessM4Tv2Model
21
+
22
+ from simulstream.server.speech_processors import SAMPLE_RATE
23
+ from simulstream.server.speech_processors.base_streamatt import BaseStreamAtt
24
+
25
+
26
+ SHIFT_SIZE = 10
27
+ SEAMLESS_AUDIO_SUBSAMPLING_FACTOR = 8
28
+ # `FRAME_LENGTH` is the length of the window used for computing the mel-filterbank features that,
29
+ # with a sample rate of 16kHz, corresponds to 25ms while `HOP_LENGTH` is how much the feature
30
+ # computation is shifted at each step, corresponding to 10ms. Values from `transformers.models.
31
+ # seamless_m4t.feature_extraction_seamless_m4t.SeamlessM4TFeatureExtractor._extract_fbank_features`
32
+ FRAME_LENGTH = 400
33
+ HOP_LENGTH = 160
34
+ # `OVERLAP_WINDOW` is length of the features that overlap at each step, as described in
35
+ # `transformers.audio_utils.spectrogram`, and corresponding to 15ms.
36
+ OVERLAP_WINDOW = FRAME_LENGTH - HOP_LENGTH
37
+
38
+
39
+ class SeamlessStreamAtt(BaseStreamAtt):
40
+ """
41
+ Perform StreamAtt policy with the chosen textual history selection using a SeamlessM4T
42
+ speech-to-text model.
43
+
44
+ Args:
45
+ config (SimpleNamespace): Configuration object.
46
+ The following additional attributes are expected:
47
+ - **max_new_tokens (int)**: The maximum numbers of tokens to generate
48
+ - **num_beams (int)**: The number of beams of the beam search
49
+ - **no_repeat_ngram_size (int)**: The maximum number of ngram repeats that
50
+ cannot be emitted
51
+ - **audio_history_max_duration (int)**: Maximum length of the audio history to store,
52
+ in seconds. Defaults to 1 hour.
53
+
54
+ """
55
+
56
+ def __init__(self, config: SimpleNamespace):
57
+ super().__init__(config)
58
+ self.audio_subsampling_factor = SEAMLESS_AUDIO_SUBSAMPLING_FACTOR
59
+ self.max_new_tokens = getattr(self.config, "max_new_tokens", 128)
60
+ self.num_beams = getattr(self.config, "num_beams", 5)
61
+ self.audio_history_max_duration = getattr(self.config, "audio_history_max_duration", 360)
62
+ self.no_repeat_ngram_size = getattr(self.config, "no_repeat_ngram_size", 5)
63
+ self.waveform_accumulator = None
64
+
65
+ @property
66
+ def audio_max_len(self) -> float:
67
+ """
68
+ Returns the maximum allowed length for the audio history converted in the space of
69
+ audio features. The number of encoded features is obtained by first converting the
70
+ self.audio_history_max_duration, originally in seconds, into milliseconds, and then
71
+ by the dimension of the shift (SHIFT_SIZE). Since the SeamlessM4t encoder subsamples
72
+ the input sequence, the resulting *audio_max_len* is obtained by further dividing
73
+ the original sequence length by *self.audio_subsampling_factor*.
74
+ """
75
+ return self.audio_history_max_duration * 1000 // SHIFT_SIZE // \
76
+ self.audio_subsampling_factor
77
+
78
+ @classmethod
79
+ def load_model(cls, config: SimpleNamespace):
80
+ if not hasattr(cls, "model") or cls.model is None:
81
+ cls.processor = AutoProcessor.from_pretrained(config.hf_model_name)
82
+ seamless_version = getattr(config, "seamless_version", 1)
83
+ if seamless_version == 2:
84
+ cls.model = SeamlessM4Tv2Model.from_pretrained(config.hf_model_name)
85
+ else:
86
+ cls.model = SeamlessM4TModel.from_pretrained(config.hf_model_name)
87
+ cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
+ cls.model.to(cls.device)
89
+
90
+ @staticmethod
91
+ def mean_variance_normalization(features: np.ndarray) -> torch.Tensor:
92
+ """
93
+ Normalization function taken from `transformers.models.seamless_m4t.
94
+ feature_extraction_seamless_m4t.SeamlessM4TFeatureExtractor`.
95
+ """
96
+ # torch defaults to ddof=1, and numpy defaults to ddof=0
97
+ mu = np.expand_dims(features.mean(axis=0), 0)
98
+ sigma = np.sqrt(np.expand_dims(features.var(axis=0, ddof=1), 0) + 1e-7)
99
+ normalized = (features - mu) / sigma
100
+ return torch.tensor(np.array(normalized))
101
+
102
+ def _preprocess(self, waveform: np.ndarray) -> torch.Tensor:
103
+ """
104
+ Extract normalized input features for the SeamlessM4T model from the new
105
+ audio chunk (`waveform`) and the overlapping tail contained in
106
+ `self.waveform_accumulator`, and then concatenating them with previously
107
+ extracted features stored in `self.audio_history`.
108
+
109
+ Steps:
110
+ 1. Convert the new waveform and the overlapping tail into mel-filterbank features.
111
+ 2. Concatenate the new features with any stored `self.audio_history`.
112
+ 3. Apply mean-variance normalization.
113
+ 4. Cache the overlapping tail of the waveform for the next step.
114
+
115
+ Returns:
116
+ torch.Tensor: Normalized feature tensor on `self.device`.
117
+ """
118
+ # Combine with overlapping part from previous step, if any
119
+ if self.waveform_accumulator is not None:
120
+ waveform = np.concatenate((self.waveform_accumulator, waveform))
121
+
122
+ if len(waveform) >= FRAME_LENGTH:
123
+ # Extract new mel-filterbank features (unnormalized)
124
+ new_features = self.processor(
125
+ audios=waveform,
126
+ return_tensors="np",
127
+ do_normalize_per_mel_bins=False,
128
+ sampling_rate=SAMPLE_RATE,
129
+ )["input_features"].squeeze(0) # shape: (T_new, 160)
130
+
131
+ # Concatenate with previous features, if available
132
+ if self.audio_history is not None:
133
+ self.audio_history = np.concatenate((self.audio_history, new_features), axis=0)
134
+ else:
135
+ self.audio_history = new_features
136
+
137
+ # Normalize all features
138
+ normalized_features = self.mean_variance_normalization(self.audio_history)
139
+
140
+ # Store the last part of the waveform for the next preprocessing step
141
+ self.waveform_accumulator = waveform[-OVERLAP_WINDOW:]
142
+
143
+ return normalized_features.unsqueeze(0).to(self.device)
144
+
145
+ def get_prefix(self):
146
+ """
147
+ Creates a prefix for the generation phase with the previous outputs stored in
148
+ the text_history. The prefix is formatted following the SeamlessM4T template
149
+ (tgt language token id + text history ids).
150
+ """
151
+ prefix_ids = torch.tensor(
152
+ [[self.model.generation_config.text_decoder_lang_to_code_id.get(self.tgt_lang)]]
153
+ ).to(self.device)
154
+ if self.text_history:
155
+ text_history_ids = torch.tensor(
156
+ self.processor.tokenizer.convert_tokens_to_ids(self.text_history)).unsqueeze(
157
+ dim=0).to(self.device)
158
+ prefix_ids = torch.cat((prefix_ids, text_history_ids), dim=1)
159
+ return prefix_ids.long()
160
+
161
+ def _generate(
162
+ self, input_features: torch.Tensor, normalize_attn: bool = True
163
+ ) -> Tuple[List[str], torch.Tensor]:
164
+ """
165
+ Generates a new hypothesis returning also the cross attention scores.
166
+ The hypothesis is forced to start with the `prefix_ids` prefix tokens,
167
+ which are then removed from the returned output tokens.
168
+ The attention scores are returned with dimensions (sequence_len, n_audio_features)
169
+ where `sequence_len` is the length of the prefix (excluding the language ID)
170
+ and of the new hypothesis.
171
+ """
172
+ prefix_ids = self.get_prefix()
173
+ gen_out = self.model.generate(
174
+ input_features=input_features,
175
+ decoder_input_ids=prefix_ids,
176
+ max_new_tokens=self.max_new_tokens,
177
+ num_beams=self.num_beams,
178
+ return_dict_in_generate=True,
179
+ output_attentions=True,
180
+ no_repeat_ngram_size=self.no_repeat_ngram_size,
181
+ generate_speech=False)
182
+ out_ids = list(gen_out.sequences[0])
183
+
184
+ # Exclude BOS, prefix, and EOS from the generated sequence
185
+ new_hypo_ids = out_ids[prefix_ids.shape[1] + 1:-1]
186
+
187
+ cross_attn = self.get_cross_attention(
188
+ gen_out, len(new_hypo_ids), normalize_attn=normalize_attn)
189
+
190
+ assert cross_attn.shape[0] == (prefix_ids.shape[1] - 1) + len(new_hypo_ids), \
191
+ f"Cross attention scores along tokens dimension ({cross_attn.shape[0]}) " \
192
+ f"mismatch with the length of the hypothesis " \
193
+ f"({(prefix_ids.shape[1] - 1) + len(new_hypo_ids)})."
194
+
195
+ new_hypo = self.processor.tokenizer.convert_ids_to_tokens(
196
+ new_hypo_ids, skip_special_tokens=True)
197
+ return new_hypo, cross_attn
198
+
199
+ def _extract_new_hypo_attention_scores(self, new_hypo_len: int, gen_out):
200
+ """
201
+ Extract attention scores `cross_attentions` from `gen_out`, which are stored for each
202
+ generation step and Layer `layer`, based on the beam index selected at each step of the
203
+ beam search and stored in `beam_indices` (if num_beams > 1) for the new hypotheses.
204
+ """
205
+ cross_attns = []
206
+ if self.num_beams > 1:
207
+ # Beam search: for each token of the new hypothesis, we select the corresponding cross
208
+ # attention from the cross attentions stored at each step of the beam search using the
209
+ # index contained in the tensor of indices beam_indices (num_beams * sequence length)
210
+ for tok_idx in range(new_hypo_len):
211
+ # Select the cross attention matrix using the beam_indices
212
+ beam_indices = gen_out.beam_indices[:, tok_idx]
213
+ # add some comments on why tok_idx + 1, and the -1 selection
214
+ cross_attn = gen_out.cross_attentions[
215
+ tok_idx + 1][self.cross_attn_layer][:, :, -1, :]
216
+ cross_attn = cross_attn.index_select(dim=0, index=beam_indices)
217
+ cross_attns.append(cross_attn)
218
+ else:
219
+ # Greedy search
220
+ for tok_idx in range(new_hypo_len):
221
+ cross_attn = gen_out.cross_attentions[
222
+ tok_idx + 1][self.cross_attn_layer][:, :, -1, :]
223
+ cross_attns.append(cross_attn)
224
+
225
+ # Cross attention scores with shape [num_heads, sequence_len, n_audio_features]
226
+ cross_attns = torch.stack(cross_attns).squeeze(1)
227
+ return cross_attns
228
+
229
+ def get_cross_attention(
230
+ self, gen_out, new_hypo_len: int, normalize_attn: bool) -> torch.Tensor:
231
+ """
232
+ Given the attention scores for the generated output (including both prefix and new
233
+ hypothesis), this function returns the cross attention scores from Layer *layer* by
234
+ averaging the scores along the attention heads dimension.
235
+ """
236
+ # The prefix is stored in the first element of the cross_attentions, equal for each beam
237
+ # and, therefore, the first is selected. Also, BOS and language ID are excluded from the
238
+ # sequence_len (first two elements).
239
+ cross_attns = (
240
+ gen_out.cross_attentions[0][self.cross_attn_layer][0, :, 2:, :].transpose(0, 1))
241
+
242
+ if new_hypo_len > 0:
243
+ new_cross_attns = self._extract_new_hypo_attention_scores(new_hypo_len, gen_out)
244
+ cross_attns = torch.cat([cross_attns, new_cross_attns], dim=0)
245
+
246
+ # Average on the attention heads dimension
247
+ cross_attns = cross_attns.mean(dim=1)
248
+
249
+ # Normalize attention scores
250
+ if normalize_attn:
251
+ cross_attns = self.normalize_attn(cross_attns)
252
+ return cross_attns
253
+
254
+ def tokens_to_string(self, tokens: List[str]) -> str:
255
+ # avoid that the initial space, if it is there, get removed in the detokenization
256
+ if self.text_history is not None and len(self.text_history) > 0:
257
+ tokens = [''] + tokens
258
+ return self.processor.tokenizer.convert_tokens_to_string(tokens)
259
+
260
+ def set_target_language(self, language: str) -> None:
261
+ self.tgt_lang = language
262
+
263
+ def set_source_language(self, language: str) -> None:
264
+ pass
265
+
266
+ def clear(self) -> None:
267
+ super().clear()
268
+ self.waveform_accumulator = None
@@ -0,0 +1,165 @@
1
+ # Copyright 2025 FBK
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License
14
+
15
+ import logging
16
+ import torch
17
+ import numpy as np
18
+
19
+ from types import SimpleNamespace
20
+ from typing import List
21
+
22
+ from simulstream.metrics.detokenizers import get_detokenizer
23
+ from simulstream.server.speech_processors import IncrementalOutput, SpeechProcessor, class_load
24
+ from simulstream.server.speech_processors import SAMPLE_RATE
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ try:
31
+ from simuleval.agents.agent import SEGMENT_TYPE_DICT
32
+ from simuleval.agents.actions import Action
33
+ from simuleval.data.segments import SpeechSegment
34
+ except Exception as e:
35
+ logger.error(
36
+ "Not able to import SimulEval. Please install SimulEval or add it to your PYTHONPATH.")
37
+ # In case we are running unit tests, avoid failures when importing the Wrapper
38
+ import os
39
+ is_testing = os.getenv("IS_TESTING")
40
+ if not (is_testing is not None and bool(is_testing)):
41
+ raise e
42
+ import builtins
43
+ # Temporarily inject types to satisfy simuleval_wrapper imports
44
+ # These need to remain in builtins for type annotations in the class definition
45
+ builtins.Action = type("Action", (), {})
46
+ builtins.SpeechSegment = type("SpeechSegment", (), {})
47
+
48
+
49
+ class SimulEvalWrapper(SpeechProcessor):
50
+ """
51
+ Wrapper processor that calls the configured `simuleval_agent` implemented on SimulEval>=1.1.0.
52
+ """
53
+ def __init__(self, config: SimpleNamespace):
54
+ super().__init__(config)
55
+ agent_class_name = getattr(config, "simuleval_agent")
56
+ agent_class = class_load(agent_class_name)
57
+ config.source_segment_size = config.speech_chunk_size * 1000
58
+ config.device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ self.simuleval_agent = agent_class(config)
60
+ self.latency_unit = config.latency_unit
61
+ self.segment_type = self._segment_type_class()
62
+ self.emission_started = False
63
+ self.detokenizer = get_detokenizer(config)
64
+
65
+ def _segment_type_class(self):
66
+ return SEGMENT_TYPE_DICT[self.simuleval_agent.target_type]
67
+
68
+ @classmethod
69
+ def load_model(cls, config: SimpleNamespace):
70
+ """
71
+ In SimulEval, the model is loaded in the init of the `simuleval_agent` and,
72
+ therefore, a copy of the model for each client is created.
73
+ """
74
+ pass
75
+
76
+ def set_target_language(self, language: str) -> None:
77
+ if hasattr(self.simuleval_agent, "tgt_lang"):
78
+ self.simuleval_agent.tgt_lang = language
79
+ else:
80
+ logger.warning("Unable to set the target language for SimulEval agent.")
81
+
82
+ def set_source_language(self, language: str) -> None:
83
+ pass
84
+
85
+ def tokens_to_string(self, tokens: List[str]) -> str:
86
+ return self.detokenizer(tokens)
87
+
88
+ def _process_action(self, action: Action) -> str:
89
+ """
90
+ Processes a SimulEval action and updates the target state accordingly.
91
+
92
+ If the action is of type READ, no output is produced and an empty prediction string is
93
+ returned. Otherwise (WRITE action), the method extracts the generated content from the
94
+ action, wraps it in the appropriate target segment type, and updates the SimulEval agent's
95
+ target state with this new segment.
96
+
97
+ Args:
98
+ action (Action): The current SimulEval action to process. It can either
99
+ request reading more input (READ) or writing an output (WRITE).
100
+
101
+ Returns:
102
+ str: The predicted output text if the action is WRITE, or an empty string
103
+ if the action is READ.
104
+ """
105
+ if action.is_read():
106
+ return ""
107
+
108
+ prediction = action.content
109
+ segment = self.segment_type(index=0, content=prediction, finished=action.finished)
110
+ self.simuleval_agent.states.update_target(segment)
111
+ return prediction
112
+
113
+ def _build_incremental_outputs(self, generated_text: str) -> IncrementalOutput:
114
+ """
115
+ Transform the prediction string from `Action.content` of SimulEval into the required
116
+ IncrementalOutput format. The token conversion follows the original SimulEval Instance
117
+ handling (https://github.com/facebookresearch/SimulEval/blob/
118
+ 536de8253b82d805c9845440169a5010ff507357/simuleval/evaluator/instance.py#L209) with the
119
+ sole exception of not removing generated spaces in the character-level languages.
120
+ Since SimulEval only supports incremental outputs, no tokens are deleted.
121
+ """
122
+ if self.latency_unit in ["word", "spm"]:
123
+ generated_tokens = generated_text.strip().split()
124
+ elif self.latency_unit == "char":
125
+ generated_tokens = list(generated_text.strip())
126
+ else:
127
+ raise NotImplementedError
128
+
129
+ if self.emission_started and self.latency_unit == "word":
130
+ generated_text = " " + generated_text
131
+
132
+ self.emission_started = True
133
+
134
+ return IncrementalOutput(
135
+ new_tokens=generated_tokens,
136
+ new_string=generated_text,
137
+ deleted_tokens=[],
138
+ deleted_string="",
139
+ )
140
+
141
+ def process_chunk(self, waveform: np.float32) -> IncrementalOutput:
142
+ source_segment = SpeechSegment(
143
+ index=0,
144
+ content=waveform.tolist(),
145
+ sample_rate=SAMPLE_RATE,
146
+ finished=False,
147
+ tgt_lang=self.simuleval_agent.tgt_lang,
148
+ )
149
+ self.simuleval_agent.states.update_source(source_segment)
150
+ action = self.simuleval_agent.policy(self.simuleval_agent.states)
151
+ prediction = self._process_action(action)
152
+ new_output = self._build_incremental_outputs(prediction)
153
+ return new_output
154
+
155
+ def end_of_stream(self) -> IncrementalOutput:
156
+ self.simuleval_agent.states.source_finished = True
157
+ action = self.simuleval_agent.policy(self.simuleval_agent.states)
158
+ prediction = self._process_action(action)
159
+ new_output = self._build_incremental_outputs(prediction)
160
+ return new_output
161
+
162
+ def clear(self) -> None:
163
+ """ In SimulEval, the agent is reset inside `simuleval_agent` """
164
+ self.simuleval_agent.reset()
165
+ self.emission_started = False