xinference 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +1 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +54 -1
- xinference/client/restful/restful_client.py +82 -2
- xinference/constants.py +3 -0
- xinference/core/chat_interface.py +297 -83
- xinference/core/model.py +24 -3
- xinference/core/progress_tracker.py +16 -8
- xinference/core/supervisor.py +51 -1
- xinference/core/worker.py +315 -47
- xinference/deploy/cmdline.py +33 -1
- xinference/model/audio/core.py +11 -1
- xinference/model/audio/megatts.py +105 -0
- xinference/model/audio/model_spec.json +24 -1
- xinference/model/audio/model_spec_modelscope.json +26 -1
- xinference/model/core.py +14 -0
- xinference/model/embedding/core.py +6 -1
- xinference/model/flexible/core.py +6 -1
- xinference/model/image/core.py +6 -1
- xinference/model/image/model_spec.json +17 -1
- xinference/model/image/model_spec_modelscope.json +17 -1
- xinference/model/llm/__init__.py +4 -6
- xinference/model/llm/core.py +5 -0
- xinference/model/llm/llama_cpp/core.py +46 -17
- xinference/model/llm/llm_family.json +530 -85
- xinference/model/llm/llm_family.py +24 -1
- xinference/model/llm/llm_family_modelscope.json +572 -1
- xinference/model/llm/mlx/core.py +16 -2
- xinference/model/llm/reasoning_parser.py +3 -3
- xinference/model/llm/sglang/core.py +111 -13
- xinference/model/llm/transformers/__init__.py +14 -0
- xinference/model/llm/transformers/core.py +31 -6
- xinference/model/llm/transformers/deepseek_vl.py +1 -1
- xinference/model/llm/transformers/deepseek_vl2.py +287 -0
- xinference/model/llm/transformers/gemma3.py +17 -2
- xinference/model/llm/transformers/intern_vl.py +28 -18
- xinference/model/llm/transformers/minicpmv26.py +21 -2
- xinference/model/llm/transformers/qwen-omni.py +308 -0
- xinference/model/llm/transformers/qwen2_audio.py +1 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -4
- xinference/model/llm/utils.py +37 -15
- xinference/model/llm/vllm/core.py +184 -8
- xinference/model/llm/vllm/distributed_executor.py +320 -0
- xinference/model/rerank/core.py +22 -12
- xinference/model/utils.py +118 -1
- xinference/model/video/core.py +6 -1
- xinference/thirdparty/deepseek_vl2/__init__.py +31 -0
- xinference/thirdparty/deepseek_vl2/models/__init__.py +26 -0
- xinference/thirdparty/deepseek_vl2/models/configuration_deepseek.py +210 -0
- xinference/thirdparty/deepseek_vl2/models/conversation.py +310 -0
- xinference/thirdparty/deepseek_vl2/models/modeling_deepseek.py +1975 -0
- xinference/thirdparty/deepseek_vl2/models/modeling_deepseek_vl_v2.py +697 -0
- xinference/thirdparty/deepseek_vl2/models/processing_deepseek_vl_v2.py +675 -0
- xinference/thirdparty/deepseek_vl2/models/siglip_vit.py +661 -0
- xinference/thirdparty/deepseek_vl2/serve/__init__.py +0 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/__init__.py +0 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/gradio_utils.py +83 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/overwrites.py +81 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/presets.py +115 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/utils.py +333 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/simsun.ttc +0 -0
- xinference/thirdparty/deepseek_vl2/serve/inference.py +197 -0
- xinference/thirdparty/deepseek_vl2/utils/__init__.py +18 -0
- xinference/thirdparty/deepseek_vl2/utils/io.py +80 -0
- xinference/thirdparty/megatts3/__init__.py +0 -0
- xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
- xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
- xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
- xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
- xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
- xinference/types.py +10 -0
- xinference/utils.py +54 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
- xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
- xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
- xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
- xinference/web/ui/src/locales/en.json +2 -1
- xinference/web/ui/src/locales/zh.json +2 -1
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/METADATA +128 -115
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/RECORD +124 -63
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
- xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
- xinference/web/ui/build/static/js/main.3cea968e.js +0 -3
- xinference/web/ui/build/static/js/main.3cea968e.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/7f59e45e3f268ab8a4788b6fb024cf8dab088736dff22f5a3a39c122a83ab930.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/dcd60488509450bfff37bfff56de2c096d51de17dd00ec60d4db49c8b483ada1.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
- /xinference/web/ui/build/static/js/{main.3cea968e.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from torch import nn
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LayerNorm(torch.nn.LayerNorm):
|
|
20
|
+
"""Layer normalization module.
|
|
21
|
+
:param int nout: output dim size
|
|
22
|
+
:param int dim: dimension to be normalized
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, nout, dim=-1, eps=1e-5):
|
|
26
|
+
"""Construct an LayerNorm object."""
|
|
27
|
+
super(LayerNorm, self).__init__(nout, eps=eps)
|
|
28
|
+
self.dim = dim
|
|
29
|
+
|
|
30
|
+
def forward(self, x):
|
|
31
|
+
"""Apply layer normalization.
|
|
32
|
+
:param torch.Tensor x: input tensor
|
|
33
|
+
:return: layer normalized tensor
|
|
34
|
+
:rtype torch.Tensor
|
|
35
|
+
"""
|
|
36
|
+
if self.dim == -1:
|
|
37
|
+
return super(LayerNorm, self).forward(x)
|
|
38
|
+
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Reshape(nn.Module):
|
|
42
|
+
def __init__(self, *args):
|
|
43
|
+
super(Reshape, self).__init__()
|
|
44
|
+
self.shape = args
|
|
45
|
+
|
|
46
|
+
def forward(self, x):
|
|
47
|
+
return x.view(self.shape)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Permute(nn.Module):
|
|
51
|
+
def __init__(self, *args):
|
|
52
|
+
super(Permute, self).__init__()
|
|
53
|
+
self.args = args
|
|
54
|
+
|
|
55
|
+
def forward(self, x):
|
|
56
|
+
return x.permute(self.args)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def Embedding(num_embeddings, embedding_dim, padding_idx=None):
|
|
60
|
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
|
61
|
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
|
62
|
+
if padding_idx is not None:
|
|
63
|
+
nn.init.constant_(m.weight[padding_idx], 0)
|
|
64
|
+
return m
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import math
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
|
|
20
|
+
import torch.nn.functional as F
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LengthRegulator(torch.nn.Module):
|
|
24
|
+
def __init__(self, pad_value=0.0):
|
|
25
|
+
super(LengthRegulator, self).__init__()
|
|
26
|
+
self.pad_value = pad_value
|
|
27
|
+
|
|
28
|
+
def forward(self, dur, dur_padding=None, alpha=1.0):
|
|
29
|
+
"""
|
|
30
|
+
Example (no batch dim version):
|
|
31
|
+
1. dur = [2,2,3]
|
|
32
|
+
2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
|
|
33
|
+
3. token_mask = [[1,1,0,0,0,0,0],
|
|
34
|
+
[0,0,1,1,0,0,0],
|
|
35
|
+
[0,0,0,0,1,1,1]]
|
|
36
|
+
4. token_idx * token_mask = [[1,1,0,0,0,0,0],
|
|
37
|
+
[0,0,2,2,0,0,0],
|
|
38
|
+
[0,0,0,0,3,3,3]]
|
|
39
|
+
5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
|
|
40
|
+
|
|
41
|
+
:param dur: Batch of durations of each frame (B, T_txt)
|
|
42
|
+
:param dur_padding: Batch of padding of each frame (B, T_txt)
|
|
43
|
+
:param alpha: duration rescale coefficient
|
|
44
|
+
:return:
|
|
45
|
+
mel2ph (B, T_speech)
|
|
46
|
+
assert alpha > 0
|
|
47
|
+
"""
|
|
48
|
+
dur = torch.round(dur.float() * alpha).long()
|
|
49
|
+
if dur_padding is not None:
|
|
50
|
+
dur = dur * (1 - dur_padding.long())
|
|
51
|
+
token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
|
|
52
|
+
dur_cumsum = torch.cumsum(dur, 1)
|
|
53
|
+
dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
|
|
54
|
+
|
|
55
|
+
pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
|
|
56
|
+
token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
|
|
57
|
+
mel2token = (token_idx * token_mask.long()).sum(1)
|
|
58
|
+
return mel2token
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class PosEmb(nn.Module):
|
|
62
|
+
def __init__(self, dim):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.dim = dim
|
|
65
|
+
half_dim = self.dim // 2
|
|
66
|
+
emb = math.log(10000) / (half_dim - 1)
|
|
67
|
+
emb = torch.exp(torch.arange(half_dim) * -emb)
|
|
68
|
+
self.emb = emb # TODO
|
|
69
|
+
|
|
70
|
+
def forward(self, x):
|
|
71
|
+
emb = x[:, :, None] * self.emb[None, None, :].to(x.device)
|
|
72
|
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
|
73
|
+
return emb
|
|
@@ -0,0 +1,403 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import math
|
|
16
|
+
import torch
|
|
17
|
+
from torch import nn
|
|
18
|
+
from torch.nn import functional as F
|
|
19
|
+
|
|
20
|
+
from tts.modules.ar_dur.commons.layers import Embedding
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def convert_pad_shape(pad_shape):
|
|
24
|
+
l = pad_shape[::-1]
|
|
25
|
+
pad_shape = [item for sublist in l for item in sublist]
|
|
26
|
+
return pad_shape
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def shift_1d(x):
|
|
30
|
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
|
31
|
+
return x
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def sequence_mask(length, max_length=None):
|
|
35
|
+
if max_length is None:
|
|
36
|
+
max_length = length.max()
|
|
37
|
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
|
38
|
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Encoder(nn.Module):
|
|
42
|
+
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
|
|
43
|
+
window_size=None, block_length=None, pre_ln=False, **kwargs):
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.hidden_channels = hidden_channels
|
|
46
|
+
self.filter_channels = filter_channels
|
|
47
|
+
self.n_heads = n_heads
|
|
48
|
+
self.n_layers = n_layers
|
|
49
|
+
self.kernel_size = kernel_size
|
|
50
|
+
self.p_dropout = p_dropout
|
|
51
|
+
self.window_size = window_size
|
|
52
|
+
self.block_length = block_length
|
|
53
|
+
self.pre_ln = pre_ln
|
|
54
|
+
|
|
55
|
+
self.drop = nn.Dropout(p_dropout)
|
|
56
|
+
self.attn_layers = nn.ModuleList()
|
|
57
|
+
self.norm_layers_1 = nn.ModuleList()
|
|
58
|
+
self.ffn_layers = nn.ModuleList()
|
|
59
|
+
self.norm_layers_2 = nn.ModuleList()
|
|
60
|
+
for i in range(self.n_layers):
|
|
61
|
+
self.attn_layers.append(
|
|
62
|
+
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size,
|
|
63
|
+
p_dropout=p_dropout, block_length=block_length))
|
|
64
|
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
|
65
|
+
self.ffn_layers.append(
|
|
66
|
+
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
|
67
|
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
|
68
|
+
if pre_ln:
|
|
69
|
+
self.last_ln = LayerNorm(hidden_channels)
|
|
70
|
+
|
|
71
|
+
def forward(self, x, x_mask, attn_mask=1):
|
|
72
|
+
if isinstance(attn_mask, torch.Tensor):
|
|
73
|
+
attn_mask = attn_mask[:, None]
|
|
74
|
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) * attn_mask
|
|
75
|
+
for i in range(self.n_layers):
|
|
76
|
+
x = x * x_mask
|
|
77
|
+
x_ = x
|
|
78
|
+
if self.pre_ln:
|
|
79
|
+
x = self.norm_layers_1[i](x)
|
|
80
|
+
y = self.attn_layers[i](x, x, attn_mask)
|
|
81
|
+
y = self.drop(y)
|
|
82
|
+
x = x_ + y
|
|
83
|
+
if not self.pre_ln:
|
|
84
|
+
x = self.norm_layers_1[i](x)
|
|
85
|
+
|
|
86
|
+
x_ = x
|
|
87
|
+
if self.pre_ln:
|
|
88
|
+
x = self.norm_layers_2[i](x)
|
|
89
|
+
y = self.ffn_layers[i](x, x_mask)
|
|
90
|
+
y = self.drop(y)
|
|
91
|
+
x = x_ + y
|
|
92
|
+
if not self.pre_ln:
|
|
93
|
+
x = self.norm_layers_2[i](x)
|
|
94
|
+
if self.pre_ln:
|
|
95
|
+
x = self.last_ln(x)
|
|
96
|
+
x = x * x_mask
|
|
97
|
+
return x
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class MultiHeadAttention(nn.Module):
|
|
101
|
+
def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0.,
|
|
102
|
+
block_length=None, proximal_bias=False, proximal_init=False):
|
|
103
|
+
super().__init__()
|
|
104
|
+
assert channels % n_heads == 0
|
|
105
|
+
|
|
106
|
+
self.channels = channels
|
|
107
|
+
self.out_channels = out_channels
|
|
108
|
+
self.n_heads = n_heads
|
|
109
|
+
self.window_size = window_size
|
|
110
|
+
self.heads_share = heads_share
|
|
111
|
+
self.block_length = block_length
|
|
112
|
+
self.proximal_bias = proximal_bias
|
|
113
|
+
self.p_dropout = p_dropout
|
|
114
|
+
self.attn = None
|
|
115
|
+
|
|
116
|
+
self.k_channels = channels // n_heads
|
|
117
|
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
|
118
|
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
|
119
|
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
|
120
|
+
if window_size is not None:
|
|
121
|
+
n_heads_rel = 1 if heads_share else n_heads
|
|
122
|
+
rel_stddev = self.k_channels ** -0.5
|
|
123
|
+
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
|
124
|
+
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
|
125
|
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
|
126
|
+
self.drop = nn.Dropout(p_dropout)
|
|
127
|
+
|
|
128
|
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
|
129
|
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
|
130
|
+
if proximal_init:
|
|
131
|
+
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
|
132
|
+
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
|
133
|
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
|
134
|
+
|
|
135
|
+
def forward(self, x, c, attn_mask=None):
|
|
136
|
+
q = self.conv_q(x)
|
|
137
|
+
k = self.conv_k(c)
|
|
138
|
+
v = self.conv_v(c)
|
|
139
|
+
|
|
140
|
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
|
141
|
+
|
|
142
|
+
x = self.conv_o(x)
|
|
143
|
+
return x
|
|
144
|
+
|
|
145
|
+
def attention(self, query, key, value, mask=None):
|
|
146
|
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
|
147
|
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
|
148
|
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
|
149
|
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
|
150
|
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
|
151
|
+
|
|
152
|
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
|
153
|
+
if self.window_size is not None:
|
|
154
|
+
assert t_s == t_t, "Relative attention is only available for self-attention."
|
|
155
|
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
|
156
|
+
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
|
|
157
|
+
rel_logits = self._relative_position_to_absolute_position(rel_logits)
|
|
158
|
+
scores_local = rel_logits / math.sqrt(self.k_channels)
|
|
159
|
+
scores = scores + scores_local
|
|
160
|
+
if self.proximal_bias:
|
|
161
|
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
|
162
|
+
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
|
163
|
+
if mask is not None:
|
|
164
|
+
scores = scores.masked_fill(mask == 0, -1e4)
|
|
165
|
+
if self.block_length is not None:
|
|
166
|
+
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
|
|
167
|
+
scores = scores * block_mask + -1e4 * (1 - block_mask)
|
|
168
|
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
|
169
|
+
p_attn = self.drop(p_attn)
|
|
170
|
+
output = torch.matmul(p_attn, value)
|
|
171
|
+
if self.window_size is not None:
|
|
172
|
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
|
173
|
+
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
|
174
|
+
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
|
175
|
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
|
176
|
+
return output, p_attn
|
|
177
|
+
|
|
178
|
+
def _matmul_with_relative_values(self, x, y):
|
|
179
|
+
"""
|
|
180
|
+
x: [b, h, l, m]
|
|
181
|
+
y: [h or 1, m, d]
|
|
182
|
+
ret: [b, h, l, d]
|
|
183
|
+
"""
|
|
184
|
+
ret = torch.matmul(x, y.unsqueeze(0))
|
|
185
|
+
return ret
|
|
186
|
+
|
|
187
|
+
def _matmul_with_relative_keys(self, x, y):
|
|
188
|
+
"""
|
|
189
|
+
x: [b, h, l, d]
|
|
190
|
+
y: [h or 1, m, d]
|
|
191
|
+
ret: [b, h, l, m]
|
|
192
|
+
"""
|
|
193
|
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
|
194
|
+
return ret
|
|
195
|
+
|
|
196
|
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
|
197
|
+
max_relative_position = 2 * self.window_size + 1
|
|
198
|
+
# Pad first before slice to avoid using cond ops.
|
|
199
|
+
pad_length = max(length - (self.window_size + 1), 0)
|
|
200
|
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
|
201
|
+
slice_end_position = slice_start_position + 2 * length - 1
|
|
202
|
+
if pad_length > 0:
|
|
203
|
+
padded_relative_embeddings = F.pad(
|
|
204
|
+
relative_embeddings,
|
|
205
|
+
convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
|
|
206
|
+
else:
|
|
207
|
+
padded_relative_embeddings = relative_embeddings
|
|
208
|
+
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
|
209
|
+
return used_relative_embeddings
|
|
210
|
+
|
|
211
|
+
def _relative_position_to_absolute_position(self, x):
|
|
212
|
+
"""
|
|
213
|
+
x: [b, h, l, 2*l-1]
|
|
214
|
+
ret: [b, h, l, l]
|
|
215
|
+
"""
|
|
216
|
+
batch, heads, length, _ = x.size()
|
|
217
|
+
# Concat columns of pad to shift from relative to absolute indexing.
|
|
218
|
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
|
219
|
+
|
|
220
|
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
|
221
|
+
x_flat = x.view([batch, heads, length * 2 * length])
|
|
222
|
+
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
|
223
|
+
|
|
224
|
+
# Reshape and slice out the padded elements.
|
|
225
|
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
|
|
226
|
+
return x_final
|
|
227
|
+
|
|
228
|
+
def _absolute_position_to_relative_position(self, x):
|
|
229
|
+
"""
|
|
230
|
+
x: [b, h, l, l]
|
|
231
|
+
ret: [b, h, l, 2*l-1]
|
|
232
|
+
"""
|
|
233
|
+
batch, heads, length, _ = x.size()
|
|
234
|
+
# padd along column
|
|
235
|
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
|
236
|
+
x_flat = x.view([batch, heads, -1])
|
|
237
|
+
# add 0's in the beginning that will skew the elements after reshape
|
|
238
|
+
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
|
239
|
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
|
240
|
+
return x_final
|
|
241
|
+
|
|
242
|
+
def _attention_bias_proximal(self, length):
|
|
243
|
+
"""Bias for self-attention to encourage attention to close positions.
|
|
244
|
+
Args:
|
|
245
|
+
length: an integer scalar.
|
|
246
|
+
Returns:
|
|
247
|
+
a Tensor with shape [1, 1, length, length]
|
|
248
|
+
"""
|
|
249
|
+
r = torch.arange(length, dtype=torch.float32)
|
|
250
|
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
|
251
|
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class FFN(nn.Module):
|
|
255
|
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
|
|
256
|
+
super().__init__()
|
|
257
|
+
self.in_channels = in_channels
|
|
258
|
+
self.out_channels = out_channels
|
|
259
|
+
self.filter_channels = filter_channels
|
|
260
|
+
self.kernel_size = kernel_size
|
|
261
|
+
self.p_dropout = p_dropout
|
|
262
|
+
self.activation = activation
|
|
263
|
+
|
|
264
|
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
|
265
|
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1)
|
|
266
|
+
self.drop = nn.Dropout(p_dropout)
|
|
267
|
+
|
|
268
|
+
def forward(self, x, x_mask):
|
|
269
|
+
x = self.conv_1(x * x_mask)
|
|
270
|
+
if self.activation == "gelu":
|
|
271
|
+
x = x * torch.sigmoid(1.702 * x)
|
|
272
|
+
else:
|
|
273
|
+
x = torch.relu(x)
|
|
274
|
+
x = self.drop(x)
|
|
275
|
+
x = self.conv_2(x * x_mask)
|
|
276
|
+
return x * x_mask
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class LayerNorm(nn.Module):
|
|
280
|
+
def __init__(self, channels, eps=1e-4):
|
|
281
|
+
super().__init__()
|
|
282
|
+
self.channels = channels
|
|
283
|
+
self.eps = eps
|
|
284
|
+
|
|
285
|
+
self.gamma = nn.Parameter(torch.ones(channels))
|
|
286
|
+
self.beta = nn.Parameter(torch.zeros(channels))
|
|
287
|
+
|
|
288
|
+
def forward(self, x):
|
|
289
|
+
n_dims = len(x.shape)
|
|
290
|
+
mean = torch.mean(x, 1, keepdim=True)
|
|
291
|
+
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
|
292
|
+
|
|
293
|
+
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
|
294
|
+
|
|
295
|
+
shape = [1, -1] + [1] * (n_dims - 2)
|
|
296
|
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
|
297
|
+
return x
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class ConvReluNorm(nn.Module):
|
|
301
|
+
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
|
302
|
+
super().__init__()
|
|
303
|
+
self.in_channels = in_channels
|
|
304
|
+
self.hidden_channels = hidden_channels
|
|
305
|
+
self.out_channels = out_channels
|
|
306
|
+
self.kernel_size = kernel_size
|
|
307
|
+
self.n_layers = n_layers
|
|
308
|
+
self.p_dropout = p_dropout
|
|
309
|
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
|
310
|
+
|
|
311
|
+
self.conv_layers = nn.ModuleList()
|
|
312
|
+
self.norm_layers = nn.ModuleList()
|
|
313
|
+
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
|
314
|
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
|
315
|
+
self.relu_drop = nn.Sequential(
|
|
316
|
+
nn.ReLU(),
|
|
317
|
+
nn.Dropout(p_dropout))
|
|
318
|
+
for _ in range(n_layers - 1):
|
|
319
|
+
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
|
320
|
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
|
321
|
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
|
322
|
+
self.proj.weight.data.zero_()
|
|
323
|
+
self.proj.bias.data.zero_()
|
|
324
|
+
|
|
325
|
+
def forward(self, x, x_mask):
|
|
326
|
+
x_org = x
|
|
327
|
+
for i in range(self.n_layers):
|
|
328
|
+
x = self.conv_layers[i](x * x_mask)
|
|
329
|
+
x = self.norm_layers[i](x)
|
|
330
|
+
x = self.relu_drop(x)
|
|
331
|
+
x = x_org + self.proj(x)
|
|
332
|
+
return x * x_mask
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class RelTransformerEncoder(nn.Module):
|
|
336
|
+
def __init__(self,
|
|
337
|
+
n_vocab,
|
|
338
|
+
out_channels,
|
|
339
|
+
hidden_channels,
|
|
340
|
+
filter_channels,
|
|
341
|
+
n_heads,
|
|
342
|
+
n_layers,
|
|
343
|
+
kernel_size,
|
|
344
|
+
p_dropout=0.0,
|
|
345
|
+
window_size=4,
|
|
346
|
+
block_length=None,
|
|
347
|
+
in_channels=None,
|
|
348
|
+
prenet=True,
|
|
349
|
+
pre_ln=True,
|
|
350
|
+
):
|
|
351
|
+
|
|
352
|
+
super().__init__()
|
|
353
|
+
|
|
354
|
+
self.n_vocab = n_vocab
|
|
355
|
+
self.out_channels = out_channels
|
|
356
|
+
self.hidden_channels = hidden_channels
|
|
357
|
+
self.filter_channels = filter_channels
|
|
358
|
+
self.n_heads = n_heads
|
|
359
|
+
self.n_layers = n_layers
|
|
360
|
+
self.kernel_size = kernel_size
|
|
361
|
+
self.p_dropout = p_dropout
|
|
362
|
+
self.window_size = window_size
|
|
363
|
+
self.block_length = block_length
|
|
364
|
+
self.prenet = prenet
|
|
365
|
+
if n_vocab > 0:
|
|
366
|
+
self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0)
|
|
367
|
+
|
|
368
|
+
if prenet:
|
|
369
|
+
if in_channels is None:
|
|
370
|
+
in_channels = hidden_channels
|
|
371
|
+
self.pre = ConvReluNorm(in_channels, in_channels, in_channels,
|
|
372
|
+
kernel_size=5, n_layers=3, p_dropout=0)
|
|
373
|
+
if in_channels is not None and in_channels != hidden_channels:
|
|
374
|
+
self.encoder_inp_proj = nn.Conv1d(in_channels, hidden_channels, 1)
|
|
375
|
+
self.encoder = Encoder(
|
|
376
|
+
hidden_channels,
|
|
377
|
+
filter_channels,
|
|
378
|
+
n_heads,
|
|
379
|
+
n_layers,
|
|
380
|
+
kernel_size,
|
|
381
|
+
p_dropout,
|
|
382
|
+
window_size=window_size,
|
|
383
|
+
block_length=block_length,
|
|
384
|
+
pre_ln=pre_ln,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
def forward(self, x, x_mask=None, other_embeds=0, attn_mask=1):
|
|
388
|
+
if self.n_vocab > 0:
|
|
389
|
+
x_lengths = (x > 0).long().sum(-1)
|
|
390
|
+
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
|
391
|
+
else:
|
|
392
|
+
x_lengths = (x.abs().sum(-1) > 0).long().sum(-1)
|
|
393
|
+
x = x + other_embeds
|
|
394
|
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
|
395
|
+
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
|
396
|
+
|
|
397
|
+
if self.prenet:
|
|
398
|
+
x = self.pre(x, x_mask)
|
|
399
|
+
self.prenet_out = x.transpose(1, 2)
|
|
400
|
+
if hasattr(self, 'encoder_inp_proj'):
|
|
401
|
+
x = self.encoder_inp_proj(x) * x_mask
|
|
402
|
+
x = self.encoder(x, x_mask, attn_mask)
|
|
403
|
+
return x.transpose(1, 2)
|