simulstream 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +47 -0
- simulstream/__init__.py +15 -0
- simulstream/client/__init__.py +0 -0
- simulstream/client/wav_reader_client.py +228 -0
- simulstream/config.py +31 -0
- simulstream/inference.py +170 -0
- simulstream/metrics/__init__.py +0 -0
- simulstream/metrics/detokenizers.py +71 -0
- simulstream/metrics/logger.py +32 -0
- simulstream/metrics/readers.py +348 -0
- simulstream/metrics/score_latency.py +130 -0
- simulstream/metrics/score_quality.py +169 -0
- simulstream/metrics/scorers/__init__.py +0 -0
- simulstream/metrics/scorers/latency/__init__.py +115 -0
- simulstream/metrics/scorers/latency/mwersegmenter.py +136 -0
- simulstream/metrics/scorers/latency/stream_laal.py +119 -0
- simulstream/metrics/scorers/quality/__init__.py +132 -0
- simulstream/metrics/scorers/quality/comet.py +57 -0
- simulstream/metrics/scorers/quality/mwersegmenter.py +93 -0
- simulstream/metrics/scorers/quality/sacrebleu.py +59 -0
- simulstream/metrics/stats.py +184 -0
- simulstream/server/__init__.py +0 -0
- simulstream/server/http_server.py +95 -0
- simulstream/server/message_processor.py +156 -0
- simulstream/server/speech_processors/__init__.py +173 -0
- simulstream/server/speech_processors/base.py +135 -0
- simulstream/server/speech_processors/base_streamatt.py +320 -0
- simulstream/server/speech_processors/canary_sliding_window_retranslation.py +73 -0
- simulstream/server/speech_processors/hf_sliding_window_retranslation.py +87 -0
- simulstream/server/speech_processors/incremental_output.py +85 -0
- simulstream/server/speech_processors/seamless_sliding_window_retranslation.py +84 -0
- simulstream/server/speech_processors/seamless_streamatt.py +268 -0
- simulstream/server/speech_processors/simuleval_wrapper.py +165 -0
- simulstream/server/speech_processors/sliding_window_retranslation.py +135 -0
- simulstream/server/speech_processors/vad_wrapper.py +180 -0
- simulstream/server/websocket_server.py +236 -0
- simulstream-0.1.0.dist-info/METADATA +465 -0
- simulstream-0.1.0.dist-info/RECORD +48 -0
- simulstream-0.1.0.dist-info/WHEEL +5 -0
- simulstream-0.1.0.dist-info/entry_points.txt +8 -0
- simulstream-0.1.0.dist-info/licenses/LICENSE +201 -0
- simulstream-0.1.0.dist-info/top_level.txt +3 -0
- uts/__init__.py +0 -0
- uts/metrics/__init__.py +0 -0
- uts/metrics/log_reader.py +50 -0
- uts/speech_processors/__init__.py +0 -0
- uts/speech_processors/test_simuleval_wrapper.py +88 -0
- uts/utils.py +5 -0
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# Copyright 2025 FBK
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import List, Callable
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class IncrementalOutput:
|
|
22
|
+
"""
|
|
23
|
+
Represents the incremental output of a speech processor for a single
|
|
24
|
+
processed chunk of audio.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
new_tokens (List[str]): List of newly generated tokens in this chunk.
|
|
28
|
+
new_string (str): Concatenated string representation of the new tokens.
|
|
29
|
+
deleted_tokens (List[str]): List of tokens that were deleted/overwritten.
|
|
30
|
+
deleted_string (str): Concatenated string representation of the deleted tokens.
|
|
31
|
+
"""
|
|
32
|
+
new_tokens: List[str]
|
|
33
|
+
new_string: str
|
|
34
|
+
deleted_tokens: List[str]
|
|
35
|
+
deleted_string: str
|
|
36
|
+
|
|
37
|
+
def strings_to_json(self) -> str:
|
|
38
|
+
"""
|
|
39
|
+
Serialize the incremental output to a JSON string.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
str: A JSON string containing the newly generated and the deleted text.
|
|
43
|
+
"""
|
|
44
|
+
return json.dumps({"new": self.new_string, "deleted": self.deleted_string})
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def merge_incremental_outputs(
|
|
48
|
+
outputs: List[IncrementalOutput],
|
|
49
|
+
tokens_to_string: Callable[[List[str]], str]) -> IncrementalOutput:
|
|
50
|
+
"""
|
|
51
|
+
Merge the incremental outputs passed as input into a single incremental output.
|
|
52
|
+
The outputs must be sorted in cronological order.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
outputs (List[IncrementalOutput]): List of incremental outputs to be merged.
|
|
56
|
+
tokens_to_string (Callable[[List[str]], str]): A function that takes a list of tokens and
|
|
57
|
+
returns a string that contains the detokenized text.
|
|
58
|
+
"""
|
|
59
|
+
if len(outputs) == 1:
|
|
60
|
+
return outputs[0]
|
|
61
|
+
if len(outputs) == 0:
|
|
62
|
+
return IncrementalOutput([], "", [], "")
|
|
63
|
+
|
|
64
|
+
current_output_tokens = outputs[0].new_tokens
|
|
65
|
+
current_output_deleted_tokens = outputs[0].deleted_tokens
|
|
66
|
+
for output in outputs[1:]:
|
|
67
|
+
num_deleted_tokens = len(output.deleted_tokens)
|
|
68
|
+
if num_deleted_tokens > 0:
|
|
69
|
+
if num_deleted_tokens < len(current_output_tokens):
|
|
70
|
+
assert output.deleted_tokens == current_output_tokens[-num_deleted_tokens:]
|
|
71
|
+
current_output_tokens = current_output_tokens[:-num_deleted_tokens]
|
|
72
|
+
else:
|
|
73
|
+
# we are deleting more than it was generated so far, so extra deleted tokens
|
|
74
|
+
# should be included
|
|
75
|
+
extra_deleted_tokens = output.deleted_tokens[:-len(current_output_tokens)]
|
|
76
|
+
current_output_deleted_tokens = \
|
|
77
|
+
extra_deleted_tokens + current_output_deleted_tokens
|
|
78
|
+
current_output_tokens = []
|
|
79
|
+
current_output_tokens += output.new_tokens
|
|
80
|
+
|
|
81
|
+
return IncrementalOutput(
|
|
82
|
+
current_output_tokens,
|
|
83
|
+
tokens_to_string(current_output_tokens),
|
|
84
|
+
current_output_deleted_tokens,
|
|
85
|
+
tokens_to_string(current_output_deleted_tokens))
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Copyright 2025 FBK
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License
|
|
14
|
+
|
|
15
|
+
from types import SimpleNamespace
|
|
16
|
+
from typing import List
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torch
|
|
20
|
+
from transformers import AutoProcessor, SeamlessM4TModel, SeamlessM4Tv2Model
|
|
21
|
+
|
|
22
|
+
from simulstream.server.speech_processors import SAMPLE_RATE
|
|
23
|
+
from simulstream.server.speech_processors.sliding_window_retranslation import \
|
|
24
|
+
SlidingWindowRetranslator
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SeamlessSlidingWindowRetranslator(SlidingWindowRetranslator):
|
|
28
|
+
"""
|
|
29
|
+
Perform Sliding Window Retranslation with a SeamlessM4T speech-to-text model.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def load_model(cls, config: SimpleNamespace):
|
|
34
|
+
if not hasattr(cls, "model") or cls.model is None:
|
|
35
|
+
cls.processor = AutoProcessor.from_pretrained(config.hf_model_name)
|
|
36
|
+
seamless_version = getattr(config, "seamless_version", 1)
|
|
37
|
+
if seamless_version == 2:
|
|
38
|
+
cls.model = SeamlessM4Tv2Model.from_pretrained(config.hf_model_name)
|
|
39
|
+
else:
|
|
40
|
+
cls.model = SeamlessM4TModel.from_pretrained(config.hf_model_name)
|
|
41
|
+
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
42
|
+
cls.model.to(cls.device)
|
|
43
|
+
|
|
44
|
+
def _generate(self, speech: torch.Tensor) -> List[str]:
|
|
45
|
+
speech_seconds = speech.shape[1] / 50 # 1 frame every 20 ms
|
|
46
|
+
extra_kwargs = {
|
|
47
|
+
"max_new_tokens": int(max(self.max_tokens_per_second * speech_seconds, 10)),
|
|
48
|
+
"generate_speech": False
|
|
49
|
+
}
|
|
50
|
+
if self.tgt_lang_tag is not None:
|
|
51
|
+
extra_kwargs["tgt_lang"] = self.tgt_lang_tag
|
|
52
|
+
generated_ids = self.model.generate(input_features=speech, **extra_kwargs)[0]
|
|
53
|
+
return self.processor.tokenizer.convert_ids_to_tokens(
|
|
54
|
+
generated_ids.squeeze(), skip_special_tokens=True)
|
|
55
|
+
|
|
56
|
+
def tokens_to_string(self, tokens: List[str]) -> str:
|
|
57
|
+
# avoid that the initial space, if it is there, get removed in the detokenization
|
|
58
|
+
if self.text_history is not None and len(self.text_history) > 0:
|
|
59
|
+
tokens = [''] + tokens
|
|
60
|
+
return self.processor.tokenizer.convert_tokens_to_string(tokens)
|
|
61
|
+
|
|
62
|
+
def _preprocess(self, waveform: np.float32) -> torch.Tensor:
|
|
63
|
+
"""
|
|
64
|
+
Extracts the filter-bank features from the input waveform and appends them to the audio
|
|
65
|
+
history. Returns the concatenated audio history and new frames, taking the last
|
|
66
|
+
`self.window_len` frames, and returns it after storing it in the audio history.
|
|
67
|
+
"""
|
|
68
|
+
if self.audio_history is not None:
|
|
69
|
+
waveform = np.concatenate((self.audio_history, waveform))
|
|
70
|
+
new_speech_len = len(waveform)
|
|
71
|
+
if new_speech_len > self.window_len:
|
|
72
|
+
waveform = waveform[-self.window_len:]
|
|
73
|
+
self.audio_history = waveform
|
|
74
|
+
new_speech = self.processor(
|
|
75
|
+
audios=waveform,
|
|
76
|
+
sampling_rate=SAMPLE_RATE,
|
|
77
|
+
return_tensors="pt")
|
|
78
|
+
return new_speech["input_features"].to(self.device)
|
|
79
|
+
|
|
80
|
+
def set_target_language(self, language: str) -> None:
|
|
81
|
+
self.tgt_lang_tag = language
|
|
82
|
+
|
|
83
|
+
def set_source_language(self, language: str) -> None:
|
|
84
|
+
pass
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
# Copyright 2025 FBK
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License
|
|
14
|
+
|
|
15
|
+
from types import SimpleNamespace
|
|
16
|
+
from typing import List, Tuple
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torch
|
|
20
|
+
from transformers import AutoProcessor, SeamlessM4TModel, SeamlessM4Tv2Model
|
|
21
|
+
|
|
22
|
+
from simulstream.server.speech_processors import SAMPLE_RATE
|
|
23
|
+
from simulstream.server.speech_processors.base_streamatt import BaseStreamAtt
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
SHIFT_SIZE = 10
|
|
27
|
+
SEAMLESS_AUDIO_SUBSAMPLING_FACTOR = 8
|
|
28
|
+
# `FRAME_LENGTH` is the length of the window used for computing the mel-filterbank features that,
|
|
29
|
+
# with a sample rate of 16kHz, corresponds to 25ms while `HOP_LENGTH` is how much the feature
|
|
30
|
+
# computation is shifted at each step, corresponding to 10ms. Values from `transformers.models.
|
|
31
|
+
# seamless_m4t.feature_extraction_seamless_m4t.SeamlessM4TFeatureExtractor._extract_fbank_features`
|
|
32
|
+
FRAME_LENGTH = 400
|
|
33
|
+
HOP_LENGTH = 160
|
|
34
|
+
# `OVERLAP_WINDOW` is length of the features that overlap at each step, as described in
|
|
35
|
+
# `transformers.audio_utils.spectrogram`, and corresponding to 15ms.
|
|
36
|
+
OVERLAP_WINDOW = FRAME_LENGTH - HOP_LENGTH
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class SeamlessStreamAtt(BaseStreamAtt):
|
|
40
|
+
"""
|
|
41
|
+
Perform StreamAtt policy with the chosen textual history selection using a SeamlessM4T
|
|
42
|
+
speech-to-text model.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
config (SimpleNamespace): Configuration object.
|
|
46
|
+
The following additional attributes are expected:
|
|
47
|
+
- **max_new_tokens (int)**: The maximum numbers of tokens to generate
|
|
48
|
+
- **num_beams (int)**: The number of beams of the beam search
|
|
49
|
+
- **no_repeat_ngram_size (int)**: The maximum number of ngram repeats that
|
|
50
|
+
cannot be emitted
|
|
51
|
+
- **audio_history_max_duration (int)**: Maximum length of the audio history to store,
|
|
52
|
+
in seconds. Defaults to 1 hour.
|
|
53
|
+
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, config: SimpleNamespace):
|
|
57
|
+
super().__init__(config)
|
|
58
|
+
self.audio_subsampling_factor = SEAMLESS_AUDIO_SUBSAMPLING_FACTOR
|
|
59
|
+
self.max_new_tokens = getattr(self.config, "max_new_tokens", 128)
|
|
60
|
+
self.num_beams = getattr(self.config, "num_beams", 5)
|
|
61
|
+
self.audio_history_max_duration = getattr(self.config, "audio_history_max_duration", 360)
|
|
62
|
+
self.no_repeat_ngram_size = getattr(self.config, "no_repeat_ngram_size", 5)
|
|
63
|
+
self.waveform_accumulator = None
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def audio_max_len(self) -> float:
|
|
67
|
+
"""
|
|
68
|
+
Returns the maximum allowed length for the audio history converted in the space of
|
|
69
|
+
audio features. The number of encoded features is obtained by first converting the
|
|
70
|
+
self.audio_history_max_duration, originally in seconds, into milliseconds, and then
|
|
71
|
+
by the dimension of the shift (SHIFT_SIZE). Since the SeamlessM4t encoder subsamples
|
|
72
|
+
the input sequence, the resulting *audio_max_len* is obtained by further dividing
|
|
73
|
+
the original sequence length by *self.audio_subsampling_factor*.
|
|
74
|
+
"""
|
|
75
|
+
return self.audio_history_max_duration * 1000 // SHIFT_SIZE // \
|
|
76
|
+
self.audio_subsampling_factor
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def load_model(cls, config: SimpleNamespace):
|
|
80
|
+
if not hasattr(cls, "model") or cls.model is None:
|
|
81
|
+
cls.processor = AutoProcessor.from_pretrained(config.hf_model_name)
|
|
82
|
+
seamless_version = getattr(config, "seamless_version", 1)
|
|
83
|
+
if seamless_version == 2:
|
|
84
|
+
cls.model = SeamlessM4Tv2Model.from_pretrained(config.hf_model_name)
|
|
85
|
+
else:
|
|
86
|
+
cls.model = SeamlessM4TModel.from_pretrained(config.hf_model_name)
|
|
87
|
+
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
88
|
+
cls.model.to(cls.device)
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def mean_variance_normalization(features: np.ndarray) -> torch.Tensor:
|
|
92
|
+
"""
|
|
93
|
+
Normalization function taken from `transformers.models.seamless_m4t.
|
|
94
|
+
feature_extraction_seamless_m4t.SeamlessM4TFeatureExtractor`.
|
|
95
|
+
"""
|
|
96
|
+
# torch defaults to ddof=1, and numpy defaults to ddof=0
|
|
97
|
+
mu = np.expand_dims(features.mean(axis=0), 0)
|
|
98
|
+
sigma = np.sqrt(np.expand_dims(features.var(axis=0, ddof=1), 0) + 1e-7)
|
|
99
|
+
normalized = (features - mu) / sigma
|
|
100
|
+
return torch.tensor(np.array(normalized))
|
|
101
|
+
|
|
102
|
+
def _preprocess(self, waveform: np.ndarray) -> torch.Tensor:
|
|
103
|
+
"""
|
|
104
|
+
Extract normalized input features for the SeamlessM4T model from the new
|
|
105
|
+
audio chunk (`waveform`) and the overlapping tail contained in
|
|
106
|
+
`self.waveform_accumulator`, and then concatenating them with previously
|
|
107
|
+
extracted features stored in `self.audio_history`.
|
|
108
|
+
|
|
109
|
+
Steps:
|
|
110
|
+
1. Convert the new waveform and the overlapping tail into mel-filterbank features.
|
|
111
|
+
2. Concatenate the new features with any stored `self.audio_history`.
|
|
112
|
+
3. Apply mean-variance normalization.
|
|
113
|
+
4. Cache the overlapping tail of the waveform for the next step.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
torch.Tensor: Normalized feature tensor on `self.device`.
|
|
117
|
+
"""
|
|
118
|
+
# Combine with overlapping part from previous step, if any
|
|
119
|
+
if self.waveform_accumulator is not None:
|
|
120
|
+
waveform = np.concatenate((self.waveform_accumulator, waveform))
|
|
121
|
+
|
|
122
|
+
if len(waveform) >= FRAME_LENGTH:
|
|
123
|
+
# Extract new mel-filterbank features (unnormalized)
|
|
124
|
+
new_features = self.processor(
|
|
125
|
+
audios=waveform,
|
|
126
|
+
return_tensors="np",
|
|
127
|
+
do_normalize_per_mel_bins=False,
|
|
128
|
+
sampling_rate=SAMPLE_RATE,
|
|
129
|
+
)["input_features"].squeeze(0) # shape: (T_new, 160)
|
|
130
|
+
|
|
131
|
+
# Concatenate with previous features, if available
|
|
132
|
+
if self.audio_history is not None:
|
|
133
|
+
self.audio_history = np.concatenate((self.audio_history, new_features), axis=0)
|
|
134
|
+
else:
|
|
135
|
+
self.audio_history = new_features
|
|
136
|
+
|
|
137
|
+
# Normalize all features
|
|
138
|
+
normalized_features = self.mean_variance_normalization(self.audio_history)
|
|
139
|
+
|
|
140
|
+
# Store the last part of the waveform for the next preprocessing step
|
|
141
|
+
self.waveform_accumulator = waveform[-OVERLAP_WINDOW:]
|
|
142
|
+
|
|
143
|
+
return normalized_features.unsqueeze(0).to(self.device)
|
|
144
|
+
|
|
145
|
+
def get_prefix(self):
|
|
146
|
+
"""
|
|
147
|
+
Creates a prefix for the generation phase with the previous outputs stored in
|
|
148
|
+
the text_history. The prefix is formatted following the SeamlessM4T template
|
|
149
|
+
(tgt language token id + text history ids).
|
|
150
|
+
"""
|
|
151
|
+
prefix_ids = torch.tensor(
|
|
152
|
+
[[self.model.generation_config.text_decoder_lang_to_code_id.get(self.tgt_lang)]]
|
|
153
|
+
).to(self.device)
|
|
154
|
+
if self.text_history:
|
|
155
|
+
text_history_ids = torch.tensor(
|
|
156
|
+
self.processor.tokenizer.convert_tokens_to_ids(self.text_history)).unsqueeze(
|
|
157
|
+
dim=0).to(self.device)
|
|
158
|
+
prefix_ids = torch.cat((prefix_ids, text_history_ids), dim=1)
|
|
159
|
+
return prefix_ids.long()
|
|
160
|
+
|
|
161
|
+
def _generate(
|
|
162
|
+
self, input_features: torch.Tensor, normalize_attn: bool = True
|
|
163
|
+
) -> Tuple[List[str], torch.Tensor]:
|
|
164
|
+
"""
|
|
165
|
+
Generates a new hypothesis returning also the cross attention scores.
|
|
166
|
+
The hypothesis is forced to start with the `prefix_ids` prefix tokens,
|
|
167
|
+
which are then removed from the returned output tokens.
|
|
168
|
+
The attention scores are returned with dimensions (sequence_len, n_audio_features)
|
|
169
|
+
where `sequence_len` is the length of the prefix (excluding the language ID)
|
|
170
|
+
and of the new hypothesis.
|
|
171
|
+
"""
|
|
172
|
+
prefix_ids = self.get_prefix()
|
|
173
|
+
gen_out = self.model.generate(
|
|
174
|
+
input_features=input_features,
|
|
175
|
+
decoder_input_ids=prefix_ids,
|
|
176
|
+
max_new_tokens=self.max_new_tokens,
|
|
177
|
+
num_beams=self.num_beams,
|
|
178
|
+
return_dict_in_generate=True,
|
|
179
|
+
output_attentions=True,
|
|
180
|
+
no_repeat_ngram_size=self.no_repeat_ngram_size,
|
|
181
|
+
generate_speech=False)
|
|
182
|
+
out_ids = list(gen_out.sequences[0])
|
|
183
|
+
|
|
184
|
+
# Exclude BOS, prefix, and EOS from the generated sequence
|
|
185
|
+
new_hypo_ids = out_ids[prefix_ids.shape[1] + 1:-1]
|
|
186
|
+
|
|
187
|
+
cross_attn = self.get_cross_attention(
|
|
188
|
+
gen_out, len(new_hypo_ids), normalize_attn=normalize_attn)
|
|
189
|
+
|
|
190
|
+
assert cross_attn.shape[0] == (prefix_ids.shape[1] - 1) + len(new_hypo_ids), \
|
|
191
|
+
f"Cross attention scores along tokens dimension ({cross_attn.shape[0]}) " \
|
|
192
|
+
f"mismatch with the length of the hypothesis " \
|
|
193
|
+
f"({(prefix_ids.shape[1] - 1) + len(new_hypo_ids)})."
|
|
194
|
+
|
|
195
|
+
new_hypo = self.processor.tokenizer.convert_ids_to_tokens(
|
|
196
|
+
new_hypo_ids, skip_special_tokens=True)
|
|
197
|
+
return new_hypo, cross_attn
|
|
198
|
+
|
|
199
|
+
def _extract_new_hypo_attention_scores(self, new_hypo_len: int, gen_out):
|
|
200
|
+
"""
|
|
201
|
+
Extract attention scores `cross_attentions` from `gen_out`, which are stored for each
|
|
202
|
+
generation step and Layer `layer`, based on the beam index selected at each step of the
|
|
203
|
+
beam search and stored in `beam_indices` (if num_beams > 1) for the new hypotheses.
|
|
204
|
+
"""
|
|
205
|
+
cross_attns = []
|
|
206
|
+
if self.num_beams > 1:
|
|
207
|
+
# Beam search: for each token of the new hypothesis, we select the corresponding cross
|
|
208
|
+
# attention from the cross attentions stored at each step of the beam search using the
|
|
209
|
+
# index contained in the tensor of indices beam_indices (num_beams * sequence length)
|
|
210
|
+
for tok_idx in range(new_hypo_len):
|
|
211
|
+
# Select the cross attention matrix using the beam_indices
|
|
212
|
+
beam_indices = gen_out.beam_indices[:, tok_idx]
|
|
213
|
+
# add some comments on why tok_idx + 1, and the -1 selection
|
|
214
|
+
cross_attn = gen_out.cross_attentions[
|
|
215
|
+
tok_idx + 1][self.cross_attn_layer][:, :, -1, :]
|
|
216
|
+
cross_attn = cross_attn.index_select(dim=0, index=beam_indices)
|
|
217
|
+
cross_attns.append(cross_attn)
|
|
218
|
+
else:
|
|
219
|
+
# Greedy search
|
|
220
|
+
for tok_idx in range(new_hypo_len):
|
|
221
|
+
cross_attn = gen_out.cross_attentions[
|
|
222
|
+
tok_idx + 1][self.cross_attn_layer][:, :, -1, :]
|
|
223
|
+
cross_attns.append(cross_attn)
|
|
224
|
+
|
|
225
|
+
# Cross attention scores with shape [num_heads, sequence_len, n_audio_features]
|
|
226
|
+
cross_attns = torch.stack(cross_attns).squeeze(1)
|
|
227
|
+
return cross_attns
|
|
228
|
+
|
|
229
|
+
def get_cross_attention(
|
|
230
|
+
self, gen_out, new_hypo_len: int, normalize_attn: bool) -> torch.Tensor:
|
|
231
|
+
"""
|
|
232
|
+
Given the attention scores for the generated output (including both prefix and new
|
|
233
|
+
hypothesis), this function returns the cross attention scores from Layer *layer* by
|
|
234
|
+
averaging the scores along the attention heads dimension.
|
|
235
|
+
"""
|
|
236
|
+
# The prefix is stored in the first element of the cross_attentions, equal for each beam
|
|
237
|
+
# and, therefore, the first is selected. Also, BOS and language ID are excluded from the
|
|
238
|
+
# sequence_len (first two elements).
|
|
239
|
+
cross_attns = (
|
|
240
|
+
gen_out.cross_attentions[0][self.cross_attn_layer][0, :, 2:, :].transpose(0, 1))
|
|
241
|
+
|
|
242
|
+
if new_hypo_len > 0:
|
|
243
|
+
new_cross_attns = self._extract_new_hypo_attention_scores(new_hypo_len, gen_out)
|
|
244
|
+
cross_attns = torch.cat([cross_attns, new_cross_attns], dim=0)
|
|
245
|
+
|
|
246
|
+
# Average on the attention heads dimension
|
|
247
|
+
cross_attns = cross_attns.mean(dim=1)
|
|
248
|
+
|
|
249
|
+
# Normalize attention scores
|
|
250
|
+
if normalize_attn:
|
|
251
|
+
cross_attns = self.normalize_attn(cross_attns)
|
|
252
|
+
return cross_attns
|
|
253
|
+
|
|
254
|
+
def tokens_to_string(self, tokens: List[str]) -> str:
|
|
255
|
+
# avoid that the initial space, if it is there, get removed in the detokenization
|
|
256
|
+
if self.text_history is not None and len(self.text_history) > 0:
|
|
257
|
+
tokens = [''] + tokens
|
|
258
|
+
return self.processor.tokenizer.convert_tokens_to_string(tokens)
|
|
259
|
+
|
|
260
|
+
def set_target_language(self, language: str) -> None:
|
|
261
|
+
self.tgt_lang = language
|
|
262
|
+
|
|
263
|
+
def set_source_language(self, language: str) -> None:
|
|
264
|
+
pass
|
|
265
|
+
|
|
266
|
+
def clear(self) -> None:
|
|
267
|
+
super().clear()
|
|
268
|
+
self.waveform_accumulator = None
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# Copyright 2025 FBK
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import torch
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from types import SimpleNamespace
|
|
20
|
+
from typing import List
|
|
21
|
+
|
|
22
|
+
from simulstream.metrics.detokenizers import get_detokenizer
|
|
23
|
+
from simulstream.server.speech_processors import IncrementalOutput, SpeechProcessor, class_load
|
|
24
|
+
from simulstream.server.speech_processors import SAMPLE_RATE
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
from simuleval.agents.agent import SEGMENT_TYPE_DICT
|
|
32
|
+
from simuleval.agents.actions import Action
|
|
33
|
+
from simuleval.data.segments import SpeechSegment
|
|
34
|
+
except Exception as e:
|
|
35
|
+
logger.error(
|
|
36
|
+
"Not able to import SimulEval. Please install SimulEval or add it to your PYTHONPATH.")
|
|
37
|
+
# In case we are running unit tests, avoid failures when importing the Wrapper
|
|
38
|
+
import os
|
|
39
|
+
is_testing = os.getenv("IS_TESTING")
|
|
40
|
+
if not (is_testing is not None and bool(is_testing)):
|
|
41
|
+
raise e
|
|
42
|
+
import builtins
|
|
43
|
+
# Temporarily inject types to satisfy simuleval_wrapper imports
|
|
44
|
+
# These need to remain in builtins for type annotations in the class definition
|
|
45
|
+
builtins.Action = type("Action", (), {})
|
|
46
|
+
builtins.SpeechSegment = type("SpeechSegment", (), {})
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SimulEvalWrapper(SpeechProcessor):
|
|
50
|
+
"""
|
|
51
|
+
Wrapper processor that calls the configured `simuleval_agent` implemented on SimulEval>=1.1.0.
|
|
52
|
+
"""
|
|
53
|
+
def __init__(self, config: SimpleNamespace):
|
|
54
|
+
super().__init__(config)
|
|
55
|
+
agent_class_name = getattr(config, "simuleval_agent")
|
|
56
|
+
agent_class = class_load(agent_class_name)
|
|
57
|
+
config.source_segment_size = config.speech_chunk_size * 1000
|
|
58
|
+
config.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
59
|
+
self.simuleval_agent = agent_class(config)
|
|
60
|
+
self.latency_unit = config.latency_unit
|
|
61
|
+
self.segment_type = self._segment_type_class()
|
|
62
|
+
self.emission_started = False
|
|
63
|
+
self.detokenizer = get_detokenizer(config)
|
|
64
|
+
|
|
65
|
+
def _segment_type_class(self):
|
|
66
|
+
return SEGMENT_TYPE_DICT[self.simuleval_agent.target_type]
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def load_model(cls, config: SimpleNamespace):
|
|
70
|
+
"""
|
|
71
|
+
In SimulEval, the model is loaded in the init of the `simuleval_agent` and,
|
|
72
|
+
therefore, a copy of the model for each client is created.
|
|
73
|
+
"""
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
def set_target_language(self, language: str) -> None:
|
|
77
|
+
if hasattr(self.simuleval_agent, "tgt_lang"):
|
|
78
|
+
self.simuleval_agent.tgt_lang = language
|
|
79
|
+
else:
|
|
80
|
+
logger.warning("Unable to set the target language for SimulEval agent.")
|
|
81
|
+
|
|
82
|
+
def set_source_language(self, language: str) -> None:
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
def tokens_to_string(self, tokens: List[str]) -> str:
|
|
86
|
+
return self.detokenizer(tokens)
|
|
87
|
+
|
|
88
|
+
def _process_action(self, action: Action) -> str:
|
|
89
|
+
"""
|
|
90
|
+
Processes a SimulEval action and updates the target state accordingly.
|
|
91
|
+
|
|
92
|
+
If the action is of type READ, no output is produced and an empty prediction string is
|
|
93
|
+
returned. Otherwise (WRITE action), the method extracts the generated content from the
|
|
94
|
+
action, wraps it in the appropriate target segment type, and updates the SimulEval agent's
|
|
95
|
+
target state with this new segment.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
action (Action): The current SimulEval action to process. It can either
|
|
99
|
+
request reading more input (READ) or writing an output (WRITE).
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
str: The predicted output text if the action is WRITE, or an empty string
|
|
103
|
+
if the action is READ.
|
|
104
|
+
"""
|
|
105
|
+
if action.is_read():
|
|
106
|
+
return ""
|
|
107
|
+
|
|
108
|
+
prediction = action.content
|
|
109
|
+
segment = self.segment_type(index=0, content=prediction, finished=action.finished)
|
|
110
|
+
self.simuleval_agent.states.update_target(segment)
|
|
111
|
+
return prediction
|
|
112
|
+
|
|
113
|
+
def _build_incremental_outputs(self, generated_text: str) -> IncrementalOutput:
|
|
114
|
+
"""
|
|
115
|
+
Transform the prediction string from `Action.content` of SimulEval into the required
|
|
116
|
+
IncrementalOutput format. The token conversion follows the original SimulEval Instance
|
|
117
|
+
handling (https://github.com/facebookresearch/SimulEval/blob/
|
|
118
|
+
536de8253b82d805c9845440169a5010ff507357/simuleval/evaluator/instance.py#L209) with the
|
|
119
|
+
sole exception of not removing generated spaces in the character-level languages.
|
|
120
|
+
Since SimulEval only supports incremental outputs, no tokens are deleted.
|
|
121
|
+
"""
|
|
122
|
+
if self.latency_unit in ["word", "spm"]:
|
|
123
|
+
generated_tokens = generated_text.strip().split()
|
|
124
|
+
elif self.latency_unit == "char":
|
|
125
|
+
generated_tokens = list(generated_text.strip())
|
|
126
|
+
else:
|
|
127
|
+
raise NotImplementedError
|
|
128
|
+
|
|
129
|
+
if self.emission_started and self.latency_unit == "word":
|
|
130
|
+
generated_text = " " + generated_text
|
|
131
|
+
|
|
132
|
+
self.emission_started = True
|
|
133
|
+
|
|
134
|
+
return IncrementalOutput(
|
|
135
|
+
new_tokens=generated_tokens,
|
|
136
|
+
new_string=generated_text,
|
|
137
|
+
deleted_tokens=[],
|
|
138
|
+
deleted_string="",
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def process_chunk(self, waveform: np.float32) -> IncrementalOutput:
|
|
142
|
+
source_segment = SpeechSegment(
|
|
143
|
+
index=0,
|
|
144
|
+
content=waveform.tolist(),
|
|
145
|
+
sample_rate=SAMPLE_RATE,
|
|
146
|
+
finished=False,
|
|
147
|
+
tgt_lang=self.simuleval_agent.tgt_lang,
|
|
148
|
+
)
|
|
149
|
+
self.simuleval_agent.states.update_source(source_segment)
|
|
150
|
+
action = self.simuleval_agent.policy(self.simuleval_agent.states)
|
|
151
|
+
prediction = self._process_action(action)
|
|
152
|
+
new_output = self._build_incremental_outputs(prediction)
|
|
153
|
+
return new_output
|
|
154
|
+
|
|
155
|
+
def end_of_stream(self) -> IncrementalOutput:
|
|
156
|
+
self.simuleval_agent.states.source_finished = True
|
|
157
|
+
action = self.simuleval_agent.policy(self.simuleval_agent.states)
|
|
158
|
+
prediction = self._process_action(action)
|
|
159
|
+
new_output = self._build_incremental_outputs(prediction)
|
|
160
|
+
return new_output
|
|
161
|
+
|
|
162
|
+
def clear(self) -> None:
|
|
163
|
+
""" In SimulEval, the agent is reset inside `simuleval_agent` """
|
|
164
|
+
self.simuleval_agent.reset()
|
|
165
|
+
self.emission_started = False
|