minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
# Copyright (c) 2019 Shigeki Karita
|
|
2
|
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
|
3
|
+
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
|
4
|
+
# 2024 Alibaba Inc (Xiang Lyu)
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""Multi-Head Attention layer definition."""
|
|
18
|
+
|
|
19
|
+
import math
|
|
20
|
+
from typing import Tuple
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
from torch import nn
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MultiHeadedAttention(nn.Module):
|
|
27
|
+
"""Multi-Head Attention layer.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
n_head (int): The number of heads.
|
|
31
|
+
n_feat (int): The number of features.
|
|
32
|
+
dropout_rate (float): Dropout rate.
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self,
|
|
37
|
+
n_head: int,
|
|
38
|
+
n_feat: int,
|
|
39
|
+
dropout_rate: float,
|
|
40
|
+
key_bias: bool = True):
|
|
41
|
+
"""Construct an MultiHeadedAttention object."""
|
|
42
|
+
super().__init__()
|
|
43
|
+
assert n_feat % n_head == 0
|
|
44
|
+
# We assume d_v always equals d_k
|
|
45
|
+
self.d_k = n_feat // n_head
|
|
46
|
+
self.h = n_head
|
|
47
|
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
|
48
|
+
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
|
|
49
|
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
|
50
|
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
|
51
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
|
52
|
+
|
|
53
|
+
def forward_qkv(
|
|
54
|
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
|
55
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
56
|
+
"""Transform query, key and value.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
60
|
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
61
|
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
torch.Tensor: Transformed query tensor, size
|
|
65
|
+
(#batch, n_head, time1, d_k).
|
|
66
|
+
torch.Tensor: Transformed key tensor, size
|
|
67
|
+
(#batch, n_head, time2, d_k).
|
|
68
|
+
torch.Tensor: Transformed value tensor, size
|
|
69
|
+
(#batch, n_head, time2, d_k).
|
|
70
|
+
|
|
71
|
+
"""
|
|
72
|
+
n_batch = query.size(0)
|
|
73
|
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
|
74
|
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
|
75
|
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
|
76
|
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
|
77
|
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
|
78
|
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
|
79
|
+
|
|
80
|
+
return q, k, v
|
|
81
|
+
|
|
82
|
+
def forward_attention(
|
|
83
|
+
self,
|
|
84
|
+
value: torch.Tensor,
|
|
85
|
+
scores: torch.Tensor,
|
|
86
|
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
|
87
|
+
) -> torch.Tensor:
|
|
88
|
+
"""Compute attention context vector.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
value (torch.Tensor): Transformed value, size
|
|
92
|
+
(#batch, n_head, time2, d_k).
|
|
93
|
+
scores (torch.Tensor): Attention score, size
|
|
94
|
+
(#batch, n_head, time1, time2).
|
|
95
|
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
|
96
|
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
|
100
|
+
weighted by the attention score (#batch, time1, time2).
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
n_batch = value.size(0)
|
|
104
|
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
|
|
105
|
+
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
|
106
|
+
# 1st chunk to ease the onnx export.]
|
|
107
|
+
# 2. pytorch training
|
|
108
|
+
if mask.size(2) > 0: # time2 > 0
|
|
109
|
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
|
110
|
+
# For last chunk, time2 might be larger than scores.size(-1)
|
|
111
|
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
|
112
|
+
scores = scores.masked_fill(mask, -float('inf'))
|
|
113
|
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
|
114
|
+
mask, 0.0) # (batch, head, time1, time2)
|
|
115
|
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
|
|
116
|
+
# 1. onnx(16/-1, -1/-1, 16/0)
|
|
117
|
+
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
|
118
|
+
else:
|
|
119
|
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
|
120
|
+
|
|
121
|
+
p_attn = self.dropout(attn)
|
|
122
|
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
|
123
|
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
|
124
|
+
self.h * self.d_k)
|
|
125
|
+
) # (batch, time1, d_model)
|
|
126
|
+
|
|
127
|
+
return self.linear_out(x) # (batch, time1, d_model)
|
|
128
|
+
|
|
129
|
+
def forward(
|
|
130
|
+
self,
|
|
131
|
+
query: torch.Tensor,
|
|
132
|
+
key: torch.Tensor,
|
|
133
|
+
value: torch.Tensor,
|
|
134
|
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
135
|
+
pos_emb: torch.Tensor = torch.empty(0),
|
|
136
|
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
|
137
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
138
|
+
"""Compute scaled dot product attention.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
142
|
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
143
|
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
144
|
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
|
145
|
+
(#batch, time1, time2).
|
|
146
|
+
1.When applying cross attention between decoder and encoder,
|
|
147
|
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
|
148
|
+
2.When applying self attention of encoder,
|
|
149
|
+
the mask is in (#batch, T, T) shape.
|
|
150
|
+
3.When applying self attention of decoder,
|
|
151
|
+
the mask is in (#batch, L, L) shape.
|
|
152
|
+
4.If the different position in decoder see different block
|
|
153
|
+
of the encoder, such as Mocha, the passed in mask could be
|
|
154
|
+
in (#batch, L, T) shape. But there is no such case in current
|
|
155
|
+
CosyVoice.
|
|
156
|
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
|
157
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
158
|
+
and `head * d_k == size`
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
|
163
|
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
|
164
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
165
|
+
and `head * d_k == size`
|
|
166
|
+
|
|
167
|
+
"""
|
|
168
|
+
q, k, v = self.forward_qkv(query, key, value)
|
|
169
|
+
|
|
170
|
+
# NOTE(xcsong):
|
|
171
|
+
# when export onnx model, for 1st chunk, we feed
|
|
172
|
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
|
173
|
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
|
174
|
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
|
175
|
+
# and we will always do splitting and
|
|
176
|
+
# concatnation(this will simplify onnx export). Note that
|
|
177
|
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
|
178
|
+
# when export jit model, for 1st chunk, we always feed
|
|
179
|
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
|
180
|
+
# >>> a = torch.ones((1, 2, 0, 4))
|
|
181
|
+
# >>> b = torch.ones((1, 2, 3, 4))
|
|
182
|
+
# >>> c = torch.cat((a, b), dim=2)
|
|
183
|
+
# >>> torch.equal(b, c) # True
|
|
184
|
+
# >>> d = torch.split(a, 2, dim=-1)
|
|
185
|
+
# >>> torch.equal(d[0], d[1]) # True
|
|
186
|
+
if cache.size(0) > 0:
|
|
187
|
+
key_cache, value_cache = torch.split(cache,
|
|
188
|
+
cache.size(-1) // 2,
|
|
189
|
+
dim=-1)
|
|
190
|
+
k = torch.cat([key_cache, k], dim=2)
|
|
191
|
+
v = torch.cat([value_cache, v], dim=2)
|
|
192
|
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
|
193
|
+
# non-trivial to calculate `next_cache_start` here.
|
|
194
|
+
new_cache = torch.cat((k, v), dim=-1)
|
|
195
|
+
|
|
196
|
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
|
197
|
+
return self.forward_attention(v, scores, mask), new_cache
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
201
|
+
"""Multi-Head Attention layer with relative position encoding.
|
|
202
|
+
Paper: https://arxiv.org/abs/1901.02860
|
|
203
|
+
Args:
|
|
204
|
+
n_head (int): The number of heads.
|
|
205
|
+
n_feat (int): The number of features.
|
|
206
|
+
dropout_rate (float): Dropout rate.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(self,
|
|
210
|
+
n_head: int,
|
|
211
|
+
n_feat: int,
|
|
212
|
+
dropout_rate: float,
|
|
213
|
+
key_bias: bool = True):
|
|
214
|
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
|
215
|
+
super().__init__(n_head, n_feat, dropout_rate, key_bias)
|
|
216
|
+
# linear transformation for positional encoding
|
|
217
|
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
|
218
|
+
# these two learnable bias are used in matrix c and matrix d
|
|
219
|
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
|
220
|
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
|
221
|
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
|
222
|
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
|
223
|
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
|
224
|
+
|
|
225
|
+
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
|
226
|
+
"""Compute relative positional encoding.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
|
230
|
+
time1 means the length of query vector.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
torch.Tensor: Output tensor.
|
|
234
|
+
|
|
235
|
+
"""
|
|
236
|
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
|
237
|
+
device=x.device,
|
|
238
|
+
dtype=x.dtype)
|
|
239
|
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
|
240
|
+
|
|
241
|
+
x_padded = x_padded.view(x.size()[0],
|
|
242
|
+
x.size()[1],
|
|
243
|
+
x.size(3) + 1, x.size(2))
|
|
244
|
+
x = x_padded[:, :, 1:].view_as(x)[
|
|
245
|
+
:, :, :, : x.size(-1) // 2 + 1
|
|
246
|
+
] # only keep the positions from 0 to time2
|
|
247
|
+
return x
|
|
248
|
+
|
|
249
|
+
def forward(
|
|
250
|
+
self,
|
|
251
|
+
query: torch.Tensor,
|
|
252
|
+
key: torch.Tensor,
|
|
253
|
+
value: torch.Tensor,
|
|
254
|
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
255
|
+
pos_emb: torch.Tensor = torch.empty(0),
|
|
256
|
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
|
257
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
258
|
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
|
259
|
+
Args:
|
|
260
|
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
261
|
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
262
|
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
263
|
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
|
264
|
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
|
265
|
+
pos_emb (torch.Tensor): Positional embedding tensor
|
|
266
|
+
(#batch, time2, size).
|
|
267
|
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
|
268
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
269
|
+
and `head * d_k == size`
|
|
270
|
+
Returns:
|
|
271
|
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
|
272
|
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
|
273
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
274
|
+
and `head * d_k == size`
|
|
275
|
+
"""
|
|
276
|
+
q, k, v = self.forward_qkv(query, key, value)
|
|
277
|
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
|
278
|
+
|
|
279
|
+
# NOTE(xcsong):
|
|
280
|
+
# when export onnx model, for 1st chunk, we feed
|
|
281
|
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
|
282
|
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
|
283
|
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
|
284
|
+
# and we will always do splitting and
|
|
285
|
+
# concatnation(this will simplify onnx export). Note that
|
|
286
|
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
|
287
|
+
# when export jit model, for 1st chunk, we always feed
|
|
288
|
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
|
289
|
+
# >>> a = torch.ones((1, 2, 0, 4))
|
|
290
|
+
# >>> b = torch.ones((1, 2, 3, 4))
|
|
291
|
+
# >>> c = torch.cat((a, b), dim=2)
|
|
292
|
+
# >>> torch.equal(b, c) # True
|
|
293
|
+
# >>> d = torch.split(a, 2, dim=-1)
|
|
294
|
+
# >>> torch.equal(d[0], d[1]) # True
|
|
295
|
+
if cache is not None and cache.size(0) > 0:
|
|
296
|
+
key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
|
|
297
|
+
k = torch.cat([key_cache, k], dim=2)
|
|
298
|
+
v = torch.cat([value_cache, v], dim=2)
|
|
299
|
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
|
300
|
+
# non-trivial to calculate `next_cache_start` here.
|
|
301
|
+
new_cache = torch.cat((k, v), dim=-1)
|
|
302
|
+
|
|
303
|
+
n_batch_pos = pos_emb.size(0)
|
|
304
|
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
|
305
|
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
|
306
|
+
|
|
307
|
+
# (batch, head, time1, d_k)
|
|
308
|
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
|
309
|
+
# (batch, head, time1, d_k)
|
|
310
|
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
|
311
|
+
|
|
312
|
+
# compute attention score
|
|
313
|
+
# first compute matrix a and matrix c
|
|
314
|
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
|
315
|
+
# (batch, head, time1, time2)
|
|
316
|
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
|
317
|
+
|
|
318
|
+
# compute matrix b and matrix d
|
|
319
|
+
# (batch, head, time1, time2)
|
|
320
|
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
|
321
|
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
|
322
|
+
if matrix_ac.shape != matrix_bd.shape:
|
|
323
|
+
matrix_bd = self.rel_shift(matrix_bd)
|
|
324
|
+
|
|
325
|
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
|
326
|
+
self.d_k) # (batch, head, time1, time2)
|
|
327
|
+
|
|
328
|
+
return self.forward_attention(v, scores, mask), new_cache
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
|
2
|
+
# 2024 Alibaba Inc (Xiang Lyu)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
|
16
|
+
"""Positonal Encoding Module."""
|
|
17
|
+
|
|
18
|
+
import math
|
|
19
|
+
from typing import Tuple, Union
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
import torch.nn.functional as F
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class EspnetRelPositionalEncoding(torch.nn.Module):
|
|
27
|
+
"""Relative positional encoding module (new implementation).
|
|
28
|
+
|
|
29
|
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
|
30
|
+
|
|
31
|
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
d_model (int): Embedding dimension.
|
|
35
|
+
dropout_rate (float): Dropout rate.
|
|
36
|
+
max_len (int): Maximum input length.
|
|
37
|
+
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
|
41
|
+
"""Construct an PositionalEncoding object."""
|
|
42
|
+
super(EspnetRelPositionalEncoding, self).__init__()
|
|
43
|
+
self.d_model = d_model
|
|
44
|
+
self.xscale = math.sqrt(self.d_model)
|
|
45
|
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
|
46
|
+
self.pe = None
|
|
47
|
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
|
48
|
+
|
|
49
|
+
def extend_pe(self, x: torch.Tensor):
|
|
50
|
+
"""Reset the positional encodings."""
|
|
51
|
+
if self.pe is not None:
|
|
52
|
+
# self.pe contains both positive and negative parts
|
|
53
|
+
# the length of self.pe is 2 * input_len - 1
|
|
54
|
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
|
55
|
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
|
56
|
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
|
57
|
+
return
|
|
58
|
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
|
59
|
+
# position of key vector. We use position relative positions when keys
|
|
60
|
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
|
61
|
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
|
62
|
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
|
63
|
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
|
64
|
+
div_term = torch.exp(
|
|
65
|
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
|
66
|
+
* -(math.log(10000.0) / self.d_model)
|
|
67
|
+
)
|
|
68
|
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
|
69
|
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
|
70
|
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
|
71
|
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
|
72
|
+
|
|
73
|
+
# Reserve the order of positive indices and concat both positive and
|
|
74
|
+
# negative indices. This is used to support the shifting trick
|
|
75
|
+
# as in https://arxiv.org/abs/1901.02860
|
|
76
|
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
|
77
|
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
|
78
|
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
|
79
|
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
|
80
|
+
|
|
81
|
+
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
|
82
|
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
|
83
|
+
"""Add positional encoding.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
|
90
|
+
|
|
91
|
+
"""
|
|
92
|
+
self.extend_pe(x)
|
|
93
|
+
x = x * self.xscale
|
|
94
|
+
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
|
|
95
|
+
return self.dropout(x), self.dropout(pos_emb)
|
|
96
|
+
|
|
97
|
+
def position_encoding(self,
|
|
98
|
+
offset: Union[int, torch.Tensor],
|
|
99
|
+
size: int) -> torch.Tensor:
|
|
100
|
+
""" For getting encoding in a streaming fashion
|
|
101
|
+
|
|
102
|
+
Attention!!!!!
|
|
103
|
+
we apply dropout only once at the whole utterance level in a none
|
|
104
|
+
streaming way, but will call this function several times with
|
|
105
|
+
increasing input size in a streaming scenario, so the dropout will
|
|
106
|
+
be applied several times.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
offset (int or torch.tensor): start offset
|
|
110
|
+
size (int): required size of position encoding
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
torch.Tensor: Corresponding encoding
|
|
114
|
+
"""
|
|
115
|
+
pos_emb = self.pe[
|
|
116
|
+
:,
|
|
117
|
+
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
|
|
118
|
+
]
|
|
119
|
+
return pos_emb
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
|
2
|
+
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
|
16
|
+
"""Encoder self-attention layer definition."""
|
|
17
|
+
|
|
18
|
+
from typing import Optional, Tuple
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from torch import nn
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ConformerEncoderLayer(nn.Module):
|
|
25
|
+
"""Encoder layer module.
|
|
26
|
+
Args:
|
|
27
|
+
size (int): Input dimension.
|
|
28
|
+
self_attn (torch.nn.Module): Self-attention module instance.
|
|
29
|
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
|
30
|
+
instance can be used as the argument.
|
|
31
|
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
|
32
|
+
`PositionwiseFeedForward` instance can be used as the argument.
|
|
33
|
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
|
34
|
+
instance.
|
|
35
|
+
`PositionwiseFeedForward` instance can be used as the argument.
|
|
36
|
+
conv_module (torch.nn.Module): Convolution module instance.
|
|
37
|
+
`ConvlutionModule` instance can be used as the argument.
|
|
38
|
+
dropout_rate (float): Dropout rate.
|
|
39
|
+
normalize_before (bool):
|
|
40
|
+
True: use layer_norm before each sub-block.
|
|
41
|
+
False: use layer_norm after each sub-block.
|
|
42
|
+
enable_cuda_graph (bool): Control whether to enable CUDA Graph.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
size: int,
|
|
48
|
+
self_attn: torch.nn.Module,
|
|
49
|
+
feed_forward: Optional[nn.Module] = None,
|
|
50
|
+
feed_forward_macaron: Optional[nn.Module] = None,
|
|
51
|
+
conv_module: Optional[nn.Module] = None,
|
|
52
|
+
dropout_rate: float = 0.1,
|
|
53
|
+
normalize_before: bool = True,
|
|
54
|
+
):
|
|
55
|
+
"""Construct an EncoderLayer object."""
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.self_attn = self_attn
|
|
58
|
+
self.feed_forward = feed_forward
|
|
59
|
+
self.feed_forward_macaron = feed_forward_macaron
|
|
60
|
+
self.conv_module = conv_module
|
|
61
|
+
self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
|
|
62
|
+
self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
|
|
63
|
+
if feed_forward_macaron is not None:
|
|
64
|
+
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
|
|
65
|
+
self.ff_scale = 0.5
|
|
66
|
+
else:
|
|
67
|
+
self.ff_scale = 1.0
|
|
68
|
+
if self.conv_module is not None:
|
|
69
|
+
self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
|
|
70
|
+
self.norm_final = nn.LayerNorm(
|
|
71
|
+
size, eps=1e-12) # for the final output of the block
|
|
72
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
73
|
+
self.size = size
|
|
74
|
+
self.normalize_before = normalize_before
|
|
75
|
+
|
|
76
|
+
def forward(
|
|
77
|
+
self,
|
|
78
|
+
x: torch.Tensor,
|
|
79
|
+
mask: torch.Tensor,
|
|
80
|
+
pos_emb: torch.Tensor,
|
|
81
|
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
82
|
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
|
83
|
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
|
84
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
85
|
+
"""Compute encoded features.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
x (torch.Tensor): (#batch, time, size)
|
|
89
|
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
|
90
|
+
(0, 0, 0) means fake mask.
|
|
91
|
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
|
92
|
+
for ConformerEncoderLayer.
|
|
93
|
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
|
94
|
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
|
95
|
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
|
96
|
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
|
97
|
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
|
98
|
+
(#batch=1, size, cache_t2)
|
|
99
|
+
Returns:
|
|
100
|
+
torch.Tensor: Output tensor (#batch, time, size).
|
|
101
|
+
torch.Tensor: Mask tensor (#batch, time, time).
|
|
102
|
+
torch.Tensor: att_cache tensor,
|
|
103
|
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
|
104
|
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
|
105
|
+
"""
|
|
106
|
+
return self._forward_impl(x, mask, pos_emb, mask_pad, att_cache, cnn_cache)
|
|
107
|
+
|
|
108
|
+
def _forward_impl(
|
|
109
|
+
self,
|
|
110
|
+
x: torch.Tensor,
|
|
111
|
+
mask: torch.Tensor,
|
|
112
|
+
pos_emb: torch.Tensor,
|
|
113
|
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
114
|
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
|
115
|
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
|
116
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
117
|
+
"""原始的前向传播实现"""
|
|
118
|
+
# whether to use macaron style
|
|
119
|
+
if self.feed_forward_macaron is not None:
|
|
120
|
+
residual = x
|
|
121
|
+
if self.normalize_before:
|
|
122
|
+
x = self.norm_ff_macaron(x)
|
|
123
|
+
x = residual + self.ff_scale * self.dropout(
|
|
124
|
+
self.feed_forward_macaron(x))
|
|
125
|
+
if not self.normalize_before:
|
|
126
|
+
x = self.norm_ff_macaron(x)
|
|
127
|
+
|
|
128
|
+
# multi-headed self-attention module
|
|
129
|
+
residual = x
|
|
130
|
+
if self.normalize_before:
|
|
131
|
+
x = self.norm_mha(x)
|
|
132
|
+
# att_cache: (b, head, cache_t, d_k*2)
|
|
133
|
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
|
|
134
|
+
att_cache)
|
|
135
|
+
x = residual + self.dropout(x_att)
|
|
136
|
+
if not self.normalize_before:
|
|
137
|
+
x = self.norm_mha(x)
|
|
138
|
+
|
|
139
|
+
# convolution module
|
|
140
|
+
# Fake new cnn cache here, and then change it in conv_module
|
|
141
|
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
|
142
|
+
if self.conv_module is not None:
|
|
143
|
+
residual = x
|
|
144
|
+
if self.normalize_before:
|
|
145
|
+
x = self.norm_conv(x)
|
|
146
|
+
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
|
147
|
+
x = residual + self.dropout(x)
|
|
148
|
+
|
|
149
|
+
if not self.normalize_before:
|
|
150
|
+
x = self.norm_conv(x)
|
|
151
|
+
|
|
152
|
+
# feed forward module
|
|
153
|
+
residual = x
|
|
154
|
+
if self.normalize_before:
|
|
155
|
+
x = self.norm_ff(x)
|
|
156
|
+
|
|
157
|
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
|
158
|
+
if not self.normalize_before:
|
|
159
|
+
x = self.norm_ff(x)
|
|
160
|
+
|
|
161
|
+
if self.conv_module is not None:
|
|
162
|
+
x = self.norm_final(x)
|
|
163
|
+
return x, mask, new_att_cache, new_cnn_cache
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# Copyright (c) 2019 Shigeki Karita
|
|
2
|
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Positionwise feed forward layer definition."""
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PositionwiseFeedForward(torch.nn.Module):
|
|
21
|
+
"""Positionwise feed forward layer.
|
|
22
|
+
|
|
23
|
+
FeedForward are appied on each position of the sequence.
|
|
24
|
+
The output dim is same with the input dim.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
idim (int): Input dimenstion.
|
|
28
|
+
hidden_units (int): The number of hidden units.
|
|
29
|
+
dropout_rate (float): Dropout rate.
|
|
30
|
+
activation (torch.nn.Module): Activation function
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
idim: int,
|
|
36
|
+
hidden_units: int,
|
|
37
|
+
dropout_rate: float,
|
|
38
|
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
|
39
|
+
):
|
|
40
|
+
"""Construct a PositionwiseFeedForward object."""
|
|
41
|
+
super(PositionwiseFeedForward, self).__init__()
|
|
42
|
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
|
43
|
+
self.activation = activation
|
|
44
|
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
|
45
|
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
|
46
|
+
|
|
47
|
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
|
48
|
+
"""Forward function.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
xs: input tensor (B, L, D)
|
|
52
|
+
Returns:
|
|
53
|
+
output tensor, (B, L, D)
|
|
54
|
+
"""
|
|
55
|
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
|
56
|
+
|