torchaudio 2.9.1__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchaudio/.dylibs/libc++.1.0.dylib +0 -0
- torchaudio/__init__.py +204 -0
- torchaudio/_extension/__init__.py +61 -0
- torchaudio/_extension/utils.py +133 -0
- torchaudio/_internal/__init__.py +10 -0
- torchaudio/_internal/module_utils.py +171 -0
- torchaudio/_torchcodec.py +340 -0
- torchaudio/compliance/__init__.py +5 -0
- torchaudio/compliance/kaldi.py +813 -0
- torchaudio/datasets/__init__.py +47 -0
- torchaudio/datasets/cmuarctic.py +157 -0
- torchaudio/datasets/cmudict.py +186 -0
- torchaudio/datasets/commonvoice.py +86 -0
- torchaudio/datasets/dr_vctk.py +121 -0
- torchaudio/datasets/fluentcommands.py +108 -0
- torchaudio/datasets/gtzan.py +1118 -0
- torchaudio/datasets/iemocap.py +147 -0
- torchaudio/datasets/librilight_limited.py +111 -0
- torchaudio/datasets/librimix.py +133 -0
- torchaudio/datasets/librispeech.py +174 -0
- torchaudio/datasets/librispeech_biasing.py +189 -0
- torchaudio/datasets/libritts.py +168 -0
- torchaudio/datasets/ljspeech.py +107 -0
- torchaudio/datasets/musdb_hq.py +139 -0
- torchaudio/datasets/quesst14.py +136 -0
- torchaudio/datasets/snips.py +157 -0
- torchaudio/datasets/speechcommands.py +183 -0
- torchaudio/datasets/tedlium.py +218 -0
- torchaudio/datasets/utils.py +54 -0
- torchaudio/datasets/vctk.py +143 -0
- torchaudio/datasets/voxceleb1.py +309 -0
- torchaudio/datasets/yesno.py +89 -0
- torchaudio/functional/__init__.py +130 -0
- torchaudio/functional/_alignment.py +128 -0
- torchaudio/functional/filtering.py +1685 -0
- torchaudio/functional/functional.py +2505 -0
- torchaudio/lib/__init__.py +0 -0
- torchaudio/lib/_torchaudio.so +0 -0
- torchaudio/lib/libtorchaudio.so +0 -0
- torchaudio/models/__init__.py +85 -0
- torchaudio/models/_hdemucs.py +1008 -0
- torchaudio/models/conformer.py +293 -0
- torchaudio/models/conv_tasnet.py +330 -0
- torchaudio/models/decoder/__init__.py +64 -0
- torchaudio/models/decoder/_ctc_decoder.py +568 -0
- torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
- torchaudio/models/deepspeech.py +84 -0
- torchaudio/models/emformer.py +884 -0
- torchaudio/models/rnnt.py +816 -0
- torchaudio/models/rnnt_decoder.py +339 -0
- torchaudio/models/squim/__init__.py +11 -0
- torchaudio/models/squim/objective.py +326 -0
- torchaudio/models/squim/subjective.py +150 -0
- torchaudio/models/tacotron2.py +1046 -0
- torchaudio/models/wav2letter.py +72 -0
- torchaudio/models/wav2vec2/__init__.py +45 -0
- torchaudio/models/wav2vec2/components.py +1167 -0
- torchaudio/models/wav2vec2/model.py +1579 -0
- torchaudio/models/wav2vec2/utils/__init__.py +7 -0
- torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
- torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
- torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
- torchaudio/models/wavernn.py +409 -0
- torchaudio/pipelines/__init__.py +102 -0
- torchaudio/pipelines/_source_separation_pipeline.py +109 -0
- torchaudio/pipelines/_squim_pipeline.py +156 -0
- torchaudio/pipelines/_tts/__init__.py +16 -0
- torchaudio/pipelines/_tts/impl.py +385 -0
- torchaudio/pipelines/_tts/interface.py +255 -0
- torchaudio/pipelines/_tts/utils.py +230 -0
- torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
- torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
- torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
- torchaudio/pipelines/_wav2vec2/utils.py +346 -0
- torchaudio/pipelines/rnnt_pipeline.py +380 -0
- torchaudio/transforms/__init__.py +78 -0
- torchaudio/transforms/_multi_channel.py +467 -0
- torchaudio/transforms/_transforms.py +2138 -0
- torchaudio/utils/__init__.py +4 -0
- torchaudio/utils/download.py +89 -0
- torchaudio/version.py +2 -0
- torchaudio-2.9.1.dist-info/METADATA +133 -0
- torchaudio-2.9.1.dist-info/RECORD +86 -0
- torchaudio-2.9.1.dist-info/WHEEL +5 -0
- torchaudio-2.9.1.dist-info/licenses/LICENSE +25 -0
- torchaudio-2.9.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,884 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
__all__ = ["Emformer"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
|
|
11
|
+
batch_size = lengths.shape[0]
|
|
12
|
+
max_length = int(torch.max(lengths).item())
|
|
13
|
+
padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
|
|
14
|
+
batch_size, max_length
|
|
15
|
+
) >= lengths.unsqueeze(1)
|
|
16
|
+
return padding_mask
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _gen_padding_mask(
|
|
20
|
+
utterance: torch.Tensor,
|
|
21
|
+
right_context: torch.Tensor,
|
|
22
|
+
summary: torch.Tensor,
|
|
23
|
+
lengths: torch.Tensor,
|
|
24
|
+
mems: torch.Tensor,
|
|
25
|
+
left_context_key: Optional[torch.Tensor] = None,
|
|
26
|
+
) -> Optional[torch.Tensor]:
|
|
27
|
+
T = right_context.size(0) + utterance.size(0) + summary.size(0)
|
|
28
|
+
B = right_context.size(1)
|
|
29
|
+
if B == 1:
|
|
30
|
+
padding_mask = None
|
|
31
|
+
else:
|
|
32
|
+
right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
|
|
33
|
+
left_context_blocks_length = left_context_key.size(0) if left_context_key is not None else 0
|
|
34
|
+
klengths = lengths + mems.size(0) + right_context_blocks_length + left_context_blocks_length
|
|
35
|
+
padding_mask = _lengths_to_padding_mask(lengths=klengths)
|
|
36
|
+
return padding_mask
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _get_activation_module(activation: str) -> torch.nn.Module:
|
|
40
|
+
if activation == "relu":
|
|
41
|
+
return torch.nn.ReLU()
|
|
42
|
+
elif activation == "gelu":
|
|
43
|
+
return torch.nn.GELU()
|
|
44
|
+
elif activation == "silu":
|
|
45
|
+
return torch.nn.SiLU()
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError(f"Unsupported activation {activation}")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _get_weight_init_gains(weight_init_scale_strategy: Optional[str], num_layers: int) -> List[Optional[float]]:
|
|
51
|
+
if weight_init_scale_strategy is None:
|
|
52
|
+
return [None for _ in range(num_layers)]
|
|
53
|
+
elif weight_init_scale_strategy == "depthwise":
|
|
54
|
+
return [1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)]
|
|
55
|
+
elif weight_init_scale_strategy == "constant":
|
|
56
|
+
return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError(f"Unsupported weight_init_scale_strategy value {weight_init_scale_strategy}")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _gen_attention_mask_block(
|
|
62
|
+
col_widths: List[int], col_mask: List[bool], num_rows: int, device: torch.device
|
|
63
|
+
) -> torch.Tensor:
|
|
64
|
+
if len(col_widths) != len(col_mask):
|
|
65
|
+
raise ValueError("Length of col_widths must match that of col_mask")
|
|
66
|
+
|
|
67
|
+
mask_block = [
|
|
68
|
+
torch.ones(num_rows, col_width, device=device)
|
|
69
|
+
if is_ones_col
|
|
70
|
+
else torch.zeros(num_rows, col_width, device=device)
|
|
71
|
+
for col_width, is_ones_col in zip(col_widths, col_mask)
|
|
72
|
+
]
|
|
73
|
+
return torch.cat(mask_block, dim=1)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class _EmformerAttention(torch.nn.Module):
|
|
77
|
+
r"""Emformer layer attention module.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
input_dim (int): input dimension.
|
|
81
|
+
num_heads (int): number of attention heads in each Emformer layer.
|
|
82
|
+
dropout (float, optional): dropout probability. (Default: 0.0)
|
|
83
|
+
weight_init_gain (float or None, optional): scale factor to apply when initializing
|
|
84
|
+
attention module parameters. (Default: ``None``)
|
|
85
|
+
tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
|
|
86
|
+
negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
input_dim: int,
|
|
92
|
+
num_heads: int,
|
|
93
|
+
dropout: float = 0.0,
|
|
94
|
+
weight_init_gain: Optional[float] = None,
|
|
95
|
+
tanh_on_mem: bool = False,
|
|
96
|
+
negative_inf: float = -1e8,
|
|
97
|
+
):
|
|
98
|
+
super().__init__()
|
|
99
|
+
|
|
100
|
+
if input_dim % num_heads != 0:
|
|
101
|
+
raise ValueError(f"input_dim ({input_dim}) is not a multiple of num_heads ({num_heads}).")
|
|
102
|
+
|
|
103
|
+
self.input_dim = input_dim
|
|
104
|
+
self.num_heads = num_heads
|
|
105
|
+
self.dropout = dropout
|
|
106
|
+
self.tanh_on_mem = tanh_on_mem
|
|
107
|
+
self.negative_inf = negative_inf
|
|
108
|
+
|
|
109
|
+
self.scaling = (self.input_dim // self.num_heads) ** -0.5
|
|
110
|
+
|
|
111
|
+
self.emb_to_key_value = torch.nn.Linear(input_dim, 2 * input_dim, bias=True)
|
|
112
|
+
self.emb_to_query = torch.nn.Linear(input_dim, input_dim, bias=True)
|
|
113
|
+
self.out_proj = torch.nn.Linear(input_dim, input_dim, bias=True)
|
|
114
|
+
|
|
115
|
+
if weight_init_gain:
|
|
116
|
+
torch.nn.init.xavier_uniform_(self.emb_to_key_value.weight, gain=weight_init_gain)
|
|
117
|
+
torch.nn.init.xavier_uniform_(self.emb_to_query.weight, gain=weight_init_gain)
|
|
118
|
+
|
|
119
|
+
def _gen_key_value(self, input: torch.Tensor, mems: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
120
|
+
T, _, _ = input.shape
|
|
121
|
+
summary_length = mems.size(0) + 1
|
|
122
|
+
right_ctx_utterance_block = input[: T - summary_length]
|
|
123
|
+
mems_right_ctx_utterance_block = torch.cat([mems, right_ctx_utterance_block])
|
|
124
|
+
key, value = self.emb_to_key_value(mems_right_ctx_utterance_block).chunk(chunks=2, dim=2)
|
|
125
|
+
return key, value
|
|
126
|
+
|
|
127
|
+
def _gen_attention_probs(
|
|
128
|
+
self,
|
|
129
|
+
attention_weights: torch.Tensor,
|
|
130
|
+
attention_mask: torch.Tensor,
|
|
131
|
+
padding_mask: Optional[torch.Tensor],
|
|
132
|
+
) -> torch.Tensor:
|
|
133
|
+
attention_weights_float = attention_weights.float()
|
|
134
|
+
attention_weights_float = attention_weights_float.masked_fill(attention_mask.unsqueeze(0), self.negative_inf)
|
|
135
|
+
T = attention_weights.size(1)
|
|
136
|
+
B = attention_weights.size(0) // self.num_heads
|
|
137
|
+
if padding_mask is not None:
|
|
138
|
+
attention_weights_float = attention_weights_float.view(B, self.num_heads, T, -1)
|
|
139
|
+
attention_weights_float = attention_weights_float.masked_fill(
|
|
140
|
+
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf
|
|
141
|
+
)
|
|
142
|
+
attention_weights_float = attention_weights_float.view(B * self.num_heads, T, -1)
|
|
143
|
+
attention_probs = torch.nn.functional.softmax(attention_weights_float, dim=-1).type_as(attention_weights)
|
|
144
|
+
return torch.nn.functional.dropout(attention_probs, p=float(self.dropout), training=self.training)
|
|
145
|
+
|
|
146
|
+
def _forward_impl(
|
|
147
|
+
self,
|
|
148
|
+
utterance: torch.Tensor,
|
|
149
|
+
lengths: torch.Tensor,
|
|
150
|
+
right_context: torch.Tensor,
|
|
151
|
+
summary: torch.Tensor,
|
|
152
|
+
mems: torch.Tensor,
|
|
153
|
+
attention_mask: torch.Tensor,
|
|
154
|
+
left_context_key: Optional[torch.Tensor] = None,
|
|
155
|
+
left_context_val: Optional[torch.Tensor] = None,
|
|
156
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
157
|
+
B = utterance.size(1)
|
|
158
|
+
T = right_context.size(0) + utterance.size(0) + summary.size(0)
|
|
159
|
+
|
|
160
|
+
# Compute query with [right context, utterance, summary].
|
|
161
|
+
query = self.emb_to_query(torch.cat([right_context, utterance, summary]))
|
|
162
|
+
|
|
163
|
+
# Compute key and value with [mems, right context, utterance].
|
|
164
|
+
key, value = self.emb_to_key_value(torch.cat([mems, right_context, utterance])).chunk(chunks=2, dim=2)
|
|
165
|
+
|
|
166
|
+
if left_context_key is not None and left_context_val is not None:
|
|
167
|
+
right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
|
|
168
|
+
key = torch.cat(
|
|
169
|
+
[
|
|
170
|
+
key[: mems.size(0) + right_context_blocks_length],
|
|
171
|
+
left_context_key,
|
|
172
|
+
key[mems.size(0) + right_context_blocks_length :],
|
|
173
|
+
],
|
|
174
|
+
)
|
|
175
|
+
value = torch.cat(
|
|
176
|
+
[
|
|
177
|
+
value[: mems.size(0) + right_context_blocks_length],
|
|
178
|
+
left_context_val,
|
|
179
|
+
value[mems.size(0) + right_context_blocks_length :],
|
|
180
|
+
],
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Compute attention weights from query, key, and value.
|
|
184
|
+
reshaped_query, reshaped_key, reshaped_value = [
|
|
185
|
+
tensor.contiguous().view(-1, B * self.num_heads, self.input_dim // self.num_heads).transpose(0, 1)
|
|
186
|
+
for tensor in [query, key, value]
|
|
187
|
+
]
|
|
188
|
+
attention_weights = torch.bmm(reshaped_query * self.scaling, reshaped_key.transpose(1, 2))
|
|
189
|
+
|
|
190
|
+
# Compute padding mask.
|
|
191
|
+
padding_mask = _gen_padding_mask(utterance, right_context, summary, lengths, mems, left_context_key)
|
|
192
|
+
|
|
193
|
+
# Compute attention probabilities.
|
|
194
|
+
attention_probs = self._gen_attention_probs(attention_weights, attention_mask, padding_mask)
|
|
195
|
+
|
|
196
|
+
# Compute attention.
|
|
197
|
+
attention = torch.bmm(attention_probs, reshaped_value)
|
|
198
|
+
if attention.shape != (
|
|
199
|
+
B * self.num_heads,
|
|
200
|
+
T,
|
|
201
|
+
self.input_dim // self.num_heads,
|
|
202
|
+
):
|
|
203
|
+
raise AssertionError("Computed attention has incorrect dimensions")
|
|
204
|
+
attention = attention.transpose(0, 1).contiguous().view(T, B, self.input_dim)
|
|
205
|
+
|
|
206
|
+
# Apply output projection.
|
|
207
|
+
output_right_context_mems = self.out_proj(attention)
|
|
208
|
+
|
|
209
|
+
summary_length = summary.size(0)
|
|
210
|
+
output_right_context = output_right_context_mems[: T - summary_length]
|
|
211
|
+
output_mems = output_right_context_mems[T - summary_length :]
|
|
212
|
+
if self.tanh_on_mem:
|
|
213
|
+
output_mems = torch.tanh(output_mems)
|
|
214
|
+
else:
|
|
215
|
+
output_mems = torch.clamp(output_mems, min=-10, max=10)
|
|
216
|
+
|
|
217
|
+
return output_right_context, output_mems, key, value
|
|
218
|
+
|
|
219
|
+
def forward(
|
|
220
|
+
self,
|
|
221
|
+
utterance: torch.Tensor,
|
|
222
|
+
lengths: torch.Tensor,
|
|
223
|
+
right_context: torch.Tensor,
|
|
224
|
+
summary: torch.Tensor,
|
|
225
|
+
mems: torch.Tensor,
|
|
226
|
+
attention_mask: torch.Tensor,
|
|
227
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
228
|
+
r"""Forward pass for training.
|
|
229
|
+
|
|
230
|
+
B: batch size;
|
|
231
|
+
D: feature dimension of each frame;
|
|
232
|
+
T: number of utterance frames;
|
|
233
|
+
R: number of right context frames;
|
|
234
|
+
S: number of summary elements;
|
|
235
|
+
M: number of memory elements.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
|
|
239
|
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
240
|
+
number of valid frames for i-th batch element in ``utterance``.
|
|
241
|
+
right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
|
|
242
|
+
summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
|
|
243
|
+
mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
|
|
244
|
+
attention_mask (torch.Tensor): attention mask for underlying attention module.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
(Tensor, Tensor):
|
|
248
|
+
Tensor
|
|
249
|
+
output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
|
|
250
|
+
Tensor
|
|
251
|
+
updated memory elements, with shape `(M, B, D)`.
|
|
252
|
+
"""
|
|
253
|
+
output, output_mems, _, _ = self._forward_impl(utterance, lengths, right_context, summary, mems, attention_mask)
|
|
254
|
+
return output, output_mems[:-1]
|
|
255
|
+
|
|
256
|
+
@torch.jit.export
|
|
257
|
+
def infer(
|
|
258
|
+
self,
|
|
259
|
+
utterance: torch.Tensor,
|
|
260
|
+
lengths: torch.Tensor,
|
|
261
|
+
right_context: torch.Tensor,
|
|
262
|
+
summary: torch.Tensor,
|
|
263
|
+
mems: torch.Tensor,
|
|
264
|
+
left_context_key: torch.Tensor,
|
|
265
|
+
left_context_val: torch.Tensor,
|
|
266
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
267
|
+
r"""Forward pass for inference.
|
|
268
|
+
|
|
269
|
+
B: batch size;
|
|
270
|
+
D: feature dimension of each frame;
|
|
271
|
+
T: number of utterance frames;
|
|
272
|
+
R: number of right context frames;
|
|
273
|
+
S: number of summary elements;
|
|
274
|
+
M: number of memory elements.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
|
|
278
|
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
279
|
+
number of valid frames for i-th batch element in ``utterance``.
|
|
280
|
+
right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
|
|
281
|
+
summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
|
|
282
|
+
mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
|
|
283
|
+
left_context_key (torch.Tensor): left context attention key computed from preceding invocation.
|
|
284
|
+
left_context_val (torch.Tensor): left context attention value computed from preceding invocation.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
(Tensor, Tensor, Tensor, and Tensor):
|
|
288
|
+
Tensor
|
|
289
|
+
output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
|
|
290
|
+
Tensor
|
|
291
|
+
updated memory elements, with shape `(M, B, D)`.
|
|
292
|
+
Tensor
|
|
293
|
+
attention key computed for left context and utterance.
|
|
294
|
+
Tensor
|
|
295
|
+
attention value computed for left context and utterance.
|
|
296
|
+
"""
|
|
297
|
+
query_dim = right_context.size(0) + utterance.size(0) + summary.size(0)
|
|
298
|
+
key_dim = right_context.size(0) + utterance.size(0) + mems.size(0) + left_context_key.size(0)
|
|
299
|
+
attention_mask = torch.zeros(query_dim, key_dim).to(dtype=torch.bool, device=utterance.device)
|
|
300
|
+
attention_mask[-1, : mems.size(0)] = True
|
|
301
|
+
output, output_mems, key, value = self._forward_impl(
|
|
302
|
+
utterance,
|
|
303
|
+
lengths,
|
|
304
|
+
right_context,
|
|
305
|
+
summary,
|
|
306
|
+
mems,
|
|
307
|
+
attention_mask,
|
|
308
|
+
left_context_key=left_context_key,
|
|
309
|
+
left_context_val=left_context_val,
|
|
310
|
+
)
|
|
311
|
+
return (
|
|
312
|
+
output,
|
|
313
|
+
output_mems,
|
|
314
|
+
key[mems.size(0) + right_context.size(0) :],
|
|
315
|
+
value[mems.size(0) + right_context.size(0) :],
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
class _EmformerLayer(torch.nn.Module):
|
|
320
|
+
r"""Emformer layer that constitutes Emformer.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
input_dim (int): input dimension.
|
|
324
|
+
num_heads (int): number of attention heads.
|
|
325
|
+
ffn_dim: (int): hidden layer dimension of feedforward network.
|
|
326
|
+
segment_length (int): length of each input segment.
|
|
327
|
+
dropout (float, optional): dropout probability. (Default: 0.0)
|
|
328
|
+
activation (str, optional): activation function to use in feedforward network.
|
|
329
|
+
Must be one of ("relu", "gelu", "silu"). (Default: "relu")
|
|
330
|
+
left_context_length (int, optional): length of left context. (Default: 0)
|
|
331
|
+
max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
|
|
332
|
+
weight_init_gain (float or None, optional): scale factor to apply when initializing
|
|
333
|
+
attention module parameters. (Default: ``None``)
|
|
334
|
+
tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
|
|
335
|
+
negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
def __init__(
|
|
339
|
+
self,
|
|
340
|
+
input_dim: int,
|
|
341
|
+
num_heads: int,
|
|
342
|
+
ffn_dim: int,
|
|
343
|
+
segment_length: int,
|
|
344
|
+
dropout: float = 0.0,
|
|
345
|
+
activation: str = "relu",
|
|
346
|
+
left_context_length: int = 0,
|
|
347
|
+
max_memory_size: int = 0,
|
|
348
|
+
weight_init_gain: Optional[float] = None,
|
|
349
|
+
tanh_on_mem: bool = False,
|
|
350
|
+
negative_inf: float = -1e8,
|
|
351
|
+
):
|
|
352
|
+
super().__init__()
|
|
353
|
+
|
|
354
|
+
self.attention = _EmformerAttention(
|
|
355
|
+
input_dim=input_dim,
|
|
356
|
+
num_heads=num_heads,
|
|
357
|
+
dropout=dropout,
|
|
358
|
+
weight_init_gain=weight_init_gain,
|
|
359
|
+
tanh_on_mem=tanh_on_mem,
|
|
360
|
+
negative_inf=negative_inf,
|
|
361
|
+
)
|
|
362
|
+
self.dropout = torch.nn.Dropout(dropout)
|
|
363
|
+
self.memory_op = torch.nn.AvgPool1d(kernel_size=segment_length, stride=segment_length, ceil_mode=True)
|
|
364
|
+
|
|
365
|
+
activation_module = _get_activation_module(activation)
|
|
366
|
+
self.pos_ff = torch.nn.Sequential(
|
|
367
|
+
torch.nn.LayerNorm(input_dim),
|
|
368
|
+
torch.nn.Linear(input_dim, ffn_dim),
|
|
369
|
+
activation_module,
|
|
370
|
+
torch.nn.Dropout(dropout),
|
|
371
|
+
torch.nn.Linear(ffn_dim, input_dim),
|
|
372
|
+
torch.nn.Dropout(dropout),
|
|
373
|
+
)
|
|
374
|
+
self.layer_norm_input = torch.nn.LayerNorm(input_dim)
|
|
375
|
+
self.layer_norm_output = torch.nn.LayerNorm(input_dim)
|
|
376
|
+
|
|
377
|
+
self.left_context_length = left_context_length
|
|
378
|
+
self.segment_length = segment_length
|
|
379
|
+
self.max_memory_size = max_memory_size
|
|
380
|
+
self.input_dim = input_dim
|
|
381
|
+
|
|
382
|
+
self.use_mem = max_memory_size > 0
|
|
383
|
+
|
|
384
|
+
def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[torch.Tensor]:
|
|
385
|
+
empty_memory = torch.zeros(self.max_memory_size, batch_size, self.input_dim, device=device)
|
|
386
|
+
left_context_key = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
|
|
387
|
+
left_context_val = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
|
|
388
|
+
past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
|
|
389
|
+
return [empty_memory, left_context_key, left_context_val, past_length]
|
|
390
|
+
|
|
391
|
+
def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
392
|
+
past_length = state[3][0][0].item()
|
|
393
|
+
past_left_context_length = min(self.left_context_length, past_length)
|
|
394
|
+
past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
|
|
395
|
+
pre_mems = state[0][self.max_memory_size - past_mem_length :]
|
|
396
|
+
lc_key = state[1][self.left_context_length - past_left_context_length :]
|
|
397
|
+
lc_val = state[2][self.left_context_length - past_left_context_length :]
|
|
398
|
+
return pre_mems, lc_key, lc_val
|
|
399
|
+
|
|
400
|
+
def _pack_state(
|
|
401
|
+
self,
|
|
402
|
+
next_k: torch.Tensor,
|
|
403
|
+
next_v: torch.Tensor,
|
|
404
|
+
update_length: int,
|
|
405
|
+
mems: torch.Tensor,
|
|
406
|
+
state: List[torch.Tensor],
|
|
407
|
+
) -> List[torch.Tensor]:
|
|
408
|
+
new_k = torch.cat([state[1], next_k])
|
|
409
|
+
new_v = torch.cat([state[2], next_v])
|
|
410
|
+
state[0] = torch.cat([state[0], mems])[-self.max_memory_size :]
|
|
411
|
+
state[1] = new_k[new_k.shape[0] - self.left_context_length :]
|
|
412
|
+
state[2] = new_v[new_v.shape[0] - self.left_context_length :]
|
|
413
|
+
state[3] = state[3] + update_length
|
|
414
|
+
return state
|
|
415
|
+
|
|
416
|
+
def _process_attention_output(
|
|
417
|
+
self,
|
|
418
|
+
rc_output: torch.Tensor,
|
|
419
|
+
utterance: torch.Tensor,
|
|
420
|
+
right_context: torch.Tensor,
|
|
421
|
+
) -> torch.Tensor:
|
|
422
|
+
result = self.dropout(rc_output) + torch.cat([right_context, utterance])
|
|
423
|
+
result = self.pos_ff(result) + result
|
|
424
|
+
result = self.layer_norm_output(result)
|
|
425
|
+
return result
|
|
426
|
+
|
|
427
|
+
def _apply_pre_attention_layer_norm(
|
|
428
|
+
self, utterance: torch.Tensor, right_context: torch.Tensor
|
|
429
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
430
|
+
layer_norm_input = self.layer_norm_input(torch.cat([right_context, utterance]))
|
|
431
|
+
return (
|
|
432
|
+
layer_norm_input[right_context.size(0) :],
|
|
433
|
+
layer_norm_input[: right_context.size(0)],
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
def _apply_post_attention_ffn(
|
|
437
|
+
self, rc_output: torch.Tensor, utterance: torch.Tensor, right_context: torch.Tensor
|
|
438
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
439
|
+
rc_output = self._process_attention_output(rc_output, utterance, right_context)
|
|
440
|
+
return rc_output[right_context.size(0) :], rc_output[: right_context.size(0)]
|
|
441
|
+
|
|
442
|
+
def _apply_attention_forward(
|
|
443
|
+
self,
|
|
444
|
+
utterance: torch.Tensor,
|
|
445
|
+
lengths: torch.Tensor,
|
|
446
|
+
right_context: torch.Tensor,
|
|
447
|
+
mems: torch.Tensor,
|
|
448
|
+
attention_mask: Optional[torch.Tensor],
|
|
449
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
450
|
+
if attention_mask is None:
|
|
451
|
+
raise ValueError("attention_mask must be not None when for_inference is False")
|
|
452
|
+
|
|
453
|
+
if self.use_mem:
|
|
454
|
+
summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
|
455
|
+
else:
|
|
456
|
+
summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
|
|
457
|
+
rc_output, next_m = self.attention(
|
|
458
|
+
utterance=utterance,
|
|
459
|
+
lengths=lengths,
|
|
460
|
+
right_context=right_context,
|
|
461
|
+
summary=summary,
|
|
462
|
+
mems=mems,
|
|
463
|
+
attention_mask=attention_mask,
|
|
464
|
+
)
|
|
465
|
+
return rc_output, next_m
|
|
466
|
+
|
|
467
|
+
def _apply_attention_infer(
|
|
468
|
+
self,
|
|
469
|
+
utterance: torch.Tensor,
|
|
470
|
+
lengths: torch.Tensor,
|
|
471
|
+
right_context: torch.Tensor,
|
|
472
|
+
mems: torch.Tensor,
|
|
473
|
+
state: Optional[List[torch.Tensor]],
|
|
474
|
+
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
|
475
|
+
if state is None:
|
|
476
|
+
state = self._init_state(utterance.size(1), device=utterance.device)
|
|
477
|
+
pre_mems, lc_key, lc_val = self._unpack_state(state)
|
|
478
|
+
if self.use_mem:
|
|
479
|
+
summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
|
480
|
+
summary = summary[:1]
|
|
481
|
+
else:
|
|
482
|
+
summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
|
|
483
|
+
rc_output, next_m, next_k, next_v = self.attention.infer(
|
|
484
|
+
utterance=utterance,
|
|
485
|
+
lengths=lengths,
|
|
486
|
+
right_context=right_context,
|
|
487
|
+
summary=summary,
|
|
488
|
+
mems=pre_mems,
|
|
489
|
+
left_context_key=lc_key,
|
|
490
|
+
left_context_val=lc_val,
|
|
491
|
+
)
|
|
492
|
+
state = self._pack_state(next_k, next_v, utterance.size(0), mems, state)
|
|
493
|
+
return rc_output, next_m, state
|
|
494
|
+
|
|
495
|
+
def forward(
|
|
496
|
+
self,
|
|
497
|
+
utterance: torch.Tensor,
|
|
498
|
+
lengths: torch.Tensor,
|
|
499
|
+
right_context: torch.Tensor,
|
|
500
|
+
mems: torch.Tensor,
|
|
501
|
+
attention_mask: torch.Tensor,
|
|
502
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
503
|
+
r"""Forward pass for training.
|
|
504
|
+
|
|
505
|
+
B: batch size;
|
|
506
|
+
D: feature dimension of each frame;
|
|
507
|
+
T: number of utterance frames;
|
|
508
|
+
R: number of right context frames;
|
|
509
|
+
M: number of memory elements.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
|
|
513
|
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
514
|
+
number of valid frames for i-th batch element in ``utterance``.
|
|
515
|
+
right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
|
|
516
|
+
mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
|
|
517
|
+
attention_mask (torch.Tensor): attention mask for underlying attention module.
|
|
518
|
+
|
|
519
|
+
Returns:
|
|
520
|
+
(Tensor, Tensor, Tensor):
|
|
521
|
+
Tensor
|
|
522
|
+
encoded utterance frames, with shape `(T, B, D)`.
|
|
523
|
+
Tensor
|
|
524
|
+
updated right context frames, with shape `(R, B, D)`.
|
|
525
|
+
Tensor
|
|
526
|
+
updated memory elements, with shape `(M, B, D)`.
|
|
527
|
+
"""
|
|
528
|
+
(
|
|
529
|
+
layer_norm_utterance,
|
|
530
|
+
layer_norm_right_context,
|
|
531
|
+
) = self._apply_pre_attention_layer_norm(utterance, right_context)
|
|
532
|
+
rc_output, output_mems = self._apply_attention_forward(
|
|
533
|
+
layer_norm_utterance,
|
|
534
|
+
lengths,
|
|
535
|
+
layer_norm_right_context,
|
|
536
|
+
mems,
|
|
537
|
+
attention_mask,
|
|
538
|
+
)
|
|
539
|
+
output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
|
|
540
|
+
return output_utterance, output_right_context, output_mems
|
|
541
|
+
|
|
542
|
+
@torch.jit.export
|
|
543
|
+
def infer(
|
|
544
|
+
self,
|
|
545
|
+
utterance: torch.Tensor,
|
|
546
|
+
lengths: torch.Tensor,
|
|
547
|
+
right_context: torch.Tensor,
|
|
548
|
+
state: Optional[List[torch.Tensor]],
|
|
549
|
+
mems: torch.Tensor,
|
|
550
|
+
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
|
551
|
+
r"""Forward pass for inference.
|
|
552
|
+
|
|
553
|
+
B: batch size;
|
|
554
|
+
D: feature dimension of each frame;
|
|
555
|
+
T: number of utterance frames;
|
|
556
|
+
R: number of right context frames;
|
|
557
|
+
M: number of memory elements.
|
|
558
|
+
|
|
559
|
+
Args:
|
|
560
|
+
utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
|
|
561
|
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
562
|
+
number of valid frames for i-th batch element in ``utterance``.
|
|
563
|
+
right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
|
|
564
|
+
state (List[torch.Tensor] or None): list of tensors representing layer internal state
|
|
565
|
+
generated in preceding invocation of ``infer``.
|
|
566
|
+
mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
|
|
567
|
+
|
|
568
|
+
Returns:
|
|
569
|
+
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
|
570
|
+
Tensor
|
|
571
|
+
encoded utterance frames, with shape `(T, B, D)`.
|
|
572
|
+
Tensor
|
|
573
|
+
updated right context frames, with shape `(R, B, D)`.
|
|
574
|
+
List[Tensor]
|
|
575
|
+
list of tensors representing layer internal state
|
|
576
|
+
generated in current invocation of ``infer``.
|
|
577
|
+
Tensor
|
|
578
|
+
updated memory elements, with shape `(M, B, D)`.
|
|
579
|
+
"""
|
|
580
|
+
(
|
|
581
|
+
layer_norm_utterance,
|
|
582
|
+
layer_norm_right_context,
|
|
583
|
+
) = self._apply_pre_attention_layer_norm(utterance, right_context)
|
|
584
|
+
rc_output, output_mems, output_state = self._apply_attention_infer(
|
|
585
|
+
layer_norm_utterance, lengths, layer_norm_right_context, mems, state
|
|
586
|
+
)
|
|
587
|
+
output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
|
|
588
|
+
return output_utterance, output_right_context, output_state, output_mems
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
class _EmformerImpl(torch.nn.Module):
|
|
592
|
+
def __init__(
|
|
593
|
+
self,
|
|
594
|
+
emformer_layers: torch.nn.ModuleList,
|
|
595
|
+
segment_length: int,
|
|
596
|
+
left_context_length: int = 0,
|
|
597
|
+
right_context_length: int = 0,
|
|
598
|
+
max_memory_size: int = 0,
|
|
599
|
+
):
|
|
600
|
+
super().__init__()
|
|
601
|
+
|
|
602
|
+
self.use_mem = max_memory_size > 0
|
|
603
|
+
self.memory_op = torch.nn.AvgPool1d(
|
|
604
|
+
kernel_size=segment_length,
|
|
605
|
+
stride=segment_length,
|
|
606
|
+
ceil_mode=True,
|
|
607
|
+
)
|
|
608
|
+
self.emformer_layers = emformer_layers
|
|
609
|
+
self.left_context_length = left_context_length
|
|
610
|
+
self.right_context_length = right_context_length
|
|
611
|
+
self.segment_length = segment_length
|
|
612
|
+
self.max_memory_size = max_memory_size
|
|
613
|
+
|
|
614
|
+
def _gen_right_context(self, input: torch.Tensor) -> torch.Tensor:
|
|
615
|
+
T = input.shape[0]
|
|
616
|
+
num_segs = math.ceil((T - self.right_context_length) / self.segment_length)
|
|
617
|
+
right_context_blocks = []
|
|
618
|
+
for seg_idx in range(num_segs - 1):
|
|
619
|
+
start = (seg_idx + 1) * self.segment_length
|
|
620
|
+
end = start + self.right_context_length
|
|
621
|
+
right_context_blocks.append(input[start:end])
|
|
622
|
+
right_context_blocks.append(input[T - self.right_context_length :])
|
|
623
|
+
return torch.cat(right_context_blocks)
|
|
624
|
+
|
|
625
|
+
def _gen_attention_mask_col_widths(self, seg_idx: int, utterance_length: int) -> List[int]:
|
|
626
|
+
num_segs = math.ceil(utterance_length / self.segment_length)
|
|
627
|
+
rc = self.right_context_length
|
|
628
|
+
lc = self.left_context_length
|
|
629
|
+
rc_start = seg_idx * rc
|
|
630
|
+
rc_end = rc_start + rc
|
|
631
|
+
seg_start = max(seg_idx * self.segment_length - lc, 0)
|
|
632
|
+
seg_end = min((seg_idx + 1) * self.segment_length, utterance_length)
|
|
633
|
+
rc_length = self.right_context_length * num_segs
|
|
634
|
+
|
|
635
|
+
if self.use_mem:
|
|
636
|
+
m_start = max(seg_idx - self.max_memory_size, 0)
|
|
637
|
+
mem_length = num_segs - 1
|
|
638
|
+
col_widths = [
|
|
639
|
+
m_start, # before memory
|
|
640
|
+
seg_idx - m_start, # memory
|
|
641
|
+
mem_length - seg_idx, # after memory
|
|
642
|
+
rc_start, # before right context
|
|
643
|
+
rc, # right context
|
|
644
|
+
rc_length - rc_end, # after right context
|
|
645
|
+
seg_start, # before query segment
|
|
646
|
+
seg_end - seg_start, # query segment
|
|
647
|
+
utterance_length - seg_end, # after query segment
|
|
648
|
+
]
|
|
649
|
+
else:
|
|
650
|
+
col_widths = [
|
|
651
|
+
rc_start, # before right context
|
|
652
|
+
rc, # right context
|
|
653
|
+
rc_length - rc_end, # after right context
|
|
654
|
+
seg_start, # before query segment
|
|
655
|
+
seg_end - seg_start, # query segment
|
|
656
|
+
utterance_length - seg_end, # after query segment
|
|
657
|
+
]
|
|
658
|
+
|
|
659
|
+
return col_widths
|
|
660
|
+
|
|
661
|
+
def _gen_attention_mask(self, input: torch.Tensor) -> torch.Tensor:
|
|
662
|
+
utterance_length = input.size(0)
|
|
663
|
+
num_segs = math.ceil(utterance_length / self.segment_length)
|
|
664
|
+
|
|
665
|
+
rc_mask = []
|
|
666
|
+
query_mask = []
|
|
667
|
+
summary_mask = []
|
|
668
|
+
|
|
669
|
+
if self.use_mem:
|
|
670
|
+
num_cols = 9
|
|
671
|
+
# memory, right context, query segment
|
|
672
|
+
rc_q_cols_mask = [idx in [1, 4, 7] for idx in range(num_cols)]
|
|
673
|
+
# right context, query segment
|
|
674
|
+
s_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
|
|
675
|
+
masks_to_concat = [rc_mask, query_mask, summary_mask]
|
|
676
|
+
else:
|
|
677
|
+
num_cols = 6
|
|
678
|
+
# right context, query segment
|
|
679
|
+
rc_q_cols_mask = [idx in [1, 4] for idx in range(num_cols)]
|
|
680
|
+
s_cols_mask = None
|
|
681
|
+
masks_to_concat = [rc_mask, query_mask]
|
|
682
|
+
|
|
683
|
+
for seg_idx in range(num_segs):
|
|
684
|
+
col_widths = self._gen_attention_mask_col_widths(seg_idx, utterance_length)
|
|
685
|
+
|
|
686
|
+
rc_mask_block = _gen_attention_mask_block(
|
|
687
|
+
col_widths, rc_q_cols_mask, self.right_context_length, input.device
|
|
688
|
+
)
|
|
689
|
+
rc_mask.append(rc_mask_block)
|
|
690
|
+
|
|
691
|
+
query_mask_block = _gen_attention_mask_block(
|
|
692
|
+
col_widths,
|
|
693
|
+
rc_q_cols_mask,
|
|
694
|
+
min(
|
|
695
|
+
self.segment_length,
|
|
696
|
+
utterance_length - seg_idx * self.segment_length,
|
|
697
|
+
),
|
|
698
|
+
input.device,
|
|
699
|
+
)
|
|
700
|
+
query_mask.append(query_mask_block)
|
|
701
|
+
|
|
702
|
+
if s_cols_mask is not None:
|
|
703
|
+
summary_mask_block = _gen_attention_mask_block(col_widths, s_cols_mask, 1, input.device)
|
|
704
|
+
summary_mask.append(summary_mask_block)
|
|
705
|
+
|
|
706
|
+
attention_mask = (1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])).to(torch.bool)
|
|
707
|
+
return attention_mask
|
|
708
|
+
|
|
709
|
+
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
710
|
+
r"""Forward pass for training and non-streaming inference.
|
|
711
|
+
|
|
712
|
+
B: batch size;
|
|
713
|
+
T: max number of input frames in batch;
|
|
714
|
+
D: feature dimension of each frame.
|
|
715
|
+
|
|
716
|
+
Args:
|
|
717
|
+
input (torch.Tensor): utterance frames right-padded with right context frames, with
|
|
718
|
+
shape `(B, T + right_context_length, D)`.
|
|
719
|
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
720
|
+
number of valid utterance frames for i-th batch element in ``input``.
|
|
721
|
+
|
|
722
|
+
Returns:
|
|
723
|
+
(Tensor, Tensor):
|
|
724
|
+
Tensor
|
|
725
|
+
output frames, with shape `(B, T, D)`.
|
|
726
|
+
Tensor
|
|
727
|
+
output lengths, with shape `(B,)` and i-th element representing
|
|
728
|
+
number of valid frames for i-th batch element in output frames.
|
|
729
|
+
"""
|
|
730
|
+
input = input.permute(1, 0, 2)
|
|
731
|
+
right_context = self._gen_right_context(input)
|
|
732
|
+
utterance = input[: input.size(0) - self.right_context_length]
|
|
733
|
+
attention_mask = self._gen_attention_mask(utterance)
|
|
734
|
+
mems = (
|
|
735
|
+
self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1]
|
|
736
|
+
if self.use_mem
|
|
737
|
+
else torch.empty(0).to(dtype=input.dtype, device=input.device)
|
|
738
|
+
)
|
|
739
|
+
output = utterance
|
|
740
|
+
for layer in self.emformer_layers:
|
|
741
|
+
output, right_context, mems = layer(output, lengths, right_context, mems, attention_mask)
|
|
742
|
+
return output.permute(1, 0, 2), lengths
|
|
743
|
+
|
|
744
|
+
@torch.jit.export
|
|
745
|
+
def infer(
|
|
746
|
+
self,
|
|
747
|
+
input: torch.Tensor,
|
|
748
|
+
lengths: torch.Tensor,
|
|
749
|
+
states: Optional[List[List[torch.Tensor]]] = None,
|
|
750
|
+
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
|
751
|
+
r"""Forward pass for streaming inference.
|
|
752
|
+
|
|
753
|
+
B: batch size;
|
|
754
|
+
D: feature dimension of each frame.
|
|
755
|
+
|
|
756
|
+
Args:
|
|
757
|
+
input (torch.Tensor): utterance frames right-padded with right context frames, with
|
|
758
|
+
shape `(B, segment_length + right_context_length, D)`.
|
|
759
|
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
760
|
+
number of valid frames for i-th batch element in ``input``.
|
|
761
|
+
states (List[List[torch.Tensor]] or None, optional): list of lists of tensors
|
|
762
|
+
representing internal state generated in preceding invocation of ``infer``. (Default: ``None``)
|
|
763
|
+
|
|
764
|
+
Returns:
|
|
765
|
+
(Tensor, Tensor, List[List[Tensor]]):
|
|
766
|
+
Tensor
|
|
767
|
+
output frames, with shape `(B, segment_length, D)`.
|
|
768
|
+
Tensor
|
|
769
|
+
output lengths, with shape `(B,)` and i-th element representing
|
|
770
|
+
number of valid frames for i-th batch element in output frames.
|
|
771
|
+
List[List[Tensor]]
|
|
772
|
+
output states; list of lists of tensors representing internal state
|
|
773
|
+
generated in current invocation of ``infer``.
|
|
774
|
+
"""
|
|
775
|
+
if input.size(1) != self.segment_length + self.right_context_length:
|
|
776
|
+
raise ValueError(
|
|
777
|
+
"Per configured segment_length and right_context_length"
|
|
778
|
+
f", expected size of {self.segment_length + self.right_context_length} for dimension 1 of input"
|
|
779
|
+
f", but got {input.size(1)}."
|
|
780
|
+
)
|
|
781
|
+
input = input.permute(1, 0, 2)
|
|
782
|
+
right_context_start_idx = input.size(0) - self.right_context_length
|
|
783
|
+
right_context = input[right_context_start_idx:]
|
|
784
|
+
utterance = input[:right_context_start_idx]
|
|
785
|
+
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
|
786
|
+
mems = (
|
|
787
|
+
self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
|
788
|
+
if self.use_mem
|
|
789
|
+
else torch.empty(0).to(dtype=input.dtype, device=input.device)
|
|
790
|
+
)
|
|
791
|
+
output = utterance
|
|
792
|
+
output_states: List[List[torch.Tensor]] = []
|
|
793
|
+
for layer_idx, layer in enumerate(self.emformer_layers):
|
|
794
|
+
output, right_context, output_state, mems = layer.infer(
|
|
795
|
+
output,
|
|
796
|
+
output_lengths,
|
|
797
|
+
right_context,
|
|
798
|
+
None if states is None else states[layer_idx],
|
|
799
|
+
mems,
|
|
800
|
+
)
|
|
801
|
+
output_states.append(output_state)
|
|
802
|
+
|
|
803
|
+
return output.permute(1, 0, 2), output_lengths, output_states
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
class Emformer(_EmformerImpl):
|
|
807
|
+
r"""Emformer architecture introduced in
|
|
808
|
+
*Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition*
|
|
809
|
+
:cite:`shi2021emformer`.
|
|
810
|
+
|
|
811
|
+
See Also:
|
|
812
|
+
* :func:`~torchaudio.models.emformer_rnnt_model`,
|
|
813
|
+
:func:`~torchaudio.models.emformer_rnnt_base`: factory functions.
|
|
814
|
+
* :class:`torchaudio.pipelines.RNNTBundle`: ASR pipelines with pretrained model.
|
|
815
|
+
|
|
816
|
+
Args:
|
|
817
|
+
input_dim (int): input dimension.
|
|
818
|
+
num_heads (int): number of attention heads in each Emformer layer.
|
|
819
|
+
ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
|
|
820
|
+
num_layers (int): number of Emformer layers to instantiate.
|
|
821
|
+
segment_length (int): length of each input segment.
|
|
822
|
+
dropout (float, optional): dropout probability. (Default: 0.0)
|
|
823
|
+
activation (str, optional): activation function to use in each Emformer layer's
|
|
824
|
+
feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
|
|
825
|
+
left_context_length (int, optional): length of left context. (Default: 0)
|
|
826
|
+
right_context_length (int, optional): length of right context. (Default: 0)
|
|
827
|
+
max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
|
|
828
|
+
weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling
|
|
829
|
+
strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
|
|
830
|
+
tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
|
|
831
|
+
negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
|
|
832
|
+
|
|
833
|
+
Examples:
|
|
834
|
+
>>> emformer = Emformer(512, 8, 2048, 20, 4, right_context_length=1)
|
|
835
|
+
>>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim
|
|
836
|
+
>>> lengths = torch.randint(1, 200, (128,)) # batch
|
|
837
|
+
>>> output, lengths = emformer(input, lengths)
|
|
838
|
+
>>> input = torch.rand(128, 5, 512)
|
|
839
|
+
>>> lengths = torch.ones(128) * 5
|
|
840
|
+
>>> output, lengths, states = emformer.infer(input, lengths, None)
|
|
841
|
+
"""
|
|
842
|
+
|
|
843
|
+
def __init__(
|
|
844
|
+
self,
|
|
845
|
+
input_dim: int,
|
|
846
|
+
num_heads: int,
|
|
847
|
+
ffn_dim: int,
|
|
848
|
+
num_layers: int,
|
|
849
|
+
segment_length: int,
|
|
850
|
+
dropout: float = 0.0,
|
|
851
|
+
activation: str = "relu",
|
|
852
|
+
left_context_length: int = 0,
|
|
853
|
+
right_context_length: int = 0,
|
|
854
|
+
max_memory_size: int = 0,
|
|
855
|
+
weight_init_scale_strategy: Optional[str] = "depthwise",
|
|
856
|
+
tanh_on_mem: bool = False,
|
|
857
|
+
negative_inf: float = -1e8,
|
|
858
|
+
):
|
|
859
|
+
weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers)
|
|
860
|
+
emformer_layers = torch.nn.ModuleList(
|
|
861
|
+
[
|
|
862
|
+
_EmformerLayer(
|
|
863
|
+
input_dim,
|
|
864
|
+
num_heads,
|
|
865
|
+
ffn_dim,
|
|
866
|
+
segment_length,
|
|
867
|
+
dropout=dropout,
|
|
868
|
+
activation=activation,
|
|
869
|
+
left_context_length=left_context_length,
|
|
870
|
+
max_memory_size=max_memory_size,
|
|
871
|
+
weight_init_gain=weight_init_gains[layer_idx],
|
|
872
|
+
tanh_on_mem=tanh_on_mem,
|
|
873
|
+
negative_inf=negative_inf,
|
|
874
|
+
)
|
|
875
|
+
for layer_idx in range(num_layers)
|
|
876
|
+
]
|
|
877
|
+
)
|
|
878
|
+
super().__init__(
|
|
879
|
+
emformer_layers,
|
|
880
|
+
segment_length,
|
|
881
|
+
left_context_length=left_context_length,
|
|
882
|
+
right_context_length=right_context_length,
|
|
883
|
+
max_memory_size=max_memory_size,
|
|
884
|
+
)
|