xinference 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +1 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +54 -1
- xinference/client/restful/restful_client.py +82 -2
- xinference/constants.py +3 -0
- xinference/core/chat_interface.py +297 -83
- xinference/core/model.py +24 -3
- xinference/core/progress_tracker.py +16 -8
- xinference/core/supervisor.py +51 -1
- xinference/core/worker.py +315 -47
- xinference/deploy/cmdline.py +33 -1
- xinference/model/audio/core.py +11 -1
- xinference/model/audio/megatts.py +105 -0
- xinference/model/audio/model_spec.json +24 -1
- xinference/model/audio/model_spec_modelscope.json +26 -1
- xinference/model/core.py +14 -0
- xinference/model/embedding/core.py +6 -1
- xinference/model/flexible/core.py +6 -1
- xinference/model/image/core.py +6 -1
- xinference/model/image/model_spec.json +17 -1
- xinference/model/image/model_spec_modelscope.json +17 -1
- xinference/model/llm/__init__.py +4 -6
- xinference/model/llm/core.py +5 -0
- xinference/model/llm/llama_cpp/core.py +46 -17
- xinference/model/llm/llm_family.json +530 -85
- xinference/model/llm/llm_family.py +24 -1
- xinference/model/llm/llm_family_modelscope.json +572 -1
- xinference/model/llm/mlx/core.py +16 -2
- xinference/model/llm/reasoning_parser.py +3 -3
- xinference/model/llm/sglang/core.py +111 -13
- xinference/model/llm/transformers/__init__.py +14 -0
- xinference/model/llm/transformers/core.py +31 -6
- xinference/model/llm/transformers/deepseek_vl.py +1 -1
- xinference/model/llm/transformers/deepseek_vl2.py +287 -0
- xinference/model/llm/transformers/gemma3.py +17 -2
- xinference/model/llm/transformers/intern_vl.py +28 -18
- xinference/model/llm/transformers/minicpmv26.py +21 -2
- xinference/model/llm/transformers/qwen-omni.py +308 -0
- xinference/model/llm/transformers/qwen2_audio.py +1 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -4
- xinference/model/llm/utils.py +37 -15
- xinference/model/llm/vllm/core.py +184 -8
- xinference/model/llm/vllm/distributed_executor.py +320 -0
- xinference/model/rerank/core.py +22 -12
- xinference/model/utils.py +118 -1
- xinference/model/video/core.py +6 -1
- xinference/thirdparty/deepseek_vl2/__init__.py +31 -0
- xinference/thirdparty/deepseek_vl2/models/__init__.py +26 -0
- xinference/thirdparty/deepseek_vl2/models/configuration_deepseek.py +210 -0
- xinference/thirdparty/deepseek_vl2/models/conversation.py +310 -0
- xinference/thirdparty/deepseek_vl2/models/modeling_deepseek.py +1975 -0
- xinference/thirdparty/deepseek_vl2/models/modeling_deepseek_vl_v2.py +697 -0
- xinference/thirdparty/deepseek_vl2/models/processing_deepseek_vl_v2.py +675 -0
- xinference/thirdparty/deepseek_vl2/models/siglip_vit.py +661 -0
- xinference/thirdparty/deepseek_vl2/serve/__init__.py +0 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/__init__.py +0 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/gradio_utils.py +83 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/overwrites.py +81 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/presets.py +115 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/utils.py +333 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/simsun.ttc +0 -0
- xinference/thirdparty/deepseek_vl2/serve/inference.py +197 -0
- xinference/thirdparty/deepseek_vl2/utils/__init__.py +18 -0
- xinference/thirdparty/deepseek_vl2/utils/io.py +80 -0
- xinference/thirdparty/megatts3/__init__.py +0 -0
- xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
- xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
- xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
- xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
- xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
- xinference/types.py +10 -0
- xinference/utils.py +54 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
- xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
- xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
- xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
- xinference/web/ui/src/locales/en.json +2 -1
- xinference/web/ui/src/locales/zh.json +2 -1
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/METADATA +128 -115
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/RECORD +124 -63
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
- xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
- xinference/web/ui/build/static/js/main.3cea968e.js +0 -3
- xinference/web/ui/build/static/js/main.3cea968e.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/7f59e45e3f268ab8a4788b6fb024cf8dab088736dff22f5a3a39c122a83ab930.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/dcd60488509450bfff37bfff56de2c096d51de17dd00ec60d4db49c8b483ada1.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
- /xinference/web/ui/build/static/js/{main.3cea968e.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,767 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import math
|
|
16
|
+
import torch
|
|
17
|
+
from torch import nn
|
|
18
|
+
from torch.nn import Parameter, Linear
|
|
19
|
+
from tts.modules.ar_dur.commons.layers import LayerNorm, Embedding
|
|
20
|
+
from tts.modules.ar_dur.commons.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
|
|
21
|
+
import torch.nn.functional as F
|
|
22
|
+
|
|
23
|
+
DEFAULT_MAX_SOURCE_POSITIONS = 3000
|
|
24
|
+
DEFAULT_MAX_TARGET_POSITIONS = 3000
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SinusoidalPositionalEmbedding(nn.Module):
|
|
28
|
+
"""This module produces sinusoidal positional embeddings of any length.
|
|
29
|
+
|
|
30
|
+
Padding symbols are ignored.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, embedding_dim, padding_idx, init_size=1024):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.embedding_dim = embedding_dim
|
|
36
|
+
self.padding_idx = padding_idx
|
|
37
|
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
|
38
|
+
init_size,
|
|
39
|
+
embedding_dim,
|
|
40
|
+
padding_idx,
|
|
41
|
+
)
|
|
42
|
+
self.register_buffer('_float_tensor', torch.FloatTensor(1))
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
|
|
46
|
+
"""Build sinusoidal embeddings.
|
|
47
|
+
|
|
48
|
+
This matches the implementation in tensor2tensor, but differs slightly
|
|
49
|
+
from the description in Section 3.5 of "Attention Is All You Need".
|
|
50
|
+
"""
|
|
51
|
+
half_dim = embedding_dim // 2
|
|
52
|
+
emb = math.log(10000) / (half_dim - 1)
|
|
53
|
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
|
54
|
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
|
55
|
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
|
56
|
+
if embedding_dim % 2 == 1:
|
|
57
|
+
# zero pad
|
|
58
|
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
|
59
|
+
if padding_idx is not None:
|
|
60
|
+
emb[padding_idx, :] = 0
|
|
61
|
+
return emb
|
|
62
|
+
|
|
63
|
+
def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
|
|
64
|
+
"""Input is expected to be of size [bsz x seqlen]."""
|
|
65
|
+
bsz, seq_len = input.shape[:2]
|
|
66
|
+
max_pos = self.padding_idx + 1 + seq_len
|
|
67
|
+
if self.weights is None or max_pos > self.weights.size(0):
|
|
68
|
+
# recompute/expand embeddings if needed
|
|
69
|
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
|
70
|
+
max_pos,
|
|
71
|
+
self.embedding_dim,
|
|
72
|
+
self.padding_idx,
|
|
73
|
+
)
|
|
74
|
+
self.weights = self.weights.to(self._float_tensor)
|
|
75
|
+
|
|
76
|
+
if incremental_state is not None:
|
|
77
|
+
# positions is the same for every token when decoding a single step
|
|
78
|
+
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
|
|
79
|
+
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
|
|
80
|
+
|
|
81
|
+
positions = make_positions(input, self.padding_idx) if positions is None else positions
|
|
82
|
+
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
|
|
83
|
+
|
|
84
|
+
def max_positions(self):
|
|
85
|
+
"""Maximum number of supported positions."""
|
|
86
|
+
return int(1e5) # an arbitrary large number
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class TransformerFFNLayer(nn.Module):
|
|
90
|
+
def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu', bias=True):
|
|
91
|
+
super().__init__()
|
|
92
|
+
self.kernel_size = kernel_size
|
|
93
|
+
self.dropout = dropout
|
|
94
|
+
self.act = act
|
|
95
|
+
if padding == 'SAME':
|
|
96
|
+
self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size,
|
|
97
|
+
padding=kernel_size // 2, bias=bias)
|
|
98
|
+
elif padding == 'LEFT':
|
|
99
|
+
self.ffn_1 = nn.Sequential(
|
|
100
|
+
nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
|
|
101
|
+
nn.Conv1d(hidden_size, filter_size, kernel_size, bias=bias)
|
|
102
|
+
)
|
|
103
|
+
self.ffn_2 = Linear(filter_size, hidden_size, bias=bias)
|
|
104
|
+
|
|
105
|
+
def forward(self, x, incremental_state=None):
|
|
106
|
+
# x: T x B x C
|
|
107
|
+
if incremental_state is not None:
|
|
108
|
+
saved_state = self._get_input_buffer(incremental_state)
|
|
109
|
+
if 'prev_input' in saved_state:
|
|
110
|
+
prev_input = saved_state['prev_input']
|
|
111
|
+
x = torch.cat((prev_input, x), dim=0)
|
|
112
|
+
x = x[-self.kernel_size:]
|
|
113
|
+
saved_state['prev_input'] = x
|
|
114
|
+
self._set_input_buffer(incremental_state, saved_state)
|
|
115
|
+
|
|
116
|
+
x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
|
|
117
|
+
x = x * self.kernel_size ** -0.5
|
|
118
|
+
|
|
119
|
+
if incremental_state is not None:
|
|
120
|
+
x = x[-1:]
|
|
121
|
+
if self.act == 'gelu':
|
|
122
|
+
x = F.gelu(x)
|
|
123
|
+
if self.act == 'relu':
|
|
124
|
+
x = F.relu(x)
|
|
125
|
+
x = F.dropout(x, self.dropout, training=self.training)
|
|
126
|
+
x = self.ffn_2(x)
|
|
127
|
+
return x
|
|
128
|
+
|
|
129
|
+
def _get_input_buffer(self, incremental_state):
|
|
130
|
+
return get_incremental_state(
|
|
131
|
+
self,
|
|
132
|
+
incremental_state,
|
|
133
|
+
'f',
|
|
134
|
+
) or {}
|
|
135
|
+
|
|
136
|
+
def _set_input_buffer(self, incremental_state, buffer):
|
|
137
|
+
set_incremental_state(
|
|
138
|
+
self,
|
|
139
|
+
incremental_state,
|
|
140
|
+
'f',
|
|
141
|
+
buffer,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def clear_buffer(self, incremental_state):
|
|
145
|
+
if incremental_state is not None:
|
|
146
|
+
saved_state = self._get_input_buffer(incremental_state)
|
|
147
|
+
if 'prev_input' in saved_state:
|
|
148
|
+
del saved_state['prev_input']
|
|
149
|
+
self._set_input_buffer(incremental_state, saved_state)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class MultiheadAttention(nn.Module):
|
|
153
|
+
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
|
|
154
|
+
add_bias_kv=False, add_zero_attn=False, self_attention=False,
|
|
155
|
+
encoder_decoder_attention=False):
|
|
156
|
+
super().__init__()
|
|
157
|
+
self.embed_dim = embed_dim
|
|
158
|
+
self.kdim = kdim if kdim is not None else embed_dim
|
|
159
|
+
self.vdim = vdim if vdim is not None else embed_dim
|
|
160
|
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
|
161
|
+
|
|
162
|
+
self.num_heads = num_heads
|
|
163
|
+
self.dropout = dropout
|
|
164
|
+
self.head_dim = embed_dim // num_heads
|
|
165
|
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
|
166
|
+
self.scaling = self.head_dim ** -0.5
|
|
167
|
+
|
|
168
|
+
self.self_attention = self_attention
|
|
169
|
+
self.encoder_decoder_attention = encoder_decoder_attention
|
|
170
|
+
|
|
171
|
+
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
|
|
172
|
+
'value to be of the same size'
|
|
173
|
+
|
|
174
|
+
if self.qkv_same_dim:
|
|
175
|
+
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
|
|
176
|
+
else:
|
|
177
|
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
|
178
|
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
|
179
|
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
|
180
|
+
|
|
181
|
+
if bias:
|
|
182
|
+
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
|
|
183
|
+
else:
|
|
184
|
+
self.register_parameter('in_proj_bias', None)
|
|
185
|
+
|
|
186
|
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
187
|
+
|
|
188
|
+
if add_bias_kv:
|
|
189
|
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
|
190
|
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
|
191
|
+
else:
|
|
192
|
+
self.bias_k = self.bias_v = None
|
|
193
|
+
|
|
194
|
+
self.add_zero_attn = add_zero_attn
|
|
195
|
+
|
|
196
|
+
self.reset_parameters()
|
|
197
|
+
|
|
198
|
+
self.enable_torch_version = False
|
|
199
|
+
self.last_attn_probs = None
|
|
200
|
+
|
|
201
|
+
def reset_parameters(self):
|
|
202
|
+
if self.qkv_same_dim:
|
|
203
|
+
nn.init.xavier_uniform_(self.in_proj_weight)
|
|
204
|
+
else:
|
|
205
|
+
nn.init.xavier_uniform_(self.k_proj_weight)
|
|
206
|
+
nn.init.xavier_uniform_(self.v_proj_weight)
|
|
207
|
+
nn.init.xavier_uniform_(self.q_proj_weight)
|
|
208
|
+
|
|
209
|
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
|
210
|
+
if self.in_proj_bias is not None:
|
|
211
|
+
nn.init.constant_(self.in_proj_bias, 0.)
|
|
212
|
+
nn.init.constant_(self.out_proj.bias, 0.)
|
|
213
|
+
if self.bias_k is not None:
|
|
214
|
+
nn.init.xavier_normal_(self.bias_k)
|
|
215
|
+
if self.bias_v is not None:
|
|
216
|
+
nn.init.xavier_normal_(self.bias_v)
|
|
217
|
+
|
|
218
|
+
def forward(
|
|
219
|
+
self,
|
|
220
|
+
query, key, value,
|
|
221
|
+
key_padding_mask=None,
|
|
222
|
+
incremental_state=None,
|
|
223
|
+
need_weights=True,
|
|
224
|
+
static_kv=False,
|
|
225
|
+
attn_mask=None,
|
|
226
|
+
before_softmax=False,
|
|
227
|
+
need_head_weights=False,
|
|
228
|
+
enc_dec_attn_constraint_mask=None,
|
|
229
|
+
reset_attn_weight=None
|
|
230
|
+
):
|
|
231
|
+
"""Input shape: Time x Batch x Channel
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
|
235
|
+
keys that are pads, of shape `(batch, src_len)`, where
|
|
236
|
+
padding elements are indicated by 1s.
|
|
237
|
+
need_weights (bool, optional): return the attention weights,
|
|
238
|
+
averaged over heads (default: False).
|
|
239
|
+
attn_mask (ByteTensor, optional): typically used to
|
|
240
|
+
implement causal attention, where the mask prevents the
|
|
241
|
+
attention from looking forward in time (default: None).
|
|
242
|
+
before_softmax (bool, optional): return the raw attention
|
|
243
|
+
weights and values before the attention softmax.
|
|
244
|
+
need_head_weights (bool, optional): return the attention
|
|
245
|
+
weights for each head. Implies *need_weights*. Default:
|
|
246
|
+
return the average attention weights over all heads.
|
|
247
|
+
"""
|
|
248
|
+
if need_head_weights:
|
|
249
|
+
need_weights = True
|
|
250
|
+
|
|
251
|
+
tgt_len, bsz, embed_dim = query.size()
|
|
252
|
+
assert embed_dim == self.embed_dim
|
|
253
|
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
|
254
|
+
|
|
255
|
+
if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
|
|
256
|
+
if self.qkv_same_dim:
|
|
257
|
+
return F.multi_head_attention_forward(query, key, value,
|
|
258
|
+
self.embed_dim, self.num_heads,
|
|
259
|
+
self.in_proj_weight,
|
|
260
|
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
|
261
|
+
self.add_zero_attn, self.dropout,
|
|
262
|
+
self.out_proj.weight, self.out_proj.bias,
|
|
263
|
+
self.training, key_padding_mask, need_weights,
|
|
264
|
+
attn_mask)
|
|
265
|
+
else:
|
|
266
|
+
return F.multi_head_attention_forward(query, key, value,
|
|
267
|
+
self.embed_dim, self.num_heads,
|
|
268
|
+
torch.empty([0]),
|
|
269
|
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
|
270
|
+
self.add_zero_attn, self.dropout,
|
|
271
|
+
self.out_proj.weight, self.out_proj.bias,
|
|
272
|
+
self.training, key_padding_mask, need_weights,
|
|
273
|
+
attn_mask, use_separate_proj_weight=True,
|
|
274
|
+
q_proj_weight=self.q_proj_weight,
|
|
275
|
+
k_proj_weight=self.k_proj_weight,
|
|
276
|
+
v_proj_weight=self.v_proj_weight)
|
|
277
|
+
|
|
278
|
+
if incremental_state is not None:
|
|
279
|
+
saved_state = self._get_input_buffer(incremental_state)
|
|
280
|
+
if 'prev_key' in saved_state:
|
|
281
|
+
# previous time steps are cached - no need to recompute
|
|
282
|
+
# key and value if they are static
|
|
283
|
+
if static_kv:
|
|
284
|
+
assert self.encoder_decoder_attention and not self.self_attention
|
|
285
|
+
key = value = None
|
|
286
|
+
else:
|
|
287
|
+
saved_state = None
|
|
288
|
+
|
|
289
|
+
if self.self_attention:
|
|
290
|
+
# self-attention
|
|
291
|
+
q, k, v = self.in_proj_qkv(query)
|
|
292
|
+
elif self.encoder_decoder_attention:
|
|
293
|
+
# encoder-decoder attention
|
|
294
|
+
q = self.in_proj_q(query)
|
|
295
|
+
if key is None:
|
|
296
|
+
assert value is None
|
|
297
|
+
k = v = None
|
|
298
|
+
else:
|
|
299
|
+
k = self.in_proj_k(key)
|
|
300
|
+
v = self.in_proj_v(key)
|
|
301
|
+
|
|
302
|
+
else:
|
|
303
|
+
q = self.in_proj_q(query)
|
|
304
|
+
k = self.in_proj_k(key)
|
|
305
|
+
v = self.in_proj_v(value)
|
|
306
|
+
q = q * self.scaling
|
|
307
|
+
|
|
308
|
+
if self.bias_k is not None:
|
|
309
|
+
assert self.bias_v is not None
|
|
310
|
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
|
311
|
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
|
312
|
+
if attn_mask is not None:
|
|
313
|
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
|
314
|
+
if key_padding_mask is not None:
|
|
315
|
+
key_padding_mask = torch.cat(
|
|
316
|
+
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
|
|
317
|
+
|
|
318
|
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
|
319
|
+
if k is not None:
|
|
320
|
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
|
321
|
+
if v is not None:
|
|
322
|
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
|
323
|
+
|
|
324
|
+
if saved_state is not None:
|
|
325
|
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
|
326
|
+
if 'prev_key' in saved_state:
|
|
327
|
+
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
|
|
328
|
+
if static_kv:
|
|
329
|
+
k = prev_key
|
|
330
|
+
else:
|
|
331
|
+
k = torch.cat((prev_key, k), dim=1)
|
|
332
|
+
if 'prev_value' in saved_state:
|
|
333
|
+
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
|
|
334
|
+
if static_kv:
|
|
335
|
+
v = prev_value
|
|
336
|
+
else:
|
|
337
|
+
v = torch.cat((prev_value, v), dim=1)
|
|
338
|
+
if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
|
|
339
|
+
prev_key_padding_mask = saved_state['prev_key_padding_mask']
|
|
340
|
+
if static_kv:
|
|
341
|
+
key_padding_mask = prev_key_padding_mask
|
|
342
|
+
else:
|
|
343
|
+
key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
|
|
344
|
+
|
|
345
|
+
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
|
346
|
+
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
|
347
|
+
saved_state['prev_key_padding_mask'] = key_padding_mask
|
|
348
|
+
|
|
349
|
+
self._set_input_buffer(incremental_state, saved_state)
|
|
350
|
+
|
|
351
|
+
src_len = k.size(1)
|
|
352
|
+
|
|
353
|
+
# This is part of a workaround to get around fork/join parallelism
|
|
354
|
+
# not supporting Optional types.
|
|
355
|
+
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
|
|
356
|
+
key_padding_mask = None
|
|
357
|
+
|
|
358
|
+
if key_padding_mask is not None:
|
|
359
|
+
assert key_padding_mask.size(0) == bsz
|
|
360
|
+
assert key_padding_mask.size(1) == src_len
|
|
361
|
+
|
|
362
|
+
if self.add_zero_attn:
|
|
363
|
+
src_len += 1
|
|
364
|
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
|
365
|
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
|
366
|
+
if attn_mask is not None:
|
|
367
|
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
|
368
|
+
if key_padding_mask is not None:
|
|
369
|
+
key_padding_mask = torch.cat(
|
|
370
|
+
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
|
|
371
|
+
|
|
372
|
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
|
373
|
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
|
374
|
+
|
|
375
|
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
|
376
|
+
|
|
377
|
+
if attn_mask is not None:
|
|
378
|
+
if len(attn_mask.shape) == 2:
|
|
379
|
+
attn_mask = attn_mask.unsqueeze(0)
|
|
380
|
+
elif len(attn_mask.shape) == 3:
|
|
381
|
+
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
|
|
382
|
+
bsz * self.num_heads, tgt_len, src_len)
|
|
383
|
+
attn_weights = attn_weights + attn_mask
|
|
384
|
+
|
|
385
|
+
if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
|
|
386
|
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
387
|
+
attn_weights = attn_weights.masked_fill(
|
|
388
|
+
enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
|
|
389
|
+
-1e8,
|
|
390
|
+
)
|
|
391
|
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
392
|
+
|
|
393
|
+
if key_padding_mask is not None:
|
|
394
|
+
# don't attend to padding symbols
|
|
395
|
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
396
|
+
attn_weights = attn_weights.masked_fill(
|
|
397
|
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
|
398
|
+
-1e8,
|
|
399
|
+
)
|
|
400
|
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
401
|
+
|
|
402
|
+
attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
403
|
+
|
|
404
|
+
if before_softmax:
|
|
405
|
+
return attn_weights, v
|
|
406
|
+
|
|
407
|
+
attn_weights_float = softmax(attn_weights, dim=-1)
|
|
408
|
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
|
409
|
+
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
|
|
410
|
+
|
|
411
|
+
if reset_attn_weight is not None:
|
|
412
|
+
if reset_attn_weight:
|
|
413
|
+
self.last_attn_probs = attn_probs.detach()
|
|
414
|
+
else:
|
|
415
|
+
assert self.last_attn_probs is not None
|
|
416
|
+
attn_probs = self.last_attn_probs
|
|
417
|
+
attn = torch.bmm(attn_probs, v)
|
|
418
|
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
|
419
|
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
|
420
|
+
attn = self.out_proj(attn)
|
|
421
|
+
|
|
422
|
+
if need_weights:
|
|
423
|
+
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
|
424
|
+
if not need_head_weights:
|
|
425
|
+
# average attention weights over heads
|
|
426
|
+
attn_weights = attn_weights.mean(dim=0)
|
|
427
|
+
else:
|
|
428
|
+
attn_weights = None
|
|
429
|
+
|
|
430
|
+
return attn, (attn_weights, attn_logits)
|
|
431
|
+
|
|
432
|
+
def in_proj_qkv(self, query):
|
|
433
|
+
return self._in_proj(query).chunk(3, dim=-1)
|
|
434
|
+
|
|
435
|
+
def in_proj_q(self, query):
|
|
436
|
+
if self.qkv_same_dim:
|
|
437
|
+
return self._in_proj(query, end=self.embed_dim)
|
|
438
|
+
else:
|
|
439
|
+
bias = self.in_proj_bias
|
|
440
|
+
if bias is not None:
|
|
441
|
+
bias = bias[:self.embed_dim]
|
|
442
|
+
return F.linear(query, self.q_proj_weight, bias)
|
|
443
|
+
|
|
444
|
+
def in_proj_k(self, key):
|
|
445
|
+
if self.qkv_same_dim:
|
|
446
|
+
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
|
|
447
|
+
else:
|
|
448
|
+
weight = self.k_proj_weight
|
|
449
|
+
bias = self.in_proj_bias
|
|
450
|
+
if bias is not None:
|
|
451
|
+
bias = bias[self.embed_dim:2 * self.embed_dim]
|
|
452
|
+
return F.linear(key, weight, bias)
|
|
453
|
+
|
|
454
|
+
def in_proj_v(self, value):
|
|
455
|
+
if self.qkv_same_dim:
|
|
456
|
+
return self._in_proj(value, start=2 * self.embed_dim)
|
|
457
|
+
else:
|
|
458
|
+
weight = self.v_proj_weight
|
|
459
|
+
bias = self.in_proj_bias
|
|
460
|
+
if bias is not None:
|
|
461
|
+
bias = bias[2 * self.embed_dim:]
|
|
462
|
+
return F.linear(value, weight, bias)
|
|
463
|
+
|
|
464
|
+
def _in_proj(self, input, start=0, end=None):
|
|
465
|
+
weight = self.in_proj_weight
|
|
466
|
+
bias = self.in_proj_bias
|
|
467
|
+
weight = weight[start:end, :]
|
|
468
|
+
if bias is not None:
|
|
469
|
+
bias = bias[start:end]
|
|
470
|
+
return F.linear(input, weight, bias)
|
|
471
|
+
|
|
472
|
+
def _get_input_buffer(self, incremental_state):
|
|
473
|
+
return get_incremental_state(
|
|
474
|
+
self,
|
|
475
|
+
incremental_state,
|
|
476
|
+
'attn_state',
|
|
477
|
+
) or {}
|
|
478
|
+
|
|
479
|
+
def _set_input_buffer(self, incremental_state, buffer):
|
|
480
|
+
set_incremental_state(
|
|
481
|
+
self,
|
|
482
|
+
incremental_state,
|
|
483
|
+
'attn_state',
|
|
484
|
+
buffer,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
|
|
488
|
+
return attn_weights
|
|
489
|
+
|
|
490
|
+
def clear_buffer(self, incremental_state=None):
|
|
491
|
+
if incremental_state is not None:
|
|
492
|
+
saved_state = self._get_input_buffer(incremental_state)
|
|
493
|
+
if 'prev_key' in saved_state:
|
|
494
|
+
del saved_state['prev_key']
|
|
495
|
+
if 'prev_value' in saved_state:
|
|
496
|
+
del saved_state['prev_value']
|
|
497
|
+
self._set_input_buffer(incremental_state, saved_state)
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
class EncSALayer(nn.Module):
|
|
501
|
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
|
|
502
|
+
relu_dropout=0.1, kernel_size=9, padding='SAME', act='gelu',
|
|
503
|
+
ffn_hidden_size=1024):
|
|
504
|
+
super().__init__()
|
|
505
|
+
self.c = c
|
|
506
|
+
self.dropout = dropout
|
|
507
|
+
self.num_heads = num_heads
|
|
508
|
+
if num_heads > 0:
|
|
509
|
+
self.layer_norm1 = LayerNorm(c)
|
|
510
|
+
self.self_attn = MultiheadAttention(
|
|
511
|
+
self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
|
|
512
|
+
self.layer_norm2 = LayerNorm(c)
|
|
513
|
+
self.ffn = TransformerFFNLayer(
|
|
514
|
+
c, ffn_hidden_size, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
|
|
515
|
+
|
|
516
|
+
def forward(self, x, encoder_padding_mask=None, **kwargs):
|
|
517
|
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
|
518
|
+
if layer_norm_training is not None:
|
|
519
|
+
self.layer_norm1.training = layer_norm_training
|
|
520
|
+
self.layer_norm2.training = layer_norm_training
|
|
521
|
+
if self.num_heads > 0:
|
|
522
|
+
residual = x
|
|
523
|
+
x = self.layer_norm1(x)
|
|
524
|
+
x, _, = self.self_attn(
|
|
525
|
+
query=x,
|
|
526
|
+
key=x,
|
|
527
|
+
value=x,
|
|
528
|
+
key_padding_mask=encoder_padding_mask
|
|
529
|
+
)
|
|
530
|
+
x = F.dropout(x, self.dropout, training=self.training)
|
|
531
|
+
x = residual + x
|
|
532
|
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
|
533
|
+
|
|
534
|
+
residual = x
|
|
535
|
+
x = self.layer_norm2(x)
|
|
536
|
+
x = self.ffn(x)
|
|
537
|
+
x = F.dropout(x, self.dropout, training=self.training)
|
|
538
|
+
x = residual + x
|
|
539
|
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
|
540
|
+
return x
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
class DecSALayer(nn.Module):
|
|
544
|
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
|
|
545
|
+
kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False):
|
|
546
|
+
super().__init__()
|
|
547
|
+
self.c = c
|
|
548
|
+
self.dropout = dropout
|
|
549
|
+
self.layer_norm1 = LayerNorm(c)
|
|
550
|
+
self.self_attn = MultiheadAttention(
|
|
551
|
+
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
|
|
552
|
+
)
|
|
553
|
+
self.layer_norm2 = LayerNorm(c)
|
|
554
|
+
self.encoder_attn = MultiheadAttention(
|
|
555
|
+
c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
|
|
556
|
+
)
|
|
557
|
+
self.layer_norm3 = LayerNorm(c)
|
|
558
|
+
self.ffn = TransformerFFNLayer(
|
|
559
|
+
c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
|
|
560
|
+
self.post_ln = post_ln
|
|
561
|
+
|
|
562
|
+
def forward(
|
|
563
|
+
self,
|
|
564
|
+
x,
|
|
565
|
+
encoder_out=None,
|
|
566
|
+
encoder_padding_mask=None,
|
|
567
|
+
incremental_state=None,
|
|
568
|
+
self_attn_mask=None,
|
|
569
|
+
self_attn_padding_mask=None,
|
|
570
|
+
attn_out=None,
|
|
571
|
+
reset_attn_weight=None,
|
|
572
|
+
**kwargs,
|
|
573
|
+
):
|
|
574
|
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
|
575
|
+
if layer_norm_training is not None:
|
|
576
|
+
self.layer_norm1.training = layer_norm_training
|
|
577
|
+
self.layer_norm2.training = layer_norm_training
|
|
578
|
+
self.layer_norm3.training = layer_norm_training
|
|
579
|
+
residual = x
|
|
580
|
+
if not self.post_ln:
|
|
581
|
+
x = self.layer_norm1(x)
|
|
582
|
+
x, _ = self.self_attn(
|
|
583
|
+
query=x,
|
|
584
|
+
key=x,
|
|
585
|
+
value=x,
|
|
586
|
+
key_padding_mask=self_attn_padding_mask,
|
|
587
|
+
incremental_state=incremental_state,
|
|
588
|
+
attn_mask=self_attn_mask
|
|
589
|
+
)
|
|
590
|
+
x = F.dropout(x, self.dropout, training=self.training)
|
|
591
|
+
x = residual + x
|
|
592
|
+
if self.post_ln:
|
|
593
|
+
x = self.layer_norm1(x)
|
|
594
|
+
|
|
595
|
+
attn_logits = None
|
|
596
|
+
if encoder_out is not None or attn_out is not None:
|
|
597
|
+
residual = x
|
|
598
|
+
if not self.post_ln:
|
|
599
|
+
x = self.layer_norm2(x)
|
|
600
|
+
if encoder_out is not None:
|
|
601
|
+
x, attn = self.encoder_attn(
|
|
602
|
+
query=x,
|
|
603
|
+
key=encoder_out,
|
|
604
|
+
value=encoder_out,
|
|
605
|
+
key_padding_mask=encoder_padding_mask,
|
|
606
|
+
incremental_state=incremental_state,
|
|
607
|
+
static_kv=True,
|
|
608
|
+
enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
|
|
609
|
+
'enc_dec_attn_constraint_mask'),
|
|
610
|
+
reset_attn_weight=reset_attn_weight
|
|
611
|
+
)
|
|
612
|
+
attn_logits = attn[1]
|
|
613
|
+
elif attn_out is not None:
|
|
614
|
+
x = self.encoder_attn.in_proj_v(attn_out)
|
|
615
|
+
if encoder_out is not None or attn_out is not None:
|
|
616
|
+
x = F.dropout(x, self.dropout, training=self.training)
|
|
617
|
+
x = residual + x
|
|
618
|
+
if self.post_ln:
|
|
619
|
+
x = self.layer_norm2(x)
|
|
620
|
+
|
|
621
|
+
residual = x
|
|
622
|
+
if not self.post_ln:
|
|
623
|
+
x = self.layer_norm3(x)
|
|
624
|
+
x = self.ffn(x, incremental_state=incremental_state)
|
|
625
|
+
x = F.dropout(x, self.dropout, training=self.training)
|
|
626
|
+
x = residual + x
|
|
627
|
+
if self.post_ln:
|
|
628
|
+
x = self.layer_norm3(x)
|
|
629
|
+
return x, attn_logits
|
|
630
|
+
|
|
631
|
+
def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
|
|
632
|
+
self.encoder_attn.clear_buffer(incremental_state)
|
|
633
|
+
self.ffn.clear_buffer(incremental_state)
|
|
634
|
+
|
|
635
|
+
def set_buffer(self, name, tensor, incremental_state):
|
|
636
|
+
return set_incremental_state(self, incremental_state, name, tensor)
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
class TransformerEncoderLayer(nn.Module):
|
|
640
|
+
def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024):
|
|
641
|
+
super().__init__()
|
|
642
|
+
self.hidden_size = hidden_size
|
|
643
|
+
self.dropout = dropout
|
|
644
|
+
self.num_heads = num_heads
|
|
645
|
+
self.op = EncSALayer(
|
|
646
|
+
hidden_size, num_heads, dropout=dropout,
|
|
647
|
+
attention_dropout=0.0, relu_dropout=dropout,
|
|
648
|
+
kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size)
|
|
649
|
+
|
|
650
|
+
def forward(self, x, **kwargs):
|
|
651
|
+
return self.op(x, **kwargs)
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
class TransformerDecoderLayer(nn.Module):
|
|
655
|
+
def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024, post_ln=False):
|
|
656
|
+
super().__init__()
|
|
657
|
+
self.hidden_size = hidden_size
|
|
658
|
+
self.dropout = dropout
|
|
659
|
+
self.num_heads = num_heads
|
|
660
|
+
self.op = DecSALayer(
|
|
661
|
+
hidden_size, num_heads, dropout=dropout,
|
|
662
|
+
attention_dropout=0.0, relu_dropout=dropout,
|
|
663
|
+
kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
|
|
664
|
+
post_ln=post_ln)
|
|
665
|
+
|
|
666
|
+
def forward(self, x, **kwargs):
|
|
667
|
+
return self.op(x, **kwargs)
|
|
668
|
+
|
|
669
|
+
def clear_buffer(self, *args):
|
|
670
|
+
return self.op.clear_buffer(*args)
|
|
671
|
+
|
|
672
|
+
def set_buffer(self, *args):
|
|
673
|
+
return self.op.set_buffer(*args)
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
class FFTBlocks(nn.Module):
|
|
677
|
+
def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=0.0,
|
|
678
|
+
num_heads=2, use_pos_embed=True, use_last_norm=True,
|
|
679
|
+
use_pos_embed_alpha=True, ffn_hidden_size=1024):
|
|
680
|
+
super().__init__()
|
|
681
|
+
self.num_layers = num_layers
|
|
682
|
+
embed_dim = self.hidden_size = hidden_size
|
|
683
|
+
self.dropout = dropout
|
|
684
|
+
self.use_pos_embed = use_pos_embed
|
|
685
|
+
self.use_last_norm = use_last_norm
|
|
686
|
+
if use_pos_embed:
|
|
687
|
+
self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
|
|
688
|
+
self.padding_idx = 0
|
|
689
|
+
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
|
|
690
|
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
|
691
|
+
embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
self.layers = nn.ModuleList([])
|
|
695
|
+
self.layers.extend([
|
|
696
|
+
TransformerEncoderLayer(self.hidden_size, self.dropout,
|
|
697
|
+
kernel_size=ffn_kernel_size, num_heads=num_heads,
|
|
698
|
+
ffn_hidden_size=ffn_hidden_size)
|
|
699
|
+
for _ in range(self.num_layers)
|
|
700
|
+
])
|
|
701
|
+
if self.use_last_norm:
|
|
702
|
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
|
703
|
+
else:
|
|
704
|
+
self.layer_norm = None
|
|
705
|
+
|
|
706
|
+
def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
|
|
707
|
+
"""
|
|
708
|
+
:param x: [B, T, C]
|
|
709
|
+
:param padding_mask: [B, T]
|
|
710
|
+
:return: [B, T, C] or [L, B, T, C]
|
|
711
|
+
"""
|
|
712
|
+
padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
|
|
713
|
+
nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
|
|
714
|
+
if self.use_pos_embed:
|
|
715
|
+
positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
|
|
716
|
+
x = x + positions
|
|
717
|
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
718
|
+
# B x T x C -> T x B x C
|
|
719
|
+
x = x.transpose(0, 1) * nonpadding_mask_TB
|
|
720
|
+
hiddens = []
|
|
721
|
+
for layer in self.layers:
|
|
722
|
+
x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
|
|
723
|
+
hiddens.append(x)
|
|
724
|
+
if self.use_last_norm:
|
|
725
|
+
x = self.layer_norm(x) * nonpadding_mask_TB
|
|
726
|
+
if return_hiddens:
|
|
727
|
+
x = torch.stack(hiddens, 0) # [L, T, B, C]
|
|
728
|
+
x = x.transpose(1, 2) # [L, B, T, C]
|
|
729
|
+
else:
|
|
730
|
+
x = x.transpose(0, 1) # [B, T, C]
|
|
731
|
+
return x
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
class FastSpeechEncoder(FFTBlocks):
|
|
735
|
+
def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_size=9,
|
|
736
|
+
dropout=0.0, num_heads=2, ffn_hidden_size=1024):
|
|
737
|
+
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
|
|
738
|
+
use_pos_embed=False, dropout=dropout, ffn_hidden_size=ffn_hidden_size)
|
|
739
|
+
self.embed_tokens = Embedding(dict_size, hidden_size, 0)
|
|
740
|
+
self.embed_scale = math.sqrt(hidden_size)
|
|
741
|
+
self.padding_idx = 0
|
|
742
|
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
|
743
|
+
hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
def forward(self, txt_tokens, attn_mask=None, other_embeds=0):
|
|
747
|
+
"""
|
|
748
|
+
|
|
749
|
+
:param txt_tokens: [B, T]
|
|
750
|
+
:return: {
|
|
751
|
+
'encoder_out': [B x T x C]
|
|
752
|
+
}
|
|
753
|
+
"""
|
|
754
|
+
encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
|
|
755
|
+
x = self.forward_embedding(txt_tokens) + other_embeds # [B, T, H]
|
|
756
|
+
if self.num_layers > 0:
|
|
757
|
+
x = super(FastSpeechEncoder, self).forward(x, encoder_padding_mask, attn_mask=attn_mask)
|
|
758
|
+
return x
|
|
759
|
+
|
|
760
|
+
def forward_embedding(self, txt_tokens):
|
|
761
|
+
# embed tokens and positions
|
|
762
|
+
x = self.embed_scale * self.embed_tokens(txt_tokens)
|
|
763
|
+
if self.use_pos_embed:
|
|
764
|
+
positions = self.embed_positions(txt_tokens)
|
|
765
|
+
x = x + positions
|
|
766
|
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
767
|
+
return x
|