xinference 1.4.1__py3-none-any.whl → 1.5.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +50 -1
- xinference/client/restful/restful_client.py +82 -2
- xinference/constants.py +3 -0
- xinference/core/chat_interface.py +297 -83
- xinference/core/model.py +1 -0
- xinference/core/progress_tracker.py +16 -8
- xinference/core/supervisor.py +45 -1
- xinference/core/worker.py +262 -37
- xinference/deploy/cmdline.py +33 -1
- xinference/model/audio/core.py +11 -1
- xinference/model/audio/megatts.py +105 -0
- xinference/model/audio/model_spec.json +24 -1
- xinference/model/audio/model_spec_modelscope.json +26 -1
- xinference/model/core.py +14 -0
- xinference/model/embedding/core.py +6 -1
- xinference/model/flexible/core.py +6 -1
- xinference/model/image/core.py +6 -1
- xinference/model/image/model_spec.json +17 -1
- xinference/model/image/model_spec_modelscope.json +17 -1
- xinference/model/llm/__init__.py +0 -4
- xinference/model/llm/core.py +4 -0
- xinference/model/llm/llama_cpp/core.py +40 -16
- xinference/model/llm/llm_family.json +415 -84
- xinference/model/llm/llm_family.py +24 -1
- xinference/model/llm/llm_family_modelscope.json +449 -0
- xinference/model/llm/mlx/core.py +16 -2
- xinference/model/llm/transformers/__init__.py +14 -0
- xinference/model/llm/transformers/core.py +30 -6
- xinference/model/llm/transformers/gemma3.py +17 -2
- xinference/model/llm/transformers/intern_vl.py +28 -18
- xinference/model/llm/transformers/minicpmv26.py +21 -2
- xinference/model/llm/transformers/qwen-omni.py +308 -0
- xinference/model/llm/transformers/qwen2_audio.py +1 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -4
- xinference/model/llm/utils.py +11 -1
- xinference/model/llm/vllm/core.py +35 -0
- xinference/model/llm/vllm/distributed_executor.py +8 -2
- xinference/model/rerank/core.py +6 -1
- xinference/model/utils.py +118 -1
- xinference/model/video/core.py +6 -1
- xinference/thirdparty/megatts3/__init__.py +0 -0
- xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
- xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
- xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
- xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
- xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
- xinference/types.py +10 -0
- xinference/utils.py +54 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
- xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
- xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
- xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
- xinference/web/ui/src/locales/en.json +2 -1
- xinference/web/ui/src/locales/zh.json +2 -1
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/METADATA +129 -114
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/RECORD +96 -60
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
- xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
- xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
- xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
- /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/entry_points.txt +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info/licenses}/LICENSE +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import random
|
|
16
|
+
from copy import deepcopy
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
from torch import nn
|
|
21
|
+
from torch.nn import Linear
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
|
|
24
|
+
from tts.modules.ar_dur.commons.layers import Embedding, LayerNorm
|
|
25
|
+
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
|
|
26
|
+
from tts.modules.ar_dur.commons.rot_transformer import RotTransformerDecoderLayer
|
|
27
|
+
from tts.modules.ar_dur.commons.transformer import SinusoidalPositionalEmbedding
|
|
28
|
+
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
|
|
29
|
+
|
|
30
|
+
FS_ENCODERS = {
|
|
31
|
+
'rel_fft': lambda hp, dict_size: RelTransformerEncoder(
|
|
32
|
+
dict_size, hp['hidden_size'], hp['hidden_size'],
|
|
33
|
+
hp['ffn_hidden_size'], hp['num_heads'], hp['enc_layers'],
|
|
34
|
+
hp['enc_ffn_kernel_size'], hp['dropout'], prenet=hp['enc_prenet'], pre_ln=hp['enc_pre_ln']),
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
def fill_with_neg_inf2(t):
|
|
38
|
+
"""FP16-compatible function that fills a tensor with -inf."""
|
|
39
|
+
return t.float().fill_(-1e8).type_as(t)
|
|
40
|
+
|
|
41
|
+
def expand_states(h, mel2token):
|
|
42
|
+
h = F.pad(h, [0, 0, 1, 0])
|
|
43
|
+
mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]])
|
|
44
|
+
h = torch.gather(h, 1, mel2token_) # [B, T, H]
|
|
45
|
+
return h
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class CodePredictor(nn.Module):
|
|
49
|
+
def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size):
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.hparams = deepcopy(hparams)
|
|
52
|
+
self.hparams['hidden_size'] = hidden_size
|
|
53
|
+
self.hidden_size = hidden_size
|
|
54
|
+
char_dict_size = hparams.get('char_dict_size', 4000)
|
|
55
|
+
if not hparams.get('lm_use_enc'):
|
|
56
|
+
self.encoder = nn.Embedding(dict_size, self.hidden_size, padding_idx=0)
|
|
57
|
+
if hparams.get('mega_use_char', True):
|
|
58
|
+
self.char_encoder = nn.Embedding(char_dict_size,
|
|
59
|
+
self.hidden_size, padding_idx=0)
|
|
60
|
+
else:
|
|
61
|
+
self.encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, dict_size)
|
|
62
|
+
if hparams.get('mega_use_char', True):
|
|
63
|
+
self.char_encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, char_dict_size)
|
|
64
|
+
if hparams['use_ph_pos_embed']:
|
|
65
|
+
self.ph_pos_embed = PosEmb(self.hidden_size)
|
|
66
|
+
|
|
67
|
+
self.char_empty_embed = nn.Embedding(1, self.hidden_size)
|
|
68
|
+
if hparams.get('use_bert_input'):
|
|
69
|
+
self.bert_input_proj = nn.Linear(768, self.hidden_size)
|
|
70
|
+
self.ling_label_embed_layers = nn.ModuleDict()
|
|
71
|
+
for k, s in zip(hparams['ling_labels'], hparams['ling_label_dict_size']):
|
|
72
|
+
self.ling_label_embed_layers[k] = Embedding(s + 3, self.hidden_size, padding_idx=0)
|
|
73
|
+
|
|
74
|
+
self.dec_hidden_size = dec_hidden_size
|
|
75
|
+
self.enc_proj = nn.Linear(self.hidden_size, dec_hidden_size)
|
|
76
|
+
self.code_emb = Embedding(code_size + 2, dec_hidden_size, 0)
|
|
77
|
+
self.use_pos_embed = hparams.get('use_pos_embed', False)
|
|
78
|
+
if self.use_pos_embed:
|
|
79
|
+
self.embed_positions = SinusoidalPositionalEmbedding(dec_hidden_size, 0, init_size=1024)
|
|
80
|
+
self.use_post_ln = hparams.get('use_post_ln', False)
|
|
81
|
+
self.layers = None
|
|
82
|
+
if not self.use_post_ln:
|
|
83
|
+
self.layer_norm = LayerNorm(dec_hidden_size)
|
|
84
|
+
self.code_size = code_size
|
|
85
|
+
self.project_out_dim = Linear(dec_hidden_size, code_size + 1, bias=True)
|
|
86
|
+
|
|
87
|
+
def forward_ling_encoder(
|
|
88
|
+
self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre):
|
|
89
|
+
ph_tokens = txt_tokens
|
|
90
|
+
hparams = self.hparams
|
|
91
|
+
ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1]
|
|
92
|
+
x_spk = self.forward_style_embed(spk_embed, spk_id, mels_timbre)
|
|
93
|
+
|
|
94
|
+
# enc_ph
|
|
95
|
+
if not hparams.get('lm_use_enc'):
|
|
96
|
+
x_ph = self.encoder(ph_tokens)
|
|
97
|
+
x_ph = x_ph + sum(
|
|
98
|
+
[self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \
|
|
99
|
+
if len(hparams['ling_labels']) > 0 else 0
|
|
100
|
+
x_ph = x_ph + x_spk
|
|
101
|
+
else:
|
|
102
|
+
# enc_ph
|
|
103
|
+
ph_enc_oembed = sum(
|
|
104
|
+
[self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \
|
|
105
|
+
if len(hparams['ling_labels']) > 0 else 0
|
|
106
|
+
ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
|
|
107
|
+
torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
|
|
108
|
+
ph_enc_oembed = ph_enc_oembed + x_spk
|
|
109
|
+
ph_enc_oembed = ph_enc_oembed * ph_nonpadding
|
|
110
|
+
x_ph = self.encoder(ph_tokens, other_embeds=ph_enc_oembed)
|
|
111
|
+
|
|
112
|
+
# enc_char
|
|
113
|
+
if char_tokens is not None and ph2char is not None:
|
|
114
|
+
char_nonpadding = (char_tokens > 0).float()[:, :, None]
|
|
115
|
+
x_char = self.char_encoder(char_tokens)
|
|
116
|
+
empty_char = (ph2char > 100000).long()
|
|
117
|
+
ph2char = ph2char * (1 - empty_char)
|
|
118
|
+
x_char_phlevel = \
|
|
119
|
+
expand_states(x_char * char_nonpadding, ph2char) \
|
|
120
|
+
* (1 - empty_char)[..., None] + \
|
|
121
|
+
self.char_empty_embed(torch.zeros_like(ph_tokens)) * empty_char[..., None]
|
|
122
|
+
else:
|
|
123
|
+
x_char_phlevel = 0
|
|
124
|
+
# x_ling
|
|
125
|
+
x_ling = x_ph + x_char_phlevel
|
|
126
|
+
x_ling = x_ling * ph_nonpadding
|
|
127
|
+
x_ling = self.enc_proj(x_ling)
|
|
128
|
+
return x_ling
|
|
129
|
+
|
|
130
|
+
def sample_one_step(self, vq_pred):
|
|
131
|
+
hparams = self.hparams
|
|
132
|
+
if hparams.get('infer_top_k'):
|
|
133
|
+
top_k = hparams.get('infer_top_k')
|
|
134
|
+
temperature = hparams.get('infer_temperature', 1)
|
|
135
|
+
vq_pred = vq_pred[:, -1] / temperature
|
|
136
|
+
# optionally crop the logits to only the top k options
|
|
137
|
+
if top_k is not None:
|
|
138
|
+
v, _ = torch.topk(vq_pred, min(top_k, vq_pred.size(-1)))
|
|
139
|
+
vq_pred[vq_pred < v[:, [-1]]] = -float('Inf')
|
|
140
|
+
# apply softmax to convert logits to (normalized) probabilities
|
|
141
|
+
probs = F.softmax(vq_pred, dim=-1)
|
|
142
|
+
# sample from the distribution
|
|
143
|
+
vq_pred = torch.multinomial(probs, num_samples=1)
|
|
144
|
+
else:
|
|
145
|
+
vq_pred = torch.argmax(F.softmax(vq_pred[:, -1], dim=-1), 1)
|
|
146
|
+
return vq_pred
|
|
147
|
+
|
|
148
|
+
def forward_style_embed(self, spk_embed=None, spk_id=None, mel_ref=None):
|
|
149
|
+
# add spk embed
|
|
150
|
+
style_embed = 0
|
|
151
|
+
if self.hparams['use_spk_embed']:
|
|
152
|
+
style_embed = style_embed + self.spk_embed_proj(spk_embed)[:, None, :]
|
|
153
|
+
if self.hparams['use_spk_id']:
|
|
154
|
+
style_embed = style_embed + self.spk_id_proj(spk_id)[:, None, :]
|
|
155
|
+
if self.hparams['use_spk_enc']:
|
|
156
|
+
style_embed = style_embed + self.spk_enc(mel_ref)[:, None, :]
|
|
157
|
+
return style_embed
|
|
158
|
+
|
|
159
|
+
def buffered_future_mask(self, tensor):
|
|
160
|
+
dim = tensor.size(0)
|
|
161
|
+
if (
|
|
162
|
+
not hasattr(self, '_future_mask')
|
|
163
|
+
or self._future_mask is None
|
|
164
|
+
or self._future_mask.device != tensor.device
|
|
165
|
+
or self._future_mask.size(0) < dim
|
|
166
|
+
):
|
|
167
|
+
self._future_mask = torch.triu(fill_with_neg_inf2(tensor.new(dim, dim)), 1)
|
|
168
|
+
return self._future_mask[:dim, :dim]
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class ARDurPredictor(CodePredictor):
|
|
172
|
+
def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size, use_rot_embed=True,
|
|
173
|
+
op_version=1):
|
|
174
|
+
super().__init__(hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size)
|
|
175
|
+
self.use_rot_embed = use_rot_embed
|
|
176
|
+
bias = hparams.get('lm_bias', True)
|
|
177
|
+
if self.use_rot_embed:
|
|
178
|
+
self.layers = nn.ModuleList([])
|
|
179
|
+
self.layers.extend([
|
|
180
|
+
RotTransformerDecoderLayer(
|
|
181
|
+
dec_hidden_size, 0.0, kernel_size=1, ffn_hidden_size=dec_hidden_size * 4,
|
|
182
|
+
post_ln=self.use_post_ln, op_version=op_version, bias=bias)
|
|
183
|
+
for _ in range(lm_num_layers)
|
|
184
|
+
])
|
|
185
|
+
if hparams['dur_model_type'] == 'ar_mse':
|
|
186
|
+
self.project_out_dim = nn.Sequential(torch.nn.Linear(dec_hidden_size, 1), nn.Softplus())
|
|
187
|
+
else:
|
|
188
|
+
self.project_out_dim = torch.nn.Linear(dec_hidden_size, code_size + 1)
|
|
189
|
+
|
|
190
|
+
def forward(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
|
|
191
|
+
prev_code, spk_id=None, spk_embed=None, mels_timbre=None, mel2ph=None,
|
|
192
|
+
incremental_state=None, x_ling=None, attn_mask=None, spk_pos_ids_flat=None,
|
|
193
|
+
prompt_length=None, cache_size=20, streaming=False):
|
|
194
|
+
x = self.code_emb(prev_code)
|
|
195
|
+
if x_ling is None:
|
|
196
|
+
x_ling = self.forward_ling_encoder(
|
|
197
|
+
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre)
|
|
198
|
+
x_ling = x_ling.flatten(0, 1)
|
|
199
|
+
txt_tokens = txt_tokens.flatten(0, 1)
|
|
200
|
+
x_ling = x_ling[txt_tokens > 0][None]
|
|
201
|
+
|
|
202
|
+
# run decoder
|
|
203
|
+
self_attn_padding_mask = None
|
|
204
|
+
if self.use_pos_embed:
|
|
205
|
+
positions = self.embed_positions(
|
|
206
|
+
prev_code,
|
|
207
|
+
incremental_state=incremental_state
|
|
208
|
+
)
|
|
209
|
+
if incremental_state is not None:
|
|
210
|
+
x_ling = x_ling[:, x.shape[1] - 1:x.shape[1]]
|
|
211
|
+
if spk_pos_ids_flat is not None:
|
|
212
|
+
spk_pos_ids_flat = spk_pos_ids_flat[:, x.shape[1] - 1:x.shape[1]]
|
|
213
|
+
x = x[:, -1:]
|
|
214
|
+
if self.use_pos_embed:
|
|
215
|
+
positions = positions[:, -1:]
|
|
216
|
+
if streaming:
|
|
217
|
+
# Shift Pos: query pos is min(cache_size, idx)
|
|
218
|
+
spk_pos_ids_flat = torch.min(torch.LongTensor([prompt_length + cache_size]).to(x.device),
|
|
219
|
+
spk_pos_ids_flat)
|
|
220
|
+
|
|
221
|
+
# # B x T x C -> T x B x C
|
|
222
|
+
if self.use_pos_embed:
|
|
223
|
+
x = x + positions
|
|
224
|
+
x_ling = x_ling[:, :self.hparams['max_tokens']].contiguous()
|
|
225
|
+
T = min(self.hparams.get('max_tokens_per_item', 1e9), x_ling.shape[1])
|
|
226
|
+
x_ling = x_ling.reshape(-1, T, x_ling.shape[-1])
|
|
227
|
+
x = x + x_ling
|
|
228
|
+
x = x.transpose(0, 1)
|
|
229
|
+
|
|
230
|
+
for idx, layer in enumerate(self.layers):
|
|
231
|
+
if incremental_state is None:
|
|
232
|
+
self_attn_mask = self.buffered_future_mask(x)
|
|
233
|
+
if attn_mask is not None:
|
|
234
|
+
self_attn_mask = self_attn_mask + (1 - attn_mask.float()) * -1e8
|
|
235
|
+
self_attn_mask = self_attn_mask.clamp_min(-1e8)
|
|
236
|
+
else:
|
|
237
|
+
self_attn_mask = None
|
|
238
|
+
|
|
239
|
+
x, attn_weights = layer(
|
|
240
|
+
x,
|
|
241
|
+
incremental_state=incremental_state,
|
|
242
|
+
self_attn_mask=self_attn_mask,
|
|
243
|
+
self_attn_padding_mask=self_attn_padding_mask,
|
|
244
|
+
spk_pos_ids_flat=spk_pos_ids_flat
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
if streaming and incremental_state != {}:
|
|
248
|
+
for k, v in incremental_state.items():
|
|
249
|
+
if 'attn_state' in k:
|
|
250
|
+
prev_key, prev_value = incremental_state[k]['prev_key'], incremental_state[k]['prev_value']
|
|
251
|
+
cur_length = prev_key.shape[2]
|
|
252
|
+
if cur_length - prompt_length > cache_size:
|
|
253
|
+
prev_key = torch.cat((prev_key[:, :, :prompt_length], prev_key[:, :, -cache_size:]), dim=2)
|
|
254
|
+
prev_value = torch.cat((prev_value[:, :, :prompt_length], prev_value[:, :, -cache_size:]),
|
|
255
|
+
dim=2)
|
|
256
|
+
incremental_state[k]['prev_key'], incremental_state[k]['prev_value'] = prev_key, prev_value
|
|
257
|
+
|
|
258
|
+
if not self.use_post_ln:
|
|
259
|
+
x = self.layer_norm(x)
|
|
260
|
+
# T x B x C -> B x T x C
|
|
261
|
+
x = x.transpose(0, 1)
|
|
262
|
+
x = self.project_out_dim(x)
|
|
263
|
+
return x
|
|
264
|
+
|
|
265
|
+
def infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
|
|
266
|
+
spk_id=None, spk_embed=None, mels_timbre=None,
|
|
267
|
+
incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False,
|
|
268
|
+
first_step_min=0, return_probs=False, first_decoder_inp=None, dur_disturb=0.0, **kwargs):
|
|
269
|
+
if incremental_state is None:
|
|
270
|
+
incremental_state = {}
|
|
271
|
+
x_ling = self.forward_ling_encoder(
|
|
272
|
+
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
|
|
273
|
+
spk_id, spk_embed, mels_timbre)
|
|
274
|
+
x_ling = x_ling.flatten(0, 1)
|
|
275
|
+
txt_tokens_ori = txt_tokens
|
|
276
|
+
txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1)
|
|
277
|
+
x_ling = x_ling[txt_tokens > 0][None]
|
|
278
|
+
txt_tokens = txt_tokens[txt_tokens > 0][None]
|
|
279
|
+
|
|
280
|
+
decoded = torch.zeros_like(txt_tokens)
|
|
281
|
+
decoded = F.pad(decoded, [1, 0], value=self.code_size + 1)
|
|
282
|
+
if incremental_state != {}:
|
|
283
|
+
if first_decoder_inp is None:
|
|
284
|
+
assert ctx_vqcodes is not None
|
|
285
|
+
decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes
|
|
286
|
+
ctx_vqcodes = None
|
|
287
|
+
else:
|
|
288
|
+
decoded[:, :1] = first_decoder_inp
|
|
289
|
+
probs = []
|
|
290
|
+
for step in range(decoded.shape[1] - 1):
|
|
291
|
+
vq_pred = self(txt_tokens, None, None, None, None,
|
|
292
|
+
decoded[:, :step + 1], None, None, None,
|
|
293
|
+
incremental_state=incremental_state, x_ling=x_ling,
|
|
294
|
+
spk_pos_ids_flat=spk_pos_ids_flat, **kwargs)
|
|
295
|
+
probs.append(vq_pred.cpu())
|
|
296
|
+
if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]:
|
|
297
|
+
if self.hparams['dur_model_type'] == 'ar_mse':
|
|
298
|
+
d = vq_pred[:, -1, 0]
|
|
299
|
+
if dur_disturb > 0 and step >= 1:
|
|
300
|
+
if random.random() > 0.5:
|
|
301
|
+
d = d * (1 + random.random() * dur_disturb)
|
|
302
|
+
else:
|
|
303
|
+
d = d / (1 + random.random() * dur_disturb)
|
|
304
|
+
d = torch.clamp_max(d, self.code_size - 1)
|
|
305
|
+
vq_pred = torch.round(d).long()
|
|
306
|
+
else:
|
|
307
|
+
vq_pred = self.sample_one_step(vq_pred)
|
|
308
|
+
decoded[:, step + 1] = torch.clamp_min(vq_pred, 1)
|
|
309
|
+
if step == 0:
|
|
310
|
+
decoded[:, step + 1] = torch.clamp_min(vq_pred, first_step_min)
|
|
311
|
+
else:
|
|
312
|
+
decoded[:, step + 1] = ctx_vqcodes[:, step]
|
|
313
|
+
decoded = decoded[:, 1:]
|
|
314
|
+
decoded_2d = torch.zeros_like(txt_tokens_ori)
|
|
315
|
+
decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = decoded
|
|
316
|
+
if return_state:
|
|
317
|
+
return decoded_2d, incremental_state
|
|
318
|
+
if return_probs:
|
|
319
|
+
return decoded_2d, torch.cat(probs, 1)
|
|
320
|
+
return decoded_2d
|
|
321
|
+
|
|
322
|
+
def streaming_infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
|
|
323
|
+
spk_id=None, spk_embed=None, mels_timbre=None,
|
|
324
|
+
incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False,
|
|
325
|
+
**kwargs):
|
|
326
|
+
if incremental_state is None:
|
|
327
|
+
incremental_state = {}
|
|
328
|
+
x_ling = self.forward_ling_encoder(
|
|
329
|
+
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
|
|
330
|
+
spk_id, spk_embed, mels_timbre)
|
|
331
|
+
x_ling = x_ling.flatten(0, 1)
|
|
332
|
+
txt_tokens_ori = txt_tokens
|
|
333
|
+
txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1)
|
|
334
|
+
x_ling = x_ling[txt_tokens > 0][None]
|
|
335
|
+
txt_tokens = txt_tokens[txt_tokens > 0][None]
|
|
336
|
+
|
|
337
|
+
vq_decoded = torch.zeros_like(txt_tokens)
|
|
338
|
+
vq_decoded = F.pad(vq_decoded, [1, 0], value=self.code_size + 1)
|
|
339
|
+
if incremental_state != {}:
|
|
340
|
+
assert ctx_vqcodes is not None
|
|
341
|
+
vq_decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes
|
|
342
|
+
ctx_vqcodes = None
|
|
343
|
+
prompt_length = list(incremental_state.items())[0][1]['prev_key'].shape[2]
|
|
344
|
+
for step in tqdm(range(vq_decoded.shape[1] - 1), desc='AR Duration Predictor inference...'):
|
|
345
|
+
vq_pred = self(txt_tokens, None, None, None, None,
|
|
346
|
+
vq_decoded[:, :step + 1], None, None, None,
|
|
347
|
+
incremental_state=incremental_state, x_ling=x_ling,
|
|
348
|
+
spk_pos_ids_flat=spk_pos_ids_flat, prompt_length=prompt_length, streaming=True, **kwargs)
|
|
349
|
+
if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]:
|
|
350
|
+
if self.hparams['dur_model_type'] == 'ar_mse':
|
|
351
|
+
vq_pred = torch.round(vq_pred[:, -1, 0]).long()
|
|
352
|
+
else:
|
|
353
|
+
vq_pred = self.sample_one_step(vq_pred)
|
|
354
|
+
vq_decoded[:, step + 1] = vq_pred
|
|
355
|
+
else:
|
|
356
|
+
vq_decoded[:, step + 1] = ctx_vqcodes[:, step]
|
|
357
|
+
vq_decoded = vq_decoded[:, 1:]
|
|
358
|
+
vq_decoded_2d = torch.zeros_like(txt_tokens_ori)
|
|
359
|
+
vq_decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = vq_decoded
|
|
360
|
+
if return_state:
|
|
361
|
+
return vq_decoded_2d, incremental_state
|
|
362
|
+
return vq_decoded_2d
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from torch import nn
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LayerNorm(torch.nn.LayerNorm):
|
|
20
|
+
"""Layer normalization module.
|
|
21
|
+
:param int nout: output dim size
|
|
22
|
+
:param int dim: dimension to be normalized
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, nout, dim=-1, eps=1e-5):
|
|
26
|
+
"""Construct an LayerNorm object."""
|
|
27
|
+
super(LayerNorm, self).__init__(nout, eps=eps)
|
|
28
|
+
self.dim = dim
|
|
29
|
+
|
|
30
|
+
def forward(self, x):
|
|
31
|
+
"""Apply layer normalization.
|
|
32
|
+
:param torch.Tensor x: input tensor
|
|
33
|
+
:return: layer normalized tensor
|
|
34
|
+
:rtype torch.Tensor
|
|
35
|
+
"""
|
|
36
|
+
if self.dim == -1:
|
|
37
|
+
return super(LayerNorm, self).forward(x)
|
|
38
|
+
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Reshape(nn.Module):
|
|
42
|
+
def __init__(self, *args):
|
|
43
|
+
super(Reshape, self).__init__()
|
|
44
|
+
self.shape = args
|
|
45
|
+
|
|
46
|
+
def forward(self, x):
|
|
47
|
+
return x.view(self.shape)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Permute(nn.Module):
|
|
51
|
+
def __init__(self, *args):
|
|
52
|
+
super(Permute, self).__init__()
|
|
53
|
+
self.args = args
|
|
54
|
+
|
|
55
|
+
def forward(self, x):
|
|
56
|
+
return x.permute(self.args)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def Embedding(num_embeddings, embedding_dim, padding_idx=None):
|
|
60
|
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
|
61
|
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
|
62
|
+
if padding_idx is not None:
|
|
63
|
+
nn.init.constant_(m.weight[padding_idx], 0)
|
|
64
|
+
return m
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import math
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
|
|
20
|
+
import torch.nn.functional as F
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LengthRegulator(torch.nn.Module):
|
|
24
|
+
def __init__(self, pad_value=0.0):
|
|
25
|
+
super(LengthRegulator, self).__init__()
|
|
26
|
+
self.pad_value = pad_value
|
|
27
|
+
|
|
28
|
+
def forward(self, dur, dur_padding=None, alpha=1.0):
|
|
29
|
+
"""
|
|
30
|
+
Example (no batch dim version):
|
|
31
|
+
1. dur = [2,2,3]
|
|
32
|
+
2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
|
|
33
|
+
3. token_mask = [[1,1,0,0,0,0,0],
|
|
34
|
+
[0,0,1,1,0,0,0],
|
|
35
|
+
[0,0,0,0,1,1,1]]
|
|
36
|
+
4. token_idx * token_mask = [[1,1,0,0,0,0,0],
|
|
37
|
+
[0,0,2,2,0,0,0],
|
|
38
|
+
[0,0,0,0,3,3,3]]
|
|
39
|
+
5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
|
|
40
|
+
|
|
41
|
+
:param dur: Batch of durations of each frame (B, T_txt)
|
|
42
|
+
:param dur_padding: Batch of padding of each frame (B, T_txt)
|
|
43
|
+
:param alpha: duration rescale coefficient
|
|
44
|
+
:return:
|
|
45
|
+
mel2ph (B, T_speech)
|
|
46
|
+
assert alpha > 0
|
|
47
|
+
"""
|
|
48
|
+
dur = torch.round(dur.float() * alpha).long()
|
|
49
|
+
if dur_padding is not None:
|
|
50
|
+
dur = dur * (1 - dur_padding.long())
|
|
51
|
+
token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
|
|
52
|
+
dur_cumsum = torch.cumsum(dur, 1)
|
|
53
|
+
dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
|
|
54
|
+
|
|
55
|
+
pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
|
|
56
|
+
token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
|
|
57
|
+
mel2token = (token_idx * token_mask.long()).sum(1)
|
|
58
|
+
return mel2token
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class PosEmb(nn.Module):
|
|
62
|
+
def __init__(self, dim):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.dim = dim
|
|
65
|
+
half_dim = self.dim // 2
|
|
66
|
+
emb = math.log(10000) / (half_dim - 1)
|
|
67
|
+
emb = torch.exp(torch.arange(half_dim) * -emb)
|
|
68
|
+
self.emb = emb # TODO
|
|
69
|
+
|
|
70
|
+
def forward(self, x):
|
|
71
|
+
emb = x[:, :, None] * self.emb[None, None, :].to(x.device)
|
|
72
|
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
|
73
|
+
return emb
|