xinference 1.4.1__py3-none-any.whl → 1.5.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +50 -1
- xinference/client/restful/restful_client.py +82 -2
- xinference/constants.py +3 -0
- xinference/core/chat_interface.py +297 -83
- xinference/core/model.py +1 -0
- xinference/core/progress_tracker.py +16 -8
- xinference/core/supervisor.py +45 -1
- xinference/core/worker.py +262 -37
- xinference/deploy/cmdline.py +33 -1
- xinference/model/audio/core.py +11 -1
- xinference/model/audio/megatts.py +105 -0
- xinference/model/audio/model_spec.json +24 -1
- xinference/model/audio/model_spec_modelscope.json +26 -1
- xinference/model/core.py +14 -0
- xinference/model/embedding/core.py +6 -1
- xinference/model/flexible/core.py +6 -1
- xinference/model/image/core.py +6 -1
- xinference/model/image/model_spec.json +17 -1
- xinference/model/image/model_spec_modelscope.json +17 -1
- xinference/model/llm/__init__.py +0 -4
- xinference/model/llm/core.py +4 -0
- xinference/model/llm/llama_cpp/core.py +40 -16
- xinference/model/llm/llm_family.json +415 -84
- xinference/model/llm/llm_family.py +24 -1
- xinference/model/llm/llm_family_modelscope.json +449 -0
- xinference/model/llm/mlx/core.py +16 -2
- xinference/model/llm/transformers/__init__.py +14 -0
- xinference/model/llm/transformers/core.py +30 -6
- xinference/model/llm/transformers/gemma3.py +17 -2
- xinference/model/llm/transformers/intern_vl.py +28 -18
- xinference/model/llm/transformers/minicpmv26.py +21 -2
- xinference/model/llm/transformers/qwen-omni.py +308 -0
- xinference/model/llm/transformers/qwen2_audio.py +1 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -4
- xinference/model/llm/utils.py +11 -1
- xinference/model/llm/vllm/core.py +35 -0
- xinference/model/llm/vllm/distributed_executor.py +8 -2
- xinference/model/rerank/core.py +6 -1
- xinference/model/utils.py +118 -1
- xinference/model/video/core.py +6 -1
- xinference/thirdparty/megatts3/__init__.py +0 -0
- xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
- xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
- xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
- xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
- xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
- xinference/types.py +10 -0
- xinference/utils.py +54 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
- xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
- xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
- xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
- xinference/web/ui/src/locales/en.json +2 -1
- xinference/web/ui/src/locales/zh.json +2 -1
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/METADATA +129 -114
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/RECORD +96 -60
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
- xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
- xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
- xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
- /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/entry_points.txt +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info/licenses}/LICENSE +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from typing import Any, Optional, Tuple
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
from torch import nn
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
|
24
|
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
|
25
|
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
|
26
|
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
|
27
|
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
|
28
|
+
return freqs_cis
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
|
32
|
+
ndim = x.ndim
|
|
33
|
+
assert 0 <= 1 < ndim
|
|
34
|
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
|
35
|
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
|
36
|
+
return freqs_cis.view(*shape)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def apply_rotary_emb(
|
|
40
|
+
xq: torch.Tensor,
|
|
41
|
+
xk: torch.Tensor,
|
|
42
|
+
freqs_cis: torch.Tensor,
|
|
43
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
44
|
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
|
45
|
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
|
46
|
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
|
47
|
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
|
48
|
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
|
49
|
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class AdaLNZero(nn.Module):
|
|
53
|
+
def __init__(self, dim):
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.silu = nn.SiLU()
|
|
56
|
+
self.linear = nn.Linear(dim, dim * 6)
|
|
57
|
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
58
|
+
|
|
59
|
+
def forward(self, x, emb=None):
|
|
60
|
+
emb = self.linear(self.silu(emb))
|
|
61
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
|
62
|
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
|
63
|
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class AdaLNZero_Out(nn.Module):
|
|
67
|
+
def __init__(self, dim):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.silu = nn.SiLU()
|
|
70
|
+
self.linear = nn.Linear(dim, dim * 2)
|
|
71
|
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
72
|
+
|
|
73
|
+
def forward(self, x, emb):
|
|
74
|
+
emb = self.linear(self.silu(emb))
|
|
75
|
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
|
76
|
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
|
77
|
+
return x
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class Attention(nn.Module):
|
|
81
|
+
def __init__(self, encoder_dim, encoder_n_heads, max_seq_len):
|
|
82
|
+
super().__init__()
|
|
83
|
+
self.encoder_n_kv_heads = encoder_n_heads
|
|
84
|
+
model_parallel_size = 1
|
|
85
|
+
self.n_local_heads = encoder_n_heads // model_parallel_size
|
|
86
|
+
self.n_local_kv_heads = self.encoder_n_kv_heads // model_parallel_size
|
|
87
|
+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
|
88
|
+
self.head_dim = encoder_dim // encoder_n_heads
|
|
89
|
+
|
|
90
|
+
self.wq = nn.Linear(
|
|
91
|
+
encoder_dim,
|
|
92
|
+
encoder_n_heads * self.head_dim,
|
|
93
|
+
)
|
|
94
|
+
self.wk = nn.Linear(
|
|
95
|
+
encoder_dim,
|
|
96
|
+
self.encoder_n_kv_heads * self.head_dim,
|
|
97
|
+
)
|
|
98
|
+
self.wv = nn.Linear(
|
|
99
|
+
encoder_dim,
|
|
100
|
+
self.encoder_n_kv_heads * self.head_dim,
|
|
101
|
+
)
|
|
102
|
+
self.wo = nn.Linear(
|
|
103
|
+
encoder_n_heads * self.head_dim,
|
|
104
|
+
encoder_dim,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def forward(
|
|
108
|
+
self,
|
|
109
|
+
x: torch.Tensor,
|
|
110
|
+
start_pos: int,
|
|
111
|
+
freqs_cis: torch.Tensor,
|
|
112
|
+
mask: Optional[torch.Tensor],
|
|
113
|
+
):
|
|
114
|
+
bsz, seqlen, _ = x.shape
|
|
115
|
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
|
116
|
+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
|
117
|
+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
|
118
|
+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
|
119
|
+
|
|
120
|
+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
|
121
|
+
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
|
122
|
+
keys = xk.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
|
123
|
+
values = xv.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
|
124
|
+
|
|
125
|
+
output = F.scaled_dot_product_attention(xq, keys, values, mask[:, None, None, :], is_causal=False)
|
|
126
|
+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
|
127
|
+
return self.wo(output)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class FeedForward(nn.Module):
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
dim: int,
|
|
134
|
+
hidden_dim: int,
|
|
135
|
+
multiple_of: int,
|
|
136
|
+
ffn_dim_multiplier: Optional[float],
|
|
137
|
+
):
|
|
138
|
+
super().__init__()
|
|
139
|
+
if ffn_dim_multiplier is not None:
|
|
140
|
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
|
141
|
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
|
142
|
+
|
|
143
|
+
self.w1 = nn.Linear(
|
|
144
|
+
dim, hidden_dim
|
|
145
|
+
)
|
|
146
|
+
self.w2 = nn.Linear(
|
|
147
|
+
hidden_dim, dim
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
def forward(self, x):
|
|
151
|
+
return self.w2(F.silu(self.w1(x)))
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class TransformerBlock(nn.Module):
|
|
155
|
+
def __init__(self, encoder_dim, encoder_n_heads, max_seq_len):
|
|
156
|
+
super().__init__()
|
|
157
|
+
self.encoder_n_heads = encoder_n_heads
|
|
158
|
+
self.encoder_dim = encoder_dim
|
|
159
|
+
self.head_dim = encoder_dim // encoder_n_heads
|
|
160
|
+
self.attention = Attention(encoder_dim, encoder_n_heads, max_seq_len)
|
|
161
|
+
self.feed_forward = FeedForward(
|
|
162
|
+
dim=encoder_dim,
|
|
163
|
+
hidden_dim=2 * encoder_dim,
|
|
164
|
+
multiple_of=256,
|
|
165
|
+
ffn_dim_multiplier=None,
|
|
166
|
+
)
|
|
167
|
+
self.attention_norm = AdaLNZero(encoder_dim)
|
|
168
|
+
self.ffn_norm = nn.LayerNorm(encoder_dim, elementwise_affine=False, eps=1e-6)
|
|
169
|
+
|
|
170
|
+
def forward(
|
|
171
|
+
self,
|
|
172
|
+
x: torch.Tensor,
|
|
173
|
+
t: torch.Tensor,
|
|
174
|
+
start_pos: int,
|
|
175
|
+
freqs_cis: torch.Tensor,
|
|
176
|
+
mask: Optional[torch.Tensor],
|
|
177
|
+
):
|
|
178
|
+
"""
|
|
179
|
+
Perform a forward pass through the TransformerBlock.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
x (torch.Tensor): Input tensor.
|
|
183
|
+
start_pos (int): Starting position for attention caching.
|
|
184
|
+
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
|
185
|
+
mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
torch.Tensor: Output tensor after applying attention and feedforward layers.
|
|
189
|
+
|
|
190
|
+
"""
|
|
191
|
+
# pre-norm & modulation for attention input
|
|
192
|
+
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attention_norm(x, emb=t)
|
|
193
|
+
|
|
194
|
+
# attention
|
|
195
|
+
attn_output = self.attention(norm, start_pos, freqs_cis, mask=mask)
|
|
196
|
+
|
|
197
|
+
# process attention output for input x
|
|
198
|
+
h = x + gate_msa.unsqueeze(1) * attn_output
|
|
199
|
+
|
|
200
|
+
norm = self.ffn_norm(h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
|
201
|
+
ff_output = self.feed_forward(norm)
|
|
202
|
+
out = h + gate_mlp.unsqueeze(1) * ff_output
|
|
203
|
+
|
|
204
|
+
return out
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class Transformer(nn.Module):
|
|
208
|
+
def __init__(self, encoder_n_layers, encoder_dim, encoder_n_heads, max_seq_len):
|
|
209
|
+
super().__init__()
|
|
210
|
+
# Decoder
|
|
211
|
+
self.layers = torch.nn.ModuleList()
|
|
212
|
+
for _ in range(encoder_n_layers):
|
|
213
|
+
self.layers.append(TransformerBlock(encoder_dim, encoder_n_heads, max_seq_len))
|
|
214
|
+
|
|
215
|
+
self.norm = AdaLNZero_Out(encoder_dim)
|
|
216
|
+
self.out_proj = nn.Linear(encoder_dim, encoder_dim)
|
|
217
|
+
|
|
218
|
+
# Rope embedding
|
|
219
|
+
freqs_cis = precompute_freqs_cis(
|
|
220
|
+
encoder_dim // encoder_n_heads, max_seq_len
|
|
221
|
+
)
|
|
222
|
+
self.register_buffer("freqs_cis", torch.view_as_real(freqs_cis), persistent=False)
|
|
223
|
+
|
|
224
|
+
def forward(self, x, t, attn_mask, start_pos=0):
|
|
225
|
+
freqs_cis = torch.view_as_complex(self.freqs_cis.float())[start_pos: start_pos + x.size(1)]
|
|
226
|
+
for i, layer in enumerate(self.layers):
|
|
227
|
+
x = layer(x, t, start_pos, freqs_cis, attn_mask)
|
|
228
|
+
x = self.norm(x, t)
|
|
229
|
+
x = self.out_proj(x)
|
|
230
|
+
return x
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
class DiagonalGaussianDistribution(object):
|
|
19
|
+
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
|
20
|
+
self.parameters = parameters
|
|
21
|
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
|
22
|
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
|
23
|
+
self.deterministic = deterministic
|
|
24
|
+
self.std = torch.exp(0.5 * self.logvar)
|
|
25
|
+
self.var = torch.exp(self.logvar)
|
|
26
|
+
if self.deterministic:
|
|
27
|
+
self.var = self.std = torch.zeros_like(
|
|
28
|
+
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def sample(self, generator=None) -> torch.Tensor:
|
|
32
|
+
# make sure sample is on the same device as the parameters and has same dtype
|
|
33
|
+
sample = torch.randn(
|
|
34
|
+
self.mean.shape,
|
|
35
|
+
generator=generator,
|
|
36
|
+
device=self.parameters.device,
|
|
37
|
+
dtype=self.parameters.dtype,
|
|
38
|
+
)
|
|
39
|
+
x = self.mean + self.std * sample
|
|
40
|
+
return x
|
|
41
|
+
|
|
42
|
+
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
|
|
43
|
+
if self.deterministic:
|
|
44
|
+
return torch.Tensor([0.0])
|
|
45
|
+
else:
|
|
46
|
+
if other is None:
|
|
47
|
+
return 0.5 * torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar
|
|
48
|
+
else:
|
|
49
|
+
return 0.5 * (
|
|
50
|
+
torch.pow(self.mean - other.mean, 2) / other.var
|
|
51
|
+
+ self.var / other.var
|
|
52
|
+
- 1.0
|
|
53
|
+
- self.logvar
|
|
54
|
+
+ other.logvar
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def nll(self, sample, dims) -> torch.Tensor:
|
|
58
|
+
if self.deterministic:
|
|
59
|
+
return torch.Tensor([0.0])
|
|
60
|
+
logtwopi = np.log(2.0 * np.pi)
|
|
61
|
+
return 0.5 * torch.sum(
|
|
62
|
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
|
63
|
+
dim=dims,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def mode(self) -> torch.Tensor:
|
|
67
|
+
return self.mean
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
import torch.nn.functional as F
|
|
17
|
+
import torch
|
|
18
|
+
import torch.utils.data
|
|
19
|
+
from librosa.filters import mel as librosa_mel_fn
|
|
20
|
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
|
21
|
+
from torch.nn import Conv1d
|
|
22
|
+
import numpy as np
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def init_weights(m, mean=0.0, std=0.01):
|
|
26
|
+
classname = m.__class__.__name__
|
|
27
|
+
if classname.find("Conv") != -1:
|
|
28
|
+
m.weight.data.normal_(mean, std)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_padding(kernel_size, dilation=1):
|
|
32
|
+
return int((kernel_size*dilation - dilation)/2)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Upsample(nn.Module):
|
|
36
|
+
def __init__(self, mult, r):
|
|
37
|
+
super(Upsample, self).__init__()
|
|
38
|
+
self.r = r
|
|
39
|
+
self.upsample = nn.Sequential(nn.Upsample(mode="nearest", scale_factor=r),
|
|
40
|
+
nn.LeakyReLU(0.2),
|
|
41
|
+
nn.ReflectionPad1d(3),
|
|
42
|
+
nn.utils.weight_norm(nn.Conv1d(mult, mult // 2, kernel_size=7, stride=1))
|
|
43
|
+
)
|
|
44
|
+
r_kernel = r if r >= 5 else 5
|
|
45
|
+
self.trans_upsample = nn.Sequential(nn.LeakyReLU(0.2),
|
|
46
|
+
nn.utils.weight_norm(nn.ConvTranspose1d(mult, mult // 2,
|
|
47
|
+
kernel_size=r_kernel * 2, stride=r,
|
|
48
|
+
padding=r_kernel - r // 2,
|
|
49
|
+
output_padding=r % 2)
|
|
50
|
+
))
|
|
51
|
+
|
|
52
|
+
def forward(self, x):
|
|
53
|
+
x = torch.sin(x) + x
|
|
54
|
+
out1 = self.upsample(x)
|
|
55
|
+
out2 = self.trans_upsample(x)
|
|
56
|
+
return out1 + out2
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class Downsample(nn.Module):
|
|
60
|
+
def __init__(self, mult, r):
|
|
61
|
+
super(Downsample, self).__init__()
|
|
62
|
+
self.r = r
|
|
63
|
+
r_kernel = r if r >= 5 else 5
|
|
64
|
+
self.trans_downsample = nn.Sequential(nn.LeakyReLU(0.2),
|
|
65
|
+
nn.utils.weight_norm(nn.Conv1d(mult, mult * 2,
|
|
66
|
+
kernel_size=r_kernel * 2, stride=r,
|
|
67
|
+
padding=r_kernel - r // 2)
|
|
68
|
+
))
|
|
69
|
+
|
|
70
|
+
def forward(self, x):
|
|
71
|
+
out = self.trans_downsample(x)
|
|
72
|
+
return out
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def weights_init(m):
|
|
76
|
+
classname = m.__class__.__name__
|
|
77
|
+
if classname.find("Conv") != -1:
|
|
78
|
+
m.weight.data.normal_(0.0, 0.02)
|
|
79
|
+
elif classname.find("BatchNorm2d") != -1:
|
|
80
|
+
m.weight.data.normal_(1.0, 0.02)
|
|
81
|
+
m.bias.data.fill_(0)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def weights_zero_init(m):
|
|
85
|
+
classname = m.__class__.__name__
|
|
86
|
+
if classname.find("Conv") != -1:
|
|
87
|
+
m.weight.data.fill_(0.0)
|
|
88
|
+
m.bias.data.fill_(0.0)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def WNConv1d(*args, **kwargs):
|
|
92
|
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def WNConvTranspose1d(*args, **kwargs):
|
|
96
|
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class Audio2Mel(nn.Module):
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
hop_length=300,
|
|
103
|
+
sampling_rate=24000,
|
|
104
|
+
n_mel_channels=80,
|
|
105
|
+
mel_fmin=0.,
|
|
106
|
+
mel_fmax=None,
|
|
107
|
+
frame_size=0.05,
|
|
108
|
+
device='cpu'
|
|
109
|
+
):
|
|
110
|
+
super().__init__()
|
|
111
|
+
##############################################
|
|
112
|
+
# FFT Parameters #
|
|
113
|
+
##############################################
|
|
114
|
+
|
|
115
|
+
self.n_fft = int(np.power(2., np.ceil(np.log(sampling_rate * frame_size) / np.log(2))))
|
|
116
|
+
window = torch.hann_window(int(sampling_rate * frame_size)).float()
|
|
117
|
+
mel_basis = librosa_mel_fn(
|
|
118
|
+
sampling_rate, self.n_fft, n_mel_channels, mel_fmin, mel_fmax
|
|
119
|
+
) # Mel filter (by librosa)
|
|
120
|
+
mel_basis = torch.from_numpy(mel_basis).float()
|
|
121
|
+
self.register_buffer("mel_basis", mel_basis)
|
|
122
|
+
self.register_buffer("window", window)
|
|
123
|
+
|
|
124
|
+
self.hop_length = hop_length
|
|
125
|
+
self.win_length = int(sampling_rate * frame_size)
|
|
126
|
+
self.sampling_rate = sampling_rate
|
|
127
|
+
self.n_mel_channels = n_mel_channels
|
|
128
|
+
|
|
129
|
+
def forward(self, audio):
|
|
130
|
+
fft = torch.stft(
|
|
131
|
+
audio.squeeze(1),
|
|
132
|
+
n_fft=self.n_fft,
|
|
133
|
+
hop_length=self.hop_length,
|
|
134
|
+
win_length=self.win_length,
|
|
135
|
+
window=self.window,
|
|
136
|
+
center=True,
|
|
137
|
+
)
|
|
138
|
+
real_part, imag_part = fft.unbind(-1)
|
|
139
|
+
magnitude = torch.sqrt(torch.clamp(real_part ** 2 + imag_part ** 2, min=1e-5))
|
|
140
|
+
mel_output = torch.matmul(self.mel_basis, magnitude)
|
|
141
|
+
|
|
142
|
+
log_mel_spec = 20 * torch.log10(torch.clamp(mel_output, min=1e-5)) - 20
|
|
143
|
+
norm_mel = (log_mel_spec + 115.) / 115.
|
|
144
|
+
mel_comp = torch.clamp(norm_mel * 8. - 4., -4., 4.)
|
|
145
|
+
|
|
146
|
+
return mel_comp
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class ResnetBlock(nn.Module):
|
|
150
|
+
def __init__(self, dim, dilation=1, dim_in=None):
|
|
151
|
+
super().__init__()
|
|
152
|
+
if dim_in is None:
|
|
153
|
+
dim_in = dim
|
|
154
|
+
|
|
155
|
+
self.block = nn.Sequential(
|
|
156
|
+
nn.LeakyReLU(0.2),
|
|
157
|
+
nn.ReflectionPad1d(dilation),
|
|
158
|
+
WNConv1d(dim_in, dim, kernel_size=3, dilation=dilation),
|
|
159
|
+
nn.LeakyReLU(0.2),
|
|
160
|
+
WNConv1d(dim, dim, kernel_size=1),
|
|
161
|
+
)
|
|
162
|
+
self.shortcut = WNConv1d(dim_in, dim, kernel_size=1)
|
|
163
|
+
|
|
164
|
+
def forward(self, x):
|
|
165
|
+
return self.shortcut(x) + self.block(x)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
'''
|
|
169
|
+
参照hifigan(https://arxiv.org/pdf/2010.05646.pdf)v2结构
|
|
170
|
+
多尺度主要是kernel_size不同,3组并行卷积模块,每个卷积模块内部采用不同的串行dilation size,且中间交叉正常无dilation卷积层
|
|
171
|
+
'''
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class ResBlockMRFV2(torch.nn.Module):
|
|
175
|
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
|
176
|
+
super(ResBlockMRFV2, self).__init__()
|
|
177
|
+
self.convs1 = nn.ModuleList([
|
|
178
|
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
|
179
|
+
padding=get_padding(kernel_size, dilation[0]))),
|
|
180
|
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
|
181
|
+
padding=get_padding(kernel_size, dilation[1]))),
|
|
182
|
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
|
183
|
+
padding=get_padding(kernel_size, dilation[2])))
|
|
184
|
+
])
|
|
185
|
+
self.convs1.apply(init_weights)
|
|
186
|
+
|
|
187
|
+
self.convs2 = nn.ModuleList([
|
|
188
|
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
|
189
|
+
padding=get_padding(kernel_size, 1))),
|
|
190
|
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
|
191
|
+
padding=get_padding(kernel_size, 1))),
|
|
192
|
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
|
193
|
+
padding=get_padding(kernel_size, 1)))
|
|
194
|
+
])
|
|
195
|
+
self.convs2.apply(init_weights)
|
|
196
|
+
|
|
197
|
+
def forward(self, x):
|
|
198
|
+
for c1, c2 in zip(self.convs1, self.convs2):
|
|
199
|
+
xt = F.leaky_relu(x, 0.2)
|
|
200
|
+
xt = c1(xt)
|
|
201
|
+
xt = F.leaky_relu(xt, 0.2)
|
|
202
|
+
xt = c2(xt)
|
|
203
|
+
x = xt + x
|
|
204
|
+
return x
|
|
205
|
+
|
|
206
|
+
def remove_weight_norm(self):
|
|
207
|
+
for l in self.convs1:
|
|
208
|
+
remove_weight_norm(l)
|
|
209
|
+
for l in self.convs2:
|
|
210
|
+
remove_weight_norm(l)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class ResBlockMRFV2Inter(torch.nn.Module):
|
|
214
|
+
def __init__(self, channels, kernel_size=3):
|
|
215
|
+
super(ResBlockMRFV2Inter, self).__init__()
|
|
216
|
+
self.block1 = ResBlockMRFV2(channels)
|
|
217
|
+
self.block2 = ResBlockMRFV2(channels, 7)
|
|
218
|
+
self.block3 = ResBlockMRFV2(channels, 11)
|
|
219
|
+
|
|
220
|
+
def forward(self, x):
|
|
221
|
+
xs = self.block1(x)
|
|
222
|
+
xs += self.block2(x)
|
|
223
|
+
xs += self.block3(x)
|
|
224
|
+
x = xs / 3
|
|
225
|
+
return x
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class Generator(nn.Module):
|
|
229
|
+
def __init__(self, input_size_, ngf, n_residual_layers, num_band, args, ratios=[5, 5, 4, 3], onnx_export=False,
|
|
230
|
+
device='cpu'):
|
|
231
|
+
super().__init__()
|
|
232
|
+
self.hop_length = args.frame_shift
|
|
233
|
+
self.args = args
|
|
234
|
+
self.onnx_export = onnx_export
|
|
235
|
+
|
|
236
|
+
# ------------- Define upsample layers ----------------
|
|
237
|
+
mult = int(2 ** len(ratios))
|
|
238
|
+
model_up = []
|
|
239
|
+
input_size = input_size_
|
|
240
|
+
model_up += [
|
|
241
|
+
nn.ReflectionPad1d(3),
|
|
242
|
+
WNConv1d(input_size, mult * ngf, kernel_size=7, padding=0),
|
|
243
|
+
]
|
|
244
|
+
|
|
245
|
+
# Upsample to raw audio scale
|
|
246
|
+
for i, r in enumerate(ratios):
|
|
247
|
+
model_up += [Upsample(mult * ngf, r)]
|
|
248
|
+
model_up += [ResBlockMRFV2Inter(mult * ngf // 2)]
|
|
249
|
+
mult //= 2
|
|
250
|
+
|
|
251
|
+
model_up += [
|
|
252
|
+
nn.LeakyReLU(0.2),
|
|
253
|
+
nn.ReflectionPad1d(3),
|
|
254
|
+
WNConv1d(ngf, num_band, kernel_size=7, padding=0),
|
|
255
|
+
nn.Tanh(),
|
|
256
|
+
]
|
|
257
|
+
if not args.use_tanh:
|
|
258
|
+
model_up[-1] = nn.Conv1d(num_band, num_band, 1)
|
|
259
|
+
model_up[-2].apply(weights_zero_init)
|
|
260
|
+
|
|
261
|
+
self.model_up = nn.Sequential(*model_up)
|
|
262
|
+
|
|
263
|
+
self.apply(weights_init)
|
|
264
|
+
|
|
265
|
+
def forward(self, mel, step=None):
|
|
266
|
+
# mel input: (batch_size, seq_num, 80)
|
|
267
|
+
if self.onnx_export:
|
|
268
|
+
mel = mel.transpose(1, 2)
|
|
269
|
+
# on onnx, for engineering, mel input: (batch_size, 80, seq_num)
|
|
270
|
+
|
|
271
|
+
# Between Down and up
|
|
272
|
+
x = mel
|
|
273
|
+
|
|
274
|
+
# Upsample pipline
|
|
275
|
+
cnt_after_upsample = 0
|
|
276
|
+
|
|
277
|
+
for i, m in enumerate(self.model_up):
|
|
278
|
+
x = m(x)
|
|
279
|
+
|
|
280
|
+
if type(m) == Upsample:
|
|
281
|
+
cnt_after_upsample += 1
|
|
282
|
+
|
|
283
|
+
return x
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import List
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
from tts.modules.wavvae.encoder.common_modules.seanet import SEANetEncoder
|
|
20
|
+
|
|
21
|
+
class Encoder(nn.Module):
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
dowmsamples: List[int] = [6, 5, 5, 4, 2],
|
|
25
|
+
):
|
|
26
|
+
super().__init__()
|
|
27
|
+
|
|
28
|
+
# breakpoint()
|
|
29
|
+
self.frame_rate = 25 # not use
|
|
30
|
+
self.encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
|
|
31
|
+
dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU',
|
|
32
|
+
kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
|
|
33
|
+
true_skip=False, compress=2)
|
|
34
|
+
|
|
35
|
+
def forward(self, audio: torch.Tensor):
|
|
36
|
+
audio = audio.unsqueeze(1) # audio(16,24000)
|
|
37
|
+
emb = self.encoder(audio)
|
|
38
|
+
return emb
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import argparse
|
|
16
|
+
import torch
|
|
17
|
+
from torch import nn
|
|
18
|
+
import torch.nn.functional as F
|
|
19
|
+
|
|
20
|
+
from tts.modules.wavvae.decoder.seanet_encoder import Encoder
|
|
21
|
+
from tts.modules.wavvae.decoder.diag_gaussian import DiagonalGaussianDistribution
|
|
22
|
+
from tts.modules.wavvae.decoder.hifigan_modules import Generator, Upsample
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class WavVAE_V3(nn.Module):
|
|
26
|
+
def __init__(self, hparams=None):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.encoder = Encoder(dowmsamples=[6, 5, 4, 4, 2])
|
|
29
|
+
self.proj_to_z = nn.Linear(512, 64)
|
|
30
|
+
self.proj_to_decoder = nn.Linear(32, 320)
|
|
31
|
+
|
|
32
|
+
config_path = hparams['melgan_config']
|
|
33
|
+
args = argparse.Namespace()
|
|
34
|
+
args.__dict__.update(config_path)
|
|
35
|
+
self.latent_upsampler = Upsample(320, 4)
|
|
36
|
+
self.decoder = Generator(
|
|
37
|
+
input_size_=160, ngf=128, n_residual_layers=4,
|
|
38
|
+
num_band=1, args=args, ratios=[5,4,4,3])
|
|
39
|
+
|
|
40
|
+
''' encode waveform into 25 hz latent representation '''
|
|
41
|
+
def encode_latent(self, audio):
|
|
42
|
+
posterior = self.encode(audio)
|
|
43
|
+
latent = posterior.sample().permute(0, 2, 1) # (b,t,latent_channel)
|
|
44
|
+
return latent
|
|
45
|
+
|
|
46
|
+
def encode(self, audio):
|
|
47
|
+
x = self.encoder(audio).permute(0, 2, 1)
|
|
48
|
+
x = self.proj_to_z(x).permute(0, 2, 1)
|
|
49
|
+
poseterior = DiagonalGaussianDistribution(x)
|
|
50
|
+
return poseterior
|
|
51
|
+
|
|
52
|
+
def decode(self, latent):
|
|
53
|
+
latent = self.proj_to_decoder(latent).permute(0, 2, 1)
|
|
54
|
+
return self.decoder(self.latent_upsampler(latent))
|
|
55
|
+
|
|
56
|
+
def forward(self, audio):
|
|
57
|
+
posterior = self.encode(audio)
|
|
58
|
+
latent = posterior.sample().permute(0, 2, 1) # (b, t, latent_channel)
|
|
59
|
+
recon_wav = self.decode(latent)
|
|
60
|
+
return recon_wav, posterior
|