minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
|
@@ -0,0 +1,998 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def subsequent_chunk_mask(
|
|
10
|
+
size: int,
|
|
11
|
+
chunk_size: int,
|
|
12
|
+
num_left_chunks: int = -1,
|
|
13
|
+
device: torch.device = torch.device("cpu"),
|
|
14
|
+
) -> torch.Tensor:
|
|
15
|
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
|
16
|
+
this is for streaming encoder
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
size (int): size of mask
|
|
20
|
+
chunk_size (int): size of chunk
|
|
21
|
+
num_left_chunks (int): number of left chunks
|
|
22
|
+
<0: use full chunk
|
|
23
|
+
>=0: use num_left_chunks
|
|
24
|
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
torch.Tensor: mask
|
|
28
|
+
|
|
29
|
+
Examples:
|
|
30
|
+
>>> subsequent_chunk_mask(4, 2)
|
|
31
|
+
[[1, 1, 0, 0],
|
|
32
|
+
[1, 1, 0, 0],
|
|
33
|
+
[1, 1, 1, 1],
|
|
34
|
+
[1, 1, 1, 1]]
|
|
35
|
+
"""
|
|
36
|
+
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
|
37
|
+
pos_idx = torch.arange(size, device=device)
|
|
38
|
+
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
|
39
|
+
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
|
40
|
+
return ret
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def add_optional_chunk_mask(xs: torch.Tensor,
|
|
44
|
+
masks: torch.Tensor,
|
|
45
|
+
use_dynamic_chunk: bool,
|
|
46
|
+
use_dynamic_left_chunk: bool,
|
|
47
|
+
decoding_chunk_size: int,
|
|
48
|
+
static_chunk_size: int,
|
|
49
|
+
num_decoding_left_chunks: int,
|
|
50
|
+
enable_full_context: bool = True):
|
|
51
|
+
""" Apply optional mask for encoder.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
|
55
|
+
mask (torch.Tensor): mask for xs, (B, 1, L)
|
|
56
|
+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
|
57
|
+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
|
58
|
+
training.
|
|
59
|
+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
|
60
|
+
0: default for training, use random dynamic chunk.
|
|
61
|
+
<0: for decoding, use full chunk.
|
|
62
|
+
>0: for decoding, use fixed chunk size as set.
|
|
63
|
+
static_chunk_size (int): chunk size for static chunk training/decoding
|
|
64
|
+
if it's greater than 0, if use_dynamic_chunk is true,
|
|
65
|
+
this parameter will be ignored
|
|
66
|
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
|
67
|
+
the chunk size is decoding_chunk_size.
|
|
68
|
+
>=0: use num_decoding_left_chunks
|
|
69
|
+
<0: use all left chunks
|
|
70
|
+
enable_full_context (bool):
|
|
71
|
+
True: chunk size is either [1, 25] or full context(max_len)
|
|
72
|
+
False: chunk size ~ U[1, 25]
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
torch.Tensor: chunk mask of the input xs.
|
|
76
|
+
"""
|
|
77
|
+
# Whether to use chunk mask or not
|
|
78
|
+
if use_dynamic_chunk:
|
|
79
|
+
max_len = xs.size(1)
|
|
80
|
+
if decoding_chunk_size < 0:
|
|
81
|
+
chunk_size = max_len
|
|
82
|
+
num_left_chunks = -1
|
|
83
|
+
elif decoding_chunk_size > 0:
|
|
84
|
+
chunk_size = decoding_chunk_size
|
|
85
|
+
num_left_chunks = num_decoding_left_chunks
|
|
86
|
+
else:
|
|
87
|
+
# chunk size is either [1, 25] or full context(max_len).
|
|
88
|
+
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
|
89
|
+
# delay, the maximum frame is 100 / 4 = 25.
|
|
90
|
+
chunk_size = torch.randint(1, max_len, (1, )).item()
|
|
91
|
+
num_left_chunks = -1
|
|
92
|
+
if chunk_size > max_len // 2 and enable_full_context:
|
|
93
|
+
chunk_size = max_len
|
|
94
|
+
else:
|
|
95
|
+
chunk_size = chunk_size % 25 + 1
|
|
96
|
+
if use_dynamic_left_chunk:
|
|
97
|
+
max_left_chunks = (max_len - 1) // chunk_size
|
|
98
|
+
num_left_chunks = torch.randint(0, max_left_chunks,
|
|
99
|
+
(1, )).item()
|
|
100
|
+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
|
101
|
+
num_left_chunks,
|
|
102
|
+
xs.device) # (L, L)
|
|
103
|
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
104
|
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
|
105
|
+
elif static_chunk_size > 0:
|
|
106
|
+
num_left_chunks = num_decoding_left_chunks
|
|
107
|
+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
|
108
|
+
num_left_chunks,
|
|
109
|
+
xs.device) # (L, L)
|
|
110
|
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
111
|
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
|
112
|
+
else:
|
|
113
|
+
chunk_masks = masks
|
|
114
|
+
assert chunk_masks.dtype == torch.bool
|
|
115
|
+
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
|
116
|
+
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
|
|
117
|
+
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
|
|
118
|
+
return chunk_masks
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|
122
|
+
"""Make mask tensor containing indices of padded part.
|
|
123
|
+
|
|
124
|
+
See description of make_non_pad_mask.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
lengths (torch.Tensor): Batch of lengths (B,).
|
|
128
|
+
Returns:
|
|
129
|
+
torch.Tensor: Mask tensor containing indices of padded part.
|
|
130
|
+
|
|
131
|
+
Examples:
|
|
132
|
+
>>> lengths = [5, 3, 2]
|
|
133
|
+
>>> make_pad_mask(lengths)
|
|
134
|
+
masks = [[0, 0, 0, 0 ,0],
|
|
135
|
+
[0, 0, 0, 1, 1],
|
|
136
|
+
[0, 0, 1, 1, 1]]
|
|
137
|
+
"""
|
|
138
|
+
batch_size = lengths.size(0)
|
|
139
|
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
|
140
|
+
seq_range = torch.arange(0,
|
|
141
|
+
max_len,
|
|
142
|
+
dtype=torch.int64,
|
|
143
|
+
device=lengths.device)
|
|
144
|
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
145
|
+
seq_length_expand = lengths.unsqueeze(-1)
|
|
146
|
+
mask = seq_range_expand >= seq_length_expand
|
|
147
|
+
return mask
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class EspnetRelPositionalEncoding(torch.nn.Module):
|
|
151
|
+
"""Relative positional encoding module (new implementation).
|
|
152
|
+
|
|
153
|
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
|
154
|
+
|
|
155
|
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
d_model (int): Embedding dimension.
|
|
159
|
+
max_len (int): Maximum input length.
|
|
160
|
+
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
def __init__(self, d_model: int, max_len: int = 5000):
|
|
164
|
+
super(EspnetRelPositionalEncoding, self).__init__()
|
|
165
|
+
self.d_model = d_model
|
|
166
|
+
self.xscale = math.sqrt(self.d_model)
|
|
167
|
+
self.pe = None
|
|
168
|
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
|
169
|
+
|
|
170
|
+
def extend_pe(self, x: torch.Tensor):
|
|
171
|
+
"""Reset the positional encodings."""
|
|
172
|
+
if self.pe is not None:
|
|
173
|
+
# self.pe contains both positive and negative parts
|
|
174
|
+
# the length of self.pe is 2 * input_len - 1
|
|
175
|
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
|
176
|
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
|
177
|
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
|
178
|
+
return
|
|
179
|
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
|
180
|
+
# position of key vector. We use position relative positions when keys
|
|
181
|
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
|
182
|
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
|
183
|
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
|
184
|
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
|
185
|
+
div_term = torch.exp(
|
|
186
|
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
|
187
|
+
* -(math.log(10000.0) / self.d_model)
|
|
188
|
+
)
|
|
189
|
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
|
190
|
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
|
191
|
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
|
192
|
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
|
193
|
+
|
|
194
|
+
# Reserve the order of positive indices and concat both positive and
|
|
195
|
+
# negative indices. This is used to support the shifting trick
|
|
196
|
+
# as in https://arxiv.org/abs/1901.02860
|
|
197
|
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
|
198
|
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
|
199
|
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
|
200
|
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
|
201
|
+
|
|
202
|
+
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
|
203
|
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
|
204
|
+
"""Add positional encoding.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
|
211
|
+
|
|
212
|
+
"""
|
|
213
|
+
self.extend_pe(x)
|
|
214
|
+
x = x * self.xscale
|
|
215
|
+
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
|
|
216
|
+
return x, pos_emb
|
|
217
|
+
|
|
218
|
+
def position_encoding(self,
|
|
219
|
+
offset: Union[int, torch.Tensor],
|
|
220
|
+
size: int) -> torch.Tensor:
|
|
221
|
+
""" For getting encoding in a streaming fashion
|
|
222
|
+
|
|
223
|
+
Attention!!!!!
|
|
224
|
+
we apply dropout only once at the whole utterance level in a none
|
|
225
|
+
streaming way, but will call this function several times with
|
|
226
|
+
increasing input size in a streaming scenario, so the dropout will
|
|
227
|
+
be applied several times.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
offset (int or torch.tensor): start offset
|
|
231
|
+
size (int): required size of position encoding
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
torch.Tensor: Corresponding encoding
|
|
235
|
+
"""
|
|
236
|
+
# How to subscript a Union type:
|
|
237
|
+
# https://github.com/pytorch/pytorch/issues/69434
|
|
238
|
+
if isinstance(offset, int):
|
|
239
|
+
pos_emb = self.pe[
|
|
240
|
+
:,
|
|
241
|
+
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
|
|
242
|
+
]
|
|
243
|
+
elif isinstance(offset, torch.Tensor):
|
|
244
|
+
pos_emb = self.pe[
|
|
245
|
+
:,
|
|
246
|
+
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
|
|
247
|
+
]
|
|
248
|
+
return pos_emb
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class LinearNoSubsampling(torch.nn.Module):
|
|
252
|
+
"""Linear transform the input without subsampling
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
idim (int): Input dimension.
|
|
256
|
+
odim (int): Output dimension.
|
|
257
|
+
pos_enc_class (torch.nn.Module): Positional encoding class.
|
|
258
|
+
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
def __init__(self, idim: int, odim: int,
|
|
262
|
+
pos_enc_class: torch.nn.Module):
|
|
263
|
+
super().__init__()
|
|
264
|
+
self.out = torch.nn.Sequential(
|
|
265
|
+
torch.nn.Linear(idim, odim),
|
|
266
|
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
|
267
|
+
)
|
|
268
|
+
self.pos_enc = pos_enc_class
|
|
269
|
+
self.right_context = 0
|
|
270
|
+
self.subsampling_rate = 1
|
|
271
|
+
|
|
272
|
+
def forward(
|
|
273
|
+
self,
|
|
274
|
+
x: torch.Tensor,
|
|
275
|
+
x_mask: torch.Tensor,
|
|
276
|
+
offset: Union[int, torch.Tensor] = 0
|
|
277
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
278
|
+
"""Input x.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
|
282
|
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
|
286
|
+
where time' = time .
|
|
287
|
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
|
288
|
+
where time' = time .
|
|
289
|
+
|
|
290
|
+
"""
|
|
291
|
+
x = self.out(x)
|
|
292
|
+
x, pos_emb = self.pos_enc(x, offset)
|
|
293
|
+
return x, pos_emb, x_mask
|
|
294
|
+
|
|
295
|
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
|
296
|
+
size: int) -> torch.Tensor:
|
|
297
|
+
return self.pos_enc.position_encoding(offset, size)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class Upsample1D(nn.Module):
|
|
301
|
+
"""A 1D upsampling layer with an optional convolution.
|
|
302
|
+
|
|
303
|
+
Parameters:
|
|
304
|
+
channels (`int`):
|
|
305
|
+
number of channels in the inputs and outputs.
|
|
306
|
+
use_conv (`bool`, default `False`):
|
|
307
|
+
option to use a convolution.
|
|
308
|
+
use_conv_transpose (`bool`, default `False`):
|
|
309
|
+
option to use a convolution transpose.
|
|
310
|
+
out_channels (`int`, optional):
|
|
311
|
+
number of output channels. Defaults to `channels`.
|
|
312
|
+
"""
|
|
313
|
+
|
|
314
|
+
def __init__(self, channels: int, out_channels: int, stride: int = 2):
|
|
315
|
+
super().__init__()
|
|
316
|
+
self.channels = channels
|
|
317
|
+
self.out_channels = out_channels
|
|
318
|
+
self.stride = stride
|
|
319
|
+
# In this mode, first repeat interpolate, than conv with stride=1
|
|
320
|
+
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
|
|
321
|
+
|
|
322
|
+
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
323
|
+
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
|
|
324
|
+
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
|
325
|
+
outputs = self.conv(outputs)
|
|
326
|
+
return outputs, input_lengths * self.stride
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class PreLookaheadLayer(nn.Module):
|
|
330
|
+
def __init__(self, channels: int, pre_lookahead_len: int = 1):
|
|
331
|
+
super().__init__()
|
|
332
|
+
self.channels = channels
|
|
333
|
+
self.pre_lookahead_len = pre_lookahead_len
|
|
334
|
+
self.conv1 = nn.Conv1d(
|
|
335
|
+
channels, channels,
|
|
336
|
+
kernel_size=pre_lookahead_len + 1,
|
|
337
|
+
stride=1, padding=0,
|
|
338
|
+
)
|
|
339
|
+
self.conv2 = nn.Conv1d(
|
|
340
|
+
channels, channels,
|
|
341
|
+
kernel_size=3, stride=1, padding=0,
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
|
|
345
|
+
"""
|
|
346
|
+
inputs: (batch_size, seq_len, channels)
|
|
347
|
+
"""
|
|
348
|
+
outputs = inputs.transpose(1, 2).contiguous()
|
|
349
|
+
context = context.transpose(1, 2).contiguous()
|
|
350
|
+
# look ahead
|
|
351
|
+
if context.size(2) == 0:
|
|
352
|
+
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
|
353
|
+
else:
|
|
354
|
+
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
|
355
|
+
assert context.size(2) == self.pre_lookahead_len
|
|
356
|
+
outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
|
|
357
|
+
outputs = F.leaky_relu(self.conv1(outputs))
|
|
358
|
+
# outputs
|
|
359
|
+
outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
|
|
360
|
+
outputs = self.conv2(outputs)
|
|
361
|
+
outputs = outputs.transpose(1, 2).contiguous()
|
|
362
|
+
|
|
363
|
+
# residual connection
|
|
364
|
+
outputs = outputs + inputs
|
|
365
|
+
return outputs
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class MultiHeadedAttention(nn.Module):
|
|
369
|
+
"""Multi-Head Attention layer.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
n_head (int): The number of heads.
|
|
373
|
+
n_feat (int): The number of features.
|
|
374
|
+
dropout_rate (float): Dropout rate.
|
|
375
|
+
key_bias (bool): Whether to use bias in key linear layer.
|
|
376
|
+
|
|
377
|
+
"""
|
|
378
|
+
|
|
379
|
+
def __init__(self,
|
|
380
|
+
n_head: int,
|
|
381
|
+
n_feat: int,
|
|
382
|
+
dropout_rate: float,
|
|
383
|
+
key_bias: bool = True):
|
|
384
|
+
super().__init__()
|
|
385
|
+
assert n_feat % n_head == 0
|
|
386
|
+
# We assume d_v always equals d_k
|
|
387
|
+
self.d_k = n_feat // n_head
|
|
388
|
+
self.h = n_head
|
|
389
|
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
|
390
|
+
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
|
|
391
|
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
|
392
|
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
|
393
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
|
394
|
+
|
|
395
|
+
def forward_qkv(
|
|
396
|
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
|
397
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
398
|
+
"""Transform query, key and value.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
402
|
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
403
|
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
torch.Tensor: Transformed query tensor, size
|
|
407
|
+
(#batch, n_head, time1, d_k).
|
|
408
|
+
torch.Tensor: Transformed key tensor, size
|
|
409
|
+
(#batch, n_head, time2, d_k).
|
|
410
|
+
torch.Tensor: Transformed value tensor, size
|
|
411
|
+
(#batch, n_head, time2, d_k).
|
|
412
|
+
|
|
413
|
+
"""
|
|
414
|
+
n_batch = query.size(0)
|
|
415
|
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
|
416
|
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
|
417
|
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
|
418
|
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
|
419
|
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
|
420
|
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
|
421
|
+
|
|
422
|
+
return q, k, v
|
|
423
|
+
|
|
424
|
+
def forward_attention(
|
|
425
|
+
self,
|
|
426
|
+
value: torch.Tensor,
|
|
427
|
+
scores: torch.Tensor,
|
|
428
|
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
|
429
|
+
) -> torch.Tensor:
|
|
430
|
+
"""Compute attention context vector.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
value (torch.Tensor): Transformed value, size
|
|
434
|
+
(#batch, n_head, time2, d_k).
|
|
435
|
+
scores (torch.Tensor): Attention score, size
|
|
436
|
+
(#batch, n_head, time1, time2).
|
|
437
|
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
|
438
|
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
|
442
|
+
weighted by the attention score (#batch, time1, time2).
|
|
443
|
+
|
|
444
|
+
"""
|
|
445
|
+
n_batch = value.size(0)
|
|
446
|
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
|
|
447
|
+
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
|
448
|
+
# 1st chunk to ease the onnx export.]
|
|
449
|
+
# 2. pytorch training
|
|
450
|
+
if mask.size(2) > 0: # time2 > 0
|
|
451
|
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
|
452
|
+
# For last chunk, time2 might be larger than scores.size(-1)
|
|
453
|
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
|
454
|
+
scores = scores.masked_fill(mask, -float('inf'))
|
|
455
|
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
|
456
|
+
mask, 0.0) # (batch, head, time1, time2)
|
|
457
|
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
|
|
458
|
+
# 1. onnx(16/-1, -1/-1, 16/0)
|
|
459
|
+
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
|
460
|
+
else:
|
|
461
|
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
|
462
|
+
|
|
463
|
+
p_attn = self.dropout(attn)
|
|
464
|
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
|
465
|
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
|
466
|
+
self.h * self.d_k)
|
|
467
|
+
) # (batch, time1, d_model)
|
|
468
|
+
|
|
469
|
+
return self.linear_out(x) # (batch, time1, d_model)
|
|
470
|
+
|
|
471
|
+
def forward(
|
|
472
|
+
self,
|
|
473
|
+
query: torch.Tensor,
|
|
474
|
+
key: torch.Tensor,
|
|
475
|
+
value: torch.Tensor,
|
|
476
|
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
477
|
+
pos_emb: torch.Tensor = torch.empty(0),
|
|
478
|
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
|
479
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
480
|
+
"""Compute scaled dot product attention.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
484
|
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
485
|
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
486
|
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
|
487
|
+
(#batch, time1, time2).
|
|
488
|
+
1.When applying cross attention between decoder and encoder,
|
|
489
|
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
|
490
|
+
2.When applying self attention of encoder,
|
|
491
|
+
the mask is in (#batch, T, T) shape.
|
|
492
|
+
3.When applying self attention of decoder,
|
|
493
|
+
the mask is in (#batch, L, L) shape.
|
|
494
|
+
4.If the different position in decoder see different block
|
|
495
|
+
of the encoder, such as Mocha, the passed in mask could be
|
|
496
|
+
in (#batch, L, T) shape. But there is no such case in current
|
|
497
|
+
CosyVoice.
|
|
498
|
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
|
499
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
500
|
+
and `head * d_k == size`
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
|
505
|
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
|
506
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
507
|
+
and `head * d_k == size`
|
|
508
|
+
|
|
509
|
+
"""
|
|
510
|
+
q, k, v = self.forward_qkv(query, key, value)
|
|
511
|
+
|
|
512
|
+
# NOTE(xcsong):
|
|
513
|
+
# when export onnx model, for 1st chunk, we feed
|
|
514
|
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
|
515
|
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
|
516
|
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
|
517
|
+
# and we will always do splitting and
|
|
518
|
+
# concatnation(this will simplify onnx export). Note that
|
|
519
|
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
|
520
|
+
# when export jit model, for 1st chunk, we always feed
|
|
521
|
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
|
522
|
+
# >>> a = torch.ones((1, 2, 0, 4))
|
|
523
|
+
# >>> b = torch.ones((1, 2, 3, 4))
|
|
524
|
+
# >>> c = torch.cat((a, b), dim=2)
|
|
525
|
+
# >>> torch.equal(b, c) # True
|
|
526
|
+
# >>> d = torch.split(a, 2, dim=-1)
|
|
527
|
+
# >>> torch.equal(d[0], d[1]) # True
|
|
528
|
+
if cache.size(0) > 0:
|
|
529
|
+
key_cache, value_cache = torch.split(cache,
|
|
530
|
+
cache.size(-1) // 2,
|
|
531
|
+
dim=-1)
|
|
532
|
+
k = torch.cat([key_cache, k], dim=2)
|
|
533
|
+
v = torch.cat([value_cache, v], dim=2)
|
|
534
|
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
|
535
|
+
# non-trivial to calculate `next_cache_start` here.
|
|
536
|
+
new_cache = torch.cat((k, v), dim=-1)
|
|
537
|
+
|
|
538
|
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
|
539
|
+
return self.forward_attention(v, scores, mask), new_cache
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
543
|
+
"""Multi-Head Attention layer with relative position encoding.
|
|
544
|
+
Paper: https://arxiv.org/abs/1901.02860
|
|
545
|
+
Args:
|
|
546
|
+
n_head (int): The number of heads.
|
|
547
|
+
n_feat (int): The number of features.
|
|
548
|
+
dropout_rate (float): Dropout rate.
|
|
549
|
+
key_bias (bool): Whether to use bias in key linear layer.
|
|
550
|
+
"""
|
|
551
|
+
|
|
552
|
+
def __init__(self,
|
|
553
|
+
n_head: int,
|
|
554
|
+
n_feat: int,
|
|
555
|
+
dropout_rate: float,
|
|
556
|
+
key_bias: bool = True):
|
|
557
|
+
super().__init__(n_head, n_feat, dropout_rate, key_bias)
|
|
558
|
+
# linear transformation for positional encoding
|
|
559
|
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
|
560
|
+
# these two learnable bias are used in matrix c and matrix d
|
|
561
|
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
|
562
|
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
|
563
|
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
|
564
|
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
|
565
|
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
|
566
|
+
|
|
567
|
+
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
|
568
|
+
"""Compute relative positional encoding.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
|
572
|
+
time1 means the length of query vector.
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
torch.Tensor: Output tensor.
|
|
576
|
+
|
|
577
|
+
"""
|
|
578
|
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
|
579
|
+
device=x.device,
|
|
580
|
+
dtype=x.dtype)
|
|
581
|
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
|
582
|
+
|
|
583
|
+
x_padded = x_padded.view(x.size()[0],
|
|
584
|
+
x.size()[1],
|
|
585
|
+
x.size(3) + 1, x.size(2))
|
|
586
|
+
x = x_padded[:, :, 1:].view_as(x)[
|
|
587
|
+
:, :, :, : x.size(-1) // 2 + 1
|
|
588
|
+
] # only keep the positions from 0 to time2
|
|
589
|
+
return x
|
|
590
|
+
|
|
591
|
+
def forward(
|
|
592
|
+
self,
|
|
593
|
+
query: torch.Tensor,
|
|
594
|
+
key: torch.Tensor,
|
|
595
|
+
value: torch.Tensor,
|
|
596
|
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
597
|
+
pos_emb: torch.Tensor = torch.empty(0),
|
|
598
|
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
|
599
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
600
|
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
|
601
|
+
Args:
|
|
602
|
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
603
|
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
604
|
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
605
|
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
|
606
|
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
|
607
|
+
pos_emb (torch.Tensor): Positional embedding tensor
|
|
608
|
+
(#batch, time2, size).
|
|
609
|
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
|
610
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
611
|
+
and `head * d_k == size`
|
|
612
|
+
Returns:
|
|
613
|
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
|
614
|
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
|
615
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
616
|
+
and `head * d_k == size`
|
|
617
|
+
"""
|
|
618
|
+
q, k, v = self.forward_qkv(query, key, value)
|
|
619
|
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
|
620
|
+
|
|
621
|
+
# NOTE(xcsong):
|
|
622
|
+
# when export onnx model, for 1st chunk, we feed
|
|
623
|
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
|
624
|
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
|
625
|
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
|
626
|
+
# and we will always do splitting and
|
|
627
|
+
# concatnation(this will simplify onnx export). Note that
|
|
628
|
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
|
629
|
+
# when export jit model, for 1st chunk, we always feed
|
|
630
|
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
|
631
|
+
# >>> a = torch.ones((1, 2, 0, 4))
|
|
632
|
+
# >>> b = torch.ones((1, 2, 3, 4))
|
|
633
|
+
# >>> c = torch.cat((a, b), dim=2)
|
|
634
|
+
# >>> torch.equal(b, c) # True
|
|
635
|
+
# >>> d = torch.split(a, 2, dim=-1)
|
|
636
|
+
# >>> torch.equal(d[0], d[1]) # True
|
|
637
|
+
if cache.size(0) > 0:
|
|
638
|
+
key_cache, value_cache = torch.split(cache,
|
|
639
|
+
cache.size(-1) // 2,
|
|
640
|
+
dim=-1)
|
|
641
|
+
k = torch.cat([key_cache, k], dim=2)
|
|
642
|
+
v = torch.cat([value_cache, v], dim=2)
|
|
643
|
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
|
644
|
+
# non-trivial to calculate `next_cache_start` here.
|
|
645
|
+
new_cache = torch.cat((k, v), dim=-1)
|
|
646
|
+
|
|
647
|
+
n_batch_pos = pos_emb.size(0)
|
|
648
|
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
|
649
|
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
|
650
|
+
|
|
651
|
+
# (batch, head, time1, d_k)
|
|
652
|
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
|
653
|
+
# (batch, head, time1, d_k)
|
|
654
|
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
|
655
|
+
|
|
656
|
+
# compute attention score
|
|
657
|
+
# first compute matrix a and matrix c
|
|
658
|
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
|
659
|
+
# (batch, head, time1, time2)
|
|
660
|
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
|
661
|
+
|
|
662
|
+
# compute matrix b and matrix d
|
|
663
|
+
# (batch, head, time1, time2)
|
|
664
|
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
|
665
|
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
|
666
|
+
if matrix_ac.shape != matrix_bd.shape:
|
|
667
|
+
matrix_bd = self.rel_shift(matrix_bd)
|
|
668
|
+
|
|
669
|
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
|
670
|
+
self.d_k) # (batch, head, time1, time2)
|
|
671
|
+
|
|
672
|
+
return self.forward_attention(v, scores, mask), new_cache
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
class PositionwiseFeedForward(torch.nn.Module):
|
|
676
|
+
"""Positionwise feed forward layer.
|
|
677
|
+
|
|
678
|
+
FeedForward are appied on each position of the sequence.
|
|
679
|
+
The output dim is same with the input dim.
|
|
680
|
+
|
|
681
|
+
Args:
|
|
682
|
+
idim (int): Input dimenstion.
|
|
683
|
+
hidden_units (int): The number of hidden units.
|
|
684
|
+
dropout_rate (float): Dropout rate.
|
|
685
|
+
activation (torch.nn.Module): Activation function
|
|
686
|
+
"""
|
|
687
|
+
|
|
688
|
+
def __init__(
|
|
689
|
+
self,
|
|
690
|
+
idim: int,
|
|
691
|
+
hidden_units: int,
|
|
692
|
+
dropout_rate: float,
|
|
693
|
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
|
694
|
+
):
|
|
695
|
+
super(PositionwiseFeedForward, self).__init__()
|
|
696
|
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
|
697
|
+
self.activation = activation
|
|
698
|
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
|
699
|
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
|
700
|
+
|
|
701
|
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
|
702
|
+
"""Forward function.
|
|
703
|
+
|
|
704
|
+
Args:
|
|
705
|
+
xs: input tensor (B, L, D)
|
|
706
|
+
Returns:
|
|
707
|
+
output tensor, (B, L, D)
|
|
708
|
+
"""
|
|
709
|
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
class ConformerEncoderLayer(nn.Module):
|
|
713
|
+
"""Encoder layer module.
|
|
714
|
+
Args:
|
|
715
|
+
size (int): Input dimension.
|
|
716
|
+
self_attn (torch.nn.Module): Self-attention module instance.
|
|
717
|
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
|
718
|
+
instance can be used as the argument.
|
|
719
|
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
|
720
|
+
`PositionwiseFeedForward` instance can be used as the argument.
|
|
721
|
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
|
722
|
+
instance.
|
|
723
|
+
`PositionwiseFeedForward` instance can be used as the argument.
|
|
724
|
+
conv_module (torch.nn.Module): Convolution module instance.
|
|
725
|
+
`ConvlutionModule` instance can be used as the argument.
|
|
726
|
+
dropout_rate (float): Dropout rate.
|
|
727
|
+
normalize_before (bool):
|
|
728
|
+
True: use layer_norm before each sub-block.
|
|
729
|
+
False: use layer_norm after each sub-block.
|
|
730
|
+
"""
|
|
731
|
+
|
|
732
|
+
def __init__(
|
|
733
|
+
self,
|
|
734
|
+
size: int,
|
|
735
|
+
self_attn: torch.nn.Module,
|
|
736
|
+
feed_forward: Optional[nn.Module] = None,
|
|
737
|
+
feed_forward_macaron: Optional[nn.Module] = None,
|
|
738
|
+
conv_module: Optional[nn.Module] = None,
|
|
739
|
+
dropout_rate: float = 0.0,
|
|
740
|
+
normalize_before: bool = True,
|
|
741
|
+
):
|
|
742
|
+
super().__init__()
|
|
743
|
+
self.self_attn = self_attn
|
|
744
|
+
self.feed_forward = feed_forward
|
|
745
|
+
self.feed_forward_macaron = feed_forward_macaron
|
|
746
|
+
self.conv_module = conv_module
|
|
747
|
+
self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
|
|
748
|
+
self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
|
|
749
|
+
if feed_forward_macaron is not None:
|
|
750
|
+
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
|
|
751
|
+
self.ff_scale = 0.5
|
|
752
|
+
else:
|
|
753
|
+
self.ff_scale = 1.0
|
|
754
|
+
if self.conv_module is not None:
|
|
755
|
+
self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
|
|
756
|
+
self.norm_final = nn.LayerNorm(
|
|
757
|
+
size, eps=1e-12) # for the final output of the block
|
|
758
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
759
|
+
self.size = size
|
|
760
|
+
self.normalize_before = normalize_before
|
|
761
|
+
|
|
762
|
+
def forward(
|
|
763
|
+
self,
|
|
764
|
+
x: torch.Tensor,
|
|
765
|
+
mask: torch.Tensor,
|
|
766
|
+
pos_emb: torch.Tensor,
|
|
767
|
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
768
|
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
|
769
|
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
|
770
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
771
|
+
"""Compute encoded features.
|
|
772
|
+
|
|
773
|
+
Args:
|
|
774
|
+
x (torch.Tensor): (#batch, time, size)
|
|
775
|
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
|
776
|
+
(0, 0, 0) means fake mask.
|
|
777
|
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
|
778
|
+
for ConformerEncoderLayer.
|
|
779
|
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
|
780
|
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
|
781
|
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
|
782
|
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
|
783
|
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
|
784
|
+
(#batch=1, size, cache_t2)
|
|
785
|
+
Returns:
|
|
786
|
+
torch.Tensor: Output tensor (#batch, time, size).
|
|
787
|
+
torch.Tensor: Mask tensor (#batch, time, time).
|
|
788
|
+
torch.Tensor: att_cache tensor,
|
|
789
|
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
|
790
|
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
|
791
|
+
"""
|
|
792
|
+
|
|
793
|
+
# whether to use macaron style
|
|
794
|
+
if self.feed_forward_macaron is not None:
|
|
795
|
+
residual = x
|
|
796
|
+
if self.normalize_before:
|
|
797
|
+
x = self.norm_ff_macaron(x)
|
|
798
|
+
x = residual + self.ff_scale * self.dropout(
|
|
799
|
+
self.feed_forward_macaron(x))
|
|
800
|
+
if not self.normalize_before:
|
|
801
|
+
x = self.norm_ff_macaron(x)
|
|
802
|
+
|
|
803
|
+
# multi-headed self-attention module
|
|
804
|
+
residual = x
|
|
805
|
+
if self.normalize_before:
|
|
806
|
+
x = self.norm_mha(x)
|
|
807
|
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
|
|
808
|
+
att_cache)
|
|
809
|
+
x = residual + self.dropout(x_att)
|
|
810
|
+
if not self.normalize_before:
|
|
811
|
+
x = self.norm_mha(x)
|
|
812
|
+
|
|
813
|
+
# convolution module
|
|
814
|
+
# Fake new cnn cache here, and then change it in conv_module
|
|
815
|
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
|
816
|
+
if self.conv_module is not None:
|
|
817
|
+
residual = x
|
|
818
|
+
if self.normalize_before:
|
|
819
|
+
x = self.norm_conv(x)
|
|
820
|
+
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
|
821
|
+
x = residual + self.dropout(x)
|
|
822
|
+
|
|
823
|
+
if not self.normalize_before:
|
|
824
|
+
x = self.norm_conv(x)
|
|
825
|
+
|
|
826
|
+
# feed forward module
|
|
827
|
+
residual = x
|
|
828
|
+
if self.normalize_before:
|
|
829
|
+
x = self.norm_ff(x)
|
|
830
|
+
|
|
831
|
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
|
832
|
+
if not self.normalize_before:
|
|
833
|
+
x = self.norm_ff(x)
|
|
834
|
+
|
|
835
|
+
if self.conv_module is not None:
|
|
836
|
+
x = self.norm_final(x)
|
|
837
|
+
|
|
838
|
+
return x, mask, new_att_cache, new_cnn_cache
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
class UpsampleConformerEncoder(torch.nn.Module):
|
|
842
|
+
"""
|
|
843
|
+
Args:
|
|
844
|
+
input_size (int): input dim
|
|
845
|
+
output_size (int): dimension of attention
|
|
846
|
+
attention_heads (int): the number of heads of multi head attention
|
|
847
|
+
linear_units (int): the hidden units number of position-wise feed
|
|
848
|
+
forward
|
|
849
|
+
num_blocks (int): the number of decoder blocks
|
|
850
|
+
static_chunk_size (int): chunk size for static chunk training and
|
|
851
|
+
decoding
|
|
852
|
+
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
|
853
|
+
training or not, You can only use fixed chunk(chunk_size > 0)
|
|
854
|
+
or dyanmic chunk size(use_dynamic_chunk = True)
|
|
855
|
+
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
|
856
|
+
dynamic chunk training
|
|
857
|
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
|
858
|
+
"""
|
|
859
|
+
|
|
860
|
+
def __init__(
|
|
861
|
+
self,
|
|
862
|
+
input_size: int = 512,
|
|
863
|
+
output_size: int = 512,
|
|
864
|
+
attention_heads: int = 8,
|
|
865
|
+
linear_units: int = 2048,
|
|
866
|
+
num_blocks: int = 6,
|
|
867
|
+
static_chunk_size: int = 25,
|
|
868
|
+
use_dynamic_chunk: bool = False,
|
|
869
|
+
use_dynamic_left_chunk: bool = False,
|
|
870
|
+
key_bias: bool = True,
|
|
871
|
+
):
|
|
872
|
+
super().__init__()
|
|
873
|
+
self._output_size = output_size
|
|
874
|
+
|
|
875
|
+
self.embed = LinearNoSubsampling(
|
|
876
|
+
input_size, output_size,
|
|
877
|
+
EspnetRelPositionalEncoding(output_size),
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
|
881
|
+
self.static_chunk_size = static_chunk_size
|
|
882
|
+
self.use_dynamic_chunk = use_dynamic_chunk
|
|
883
|
+
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
|
884
|
+
activation = torch.nn.SiLU()
|
|
885
|
+
# self-attention module definition
|
|
886
|
+
encoder_selfattn_layer_args = (
|
|
887
|
+
attention_heads,
|
|
888
|
+
output_size,
|
|
889
|
+
0.0,
|
|
890
|
+
key_bias,
|
|
891
|
+
)
|
|
892
|
+
# feed-forward module definition
|
|
893
|
+
positionwise_layer_args = (
|
|
894
|
+
output_size,
|
|
895
|
+
linear_units,
|
|
896
|
+
0.0,
|
|
897
|
+
activation,
|
|
898
|
+
)
|
|
899
|
+
# convolution module definition
|
|
900
|
+
self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
|
|
901
|
+
self.encoders = torch.nn.ModuleList([
|
|
902
|
+
ConformerEncoderLayer(
|
|
903
|
+
output_size,
|
|
904
|
+
RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
|
|
905
|
+
PositionwiseFeedForward(*positionwise_layer_args),
|
|
906
|
+
) for _ in range(num_blocks)
|
|
907
|
+
])
|
|
908
|
+
self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
|
|
909
|
+
self.up_embed = LinearNoSubsampling(
|
|
910
|
+
input_size, output_size,
|
|
911
|
+
EspnetRelPositionalEncoding(output_size),
|
|
912
|
+
)
|
|
913
|
+
self.up_encoders = torch.nn.ModuleList([
|
|
914
|
+
ConformerEncoderLayer(
|
|
915
|
+
output_size,
|
|
916
|
+
RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
|
|
917
|
+
PositionwiseFeedForward(*positionwise_layer_args),
|
|
918
|
+
) for _ in range(4)
|
|
919
|
+
])
|
|
920
|
+
|
|
921
|
+
def output_size(self) -> int:
|
|
922
|
+
return self._output_size
|
|
923
|
+
|
|
924
|
+
def forward(
|
|
925
|
+
self,
|
|
926
|
+
xs: torch.Tensor,
|
|
927
|
+
xs_lens: torch.Tensor,
|
|
928
|
+
context: torch.Tensor = torch.zeros(0, 0, 0),
|
|
929
|
+
decoding_chunk_size: int = 0,
|
|
930
|
+
num_decoding_left_chunks: int = -1,
|
|
931
|
+
streaming: bool = False,
|
|
932
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
933
|
+
"""Embed positions in tensor.
|
|
934
|
+
|
|
935
|
+
Args:
|
|
936
|
+
xs: padded input tensor (B, T, D)
|
|
937
|
+
xs_lens: input length (B)
|
|
938
|
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
|
939
|
+
0: default for training, use random dynamic chunk.
|
|
940
|
+
<0: for decoding, use full chunk.
|
|
941
|
+
>0: for decoding, use fixed chunk size as set.
|
|
942
|
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
|
943
|
+
the chunk size is decoding_chunk_size.
|
|
944
|
+
>=0: use num_decoding_left_chunks
|
|
945
|
+
<0: use all left chunks
|
|
946
|
+
Returns:
|
|
947
|
+
encoder output tensor xs, and subsampled masks
|
|
948
|
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
|
949
|
+
masks: torch.Tensor batch padding mask after subsample
|
|
950
|
+
(B, 1, T' ~= T/subsample_rate)
|
|
951
|
+
NOTE(xcsong):
|
|
952
|
+
We pass the `__call__` method of the modules instead of `forward` to the
|
|
953
|
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
|
954
|
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
|
955
|
+
"""
|
|
956
|
+
T = xs.size(1)
|
|
957
|
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
|
958
|
+
xs, pos_emb, masks = self.embed(xs, masks)
|
|
959
|
+
if context.size(1) != 0:
|
|
960
|
+
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
|
961
|
+
context_masks = torch.ones(1, 1, context.size(1)).to(masks)
|
|
962
|
+
context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
|
|
963
|
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
|
964
|
+
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
|
|
965
|
+
# lookahead + conformer encoder
|
|
966
|
+
xs = self.pre_lookahead_layer(xs, context=context)
|
|
967
|
+
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
|
968
|
+
|
|
969
|
+
# upsample + conformer encoder
|
|
970
|
+
xs = xs.transpose(1, 2).contiguous()
|
|
971
|
+
xs, xs_lens = self.up_layer(xs, xs_lens)
|
|
972
|
+
xs = xs.transpose(1, 2).contiguous()
|
|
973
|
+
T = xs.size(1)
|
|
974
|
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
|
975
|
+
xs, pos_emb, masks = self.up_embed(xs, masks)
|
|
976
|
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
|
977
|
+
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
|
|
978
|
+
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
|
979
|
+
|
|
980
|
+
xs = self.after_norm(xs)
|
|
981
|
+
# Here we assume the mask is not changed in encoder layers, so just
|
|
982
|
+
# return the masks before encoder layers, and the masks will be used
|
|
983
|
+
# for cross attention with decoder later
|
|
984
|
+
return xs, masks
|
|
985
|
+
|
|
986
|
+
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
|
987
|
+
pos_emb: torch.Tensor,
|
|
988
|
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
|
989
|
+
for layer in self.encoders:
|
|
990
|
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
|
991
|
+
return xs
|
|
992
|
+
|
|
993
|
+
def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
|
994
|
+
pos_emb: torch.Tensor,
|
|
995
|
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
|
996
|
+
for layer in self.up_encoders:
|
|
997
|
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
|
998
|
+
return xs
|