xinference 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (132) hide show
  1. xinference/_compat.py +1 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +54 -1
  4. xinference/client/restful/restful_client.py +82 -2
  5. xinference/constants.py +3 -0
  6. xinference/core/chat_interface.py +297 -83
  7. xinference/core/model.py +24 -3
  8. xinference/core/progress_tracker.py +16 -8
  9. xinference/core/supervisor.py +51 -1
  10. xinference/core/worker.py +315 -47
  11. xinference/deploy/cmdline.py +33 -1
  12. xinference/model/audio/core.py +11 -1
  13. xinference/model/audio/megatts.py +105 -0
  14. xinference/model/audio/model_spec.json +24 -1
  15. xinference/model/audio/model_spec_modelscope.json +26 -1
  16. xinference/model/core.py +14 -0
  17. xinference/model/embedding/core.py +6 -1
  18. xinference/model/flexible/core.py +6 -1
  19. xinference/model/image/core.py +6 -1
  20. xinference/model/image/model_spec.json +17 -1
  21. xinference/model/image/model_spec_modelscope.json +17 -1
  22. xinference/model/llm/__init__.py +4 -6
  23. xinference/model/llm/core.py +5 -0
  24. xinference/model/llm/llama_cpp/core.py +46 -17
  25. xinference/model/llm/llm_family.json +530 -85
  26. xinference/model/llm/llm_family.py +24 -1
  27. xinference/model/llm/llm_family_modelscope.json +572 -1
  28. xinference/model/llm/mlx/core.py +16 -2
  29. xinference/model/llm/reasoning_parser.py +3 -3
  30. xinference/model/llm/sglang/core.py +111 -13
  31. xinference/model/llm/transformers/__init__.py +14 -0
  32. xinference/model/llm/transformers/core.py +31 -6
  33. xinference/model/llm/transformers/deepseek_vl.py +1 -1
  34. xinference/model/llm/transformers/deepseek_vl2.py +287 -0
  35. xinference/model/llm/transformers/gemma3.py +17 -2
  36. xinference/model/llm/transformers/intern_vl.py +28 -18
  37. xinference/model/llm/transformers/minicpmv26.py +21 -2
  38. xinference/model/llm/transformers/qwen-omni.py +308 -0
  39. xinference/model/llm/transformers/qwen2_audio.py +1 -1
  40. xinference/model/llm/transformers/qwen2_vl.py +20 -4
  41. xinference/model/llm/utils.py +37 -15
  42. xinference/model/llm/vllm/core.py +184 -8
  43. xinference/model/llm/vllm/distributed_executor.py +320 -0
  44. xinference/model/rerank/core.py +22 -12
  45. xinference/model/utils.py +118 -1
  46. xinference/model/video/core.py +6 -1
  47. xinference/thirdparty/deepseek_vl2/__init__.py +31 -0
  48. xinference/thirdparty/deepseek_vl2/models/__init__.py +26 -0
  49. xinference/thirdparty/deepseek_vl2/models/configuration_deepseek.py +210 -0
  50. xinference/thirdparty/deepseek_vl2/models/conversation.py +310 -0
  51. xinference/thirdparty/deepseek_vl2/models/modeling_deepseek.py +1975 -0
  52. xinference/thirdparty/deepseek_vl2/models/modeling_deepseek_vl_v2.py +697 -0
  53. xinference/thirdparty/deepseek_vl2/models/processing_deepseek_vl_v2.py +675 -0
  54. xinference/thirdparty/deepseek_vl2/models/siglip_vit.py +661 -0
  55. xinference/thirdparty/deepseek_vl2/serve/__init__.py +0 -0
  56. xinference/thirdparty/deepseek_vl2/serve/app_modules/__init__.py +0 -0
  57. xinference/thirdparty/deepseek_vl2/serve/app_modules/gradio_utils.py +83 -0
  58. xinference/thirdparty/deepseek_vl2/serve/app_modules/overwrites.py +81 -0
  59. xinference/thirdparty/deepseek_vl2/serve/app_modules/presets.py +115 -0
  60. xinference/thirdparty/deepseek_vl2/serve/app_modules/utils.py +333 -0
  61. xinference/thirdparty/deepseek_vl2/serve/assets/Kelpy-Codos.js +100 -0
  62. xinference/thirdparty/deepseek_vl2/serve/assets/avatar.png +0 -0
  63. xinference/thirdparty/deepseek_vl2/serve/assets/custom.css +355 -0
  64. xinference/thirdparty/deepseek_vl2/serve/assets/custom.js +22 -0
  65. xinference/thirdparty/deepseek_vl2/serve/assets/favicon.ico +0 -0
  66. xinference/thirdparty/deepseek_vl2/serve/assets/simsun.ttc +0 -0
  67. xinference/thirdparty/deepseek_vl2/serve/inference.py +197 -0
  68. xinference/thirdparty/deepseek_vl2/utils/__init__.py +18 -0
  69. xinference/thirdparty/deepseek_vl2/utils/io.py +80 -0
  70. xinference/thirdparty/megatts3/__init__.py +0 -0
  71. xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
  72. xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
  73. xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
  74. xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
  75. xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
  76. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
  77. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
  78. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
  79. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
  80. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
  81. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
  82. xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
  83. xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
  84. xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
  85. xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
  86. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
  87. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
  88. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
  89. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
  90. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
  91. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
  92. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
  93. xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
  94. xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
  95. xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
  96. xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
  97. xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
  98. xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
  99. xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
  100. xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
  101. xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
  102. xinference/types.py +10 -0
  103. xinference/utils.py +54 -0
  104. xinference/web/ui/build/asset-manifest.json +6 -6
  105. xinference/web/ui/build/index.html +1 -1
  106. xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
  107. xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
  108. xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
  109. xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
  110. xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
  111. xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
  112. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
  113. xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
  114. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
  115. xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
  116. xinference/web/ui/src/locales/en.json +2 -1
  117. xinference/web/ui/src/locales/zh.json +2 -1
  118. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/METADATA +128 -115
  119. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/RECORD +124 -63
  120. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/WHEEL +1 -1
  121. xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
  122. xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
  123. xinference/web/ui/build/static/js/main.3cea968e.js +0 -3
  124. xinference/web/ui/build/static/js/main.3cea968e.js.map +0 -1
  125. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
  126. xinference/web/ui/node_modules/.cache/babel-loader/7f59e45e3f268ab8a4788b6fb024cf8dab088736dff22f5a3a39c122a83ab930.json +0 -1
  127. xinference/web/ui/node_modules/.cache/babel-loader/dcd60488509450bfff37bfff56de2c096d51de17dd00ec60d4db49c8b483ada1.json +0 -1
  128. xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
  129. /xinference/web/ui/build/static/js/{main.3cea968e.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
  130. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/entry_points.txt +0 -0
  131. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info/licenses}/LICENSE +0 -0
  132. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,230 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Optional, Tuple
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+
23
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
24
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
25
+ t = torch.arange(end, device=freqs.device) # type: ignore
26
+ freqs = torch.outer(t, freqs).float() # type: ignore
27
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
28
+ return freqs_cis
29
+
30
+
31
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
32
+ ndim = x.ndim
33
+ assert 0 <= 1 < ndim
34
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
35
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
36
+ return freqs_cis.view(*shape)
37
+
38
+
39
+ def apply_rotary_emb(
40
+ xq: torch.Tensor,
41
+ xk: torch.Tensor,
42
+ freqs_cis: torch.Tensor,
43
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
44
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
45
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
46
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
47
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
48
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
49
+ return xq_out.type_as(xq), xk_out.type_as(xk)
50
+
51
+
52
+ class AdaLNZero(nn.Module):
53
+ def __init__(self, dim):
54
+ super().__init__()
55
+ self.silu = nn.SiLU()
56
+ self.linear = nn.Linear(dim, dim * 6)
57
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
58
+
59
+ def forward(self, x, emb=None):
60
+ emb = self.linear(self.silu(emb))
61
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
62
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
63
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
64
+
65
+
66
+ class AdaLNZero_Out(nn.Module):
67
+ def __init__(self, dim):
68
+ super().__init__()
69
+ self.silu = nn.SiLU()
70
+ self.linear = nn.Linear(dim, dim * 2)
71
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
72
+
73
+ def forward(self, x, emb):
74
+ emb = self.linear(self.silu(emb))
75
+ scale, shift = torch.chunk(emb, 2, dim=1)
76
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
77
+ return x
78
+
79
+
80
+ class Attention(nn.Module):
81
+ def __init__(self, encoder_dim, encoder_n_heads, max_seq_len):
82
+ super().__init__()
83
+ self.encoder_n_kv_heads = encoder_n_heads
84
+ model_parallel_size = 1
85
+ self.n_local_heads = encoder_n_heads // model_parallel_size
86
+ self.n_local_kv_heads = self.encoder_n_kv_heads // model_parallel_size
87
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
88
+ self.head_dim = encoder_dim // encoder_n_heads
89
+
90
+ self.wq = nn.Linear(
91
+ encoder_dim,
92
+ encoder_n_heads * self.head_dim,
93
+ )
94
+ self.wk = nn.Linear(
95
+ encoder_dim,
96
+ self.encoder_n_kv_heads * self.head_dim,
97
+ )
98
+ self.wv = nn.Linear(
99
+ encoder_dim,
100
+ self.encoder_n_kv_heads * self.head_dim,
101
+ )
102
+ self.wo = nn.Linear(
103
+ encoder_n_heads * self.head_dim,
104
+ encoder_dim,
105
+ )
106
+
107
+ def forward(
108
+ self,
109
+ x: torch.Tensor,
110
+ start_pos: int,
111
+ freqs_cis: torch.Tensor,
112
+ mask: Optional[torch.Tensor],
113
+ ):
114
+ bsz, seqlen, _ = x.shape
115
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
116
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
117
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
118
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
119
+
120
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
121
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
122
+ keys = xk.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
123
+ values = xv.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
124
+
125
+ output = F.scaled_dot_product_attention(xq, keys, values, mask[:, None, None, :], is_causal=False)
126
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
127
+ return self.wo(output)
128
+
129
+
130
+ class FeedForward(nn.Module):
131
+ def __init__(
132
+ self,
133
+ dim: int,
134
+ hidden_dim: int,
135
+ multiple_of: int,
136
+ ffn_dim_multiplier: Optional[float],
137
+ ):
138
+ super().__init__()
139
+ if ffn_dim_multiplier is not None:
140
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
141
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
142
+
143
+ self.w1 = nn.Linear(
144
+ dim, hidden_dim
145
+ )
146
+ self.w2 = nn.Linear(
147
+ hidden_dim, dim
148
+ )
149
+
150
+ def forward(self, x):
151
+ return self.w2(F.silu(self.w1(x)))
152
+
153
+
154
+ class TransformerBlock(nn.Module):
155
+ def __init__(self, encoder_dim, encoder_n_heads, max_seq_len):
156
+ super().__init__()
157
+ self.encoder_n_heads = encoder_n_heads
158
+ self.encoder_dim = encoder_dim
159
+ self.head_dim = encoder_dim // encoder_n_heads
160
+ self.attention = Attention(encoder_dim, encoder_n_heads, max_seq_len)
161
+ self.feed_forward = FeedForward(
162
+ dim=encoder_dim,
163
+ hidden_dim=2 * encoder_dim,
164
+ multiple_of=256,
165
+ ffn_dim_multiplier=None,
166
+ )
167
+ self.attention_norm = AdaLNZero(encoder_dim)
168
+ self.ffn_norm = nn.LayerNorm(encoder_dim, elementwise_affine=False, eps=1e-6)
169
+
170
+ def forward(
171
+ self,
172
+ x: torch.Tensor,
173
+ t: torch.Tensor,
174
+ start_pos: int,
175
+ freqs_cis: torch.Tensor,
176
+ mask: Optional[torch.Tensor],
177
+ ):
178
+ """
179
+ Perform a forward pass through the TransformerBlock.
180
+
181
+ Args:
182
+ x (torch.Tensor): Input tensor.
183
+ start_pos (int): Starting position for attention caching.
184
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
185
+ mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
186
+
187
+ Returns:
188
+ torch.Tensor: Output tensor after applying attention and feedforward layers.
189
+
190
+ """
191
+ # pre-norm & modulation for attention input
192
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attention_norm(x, emb=t)
193
+
194
+ # attention
195
+ attn_output = self.attention(norm, start_pos, freqs_cis, mask=mask)
196
+
197
+ # process attention output for input x
198
+ h = x + gate_msa.unsqueeze(1) * attn_output
199
+
200
+ norm = self.ffn_norm(h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
201
+ ff_output = self.feed_forward(norm)
202
+ out = h + gate_mlp.unsqueeze(1) * ff_output
203
+
204
+ return out
205
+
206
+
207
+ class Transformer(nn.Module):
208
+ def __init__(self, encoder_n_layers, encoder_dim, encoder_n_heads, max_seq_len):
209
+ super().__init__()
210
+ # Decoder
211
+ self.layers = torch.nn.ModuleList()
212
+ for _ in range(encoder_n_layers):
213
+ self.layers.append(TransformerBlock(encoder_dim, encoder_n_heads, max_seq_len))
214
+
215
+ self.norm = AdaLNZero_Out(encoder_dim)
216
+ self.out_proj = nn.Linear(encoder_dim, encoder_dim)
217
+
218
+ # Rope embedding
219
+ freqs_cis = precompute_freqs_cis(
220
+ encoder_dim // encoder_n_heads, max_seq_len
221
+ )
222
+ self.register_buffer("freqs_cis", torch.view_as_real(freqs_cis), persistent=False)
223
+
224
+ def forward(self, x, t, attn_mask, start_pos=0):
225
+ freqs_cis = torch.view_as_complex(self.freqs_cis.float())[start_pos: start_pos + x.size(1)]
226
+ for i, layer in enumerate(self.layers):
227
+ x = layer(x, t, start_pos, freqs_cis, attn_mask)
228
+ x = self.norm(x, t)
229
+ x = self.out_proj(x)
230
+ return x
@@ -0,0 +1,67 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import numpy as np
17
+
18
+ class DiagonalGaussianDistribution(object):
19
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
20
+ self.parameters = parameters
21
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
22
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
23
+ self.deterministic = deterministic
24
+ self.std = torch.exp(0.5 * self.logvar)
25
+ self.var = torch.exp(self.logvar)
26
+ if self.deterministic:
27
+ self.var = self.std = torch.zeros_like(
28
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
29
+ )
30
+
31
+ def sample(self, generator=None) -> torch.Tensor:
32
+ # make sure sample is on the same device as the parameters and has same dtype
33
+ sample = torch.randn(
34
+ self.mean.shape,
35
+ generator=generator,
36
+ device=self.parameters.device,
37
+ dtype=self.parameters.dtype,
38
+ )
39
+ x = self.mean + self.std * sample
40
+ return x
41
+
42
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
43
+ if self.deterministic:
44
+ return torch.Tensor([0.0])
45
+ else:
46
+ if other is None:
47
+ return 0.5 * torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar
48
+ else:
49
+ return 0.5 * (
50
+ torch.pow(self.mean - other.mean, 2) / other.var
51
+ + self.var / other.var
52
+ - 1.0
53
+ - self.logvar
54
+ + other.logvar
55
+ )
56
+
57
+ def nll(self, sample, dims) -> torch.Tensor:
58
+ if self.deterministic:
59
+ return torch.Tensor([0.0])
60
+ logtwopi = np.log(2.0 * np.pi)
61
+ return 0.5 * torch.sum(
62
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
63
+ dim=dims,
64
+ )
65
+
66
+ def mode(self) -> torch.Tensor:
67
+ return self.mean
@@ -0,0 +1,283 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import torch
18
+ import torch.utils.data
19
+ from librosa.filters import mel as librosa_mel_fn
20
+ from torch.nn.utils import weight_norm, remove_weight_norm
21
+ from torch.nn import Conv1d
22
+ import numpy as np
23
+
24
+
25
+ def init_weights(m, mean=0.0, std=0.01):
26
+ classname = m.__class__.__name__
27
+ if classname.find("Conv") != -1:
28
+ m.weight.data.normal_(mean, std)
29
+
30
+
31
+ def get_padding(kernel_size, dilation=1):
32
+ return int((kernel_size*dilation - dilation)/2)
33
+
34
+
35
+ class Upsample(nn.Module):
36
+ def __init__(self, mult, r):
37
+ super(Upsample, self).__init__()
38
+ self.r = r
39
+ self.upsample = nn.Sequential(nn.Upsample(mode="nearest", scale_factor=r),
40
+ nn.LeakyReLU(0.2),
41
+ nn.ReflectionPad1d(3),
42
+ nn.utils.weight_norm(nn.Conv1d(mult, mult // 2, kernel_size=7, stride=1))
43
+ )
44
+ r_kernel = r if r >= 5 else 5
45
+ self.trans_upsample = nn.Sequential(nn.LeakyReLU(0.2),
46
+ nn.utils.weight_norm(nn.ConvTranspose1d(mult, mult // 2,
47
+ kernel_size=r_kernel * 2, stride=r,
48
+ padding=r_kernel - r // 2,
49
+ output_padding=r % 2)
50
+ ))
51
+
52
+ def forward(self, x):
53
+ x = torch.sin(x) + x
54
+ out1 = self.upsample(x)
55
+ out2 = self.trans_upsample(x)
56
+ return out1 + out2
57
+
58
+
59
+ class Downsample(nn.Module):
60
+ def __init__(self, mult, r):
61
+ super(Downsample, self).__init__()
62
+ self.r = r
63
+ r_kernel = r if r >= 5 else 5
64
+ self.trans_downsample = nn.Sequential(nn.LeakyReLU(0.2),
65
+ nn.utils.weight_norm(nn.Conv1d(mult, mult * 2,
66
+ kernel_size=r_kernel * 2, stride=r,
67
+ padding=r_kernel - r // 2)
68
+ ))
69
+
70
+ def forward(self, x):
71
+ out = self.trans_downsample(x)
72
+ return out
73
+
74
+
75
+ def weights_init(m):
76
+ classname = m.__class__.__name__
77
+ if classname.find("Conv") != -1:
78
+ m.weight.data.normal_(0.0, 0.02)
79
+ elif classname.find("BatchNorm2d") != -1:
80
+ m.weight.data.normal_(1.0, 0.02)
81
+ m.bias.data.fill_(0)
82
+
83
+
84
+ def weights_zero_init(m):
85
+ classname = m.__class__.__name__
86
+ if classname.find("Conv") != -1:
87
+ m.weight.data.fill_(0.0)
88
+ m.bias.data.fill_(0.0)
89
+
90
+
91
+ def WNConv1d(*args, **kwargs):
92
+ return weight_norm(nn.Conv1d(*args, **kwargs))
93
+
94
+
95
+ def WNConvTranspose1d(*args, **kwargs):
96
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
97
+
98
+
99
+ class Audio2Mel(nn.Module):
100
+ def __init__(
101
+ self,
102
+ hop_length=300,
103
+ sampling_rate=24000,
104
+ n_mel_channels=80,
105
+ mel_fmin=0.,
106
+ mel_fmax=None,
107
+ frame_size=0.05,
108
+ device='cpu'
109
+ ):
110
+ super().__init__()
111
+ ##############################################
112
+ # FFT Parameters #
113
+ ##############################################
114
+
115
+ self.n_fft = int(np.power(2., np.ceil(np.log(sampling_rate * frame_size) / np.log(2))))
116
+ window = torch.hann_window(int(sampling_rate * frame_size)).float()
117
+ mel_basis = librosa_mel_fn(
118
+ sampling_rate, self.n_fft, n_mel_channels, mel_fmin, mel_fmax
119
+ ) # Mel filter (by librosa)
120
+ mel_basis = torch.from_numpy(mel_basis).float()
121
+ self.register_buffer("mel_basis", mel_basis)
122
+ self.register_buffer("window", window)
123
+
124
+ self.hop_length = hop_length
125
+ self.win_length = int(sampling_rate * frame_size)
126
+ self.sampling_rate = sampling_rate
127
+ self.n_mel_channels = n_mel_channels
128
+
129
+ def forward(self, audio):
130
+ fft = torch.stft(
131
+ audio.squeeze(1),
132
+ n_fft=self.n_fft,
133
+ hop_length=self.hop_length,
134
+ win_length=self.win_length,
135
+ window=self.window,
136
+ center=True,
137
+ )
138
+ real_part, imag_part = fft.unbind(-1)
139
+ magnitude = torch.sqrt(torch.clamp(real_part ** 2 + imag_part ** 2, min=1e-5))
140
+ mel_output = torch.matmul(self.mel_basis, magnitude)
141
+
142
+ log_mel_spec = 20 * torch.log10(torch.clamp(mel_output, min=1e-5)) - 20
143
+ norm_mel = (log_mel_spec + 115.) / 115.
144
+ mel_comp = torch.clamp(norm_mel * 8. - 4., -4., 4.)
145
+
146
+ return mel_comp
147
+
148
+
149
+ class ResnetBlock(nn.Module):
150
+ def __init__(self, dim, dilation=1, dim_in=None):
151
+ super().__init__()
152
+ if dim_in is None:
153
+ dim_in = dim
154
+
155
+ self.block = nn.Sequential(
156
+ nn.LeakyReLU(0.2),
157
+ nn.ReflectionPad1d(dilation),
158
+ WNConv1d(dim_in, dim, kernel_size=3, dilation=dilation),
159
+ nn.LeakyReLU(0.2),
160
+ WNConv1d(dim, dim, kernel_size=1),
161
+ )
162
+ self.shortcut = WNConv1d(dim_in, dim, kernel_size=1)
163
+
164
+ def forward(self, x):
165
+ return self.shortcut(x) + self.block(x)
166
+
167
+
168
+ '''
169
+ 参照hifigan(https://arxiv.org/pdf/2010.05646.pdf)v2结构
170
+ 多尺度主要是kernel_size不同,3组并行卷积模块,每个卷积模块内部采用不同的串行dilation size,且中间交叉正常无dilation卷积层
171
+ '''
172
+
173
+
174
+ class ResBlockMRFV2(torch.nn.Module):
175
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
176
+ super(ResBlockMRFV2, self).__init__()
177
+ self.convs1 = nn.ModuleList([
178
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
179
+ padding=get_padding(kernel_size, dilation[0]))),
180
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
181
+ padding=get_padding(kernel_size, dilation[1]))),
182
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
183
+ padding=get_padding(kernel_size, dilation[2])))
184
+ ])
185
+ self.convs1.apply(init_weights)
186
+
187
+ self.convs2 = nn.ModuleList([
188
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
189
+ padding=get_padding(kernel_size, 1))),
190
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
191
+ padding=get_padding(kernel_size, 1))),
192
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
193
+ padding=get_padding(kernel_size, 1)))
194
+ ])
195
+ self.convs2.apply(init_weights)
196
+
197
+ def forward(self, x):
198
+ for c1, c2 in zip(self.convs1, self.convs2):
199
+ xt = F.leaky_relu(x, 0.2)
200
+ xt = c1(xt)
201
+ xt = F.leaky_relu(xt, 0.2)
202
+ xt = c2(xt)
203
+ x = xt + x
204
+ return x
205
+
206
+ def remove_weight_norm(self):
207
+ for l in self.convs1:
208
+ remove_weight_norm(l)
209
+ for l in self.convs2:
210
+ remove_weight_norm(l)
211
+
212
+
213
+ class ResBlockMRFV2Inter(torch.nn.Module):
214
+ def __init__(self, channels, kernel_size=3):
215
+ super(ResBlockMRFV2Inter, self).__init__()
216
+ self.block1 = ResBlockMRFV2(channels)
217
+ self.block2 = ResBlockMRFV2(channels, 7)
218
+ self.block3 = ResBlockMRFV2(channels, 11)
219
+
220
+ def forward(self, x):
221
+ xs = self.block1(x)
222
+ xs += self.block2(x)
223
+ xs += self.block3(x)
224
+ x = xs / 3
225
+ return x
226
+
227
+
228
+ class Generator(nn.Module):
229
+ def __init__(self, input_size_, ngf, n_residual_layers, num_band, args, ratios=[5, 5, 4, 3], onnx_export=False,
230
+ device='cpu'):
231
+ super().__init__()
232
+ self.hop_length = args.frame_shift
233
+ self.args = args
234
+ self.onnx_export = onnx_export
235
+
236
+ # ------------- Define upsample layers ----------------
237
+ mult = int(2 ** len(ratios))
238
+ model_up = []
239
+ input_size = input_size_
240
+ model_up += [
241
+ nn.ReflectionPad1d(3),
242
+ WNConv1d(input_size, mult * ngf, kernel_size=7, padding=0),
243
+ ]
244
+
245
+ # Upsample to raw audio scale
246
+ for i, r in enumerate(ratios):
247
+ model_up += [Upsample(mult * ngf, r)]
248
+ model_up += [ResBlockMRFV2Inter(mult * ngf // 2)]
249
+ mult //= 2
250
+
251
+ model_up += [
252
+ nn.LeakyReLU(0.2),
253
+ nn.ReflectionPad1d(3),
254
+ WNConv1d(ngf, num_band, kernel_size=7, padding=0),
255
+ nn.Tanh(),
256
+ ]
257
+ if not args.use_tanh:
258
+ model_up[-1] = nn.Conv1d(num_band, num_band, 1)
259
+ model_up[-2].apply(weights_zero_init)
260
+
261
+ self.model_up = nn.Sequential(*model_up)
262
+
263
+ self.apply(weights_init)
264
+
265
+ def forward(self, mel, step=None):
266
+ # mel input: (batch_size, seq_num, 80)
267
+ if self.onnx_export:
268
+ mel = mel.transpose(1, 2)
269
+ # on onnx, for engineering, mel input: (batch_size, 80, seq_num)
270
+
271
+ # Between Down and up
272
+ x = mel
273
+
274
+ # Upsample pipline
275
+ cnt_after_upsample = 0
276
+
277
+ for i, m in enumerate(self.model_up):
278
+ x = m(x)
279
+
280
+ if type(m) == Upsample:
281
+ cnt_after_upsample += 1
282
+
283
+ return x
@@ -0,0 +1,38 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List
16
+
17
+ import torch
18
+ from torch import nn
19
+ from tts.modules.wavvae.encoder.common_modules.seanet import SEANetEncoder
20
+
21
+ class Encoder(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dowmsamples: List[int] = [6, 5, 5, 4, 2],
25
+ ):
26
+ super().__init__()
27
+
28
+ # breakpoint()
29
+ self.frame_rate = 25 # not use
30
+ self.encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
31
+ dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU',
32
+ kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
33
+ true_skip=False, compress=2)
34
+
35
+ def forward(self, audio: torch.Tensor):
36
+ audio = audio.unsqueeze(1) # audio(16,24000)
37
+ emb = self.encoder(audio)
38
+ return emb
@@ -0,0 +1,60 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import torch
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+
20
+ from tts.modules.wavvae.decoder.seanet_encoder import Encoder
21
+ from tts.modules.wavvae.decoder.diag_gaussian import DiagonalGaussianDistribution
22
+ from tts.modules.wavvae.decoder.hifigan_modules import Generator, Upsample
23
+
24
+
25
+ class WavVAE_V3(nn.Module):
26
+ def __init__(self, hparams=None):
27
+ super().__init__()
28
+ self.encoder = Encoder(dowmsamples=[6, 5, 4, 4, 2])
29
+ self.proj_to_z = nn.Linear(512, 64)
30
+ self.proj_to_decoder = nn.Linear(32, 320)
31
+
32
+ config_path = hparams['melgan_config']
33
+ args = argparse.Namespace()
34
+ args.__dict__.update(config_path)
35
+ self.latent_upsampler = Upsample(320, 4)
36
+ self.decoder = Generator(
37
+ input_size_=160, ngf=128, n_residual_layers=4,
38
+ num_band=1, args=args, ratios=[5,4,4,3])
39
+
40
+ ''' encode waveform into 25 hz latent representation '''
41
+ def encode_latent(self, audio):
42
+ posterior = self.encode(audio)
43
+ latent = posterior.sample().permute(0, 2, 1) # (b,t,latent_channel)
44
+ return latent
45
+
46
+ def encode(self, audio):
47
+ x = self.encoder(audio).permute(0, 2, 1)
48
+ x = self.proj_to_z(x).permute(0, 2, 1)
49
+ poseterior = DiagonalGaussianDistribution(x)
50
+ return poseterior
51
+
52
+ def decode(self, latent):
53
+ latent = self.proj_to_decoder(latent).permute(0, 2, 1)
54
+ return self.decoder(self.latent_upsampler(latent))
55
+
56
+ def forward(self, audio):
57
+ posterior = self.encode(audio)
58
+ latent = posterior.sample().permute(0, 2, 1) # (b, t, latent_channel)
59
+ recon_wav = self.decode(latent)
60
+ return recon_wav, posterior