torchaudio 2.8.0__cp312-cp312-win_amd64.whl → 2.9.0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchaudio might be problematic. Click here for more details.
- torchaudio/__init__.py +179 -39
- torchaudio/_extension/__init__.py +1 -14
- torchaudio/_extension/utils.py +0 -47
- torchaudio/_internal/module_utils.py +12 -3
- torchaudio/_torchcodec.py +73 -85
- torchaudio/datasets/cmuarctic.py +1 -1
- torchaudio/datasets/utils.py +1 -1
- torchaudio/functional/__init__.py +0 -2
- torchaudio/functional/_alignment.py +1 -1
- torchaudio/functional/filtering.py +70 -55
- torchaudio/functional/functional.py +26 -60
- torchaudio/lib/_torchaudio.pyd +0 -0
- torchaudio/lib/libtorchaudio.pyd +0 -0
- torchaudio/models/decoder/__init__.py +14 -2
- torchaudio/models/decoder/_ctc_decoder.py +6 -6
- torchaudio/models/decoder/_cuda_ctc_decoder.py +1 -1
- torchaudio/models/squim/objective.py +2 -2
- torchaudio/pipelines/_source_separation_pipeline.py +1 -1
- torchaudio/pipelines/_squim_pipeline.py +2 -2
- torchaudio/pipelines/_tts/utils.py +1 -1
- torchaudio/pipelines/rnnt_pipeline.py +4 -4
- torchaudio/transforms/__init__.py +1 -0
- torchaudio/transforms/_transforms.py +2 -2
- torchaudio/utils/__init__.py +2 -9
- torchaudio/utils/download.py +1 -3
- torchaudio/version.py +2 -2
- {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/METADATA +8 -11
- torchaudio-2.9.0.dist-info/RECORD +85 -0
- {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/top_level.txt +0 -1
- torchaudio/_backend/__init__.py +0 -61
- torchaudio/_backend/backend.py +0 -53
- torchaudio/_backend/common.py +0 -52
- torchaudio/_backend/ffmpeg.py +0 -334
- torchaudio/_backend/soundfile.py +0 -54
- torchaudio/_backend/soundfile_backend.py +0 -457
- torchaudio/_backend/sox.py +0 -91
- torchaudio/_backend/utils.py +0 -350
- torchaudio/backend/__init__.py +0 -8
- torchaudio/backend/_no_backend.py +0 -25
- torchaudio/backend/_sox_io_backend.py +0 -294
- torchaudio/backend/common.py +0 -13
- torchaudio/backend/no_backend.py +0 -14
- torchaudio/backend/soundfile_backend.py +0 -14
- torchaudio/backend/sox_io_backend.py +0 -14
- torchaudio/io/__init__.py +0 -20
- torchaudio/io/_effector.py +0 -347
- torchaudio/io/_playback.py +0 -72
- torchaudio/kaldi_io.py +0 -150
- torchaudio/prototype/__init__.py +0 -0
- torchaudio/prototype/datasets/__init__.py +0 -4
- torchaudio/prototype/datasets/musan.py +0 -68
- torchaudio/prototype/functional/__init__.py +0 -26
- torchaudio/prototype/functional/_dsp.py +0 -441
- torchaudio/prototype/functional/_rir.py +0 -382
- torchaudio/prototype/functional/functional.py +0 -193
- torchaudio/prototype/models/__init__.py +0 -39
- torchaudio/prototype/models/_conformer_wav2vec2.py +0 -801
- torchaudio/prototype/models/_emformer_hubert.py +0 -337
- torchaudio/prototype/models/conv_emformer.py +0 -529
- torchaudio/prototype/models/hifi_gan.py +0 -342
- torchaudio/prototype/models/rnnt.py +0 -717
- torchaudio/prototype/models/rnnt_decoder.py +0 -402
- torchaudio/prototype/pipelines/__init__.py +0 -21
- torchaudio/prototype/pipelines/_vggish/__init__.py +0 -7
- torchaudio/prototype/pipelines/_vggish/_vggish_impl.py +0 -236
- torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py +0 -83
- torchaudio/prototype/pipelines/hifigan_pipeline.py +0 -233
- torchaudio/prototype/pipelines/rnnt_pipeline.py +0 -58
- torchaudio/prototype/transforms/__init__.py +0 -9
- torchaudio/prototype/transforms/_transforms.py +0 -461
- torchaudio/sox_effects/__init__.py +0 -10
- torchaudio/sox_effects/sox_effects.py +0 -275
- torchaudio/utils/ffmpeg_utils.py +0 -11
- torchaudio/utils/sox_utils.py +0 -118
- torchaudio-2.8.0.dist-info/RECORD +0 -145
- torio/__init__.py +0 -8
- torio/_extension/__init__.py +0 -13
- torio/_extension/utils.py +0 -147
- torio/io/__init__.py +0 -9
- torio/io/_streaming_media_decoder.py +0 -977
- torio/io/_streaming_media_encoder.py +0 -502
- torio/lib/__init__.py +0 -0
- torio/lib/_torio_ffmpeg4.pyd +0 -0
- torio/lib/_torio_ffmpeg5.pyd +0 -0
- torio/lib/_torio_ffmpeg6.pyd +0 -0
- torio/lib/libtorio_ffmpeg4.pyd +0 -0
- torio/lib/libtorio_ffmpeg5.pyd +0 -0
- torio/lib/libtorio_ffmpeg6.pyd +0 -0
- torio/utils/__init__.py +0 -4
- torio/utils/ffmpeg_utils.py +0 -275
- {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/WHEEL +0 -0
- {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,529 +0,0 @@
|
|
|
1
|
-
import math
|
|
2
|
-
from typing import List, Optional, Tuple
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
from torchaudio.models.emformer import _EmformerAttention, _EmformerImpl, _get_weight_init_gains
|
|
6
|
-
from torchaudio._internal.module_utils import dropping_class_support, dropping_support
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def _get_activation_module(activation: str) -> torch.nn.Module:
|
|
11
|
-
if activation == "relu":
|
|
12
|
-
return torch.nn.ReLU()
|
|
13
|
-
elif activation == "gelu":
|
|
14
|
-
return torch.nn.GELU()
|
|
15
|
-
elif activation == "silu":
|
|
16
|
-
return torch.nn.SiLU()
|
|
17
|
-
else:
|
|
18
|
-
raise ValueError(f"Unsupported activation {activation}")
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class _ResidualContainer(torch.nn.Module):
|
|
22
|
-
def __init__(self, module: torch.nn.Module, output_weight: int):
|
|
23
|
-
super().__init__()
|
|
24
|
-
self.module = module
|
|
25
|
-
self.output_weight = output_weight
|
|
26
|
-
|
|
27
|
-
def forward(self, input: torch.Tensor):
|
|
28
|
-
output = self.module(input)
|
|
29
|
-
return output * self.output_weight + input
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class _ConvolutionModule(torch.nn.Module):
|
|
33
|
-
def __init__(
|
|
34
|
-
self,
|
|
35
|
-
input_dim: int,
|
|
36
|
-
segment_length: int,
|
|
37
|
-
right_context_length: int,
|
|
38
|
-
kernel_size: int,
|
|
39
|
-
activation: str = "silu",
|
|
40
|
-
dropout: float = 0.0,
|
|
41
|
-
):
|
|
42
|
-
super().__init__()
|
|
43
|
-
self.input_dim = input_dim
|
|
44
|
-
self.segment_length = segment_length
|
|
45
|
-
self.right_context_length = right_context_length
|
|
46
|
-
self.state_size = kernel_size - 1
|
|
47
|
-
|
|
48
|
-
self.pre_conv = torch.nn.Sequential(
|
|
49
|
-
torch.nn.LayerNorm(input_dim), torch.nn.Linear(input_dim, 2 * input_dim, bias=True), torch.nn.GLU()
|
|
50
|
-
)
|
|
51
|
-
self.conv = torch.nn.Conv1d(
|
|
52
|
-
in_channels=input_dim,
|
|
53
|
-
out_channels=input_dim,
|
|
54
|
-
kernel_size=kernel_size,
|
|
55
|
-
stride=1,
|
|
56
|
-
padding=0,
|
|
57
|
-
groups=input_dim,
|
|
58
|
-
)
|
|
59
|
-
self.post_conv = torch.nn.Sequential(
|
|
60
|
-
torch.nn.LayerNorm(input_dim),
|
|
61
|
-
_get_activation_module(activation),
|
|
62
|
-
torch.nn.Linear(input_dim, input_dim, bias=True),
|
|
63
|
-
torch.nn.Dropout(p=dropout),
|
|
64
|
-
)
|
|
65
|
-
|
|
66
|
-
def _split_right_context(self, utterance: torch.Tensor, right_context: torch.Tensor) -> torch.Tensor:
|
|
67
|
-
T, B, D = right_context.size()
|
|
68
|
-
if T % self.right_context_length != 0:
|
|
69
|
-
raise ValueError("Tensor length should be divisible by its right context length")
|
|
70
|
-
num_segments = T // self.right_context_length
|
|
71
|
-
# (num_segments, right context length, B, D)
|
|
72
|
-
right_context_segments = right_context.reshape(num_segments, self.right_context_length, B, D)
|
|
73
|
-
right_context_segments = right_context_segments.permute(0, 2, 1, 3).reshape(
|
|
74
|
-
num_segments * B, self.right_context_length, D
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
pad_segments = [] # [(kernel_size - 1, B, D), ...]
|
|
78
|
-
for seg_idx in range(num_segments):
|
|
79
|
-
end_idx = min(self.state_size + (seg_idx + 1) * self.segment_length, utterance.size(0))
|
|
80
|
-
start_idx = end_idx - self.state_size
|
|
81
|
-
pad_segments.append(utterance[start_idx:end_idx, :, :])
|
|
82
|
-
|
|
83
|
-
pad_segments = torch.cat(pad_segments, dim=1).permute(1, 0, 2) # (num_segments * B, kernel_size - 1, D)
|
|
84
|
-
return torch.cat([pad_segments, right_context_segments], dim=1).permute(0, 2, 1)
|
|
85
|
-
|
|
86
|
-
def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor:
|
|
87
|
-
# (num_segments * B, D, right_context_length)
|
|
88
|
-
right_context = right_context.reshape(-1, B, self.input_dim, self.right_context_length)
|
|
89
|
-
right_context = right_context.permute(0, 3, 1, 2)
|
|
90
|
-
return right_context.reshape(-1, B, self.input_dim) # (right_context_length * num_segments, B, D)
|
|
91
|
-
|
|
92
|
-
def forward(
|
|
93
|
-
self, utterance: torch.Tensor, right_context: torch.Tensor, state: Optional[torch.Tensor]
|
|
94
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
95
|
-
input = torch.cat((right_context, utterance)) # input: (T, B, D)
|
|
96
|
-
x = self.pre_conv(input)
|
|
97
|
-
x_right_context, x_utterance = x[: right_context.size(0), :, :], x[right_context.size(0) :, :, :]
|
|
98
|
-
x_utterance = x_utterance.permute(1, 2, 0) # (B, D, T_utterance)
|
|
99
|
-
|
|
100
|
-
if state is None:
|
|
101
|
-
state = torch.zeros(
|
|
102
|
-
input.size(1),
|
|
103
|
-
input.size(2),
|
|
104
|
-
self.state_size,
|
|
105
|
-
device=input.device,
|
|
106
|
-
dtype=input.dtype,
|
|
107
|
-
) # (B, D, T)
|
|
108
|
-
state_x_utterance = torch.cat([state, x_utterance], dim=2)
|
|
109
|
-
|
|
110
|
-
conv_utterance = self.conv(state_x_utterance) # (B, D, T_utterance)
|
|
111
|
-
conv_utterance = conv_utterance.permute(2, 0, 1)
|
|
112
|
-
|
|
113
|
-
if self.right_context_length > 0:
|
|
114
|
-
# (B * num_segments, D, right_context_length + kernel_size - 1)
|
|
115
|
-
right_context_block = self._split_right_context(state_x_utterance.permute(2, 0, 1), x_right_context)
|
|
116
|
-
conv_right_context_block = self.conv(right_context_block) # (B * num_segments, D, right_context_length)
|
|
117
|
-
# (T_right_context, B, D)
|
|
118
|
-
conv_right_context = self._merge_right_context(conv_right_context_block, input.size(1))
|
|
119
|
-
y = torch.cat([conv_right_context, conv_utterance], dim=0)
|
|
120
|
-
else:
|
|
121
|
-
y = conv_utterance
|
|
122
|
-
|
|
123
|
-
output = self.post_conv(y) + input
|
|
124
|
-
new_state = state_x_utterance[:, :, -self.state_size :]
|
|
125
|
-
return output[right_context.size(0) :], output[: right_context.size(0)], new_state
|
|
126
|
-
|
|
127
|
-
def infer(
|
|
128
|
-
self, utterance: torch.Tensor, right_context: torch.Tensor, state: Optional[torch.Tensor]
|
|
129
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
130
|
-
input = torch.cat((utterance, right_context))
|
|
131
|
-
x = self.pre_conv(input) # (T, B, D)
|
|
132
|
-
x = x.permute(1, 2, 0) # (B, D, T)
|
|
133
|
-
|
|
134
|
-
if state is None:
|
|
135
|
-
state = torch.zeros(
|
|
136
|
-
input.size(1),
|
|
137
|
-
input.size(2),
|
|
138
|
-
self.state_size,
|
|
139
|
-
device=input.device,
|
|
140
|
-
dtype=input.dtype,
|
|
141
|
-
) # (B, D, T)
|
|
142
|
-
state_x = torch.cat([state, x], dim=2)
|
|
143
|
-
conv_out = self.conv(state_x)
|
|
144
|
-
conv_out = conv_out.permute(2, 0, 1) # T, B, D
|
|
145
|
-
output = self.post_conv(conv_out) + input
|
|
146
|
-
new_state = state_x[:, :, -self.state_size - right_context.size(0) : -right_context.size(0)]
|
|
147
|
-
return output[: utterance.size(0)], output[utterance.size(0) :], new_state
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
class _ConvEmformerLayer(torch.nn.Module):
|
|
151
|
-
r"""Convolution-augmented Emformer layer that constitutes ConvEmformer.
|
|
152
|
-
|
|
153
|
-
Args:
|
|
154
|
-
input_dim (int): input dimension.
|
|
155
|
-
num_heads (int): number of attention heads.
|
|
156
|
-
ffn_dim: (int): hidden layer dimension of feedforward network.
|
|
157
|
-
segment_length (int): length of each input segment.
|
|
158
|
-
kernel_size (int): size of kernel to use in convolution module.
|
|
159
|
-
dropout (float, optional): dropout probability. (Default: 0.0)
|
|
160
|
-
ffn_activation (str, optional): activation function to use in feedforward network.
|
|
161
|
-
Must be one of ("relu", "gelu", "silu"). (Default: "relu")
|
|
162
|
-
left_context_length (int, optional): length of left context. (Default: 0)
|
|
163
|
-
right_context_length (int, optional): length of right context. (Default: 0)
|
|
164
|
-
max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
|
|
165
|
-
weight_init_gain (float or None, optional): scale factor to apply when initializing
|
|
166
|
-
attention module parameters. (Default: ``None``)
|
|
167
|
-
tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
|
|
168
|
-
negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
|
|
169
|
-
conv_activation (str, optional): activation function to use in convolution module.
|
|
170
|
-
Must be one of ("relu", "gelu", "silu"). (Default: "silu")
|
|
171
|
-
"""
|
|
172
|
-
|
|
173
|
-
def __init__(
|
|
174
|
-
self,
|
|
175
|
-
input_dim: int,
|
|
176
|
-
num_heads: int,
|
|
177
|
-
ffn_dim: int,
|
|
178
|
-
segment_length: int,
|
|
179
|
-
kernel_size: int,
|
|
180
|
-
dropout: float = 0.0,
|
|
181
|
-
ffn_activation: str = "relu",
|
|
182
|
-
left_context_length: int = 0,
|
|
183
|
-
right_context_length: int = 0,
|
|
184
|
-
max_memory_size: int = 0,
|
|
185
|
-
weight_init_gain: Optional[float] = None,
|
|
186
|
-
tanh_on_mem: bool = False,
|
|
187
|
-
negative_inf: float = -1e8,
|
|
188
|
-
conv_activation: str = "silu",
|
|
189
|
-
):
|
|
190
|
-
super().__init__()
|
|
191
|
-
# TODO: implement talking heads attention.
|
|
192
|
-
self.attention = _EmformerAttention(
|
|
193
|
-
input_dim=input_dim,
|
|
194
|
-
num_heads=num_heads,
|
|
195
|
-
dropout=dropout,
|
|
196
|
-
weight_init_gain=weight_init_gain,
|
|
197
|
-
tanh_on_mem=tanh_on_mem,
|
|
198
|
-
negative_inf=negative_inf,
|
|
199
|
-
)
|
|
200
|
-
self.dropout = torch.nn.Dropout(dropout)
|
|
201
|
-
self.memory_op = torch.nn.AvgPool1d(kernel_size=segment_length, stride=segment_length, ceil_mode=True)
|
|
202
|
-
|
|
203
|
-
activation_module = _get_activation_module(ffn_activation)
|
|
204
|
-
self.ffn0 = _ResidualContainer(
|
|
205
|
-
torch.nn.Sequential(
|
|
206
|
-
torch.nn.LayerNorm(input_dim),
|
|
207
|
-
torch.nn.Linear(input_dim, ffn_dim),
|
|
208
|
-
activation_module,
|
|
209
|
-
torch.nn.Dropout(dropout),
|
|
210
|
-
torch.nn.Linear(ffn_dim, input_dim),
|
|
211
|
-
torch.nn.Dropout(dropout),
|
|
212
|
-
),
|
|
213
|
-
0.5,
|
|
214
|
-
)
|
|
215
|
-
self.ffn1 = _ResidualContainer(
|
|
216
|
-
torch.nn.Sequential(
|
|
217
|
-
torch.nn.LayerNorm(input_dim),
|
|
218
|
-
torch.nn.Linear(input_dim, ffn_dim),
|
|
219
|
-
activation_module,
|
|
220
|
-
torch.nn.Dropout(dropout),
|
|
221
|
-
torch.nn.Linear(ffn_dim, input_dim),
|
|
222
|
-
torch.nn.Dropout(dropout),
|
|
223
|
-
),
|
|
224
|
-
0.5,
|
|
225
|
-
)
|
|
226
|
-
self.layer_norm_input = torch.nn.LayerNorm(input_dim)
|
|
227
|
-
self.layer_norm_output = torch.nn.LayerNorm(input_dim)
|
|
228
|
-
|
|
229
|
-
self.conv = _ConvolutionModule(
|
|
230
|
-
input_dim=input_dim,
|
|
231
|
-
kernel_size=kernel_size,
|
|
232
|
-
activation=conv_activation,
|
|
233
|
-
dropout=dropout,
|
|
234
|
-
segment_length=segment_length,
|
|
235
|
-
right_context_length=right_context_length,
|
|
236
|
-
)
|
|
237
|
-
|
|
238
|
-
self.left_context_length = left_context_length
|
|
239
|
-
self.segment_length = segment_length
|
|
240
|
-
self.max_memory_size = max_memory_size
|
|
241
|
-
self.input_dim = input_dim
|
|
242
|
-
self.kernel_size = kernel_size
|
|
243
|
-
self.use_mem = max_memory_size > 0
|
|
244
|
-
|
|
245
|
-
def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[torch.Tensor]:
|
|
246
|
-
empty_memory = torch.zeros(self.max_memory_size, batch_size, self.input_dim, device=device)
|
|
247
|
-
left_context_key = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
|
|
248
|
-
left_context_val = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
|
|
249
|
-
past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
|
|
250
|
-
conv_cache = torch.zeros(
|
|
251
|
-
batch_size,
|
|
252
|
-
self.input_dim,
|
|
253
|
-
self.kernel_size - 1,
|
|
254
|
-
device=device,
|
|
255
|
-
)
|
|
256
|
-
return [empty_memory, left_context_key, left_context_val, past_length, conv_cache]
|
|
257
|
-
|
|
258
|
-
def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
259
|
-
past_length = state[3][0][0].item()
|
|
260
|
-
past_left_context_length = min(self.left_context_length, past_length)
|
|
261
|
-
past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
|
|
262
|
-
pre_mems = state[0][self.max_memory_size - past_mem_length :]
|
|
263
|
-
lc_key = state[1][self.left_context_length - past_left_context_length :]
|
|
264
|
-
lc_val = state[2][self.left_context_length - past_left_context_length :]
|
|
265
|
-
conv_cache = state[4]
|
|
266
|
-
return pre_mems, lc_key, lc_val, conv_cache
|
|
267
|
-
|
|
268
|
-
def _pack_state(
|
|
269
|
-
self,
|
|
270
|
-
next_k: torch.Tensor,
|
|
271
|
-
next_v: torch.Tensor,
|
|
272
|
-
update_length: int,
|
|
273
|
-
mems: torch.Tensor,
|
|
274
|
-
conv_cache: torch.Tensor,
|
|
275
|
-
state: List[torch.Tensor],
|
|
276
|
-
) -> List[torch.Tensor]:
|
|
277
|
-
new_k = torch.cat([state[1], next_k])
|
|
278
|
-
new_v = torch.cat([state[2], next_v])
|
|
279
|
-
state[0] = torch.cat([state[0], mems])[-self.max_memory_size :]
|
|
280
|
-
state[1] = new_k[new_k.shape[0] - self.left_context_length :]
|
|
281
|
-
state[2] = new_v[new_v.shape[0] - self.left_context_length :]
|
|
282
|
-
state[3] = state[3] + update_length
|
|
283
|
-
state[4] = conv_cache
|
|
284
|
-
return state
|
|
285
|
-
|
|
286
|
-
def _apply_pre_attention(
|
|
287
|
-
self, utterance: torch.Tensor, right_context: torch.Tensor, summary: torch.Tensor
|
|
288
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
289
|
-
x = torch.cat([right_context, utterance, summary])
|
|
290
|
-
ffn0_out = self.ffn0(x)
|
|
291
|
-
layer_norm_input_out = self.layer_norm_input(ffn0_out)
|
|
292
|
-
layer_norm_input_right_context, layer_norm_input_utterance, layer_norm_input_summary = (
|
|
293
|
-
layer_norm_input_out[: right_context.size(0)],
|
|
294
|
-
layer_norm_input_out[right_context.size(0) : right_context.size(0) + utterance.size(0)],
|
|
295
|
-
layer_norm_input_out[right_context.size(0) + utterance.size(0) :],
|
|
296
|
-
)
|
|
297
|
-
return ffn0_out, layer_norm_input_right_context, layer_norm_input_utterance, layer_norm_input_summary
|
|
298
|
-
|
|
299
|
-
def _apply_post_attention(
|
|
300
|
-
self,
|
|
301
|
-
rc_output: torch.Tensor,
|
|
302
|
-
ffn0_out: torch.Tensor,
|
|
303
|
-
conv_cache: Optional[torch.Tensor],
|
|
304
|
-
rc_length: int,
|
|
305
|
-
utterance_length: int,
|
|
306
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
307
|
-
result = self.dropout(rc_output) + ffn0_out[: rc_length + utterance_length]
|
|
308
|
-
conv_utterance, conv_right_context, conv_cache = self.conv(result[rc_length:], result[:rc_length], conv_cache)
|
|
309
|
-
result = torch.cat([conv_right_context, conv_utterance])
|
|
310
|
-
result = self.ffn1(result)
|
|
311
|
-
result = self.layer_norm_output(result)
|
|
312
|
-
output_utterance, output_right_context = result[rc_length:], result[:rc_length]
|
|
313
|
-
return output_utterance, output_right_context, conv_cache
|
|
314
|
-
|
|
315
|
-
def forward(
|
|
316
|
-
self,
|
|
317
|
-
utterance: torch.Tensor,
|
|
318
|
-
lengths: torch.Tensor,
|
|
319
|
-
right_context: torch.Tensor,
|
|
320
|
-
mems: torch.Tensor,
|
|
321
|
-
attention_mask: torch.Tensor,
|
|
322
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
323
|
-
r"""Forward pass for training.
|
|
324
|
-
|
|
325
|
-
B: batch size;
|
|
326
|
-
D: feature dimension of each frame;
|
|
327
|
-
T: number of utterance frames;
|
|
328
|
-
R: number of right context frames;
|
|
329
|
-
M: number of memory elements.
|
|
330
|
-
|
|
331
|
-
Args:
|
|
332
|
-
utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
|
|
333
|
-
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
334
|
-
number of valid frames for i-th batch element in ``utterance``.
|
|
335
|
-
right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
|
|
336
|
-
mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
|
|
337
|
-
attention_mask (torch.Tensor): attention mask for underlying attention module.
|
|
338
|
-
|
|
339
|
-
Returns:
|
|
340
|
-
(Tensor, Tensor, Tensor):
|
|
341
|
-
Tensor
|
|
342
|
-
encoded utterance frames, with shape `(T, B, D)`.
|
|
343
|
-
Tensor
|
|
344
|
-
updated right context frames, with shape `(R, B, D)`.
|
|
345
|
-
Tensor
|
|
346
|
-
updated memory elements, with shape `(M, B, D)`.
|
|
347
|
-
"""
|
|
348
|
-
if self.use_mem:
|
|
349
|
-
summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
|
350
|
-
else:
|
|
351
|
-
summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
|
|
352
|
-
|
|
353
|
-
(
|
|
354
|
-
ffn0_out,
|
|
355
|
-
layer_norm_input_right_context,
|
|
356
|
-
layer_norm_input_utterance,
|
|
357
|
-
layer_norm_input_summary,
|
|
358
|
-
) = self._apply_pre_attention(utterance, right_context, summary)
|
|
359
|
-
|
|
360
|
-
rc_output, output_mems = self.attention(
|
|
361
|
-
utterance=layer_norm_input_utterance,
|
|
362
|
-
lengths=lengths,
|
|
363
|
-
right_context=layer_norm_input_right_context,
|
|
364
|
-
summary=layer_norm_input_summary,
|
|
365
|
-
mems=mems,
|
|
366
|
-
attention_mask=attention_mask,
|
|
367
|
-
)
|
|
368
|
-
|
|
369
|
-
output_utterance, output_right_context, _ = self._apply_post_attention(
|
|
370
|
-
rc_output, ffn0_out, None, right_context.size(0), utterance.size(0)
|
|
371
|
-
)
|
|
372
|
-
|
|
373
|
-
return output_utterance, output_right_context, output_mems
|
|
374
|
-
|
|
375
|
-
@torch.jit.export
|
|
376
|
-
def infer(
|
|
377
|
-
self,
|
|
378
|
-
utterance: torch.Tensor,
|
|
379
|
-
lengths: torch.Tensor,
|
|
380
|
-
right_context: torch.Tensor,
|
|
381
|
-
state: Optional[List[torch.Tensor]],
|
|
382
|
-
mems: torch.Tensor,
|
|
383
|
-
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
|
384
|
-
r"""Forward pass for inference.
|
|
385
|
-
|
|
386
|
-
B: batch size;
|
|
387
|
-
D: feature dimension of each frame;
|
|
388
|
-
T: number of utterance frames;
|
|
389
|
-
R: number of right context frames;
|
|
390
|
-
M: number of memory elements.
|
|
391
|
-
|
|
392
|
-
Args:
|
|
393
|
-
utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
|
|
394
|
-
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
395
|
-
number of valid frames for i-th batch element in ``utterance``.
|
|
396
|
-
right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
|
|
397
|
-
state (List[torch.Tensor] or None): list of tensors representing layer internal state
|
|
398
|
-
generated in preceding invocation of ``infer``.
|
|
399
|
-
mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
|
|
400
|
-
|
|
401
|
-
Returns:
|
|
402
|
-
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
|
403
|
-
Tensor
|
|
404
|
-
encoded utterance frames, with shape `(T, B, D)`.
|
|
405
|
-
Tensor
|
|
406
|
-
updated right context frames, with shape `(R, B, D)`.
|
|
407
|
-
List[Tensor]
|
|
408
|
-
list of tensors representing layer internal state
|
|
409
|
-
generated in current invocation of ``infer``.
|
|
410
|
-
Tensor
|
|
411
|
-
updated memory elements, with shape `(M, B, D)`.
|
|
412
|
-
"""
|
|
413
|
-
if self.use_mem:
|
|
414
|
-
summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:1]
|
|
415
|
-
else:
|
|
416
|
-
summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
|
|
417
|
-
|
|
418
|
-
(
|
|
419
|
-
ffn0_out,
|
|
420
|
-
layer_norm_input_right_context,
|
|
421
|
-
layer_norm_input_utterance,
|
|
422
|
-
layer_norm_input_summary,
|
|
423
|
-
) = self._apply_pre_attention(utterance, right_context, summary)
|
|
424
|
-
|
|
425
|
-
if state is None:
|
|
426
|
-
state = self._init_state(layer_norm_input_utterance.size(1), device=layer_norm_input_utterance.device)
|
|
427
|
-
pre_mems, lc_key, lc_val, conv_cache = self._unpack_state(state)
|
|
428
|
-
|
|
429
|
-
rc_output, next_m, next_k, next_v = self.attention.infer(
|
|
430
|
-
utterance=layer_norm_input_utterance,
|
|
431
|
-
lengths=lengths,
|
|
432
|
-
right_context=layer_norm_input_right_context,
|
|
433
|
-
summary=layer_norm_input_summary,
|
|
434
|
-
mems=pre_mems,
|
|
435
|
-
left_context_key=lc_key,
|
|
436
|
-
left_context_val=lc_val,
|
|
437
|
-
)
|
|
438
|
-
|
|
439
|
-
output_utterance, output_right_context, conv_cache = self._apply_post_attention(
|
|
440
|
-
rc_output, ffn0_out, conv_cache, right_context.size(0), utterance.size(0)
|
|
441
|
-
)
|
|
442
|
-
output_state = self._pack_state(next_k, next_v, utterance.size(0), mems, conv_cache, state)
|
|
443
|
-
return output_utterance, output_right_context, output_state, next_m
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
@dropping_class_support
|
|
447
|
-
class ConvEmformer(_EmformerImpl):
|
|
448
|
-
r"""Implements the convolution-augmented streaming transformer architecture introduced in
|
|
449
|
-
*Streaming Transformer Transducer based Speech Recognition Using Non-Causal Convolution*
|
|
450
|
-
:cite:`9747706`.
|
|
451
|
-
|
|
452
|
-
Args:
|
|
453
|
-
input_dim (int): input dimension.
|
|
454
|
-
num_heads (int): number of attention heads in each ConvEmformer layer.
|
|
455
|
-
ffn_dim (int): hidden layer dimension of each ConvEmformer layer's feedforward network.
|
|
456
|
-
num_layers (int): number of ConvEmformer layers to instantiate.
|
|
457
|
-
segment_length (int): length of each input segment.
|
|
458
|
-
kernel_size (int): size of kernel to use in convolution modules.
|
|
459
|
-
dropout (float, optional): dropout probability. (Default: 0.0)
|
|
460
|
-
ffn_activation (str, optional): activation function to use in feedforward networks.
|
|
461
|
-
Must be one of ("relu", "gelu", "silu"). (Default: "relu")
|
|
462
|
-
left_context_length (int, optional): length of left context. (Default: 0)
|
|
463
|
-
right_context_length (int, optional): length of right context. (Default: 0)
|
|
464
|
-
max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
|
|
465
|
-
weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling
|
|
466
|
-
strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
|
|
467
|
-
tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
|
|
468
|
-
negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
|
|
469
|
-
conv_activation (str, optional): activation function to use in convolution modules.
|
|
470
|
-
Must be one of ("relu", "gelu", "silu"). (Default: "silu")
|
|
471
|
-
|
|
472
|
-
Examples:
|
|
473
|
-
>>> conv_emformer = ConvEmformer(80, 4, 1024, 12, 16, 8, right_context_length=4)
|
|
474
|
-
>>> input = torch.rand(10, 200, 80)
|
|
475
|
-
>>> lengths = torch.randint(1, 200, (10,))
|
|
476
|
-
>>> output, lengths = conv_emformer(input, lengths)
|
|
477
|
-
>>> input = torch.rand(4, 20, 80)
|
|
478
|
-
>>> lengths = torch.ones(4) * 20
|
|
479
|
-
>>> output, lengths, states = conv_emformer.infer(input, lengths, None)
|
|
480
|
-
"""
|
|
481
|
-
|
|
482
|
-
@dropping_support
|
|
483
|
-
def __init__(
|
|
484
|
-
self,
|
|
485
|
-
input_dim: int,
|
|
486
|
-
num_heads: int,
|
|
487
|
-
ffn_dim: int,
|
|
488
|
-
num_layers: int,
|
|
489
|
-
segment_length: int,
|
|
490
|
-
kernel_size: int,
|
|
491
|
-
dropout: float = 0.0,
|
|
492
|
-
ffn_activation: str = "relu",
|
|
493
|
-
left_context_length: int = 0,
|
|
494
|
-
right_context_length: int = 0,
|
|
495
|
-
max_memory_size: int = 0,
|
|
496
|
-
weight_init_scale_strategy: Optional[str] = "depthwise",
|
|
497
|
-
tanh_on_mem: bool = False,
|
|
498
|
-
negative_inf: float = -1e8,
|
|
499
|
-
conv_activation: str = "silu",
|
|
500
|
-
):
|
|
501
|
-
weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers)
|
|
502
|
-
emformer_layers = torch.nn.ModuleList(
|
|
503
|
-
[
|
|
504
|
-
_ConvEmformerLayer(
|
|
505
|
-
input_dim,
|
|
506
|
-
num_heads,
|
|
507
|
-
ffn_dim,
|
|
508
|
-
segment_length,
|
|
509
|
-
kernel_size,
|
|
510
|
-
dropout=dropout,
|
|
511
|
-
ffn_activation=ffn_activation,
|
|
512
|
-
left_context_length=left_context_length,
|
|
513
|
-
right_context_length=right_context_length,
|
|
514
|
-
max_memory_size=max_memory_size,
|
|
515
|
-
weight_init_gain=weight_init_gains[layer_idx],
|
|
516
|
-
tanh_on_mem=tanh_on_mem,
|
|
517
|
-
negative_inf=negative_inf,
|
|
518
|
-
conv_activation=conv_activation,
|
|
519
|
-
)
|
|
520
|
-
for layer_idx in range(num_layers)
|
|
521
|
-
]
|
|
522
|
-
)
|
|
523
|
-
super().__init__(
|
|
524
|
-
emformer_layers,
|
|
525
|
-
segment_length,
|
|
526
|
-
left_context_length=left_context_length,
|
|
527
|
-
right_context_length=right_context_length,
|
|
528
|
-
max_memory_size=max_memory_size,
|
|
529
|
-
)
|