minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
|
2
|
+
# 2024 Alibaba Inc (Xiang Lyu)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
|
16
|
+
"""Subsampling layer definition."""
|
|
17
|
+
|
|
18
|
+
from typing import Tuple, Union
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BaseSubsampling(torch.nn.Module):
|
|
24
|
+
|
|
25
|
+
def __init__(self):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.right_context = 0
|
|
28
|
+
self.subsampling_rate = 1
|
|
29
|
+
|
|
30
|
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
|
31
|
+
size: int) -> torch.Tensor:
|
|
32
|
+
return self.pos_enc.position_encoding(offset, size)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class EmbedinigNoSubsampling(BaseSubsampling):
|
|
36
|
+
"""Embedding input without subsampling
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
|
40
|
+
pos_enc_class: torch.nn.Module):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.embed = torch.nn.Embedding(idim, odim)
|
|
43
|
+
self.pos_enc = pos_enc_class
|
|
44
|
+
|
|
45
|
+
def forward(
|
|
46
|
+
self,
|
|
47
|
+
x: torch.Tensor,
|
|
48
|
+
x_mask: torch.Tensor,
|
|
49
|
+
offset: Union[int, torch.Tensor] = 0
|
|
50
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
51
|
+
"""Input x.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
|
55
|
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
|
59
|
+
where time' = time .
|
|
60
|
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
|
61
|
+
where time' = time .
|
|
62
|
+
|
|
63
|
+
"""
|
|
64
|
+
x = self.embed(x)
|
|
65
|
+
x, pos_emb = self.pos_enc(x, offset)
|
|
66
|
+
return x, pos_emb, x_mask
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class LinearNoSubsampling(BaseSubsampling):
|
|
70
|
+
"""Linear transform the input without subsampling
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
idim (int): Input dimension.
|
|
74
|
+
odim (int): Output dimension.
|
|
75
|
+
dropout_rate (float): Dropout rate.
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
|
80
|
+
pos_enc_class: torch.nn.Module):
|
|
81
|
+
"""Construct an linear object."""
|
|
82
|
+
super().__init__()
|
|
83
|
+
self.out = torch.nn.Sequential(
|
|
84
|
+
torch.nn.Linear(idim, odim),
|
|
85
|
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
|
86
|
+
torch.nn.Dropout(dropout_rate),
|
|
87
|
+
)
|
|
88
|
+
self.pos_enc = pos_enc_class
|
|
89
|
+
self.right_context = 0
|
|
90
|
+
self.subsampling_rate = 1
|
|
91
|
+
|
|
92
|
+
def forward(
|
|
93
|
+
self,
|
|
94
|
+
x: torch.Tensor,
|
|
95
|
+
x_mask: torch.Tensor,
|
|
96
|
+
offset: Union[int, torch.Tensor] = 0
|
|
97
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
98
|
+
"""Input x.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
|
102
|
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
|
106
|
+
where time' = time .
|
|
107
|
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
|
108
|
+
where time' = time .
|
|
109
|
+
|
|
110
|
+
"""
|
|
111
|
+
x = self.out(x)
|
|
112
|
+
x, pos_emb = self.pos_enc(x, offset)
|
|
113
|
+
return x, pos_emb, x_mask
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class Conv1dSubsampling2(BaseSubsampling):
|
|
117
|
+
"""Convolutional 1D subsampling (to 1/2 length).
|
|
118
|
+
It is designed for Whisper, ref:
|
|
119
|
+
https://github.com/openai/whisper/blob/main/whisper/model.py
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
idim (int): Input dimension.
|
|
123
|
+
odim (int): Output dimension.
|
|
124
|
+
dropout_rate (float): Dropout rate.
|
|
125
|
+
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
|
129
|
+
pos_enc_class: torch.nn.Module):
|
|
130
|
+
"""Construct an Conv1dSubsampling2 object."""
|
|
131
|
+
super().__init__()
|
|
132
|
+
self.conv = torch.nn.Sequential(
|
|
133
|
+
torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
|
|
134
|
+
torch.nn.GELU(),
|
|
135
|
+
torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
|
|
136
|
+
torch.nn.GELU(),
|
|
137
|
+
)
|
|
138
|
+
self.pos_enc = pos_enc_class
|
|
139
|
+
# The right context for every conv layer is computed by:
|
|
140
|
+
# (kernel_size - 1) * frame_rate_of_this_layer
|
|
141
|
+
self.subsampling_rate = 2
|
|
142
|
+
# 4 = (3 - 1) * 1 + (3 - 1) * 1
|
|
143
|
+
self.right_context = 4
|
|
144
|
+
|
|
145
|
+
def forward(
|
|
146
|
+
self,
|
|
147
|
+
x: torch.Tensor,
|
|
148
|
+
x_mask: torch.Tensor,
|
|
149
|
+
offset: Union[int, torch.Tensor] = 0
|
|
150
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
151
|
+
"""Subsample x.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
|
155
|
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
|
159
|
+
where time' = time // 2.
|
|
160
|
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
|
161
|
+
where time' = time // 2.
|
|
162
|
+
torch.Tensor: positional encoding
|
|
163
|
+
|
|
164
|
+
"""
|
|
165
|
+
time = x.size(1)
|
|
166
|
+
x = x.transpose(1, 2) # (b, f, t)
|
|
167
|
+
x = self.conv(x)
|
|
168
|
+
x = x.transpose(1, 2) # (b, t, f)
|
|
169
|
+
x, pos_emb = self.pos_enc(x, offset)
|
|
170
|
+
return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class Conv2dSubsampling4(BaseSubsampling):
|
|
174
|
+
"""Convolutional 2D subsampling (to 1/4 length).
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
idim (int): Input dimension.
|
|
178
|
+
odim (int): Output dimension.
|
|
179
|
+
dropout_rate (float): Dropout rate.
|
|
180
|
+
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
|
184
|
+
pos_enc_class: torch.nn.Module):
|
|
185
|
+
"""Construct an Conv2dSubsampling4 object."""
|
|
186
|
+
super().__init__()
|
|
187
|
+
self.conv = torch.nn.Sequential(
|
|
188
|
+
torch.nn.Conv2d(1, odim, 3, 2),
|
|
189
|
+
torch.nn.ReLU(),
|
|
190
|
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
|
191
|
+
torch.nn.ReLU(),
|
|
192
|
+
)
|
|
193
|
+
self.out = torch.nn.Sequential(
|
|
194
|
+
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
|
|
195
|
+
self.pos_enc = pos_enc_class
|
|
196
|
+
# The right context for every conv layer is computed by:
|
|
197
|
+
# (kernel_size - 1) * frame_rate_of_this_layer
|
|
198
|
+
self.subsampling_rate = 4
|
|
199
|
+
# 6 = (3 - 1) * 1 + (3 - 1) * 2
|
|
200
|
+
self.right_context = 6
|
|
201
|
+
|
|
202
|
+
def forward(
|
|
203
|
+
self,
|
|
204
|
+
x: torch.Tensor,
|
|
205
|
+
x_mask: torch.Tensor,
|
|
206
|
+
offset: Union[int, torch.Tensor] = 0
|
|
207
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
208
|
+
"""Subsample x.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
|
212
|
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
|
216
|
+
where time' = time // 4.
|
|
217
|
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
|
218
|
+
where time' = time // 4.
|
|
219
|
+
torch.Tensor: positional encoding
|
|
220
|
+
|
|
221
|
+
"""
|
|
222
|
+
x = x.unsqueeze(1) # (b, c=1, t, f)
|
|
223
|
+
x = self.conv(x)
|
|
224
|
+
b, c, t, f = x.size()
|
|
225
|
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
|
226
|
+
x, pos_emb = self.pos_enc(x, offset)
|
|
227
|
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class Conv2dSubsampling6(BaseSubsampling):
|
|
231
|
+
"""Convolutional 2D subsampling (to 1/6 length).
|
|
232
|
+
Args:
|
|
233
|
+
idim (int): Input dimension.
|
|
234
|
+
odim (int): Output dimension.
|
|
235
|
+
dropout_rate (float): Dropout rate.
|
|
236
|
+
pos_enc (torch.nn.Module): Custom position encoding layer.
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
|
240
|
+
pos_enc_class: torch.nn.Module):
|
|
241
|
+
"""Construct an Conv2dSubsampling6 object."""
|
|
242
|
+
super().__init__()
|
|
243
|
+
self.conv = torch.nn.Sequential(
|
|
244
|
+
torch.nn.Conv2d(1, odim, 3, 2),
|
|
245
|
+
torch.nn.ReLU(),
|
|
246
|
+
torch.nn.Conv2d(odim, odim, 5, 3),
|
|
247
|
+
torch.nn.ReLU(),
|
|
248
|
+
)
|
|
249
|
+
self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
|
|
250
|
+
odim)
|
|
251
|
+
self.pos_enc = pos_enc_class
|
|
252
|
+
# 10 = (3 - 1) * 1 + (5 - 1) * 2
|
|
253
|
+
self.subsampling_rate = 6
|
|
254
|
+
self.right_context = 10
|
|
255
|
+
|
|
256
|
+
def forward(
|
|
257
|
+
self,
|
|
258
|
+
x: torch.Tensor,
|
|
259
|
+
x_mask: torch.Tensor,
|
|
260
|
+
offset: Union[int, torch.Tensor] = 0
|
|
261
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
262
|
+
"""Subsample x.
|
|
263
|
+
Args:
|
|
264
|
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
|
265
|
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
|
269
|
+
where time' = time // 6.
|
|
270
|
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
|
271
|
+
where time' = time // 6.
|
|
272
|
+
torch.Tensor: positional encoding
|
|
273
|
+
"""
|
|
274
|
+
x = x.unsqueeze(1) # (b, c, t, f)
|
|
275
|
+
x = self.conv(x)
|
|
276
|
+
b, c, t, f = x.size()
|
|
277
|
+
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
|
278
|
+
x, pos_emb = self.pos_enc(x, offset)
|
|
279
|
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class Conv2dSubsampling8(BaseSubsampling):
|
|
283
|
+
"""Convolutional 2D subsampling (to 1/8 length).
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
idim (int): Input dimension.
|
|
287
|
+
odim (int): Output dimension.
|
|
288
|
+
dropout_rate (float): Dropout rate.
|
|
289
|
+
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
|
293
|
+
pos_enc_class: torch.nn.Module):
|
|
294
|
+
"""Construct an Conv2dSubsampling8 object."""
|
|
295
|
+
super().__init__()
|
|
296
|
+
self.conv = torch.nn.Sequential(
|
|
297
|
+
torch.nn.Conv2d(1, odim, 3, 2),
|
|
298
|
+
torch.nn.ReLU(),
|
|
299
|
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
|
300
|
+
torch.nn.ReLU(),
|
|
301
|
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
|
302
|
+
torch.nn.ReLU(),
|
|
303
|
+
)
|
|
304
|
+
self.linear = torch.nn.Linear(
|
|
305
|
+
odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
|
|
306
|
+
self.pos_enc = pos_enc_class
|
|
307
|
+
self.subsampling_rate = 8
|
|
308
|
+
# 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
|
|
309
|
+
self.right_context = 14
|
|
310
|
+
|
|
311
|
+
def forward(
|
|
312
|
+
self,
|
|
313
|
+
x: torch.Tensor,
|
|
314
|
+
x_mask: torch.Tensor,
|
|
315
|
+
offset: Union[int, torch.Tensor] = 0
|
|
316
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
317
|
+
"""Subsample x.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
|
321
|
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
|
325
|
+
where time' = time // 8.
|
|
326
|
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
|
327
|
+
where time' = time // 8.
|
|
328
|
+
torch.Tensor: positional encoding
|
|
329
|
+
"""
|
|
330
|
+
x = x.unsqueeze(1) # (b, c, t, f)
|
|
331
|
+
x = self.conv(x)
|
|
332
|
+
b, c, t, f = x.size()
|
|
333
|
+
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
|
334
|
+
x, pos_emb = self.pos_enc(x, offset)
|
|
335
|
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
class LegacyLinearNoSubsampling(BaseSubsampling):
|
|
339
|
+
"""Linear transform the input without subsampling
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
idim (int): Input dimension.
|
|
343
|
+
odim (int): Output dimension.
|
|
344
|
+
dropout_rate (float): Dropout rate.
|
|
345
|
+
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
|
349
|
+
pos_enc_class: torch.nn.Module):
|
|
350
|
+
"""Construct an linear object."""
|
|
351
|
+
super().__init__()
|
|
352
|
+
self.out = torch.nn.Sequential(
|
|
353
|
+
torch.nn.Linear(idim, odim),
|
|
354
|
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
|
355
|
+
torch.nn.Dropout(dropout_rate),
|
|
356
|
+
torch.nn.ReLU(),
|
|
357
|
+
)
|
|
358
|
+
self.pos_enc = pos_enc_class
|
|
359
|
+
self.right_context = 0
|
|
360
|
+
self.subsampling_rate = 1
|
|
361
|
+
|
|
362
|
+
def forward(
|
|
363
|
+
self,
|
|
364
|
+
x: torch.Tensor,
|
|
365
|
+
x_mask: torch.Tensor,
|
|
366
|
+
offset: Union[int, torch.Tensor] = 0
|
|
367
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
368
|
+
"""Input x.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
|
372
|
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
|
376
|
+
where time' = time .
|
|
377
|
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
|
378
|
+
where time' = time .
|
|
379
|
+
|
|
380
|
+
"""
|
|
381
|
+
x = self.out(x)
|
|
382
|
+
x, pos_emb = self.pos_enc(x, offset)
|
|
383
|
+
return x, pos_emb, x_mask
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
|
2
|
+
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
|
3
|
+
# 2024 Alibaba Inc (Xiang Lyu)
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
|
17
|
+
"""Encoder definition."""
|
|
18
|
+
from typing import Tuple
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from torch import nn
|
|
22
|
+
from torch.nn import functional as F
|
|
23
|
+
|
|
24
|
+
from cosyvoice.transformer.convolution import ConvolutionModule
|
|
25
|
+
from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
|
|
26
|
+
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
|
27
|
+
from cosyvoice.utils.class_utils import (
|
|
28
|
+
COSYVOICE_EMB_CLASSES,
|
|
29
|
+
COSYVOICE_SUBSAMPLE_CLASSES,
|
|
30
|
+
COSYVOICE_ATTENTION_CLASSES,
|
|
31
|
+
COSYVOICE_ACTIVATION_CLASSES,
|
|
32
|
+
)
|
|
33
|
+
from cosyvoice.utils.mask import make_pad_mask
|
|
34
|
+
from cosyvoice.utils.mask import add_optional_chunk_mask
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Upsample1D(nn.Module):
|
|
38
|
+
"""A 1D upsampling layer with an optional convolution.
|
|
39
|
+
|
|
40
|
+
Parameters:
|
|
41
|
+
channels (`int`):
|
|
42
|
+
number of channels in the inputs and outputs.
|
|
43
|
+
use_conv (`bool`, default `False`):
|
|
44
|
+
option to use a convolution.
|
|
45
|
+
use_conv_transpose (`bool`, default `False`):
|
|
46
|
+
option to use a convolution transpose.
|
|
47
|
+
out_channels (`int`, optional):
|
|
48
|
+
number of output channels. Defaults to `channels`.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, channels: int, out_channels: int, stride: int = 2):
|
|
52
|
+
super().__init__()
|
|
53
|
+
self.channels = channels
|
|
54
|
+
self.out_channels = out_channels
|
|
55
|
+
self.stride = stride
|
|
56
|
+
# In this mode, first repeat interpolate, than conv with stride=1
|
|
57
|
+
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
|
|
58
|
+
|
|
59
|
+
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
60
|
+
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
|
|
61
|
+
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
|
62
|
+
outputs = self.conv(outputs)
|
|
63
|
+
return outputs, input_lengths * self.stride
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class PreLookaheadLayer(nn.Module):
|
|
67
|
+
def __init__(self, channels: int, pre_lookahead_len: int = 1):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.channels = channels
|
|
70
|
+
self.pre_lookahead_len = pre_lookahead_len
|
|
71
|
+
self.conv1 = nn.Conv1d(
|
|
72
|
+
channels, channels,
|
|
73
|
+
kernel_size=pre_lookahead_len + 1,
|
|
74
|
+
stride=1, padding=0,
|
|
75
|
+
)
|
|
76
|
+
self.conv2 = nn.Conv1d(
|
|
77
|
+
channels, channels,
|
|
78
|
+
kernel_size=3, stride=1, padding=0,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
|
|
82
|
+
"""
|
|
83
|
+
inputs: (batch_size, seq_len, channels)
|
|
84
|
+
"""
|
|
85
|
+
outputs = inputs.transpose(1, 2).contiguous()
|
|
86
|
+
context = context.transpose(1, 2).contiguous()
|
|
87
|
+
# look ahead
|
|
88
|
+
if context.size(2) == 0:
|
|
89
|
+
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
|
90
|
+
else:
|
|
91
|
+
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
|
92
|
+
assert context.size(2) == self.pre_lookahead_len
|
|
93
|
+
outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
|
|
94
|
+
outputs = F.leaky_relu(self.conv1(outputs))
|
|
95
|
+
# outputs
|
|
96
|
+
outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
|
|
97
|
+
outputs = self.conv2(outputs)
|
|
98
|
+
outputs = outputs.transpose(1, 2).contiguous()
|
|
99
|
+
|
|
100
|
+
# residual connection
|
|
101
|
+
outputs = outputs + inputs
|
|
102
|
+
return outputs
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class UpsampleConformerEncoder(torch.nn.Module):
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
input_size: int,
|
|
110
|
+
output_size: int = 256,
|
|
111
|
+
attention_heads: int = 4,
|
|
112
|
+
linear_units: int = 2048,
|
|
113
|
+
num_blocks: int = 6,
|
|
114
|
+
dropout_rate: float = 0.1,
|
|
115
|
+
positional_dropout_rate: float = 0.1,
|
|
116
|
+
attention_dropout_rate: float = 0.0,
|
|
117
|
+
input_layer: str = "conv2d",
|
|
118
|
+
pos_enc_layer_type: str = "rel_pos",
|
|
119
|
+
normalize_before: bool = True,
|
|
120
|
+
static_chunk_size: int = 0,
|
|
121
|
+
use_dynamic_chunk: bool = False,
|
|
122
|
+
global_cmvn: torch.nn.Module = None,
|
|
123
|
+
use_dynamic_left_chunk: bool = False,
|
|
124
|
+
positionwise_conv_kernel_size: int = 1,
|
|
125
|
+
macaron_style: bool = True,
|
|
126
|
+
selfattention_layer_type: str = "rel_selfattn",
|
|
127
|
+
activation_type: str = "swish",
|
|
128
|
+
use_cnn_module: bool = True,
|
|
129
|
+
cnn_module_kernel: int = 15,
|
|
130
|
+
causal: bool = False,
|
|
131
|
+
cnn_module_norm: str = "batch_norm",
|
|
132
|
+
key_bias: bool = True,
|
|
133
|
+
gradient_checkpointing: bool = False,
|
|
134
|
+
):
|
|
135
|
+
"""
|
|
136
|
+
Args:
|
|
137
|
+
input_size (int): input dim
|
|
138
|
+
output_size (int): dimension of attention
|
|
139
|
+
attention_heads (int): the number of heads of multi head attention
|
|
140
|
+
linear_units (int): the hidden units number of position-wise feed
|
|
141
|
+
forward
|
|
142
|
+
num_blocks (int): the number of decoder blocks
|
|
143
|
+
dropout_rate (float): dropout rate
|
|
144
|
+
attention_dropout_rate (float): dropout rate in attention
|
|
145
|
+
positional_dropout_rate (float): dropout rate after adding
|
|
146
|
+
positional encoding
|
|
147
|
+
input_layer (str): input layer type.
|
|
148
|
+
optional [linear, conv2d, conv2d6, conv2d8]
|
|
149
|
+
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
|
150
|
+
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
|
|
151
|
+
normalize_before (bool):
|
|
152
|
+
True: use layer_norm before each sub-block of a layer.
|
|
153
|
+
False: use layer_norm after each sub-block of a layer.
|
|
154
|
+
static_chunk_size (int): chunk size for static chunk training and
|
|
155
|
+
decoding
|
|
156
|
+
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
|
157
|
+
training or not, You can only use fixed chunk(chunk_size > 0)
|
|
158
|
+
or dyanmic chunk size(use_dynamic_chunk = True)
|
|
159
|
+
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
|
|
160
|
+
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
|
161
|
+
dynamic chunk training
|
|
162
|
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
|
163
|
+
gradient_checkpointing: rerunning a forward-pass segment for each
|
|
164
|
+
checkpointed segment during backward.
|
|
165
|
+
"""
|
|
166
|
+
super().__init__()
|
|
167
|
+
self._output_size = output_size
|
|
168
|
+
|
|
169
|
+
self.global_cmvn = global_cmvn
|
|
170
|
+
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
|
171
|
+
input_size,
|
|
172
|
+
output_size,
|
|
173
|
+
dropout_rate,
|
|
174
|
+
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
|
175
|
+
positional_dropout_rate),
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
self.normalize_before = normalize_before
|
|
179
|
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
|
180
|
+
self.static_chunk_size = static_chunk_size
|
|
181
|
+
self.use_dynamic_chunk = use_dynamic_chunk
|
|
182
|
+
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
|
183
|
+
self.gradient_checkpointing = gradient_checkpointing
|
|
184
|
+
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
|
185
|
+
# self-attention module definition
|
|
186
|
+
encoder_selfattn_layer_args = (
|
|
187
|
+
attention_heads,
|
|
188
|
+
output_size,
|
|
189
|
+
attention_dropout_rate,
|
|
190
|
+
key_bias,
|
|
191
|
+
)
|
|
192
|
+
# feed-forward module definition
|
|
193
|
+
positionwise_layer_args = (
|
|
194
|
+
output_size,
|
|
195
|
+
linear_units,
|
|
196
|
+
dropout_rate,
|
|
197
|
+
activation,
|
|
198
|
+
)
|
|
199
|
+
# convolution module definition
|
|
200
|
+
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
|
201
|
+
cnn_module_norm, causal)
|
|
202
|
+
self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
|
|
203
|
+
self.encoders = torch.nn.ModuleList([
|
|
204
|
+
ConformerEncoderLayer(
|
|
205
|
+
output_size,
|
|
206
|
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
|
207
|
+
*encoder_selfattn_layer_args),
|
|
208
|
+
PositionwiseFeedForward(*positionwise_layer_args),
|
|
209
|
+
PositionwiseFeedForward(
|
|
210
|
+
*positionwise_layer_args) if macaron_style else None,
|
|
211
|
+
ConvolutionModule(
|
|
212
|
+
*convolution_layer_args) if use_cnn_module else None,
|
|
213
|
+
dropout_rate,
|
|
214
|
+
normalize_before,
|
|
215
|
+
) for _ in range(num_blocks)
|
|
216
|
+
])
|
|
217
|
+
self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
|
|
218
|
+
self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
|
219
|
+
input_size,
|
|
220
|
+
output_size,
|
|
221
|
+
dropout_rate,
|
|
222
|
+
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
|
223
|
+
positional_dropout_rate),
|
|
224
|
+
)
|
|
225
|
+
self.up_encoders = torch.nn.ModuleList([
|
|
226
|
+
ConformerEncoderLayer(
|
|
227
|
+
output_size,
|
|
228
|
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
|
229
|
+
*encoder_selfattn_layer_args),
|
|
230
|
+
PositionwiseFeedForward(*positionwise_layer_args),
|
|
231
|
+
PositionwiseFeedForward(
|
|
232
|
+
*positionwise_layer_args) if macaron_style else None,
|
|
233
|
+
ConvolutionModule(
|
|
234
|
+
*convolution_layer_args) if use_cnn_module else None,
|
|
235
|
+
dropout_rate,
|
|
236
|
+
normalize_before,
|
|
237
|
+
) for _ in range(4)
|
|
238
|
+
])
|
|
239
|
+
|
|
240
|
+
def output_size(self) -> int:
|
|
241
|
+
return self._output_size
|
|
242
|
+
|
|
243
|
+
def forward(
|
|
244
|
+
self,
|
|
245
|
+
xs: torch.Tensor,
|
|
246
|
+
xs_lens: torch.Tensor,
|
|
247
|
+
context: torch.Tensor = torch.zeros(0, 0, 0),
|
|
248
|
+
decoding_chunk_size: int = 0,
|
|
249
|
+
num_decoding_left_chunks: int = -1,
|
|
250
|
+
streaming: bool = False,
|
|
251
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
252
|
+
"""Embed positions in tensor.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
xs: padded input tensor (B, T, D)
|
|
256
|
+
xs_lens: input length (B)
|
|
257
|
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
|
258
|
+
0: default for training, use random dynamic chunk.
|
|
259
|
+
<0: for decoding, use full chunk.
|
|
260
|
+
>0: for decoding, use fixed chunk size as set.
|
|
261
|
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
|
262
|
+
the chunk size is decoding_chunk_size.
|
|
263
|
+
>=0: use num_decoding_left_chunks
|
|
264
|
+
<0: use all left chunks
|
|
265
|
+
Returns:
|
|
266
|
+
encoder output tensor xs, and subsampled masks
|
|
267
|
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
|
268
|
+
masks: torch.Tensor batch padding mask after subsample
|
|
269
|
+
(B, 1, T' ~= T/subsample_rate)
|
|
270
|
+
NOTE(xcsong):
|
|
271
|
+
We pass the `__call__` method of the modules instead of `forward` to the
|
|
272
|
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
|
273
|
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
|
274
|
+
"""
|
|
275
|
+
T = xs.size(1)
|
|
276
|
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
|
277
|
+
if self.global_cmvn is not None:
|
|
278
|
+
xs = self.global_cmvn(xs)
|
|
279
|
+
xs, pos_emb, masks = self.embed(xs, masks)
|
|
280
|
+
if context.size(1) != 0:
|
|
281
|
+
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
|
282
|
+
context_masks = torch.ones(1, 1, context.size(1)).to(masks)
|
|
283
|
+
context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
|
|
284
|
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
|
285
|
+
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
|
|
286
|
+
# lookahead + conformer encoder
|
|
287
|
+
xs = self.pre_lookahead_layer(xs, context=context)
|
|
288
|
+
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
|
289
|
+
|
|
290
|
+
# upsample + conformer encoder
|
|
291
|
+
xs = xs.transpose(1, 2).contiguous()
|
|
292
|
+
xs, xs_lens = self.up_layer(xs, xs_lens)
|
|
293
|
+
xs = xs.transpose(1, 2).contiguous()
|
|
294
|
+
T = xs.size(1)
|
|
295
|
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
|
296
|
+
xs, pos_emb, masks = self.up_embed(xs, masks)
|
|
297
|
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
|
298
|
+
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
|
|
299
|
+
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
|
300
|
+
|
|
301
|
+
if self.normalize_before:
|
|
302
|
+
xs = self.after_norm(xs)
|
|
303
|
+
# Here we assume the mask is not changed in encoder layers, so just
|
|
304
|
+
# return the masks before encoder layers, and the masks will be used
|
|
305
|
+
# for cross attention with decoder later
|
|
306
|
+
return xs, masks
|
|
307
|
+
|
|
308
|
+
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
|
309
|
+
pos_emb: torch.Tensor,
|
|
310
|
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
|
311
|
+
for layer in self.encoders:
|
|
312
|
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
|
313
|
+
return xs
|
|
314
|
+
|
|
315
|
+
def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
|
316
|
+
pos_emb: torch.Tensor,
|
|
317
|
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
|
318
|
+
for layer in self.up_encoders:
|
|
319
|
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
|
320
|
+
return xs
|
|
File without changes
|