xinference 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +50 -1
- xinference/client/restful/restful_client.py +82 -2
- xinference/constants.py +3 -0
- xinference/core/chat_interface.py +297 -83
- xinference/core/model.py +1 -0
- xinference/core/progress_tracker.py +16 -8
- xinference/core/supervisor.py +45 -1
- xinference/core/worker.py +262 -37
- xinference/deploy/cmdline.py +33 -1
- xinference/model/audio/core.py +11 -1
- xinference/model/audio/megatts.py +105 -0
- xinference/model/audio/model_spec.json +24 -1
- xinference/model/audio/model_spec_modelscope.json +26 -1
- xinference/model/core.py +14 -0
- xinference/model/embedding/core.py +6 -1
- xinference/model/flexible/core.py +6 -1
- xinference/model/image/core.py +6 -1
- xinference/model/image/model_spec.json +17 -1
- xinference/model/image/model_spec_modelscope.json +17 -1
- xinference/model/llm/__init__.py +0 -4
- xinference/model/llm/core.py +4 -0
- xinference/model/llm/llama_cpp/core.py +40 -16
- xinference/model/llm/llm_family.json +413 -84
- xinference/model/llm/llm_family.py +24 -1
- xinference/model/llm/llm_family_modelscope.json +447 -0
- xinference/model/llm/mlx/core.py +16 -2
- xinference/model/llm/transformers/__init__.py +14 -0
- xinference/model/llm/transformers/core.py +30 -6
- xinference/model/llm/transformers/gemma3.py +17 -2
- xinference/model/llm/transformers/intern_vl.py +28 -18
- xinference/model/llm/transformers/minicpmv26.py +21 -2
- xinference/model/llm/transformers/qwen-omni.py +308 -0
- xinference/model/llm/transformers/qwen2_audio.py +1 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -4
- xinference/model/llm/utils.py +11 -1
- xinference/model/llm/vllm/core.py +35 -0
- xinference/model/llm/vllm/distributed_executor.py +8 -2
- xinference/model/rerank/core.py +6 -1
- xinference/model/utils.py +118 -1
- xinference/model/video/core.py +6 -1
- xinference/thirdparty/megatts3/__init__.py +0 -0
- xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
- xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
- xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
- xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
- xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
- xinference/types.py +10 -0
- xinference/utils.py +54 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
- xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
- xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
- xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
- xinference/web/ui/src/locales/en.json +2 -1
- xinference/web/ui/src/locales/zh.json +2 -1
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/METADATA +127 -114
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/RECORD +96 -60
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
- xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
- xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
- xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
- /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,403 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import math
|
|
16
|
+
import torch
|
|
17
|
+
from torch import nn
|
|
18
|
+
from torch.nn import functional as F
|
|
19
|
+
|
|
20
|
+
from tts.modules.ar_dur.commons.layers import Embedding
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def convert_pad_shape(pad_shape):
|
|
24
|
+
l = pad_shape[::-1]
|
|
25
|
+
pad_shape = [item for sublist in l for item in sublist]
|
|
26
|
+
return pad_shape
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def shift_1d(x):
|
|
30
|
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
|
31
|
+
return x
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def sequence_mask(length, max_length=None):
|
|
35
|
+
if max_length is None:
|
|
36
|
+
max_length = length.max()
|
|
37
|
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
|
38
|
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Encoder(nn.Module):
|
|
42
|
+
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
|
|
43
|
+
window_size=None, block_length=None, pre_ln=False, **kwargs):
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.hidden_channels = hidden_channels
|
|
46
|
+
self.filter_channels = filter_channels
|
|
47
|
+
self.n_heads = n_heads
|
|
48
|
+
self.n_layers = n_layers
|
|
49
|
+
self.kernel_size = kernel_size
|
|
50
|
+
self.p_dropout = p_dropout
|
|
51
|
+
self.window_size = window_size
|
|
52
|
+
self.block_length = block_length
|
|
53
|
+
self.pre_ln = pre_ln
|
|
54
|
+
|
|
55
|
+
self.drop = nn.Dropout(p_dropout)
|
|
56
|
+
self.attn_layers = nn.ModuleList()
|
|
57
|
+
self.norm_layers_1 = nn.ModuleList()
|
|
58
|
+
self.ffn_layers = nn.ModuleList()
|
|
59
|
+
self.norm_layers_2 = nn.ModuleList()
|
|
60
|
+
for i in range(self.n_layers):
|
|
61
|
+
self.attn_layers.append(
|
|
62
|
+
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size,
|
|
63
|
+
p_dropout=p_dropout, block_length=block_length))
|
|
64
|
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
|
65
|
+
self.ffn_layers.append(
|
|
66
|
+
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
|
67
|
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
|
68
|
+
if pre_ln:
|
|
69
|
+
self.last_ln = LayerNorm(hidden_channels)
|
|
70
|
+
|
|
71
|
+
def forward(self, x, x_mask, attn_mask=1):
|
|
72
|
+
if isinstance(attn_mask, torch.Tensor):
|
|
73
|
+
attn_mask = attn_mask[:, None]
|
|
74
|
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) * attn_mask
|
|
75
|
+
for i in range(self.n_layers):
|
|
76
|
+
x = x * x_mask
|
|
77
|
+
x_ = x
|
|
78
|
+
if self.pre_ln:
|
|
79
|
+
x = self.norm_layers_1[i](x)
|
|
80
|
+
y = self.attn_layers[i](x, x, attn_mask)
|
|
81
|
+
y = self.drop(y)
|
|
82
|
+
x = x_ + y
|
|
83
|
+
if not self.pre_ln:
|
|
84
|
+
x = self.norm_layers_1[i](x)
|
|
85
|
+
|
|
86
|
+
x_ = x
|
|
87
|
+
if self.pre_ln:
|
|
88
|
+
x = self.norm_layers_2[i](x)
|
|
89
|
+
y = self.ffn_layers[i](x, x_mask)
|
|
90
|
+
y = self.drop(y)
|
|
91
|
+
x = x_ + y
|
|
92
|
+
if not self.pre_ln:
|
|
93
|
+
x = self.norm_layers_2[i](x)
|
|
94
|
+
if self.pre_ln:
|
|
95
|
+
x = self.last_ln(x)
|
|
96
|
+
x = x * x_mask
|
|
97
|
+
return x
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class MultiHeadAttention(nn.Module):
|
|
101
|
+
def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0.,
|
|
102
|
+
block_length=None, proximal_bias=False, proximal_init=False):
|
|
103
|
+
super().__init__()
|
|
104
|
+
assert channels % n_heads == 0
|
|
105
|
+
|
|
106
|
+
self.channels = channels
|
|
107
|
+
self.out_channels = out_channels
|
|
108
|
+
self.n_heads = n_heads
|
|
109
|
+
self.window_size = window_size
|
|
110
|
+
self.heads_share = heads_share
|
|
111
|
+
self.block_length = block_length
|
|
112
|
+
self.proximal_bias = proximal_bias
|
|
113
|
+
self.p_dropout = p_dropout
|
|
114
|
+
self.attn = None
|
|
115
|
+
|
|
116
|
+
self.k_channels = channels // n_heads
|
|
117
|
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
|
118
|
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
|
119
|
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
|
120
|
+
if window_size is not None:
|
|
121
|
+
n_heads_rel = 1 if heads_share else n_heads
|
|
122
|
+
rel_stddev = self.k_channels ** -0.5
|
|
123
|
+
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
|
124
|
+
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
|
125
|
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
|
126
|
+
self.drop = nn.Dropout(p_dropout)
|
|
127
|
+
|
|
128
|
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
|
129
|
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
|
130
|
+
if proximal_init:
|
|
131
|
+
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
|
132
|
+
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
|
133
|
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
|
134
|
+
|
|
135
|
+
def forward(self, x, c, attn_mask=None):
|
|
136
|
+
q = self.conv_q(x)
|
|
137
|
+
k = self.conv_k(c)
|
|
138
|
+
v = self.conv_v(c)
|
|
139
|
+
|
|
140
|
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
|
141
|
+
|
|
142
|
+
x = self.conv_o(x)
|
|
143
|
+
return x
|
|
144
|
+
|
|
145
|
+
def attention(self, query, key, value, mask=None):
|
|
146
|
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
|
147
|
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
|
148
|
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
|
149
|
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
|
150
|
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
|
151
|
+
|
|
152
|
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
|
153
|
+
if self.window_size is not None:
|
|
154
|
+
assert t_s == t_t, "Relative attention is only available for self-attention."
|
|
155
|
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
|
156
|
+
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
|
|
157
|
+
rel_logits = self._relative_position_to_absolute_position(rel_logits)
|
|
158
|
+
scores_local = rel_logits / math.sqrt(self.k_channels)
|
|
159
|
+
scores = scores + scores_local
|
|
160
|
+
if self.proximal_bias:
|
|
161
|
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
|
162
|
+
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
|
163
|
+
if mask is not None:
|
|
164
|
+
scores = scores.masked_fill(mask == 0, -1e4)
|
|
165
|
+
if self.block_length is not None:
|
|
166
|
+
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
|
|
167
|
+
scores = scores * block_mask + -1e4 * (1 - block_mask)
|
|
168
|
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
|
169
|
+
p_attn = self.drop(p_attn)
|
|
170
|
+
output = torch.matmul(p_attn, value)
|
|
171
|
+
if self.window_size is not None:
|
|
172
|
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
|
173
|
+
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
|
174
|
+
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
|
175
|
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
|
176
|
+
return output, p_attn
|
|
177
|
+
|
|
178
|
+
def _matmul_with_relative_values(self, x, y):
|
|
179
|
+
"""
|
|
180
|
+
x: [b, h, l, m]
|
|
181
|
+
y: [h or 1, m, d]
|
|
182
|
+
ret: [b, h, l, d]
|
|
183
|
+
"""
|
|
184
|
+
ret = torch.matmul(x, y.unsqueeze(0))
|
|
185
|
+
return ret
|
|
186
|
+
|
|
187
|
+
def _matmul_with_relative_keys(self, x, y):
|
|
188
|
+
"""
|
|
189
|
+
x: [b, h, l, d]
|
|
190
|
+
y: [h or 1, m, d]
|
|
191
|
+
ret: [b, h, l, m]
|
|
192
|
+
"""
|
|
193
|
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
|
194
|
+
return ret
|
|
195
|
+
|
|
196
|
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
|
197
|
+
max_relative_position = 2 * self.window_size + 1
|
|
198
|
+
# Pad first before slice to avoid using cond ops.
|
|
199
|
+
pad_length = max(length - (self.window_size + 1), 0)
|
|
200
|
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
|
201
|
+
slice_end_position = slice_start_position + 2 * length - 1
|
|
202
|
+
if pad_length > 0:
|
|
203
|
+
padded_relative_embeddings = F.pad(
|
|
204
|
+
relative_embeddings,
|
|
205
|
+
convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
|
|
206
|
+
else:
|
|
207
|
+
padded_relative_embeddings = relative_embeddings
|
|
208
|
+
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
|
209
|
+
return used_relative_embeddings
|
|
210
|
+
|
|
211
|
+
def _relative_position_to_absolute_position(self, x):
|
|
212
|
+
"""
|
|
213
|
+
x: [b, h, l, 2*l-1]
|
|
214
|
+
ret: [b, h, l, l]
|
|
215
|
+
"""
|
|
216
|
+
batch, heads, length, _ = x.size()
|
|
217
|
+
# Concat columns of pad to shift from relative to absolute indexing.
|
|
218
|
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
|
219
|
+
|
|
220
|
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
|
221
|
+
x_flat = x.view([batch, heads, length * 2 * length])
|
|
222
|
+
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
|
223
|
+
|
|
224
|
+
# Reshape and slice out the padded elements.
|
|
225
|
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
|
|
226
|
+
return x_final
|
|
227
|
+
|
|
228
|
+
def _absolute_position_to_relative_position(self, x):
|
|
229
|
+
"""
|
|
230
|
+
x: [b, h, l, l]
|
|
231
|
+
ret: [b, h, l, 2*l-1]
|
|
232
|
+
"""
|
|
233
|
+
batch, heads, length, _ = x.size()
|
|
234
|
+
# padd along column
|
|
235
|
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
|
236
|
+
x_flat = x.view([batch, heads, -1])
|
|
237
|
+
# add 0's in the beginning that will skew the elements after reshape
|
|
238
|
+
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
|
239
|
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
|
240
|
+
return x_final
|
|
241
|
+
|
|
242
|
+
def _attention_bias_proximal(self, length):
|
|
243
|
+
"""Bias for self-attention to encourage attention to close positions.
|
|
244
|
+
Args:
|
|
245
|
+
length: an integer scalar.
|
|
246
|
+
Returns:
|
|
247
|
+
a Tensor with shape [1, 1, length, length]
|
|
248
|
+
"""
|
|
249
|
+
r = torch.arange(length, dtype=torch.float32)
|
|
250
|
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
|
251
|
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class FFN(nn.Module):
|
|
255
|
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
|
|
256
|
+
super().__init__()
|
|
257
|
+
self.in_channels = in_channels
|
|
258
|
+
self.out_channels = out_channels
|
|
259
|
+
self.filter_channels = filter_channels
|
|
260
|
+
self.kernel_size = kernel_size
|
|
261
|
+
self.p_dropout = p_dropout
|
|
262
|
+
self.activation = activation
|
|
263
|
+
|
|
264
|
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
|
265
|
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1)
|
|
266
|
+
self.drop = nn.Dropout(p_dropout)
|
|
267
|
+
|
|
268
|
+
def forward(self, x, x_mask):
|
|
269
|
+
x = self.conv_1(x * x_mask)
|
|
270
|
+
if self.activation == "gelu":
|
|
271
|
+
x = x * torch.sigmoid(1.702 * x)
|
|
272
|
+
else:
|
|
273
|
+
x = torch.relu(x)
|
|
274
|
+
x = self.drop(x)
|
|
275
|
+
x = self.conv_2(x * x_mask)
|
|
276
|
+
return x * x_mask
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class LayerNorm(nn.Module):
|
|
280
|
+
def __init__(self, channels, eps=1e-4):
|
|
281
|
+
super().__init__()
|
|
282
|
+
self.channels = channels
|
|
283
|
+
self.eps = eps
|
|
284
|
+
|
|
285
|
+
self.gamma = nn.Parameter(torch.ones(channels))
|
|
286
|
+
self.beta = nn.Parameter(torch.zeros(channels))
|
|
287
|
+
|
|
288
|
+
def forward(self, x):
|
|
289
|
+
n_dims = len(x.shape)
|
|
290
|
+
mean = torch.mean(x, 1, keepdim=True)
|
|
291
|
+
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
|
292
|
+
|
|
293
|
+
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
|
294
|
+
|
|
295
|
+
shape = [1, -1] + [1] * (n_dims - 2)
|
|
296
|
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
|
297
|
+
return x
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class ConvReluNorm(nn.Module):
|
|
301
|
+
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
|
302
|
+
super().__init__()
|
|
303
|
+
self.in_channels = in_channels
|
|
304
|
+
self.hidden_channels = hidden_channels
|
|
305
|
+
self.out_channels = out_channels
|
|
306
|
+
self.kernel_size = kernel_size
|
|
307
|
+
self.n_layers = n_layers
|
|
308
|
+
self.p_dropout = p_dropout
|
|
309
|
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
|
310
|
+
|
|
311
|
+
self.conv_layers = nn.ModuleList()
|
|
312
|
+
self.norm_layers = nn.ModuleList()
|
|
313
|
+
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
|
314
|
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
|
315
|
+
self.relu_drop = nn.Sequential(
|
|
316
|
+
nn.ReLU(),
|
|
317
|
+
nn.Dropout(p_dropout))
|
|
318
|
+
for _ in range(n_layers - 1):
|
|
319
|
+
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
|
320
|
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
|
321
|
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
|
322
|
+
self.proj.weight.data.zero_()
|
|
323
|
+
self.proj.bias.data.zero_()
|
|
324
|
+
|
|
325
|
+
def forward(self, x, x_mask):
|
|
326
|
+
x_org = x
|
|
327
|
+
for i in range(self.n_layers):
|
|
328
|
+
x = self.conv_layers[i](x * x_mask)
|
|
329
|
+
x = self.norm_layers[i](x)
|
|
330
|
+
x = self.relu_drop(x)
|
|
331
|
+
x = x_org + self.proj(x)
|
|
332
|
+
return x * x_mask
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class RelTransformerEncoder(nn.Module):
|
|
336
|
+
def __init__(self,
|
|
337
|
+
n_vocab,
|
|
338
|
+
out_channels,
|
|
339
|
+
hidden_channels,
|
|
340
|
+
filter_channels,
|
|
341
|
+
n_heads,
|
|
342
|
+
n_layers,
|
|
343
|
+
kernel_size,
|
|
344
|
+
p_dropout=0.0,
|
|
345
|
+
window_size=4,
|
|
346
|
+
block_length=None,
|
|
347
|
+
in_channels=None,
|
|
348
|
+
prenet=True,
|
|
349
|
+
pre_ln=True,
|
|
350
|
+
):
|
|
351
|
+
|
|
352
|
+
super().__init__()
|
|
353
|
+
|
|
354
|
+
self.n_vocab = n_vocab
|
|
355
|
+
self.out_channels = out_channels
|
|
356
|
+
self.hidden_channels = hidden_channels
|
|
357
|
+
self.filter_channels = filter_channels
|
|
358
|
+
self.n_heads = n_heads
|
|
359
|
+
self.n_layers = n_layers
|
|
360
|
+
self.kernel_size = kernel_size
|
|
361
|
+
self.p_dropout = p_dropout
|
|
362
|
+
self.window_size = window_size
|
|
363
|
+
self.block_length = block_length
|
|
364
|
+
self.prenet = prenet
|
|
365
|
+
if n_vocab > 0:
|
|
366
|
+
self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0)
|
|
367
|
+
|
|
368
|
+
if prenet:
|
|
369
|
+
if in_channels is None:
|
|
370
|
+
in_channels = hidden_channels
|
|
371
|
+
self.pre = ConvReluNorm(in_channels, in_channels, in_channels,
|
|
372
|
+
kernel_size=5, n_layers=3, p_dropout=0)
|
|
373
|
+
if in_channels is not None and in_channels != hidden_channels:
|
|
374
|
+
self.encoder_inp_proj = nn.Conv1d(in_channels, hidden_channels, 1)
|
|
375
|
+
self.encoder = Encoder(
|
|
376
|
+
hidden_channels,
|
|
377
|
+
filter_channels,
|
|
378
|
+
n_heads,
|
|
379
|
+
n_layers,
|
|
380
|
+
kernel_size,
|
|
381
|
+
p_dropout,
|
|
382
|
+
window_size=window_size,
|
|
383
|
+
block_length=block_length,
|
|
384
|
+
pre_ln=pre_ln,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
def forward(self, x, x_mask=None, other_embeds=0, attn_mask=1):
|
|
388
|
+
if self.n_vocab > 0:
|
|
389
|
+
x_lengths = (x > 0).long().sum(-1)
|
|
390
|
+
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
|
391
|
+
else:
|
|
392
|
+
x_lengths = (x.abs().sum(-1) > 0).long().sum(-1)
|
|
393
|
+
x = x + other_embeds
|
|
394
|
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
|
395
|
+
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
|
396
|
+
|
|
397
|
+
if self.prenet:
|
|
398
|
+
x = self.pre(x, x_mask)
|
|
399
|
+
self.prenet_out = x.transpose(1, 2)
|
|
400
|
+
if hasattr(self, 'encoder_inp_proj'):
|
|
401
|
+
x = self.encoder_inp_proj(x) * x_mask
|
|
402
|
+
x = self.encoder(x, x_mask, attn_mask)
|
|
403
|
+
return x.transpose(1, 2)
|