xinference 1.5.0.post2__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +107 -11
- xinference/client/restful/restful_client.py +51 -11
- xinference/constants.py +5 -1
- xinference/core/media_interface.py +758 -0
- xinference/core/model.py +49 -9
- xinference/core/supervisor.py +1 -1
- xinference/core/utils.py +1 -1
- xinference/core/worker.py +33 -39
- xinference/deploy/cmdline.py +17 -0
- xinference/deploy/utils.py +0 -3
- xinference/model/audio/__init__.py +16 -27
- xinference/model/audio/core.py +2 -1
- xinference/model/audio/cosyvoice.py +4 -2
- xinference/model/audio/model_spec.json +63 -46
- xinference/model/audio/model_spec_modelscope.json +31 -14
- xinference/model/embedding/__init__.py +16 -24
- xinference/model/image/__init__.py +15 -25
- xinference/model/llm/__init__.py +40 -115
- xinference/model/llm/core.py +29 -6
- xinference/model/llm/llama_cpp/core.py +30 -347
- xinference/model/llm/llm_family.json +1674 -2203
- xinference/model/llm/llm_family.py +71 -7
- xinference/model/llm/llm_family_csghub.json +0 -32
- xinference/model/llm/llm_family_modelscope.json +1838 -2016
- xinference/model/llm/llm_family_openmind_hub.json +19 -325
- xinference/model/llm/lmdeploy/core.py +7 -2
- xinference/model/llm/mlx/core.py +23 -7
- xinference/model/llm/reasoning_parser.py +281 -5
- xinference/model/llm/sglang/core.py +39 -11
- xinference/model/llm/transformers/chatglm.py +9 -2
- xinference/model/llm/transformers/cogagent.py +10 -12
- xinference/model/llm/transformers/cogvlm2.py +6 -3
- xinference/model/llm/transformers/cogvlm2_video.py +3 -6
- xinference/model/llm/transformers/core.py +58 -60
- xinference/model/llm/transformers/deepseek_v2.py +4 -2
- xinference/model/llm/transformers/deepseek_vl.py +10 -4
- xinference/model/llm/transformers/deepseek_vl2.py +9 -4
- xinference/model/llm/transformers/gemma3.py +4 -5
- xinference/model/llm/transformers/glm4v.py +3 -21
- xinference/model/llm/transformers/glm_edge_v.py +3 -20
- xinference/model/llm/transformers/intern_vl.py +3 -6
- xinference/model/llm/transformers/internlm2.py +1 -1
- xinference/model/llm/transformers/minicpmv25.py +4 -2
- xinference/model/llm/transformers/minicpmv26.py +5 -3
- xinference/model/llm/transformers/omnilmm.py +1 -1
- xinference/model/llm/transformers/opt.py +1 -1
- xinference/model/llm/transformers/ovis2.py +302 -0
- xinference/model/llm/transformers/qwen-omni.py +8 -1
- xinference/model/llm/transformers/qwen2_audio.py +3 -1
- xinference/model/llm/transformers/qwen2_vl.py +5 -1
- xinference/model/llm/transformers/qwen_vl.py +5 -2
- xinference/model/llm/utils.py +96 -45
- xinference/model/llm/vllm/core.py +108 -24
- xinference/model/llm/vllm/distributed_executor.py +8 -7
- xinference/model/llm/vllm/xavier/allocator.py +1 -1
- xinference/model/llm/vllm/xavier/block_manager.py +1 -1
- xinference/model/llm/vllm/xavier/block_tracker.py +3 -3
- xinference/model/llm/vllm/xavier/executor.py +1 -1
- xinference/model/llm/vllm/xavier/test/test_xavier.py +2 -11
- xinference/model/rerank/__init__.py +13 -24
- xinference/model/video/__init__.py +15 -25
- xinference/model/video/core.py +3 -3
- xinference/model/video/diffusers.py +157 -13
- xinference/model/video/model_spec.json +100 -0
- xinference/model/video/model_spec_modelscope.json +104 -0
- xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
- xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
- xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
- xinference/thirdparty/cosyvoice/bin/train.py +7 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
- xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
- xinference/thirdparty/cosyvoice/cli/model.py +140 -155
- xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
- xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
- xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
- xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
- xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
- xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
- xinference/thirdparty/cosyvoice/utils/common.py +1 -1
- xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
- xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
- xinference/types.py +2 -71
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.0f6523be.css → main.337afe76.css} +2 -2
- xinference/web/ui/build/static/css/main.337afe76.css.map +1 -0
- xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
- xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6798e126f3bc5f95a4c16a9c2ad52ffe77970c62406d83e20604dfda7ffd2247.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b617f7d21a95045fc57b26a9373551740f1978a826134cbf705c3a1bf8714a93.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c1506cb142151366074975f30fa1ff9cd6e5e978b62a4b074dfc16fe08d70d75.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c5c7c2cd1b863ce41adff2c4737bba06eef3a1acf28288cb83d992060f6b8923.json +1 -0
- xinference/web/ui/src/locales/en.json +7 -4
- xinference/web/ui/src/locales/zh.json +7 -4
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/METADATA +56 -36
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/RECORD +120 -121
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/WHEEL +1 -1
- xinference/core/image_interface.py +0 -377
- xinference/model/llm/transformers/compression.py +0 -258
- xinference/model/llm/transformers/yi_vl.py +0 -239
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
- xinference/web/ui/build/static/css/main.0f6523be.css.map +0 -1
- xinference/web/ui/build/static/js/main.4b67a723.js +0 -3
- xinference/web/ui/build/static/js/main.4b67a723.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8f9af2979e45d4648f0cfae108363e58ee421c29a9d4e7329b6f06d9adfd4133.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/9c8b1a86e7c65b2b2599a205e30920652d6c2105f926508ef5bcf29a3ef4ce76.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e4ba658c6b3b0490910acdae0c535a892257efb61539a24adf8038fc653bd22f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/efe7cd132c27a8f9fd5352a394c491fd5fb0da0348cf9fcbd923164a32365eab.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f199e8173f6409a5802ed44acb95f218388131136504b2e9132129e150c92f9a.json +0 -1
- /xinference/web/ui/build/static/js/{main.4b67a723.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -91,6 +91,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
|
91
91
|
conds = conds.transpose(1, 2)
|
|
92
92
|
|
|
93
93
|
mask = (~make_pad_mask(feat_len)).to(h)
|
|
94
|
+
# NOTE this is unnecessary, feat/h already same shape
|
|
94
95
|
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
|
95
96
|
loss, _ = self.decoder.compute_loss(
|
|
96
97
|
feat.transpose(1, 2).contiguous(),
|
|
@@ -116,7 +117,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
|
116
117
|
embedding = F.normalize(embedding, dim=1)
|
|
117
118
|
embedding = self.spk_embed_affine_layer(embedding)
|
|
118
119
|
|
|
119
|
-
# concat
|
|
120
|
+
# concat speech token and prompt speech token
|
|
120
121
|
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
|
121
122
|
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
|
122
123
|
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
|
@@ -129,7 +130,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
|
129
130
|
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
|
130
131
|
|
|
131
132
|
# get conditions
|
|
132
|
-
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
|
133
|
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
|
133
134
|
conds[:, :mel_len1] = prompt_feat
|
|
134
135
|
conds = conds.transpose(1, 2)
|
|
135
136
|
|
|
@@ -141,11 +142,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
|
141
142
|
cond=conds,
|
|
142
143
|
n_timesteps=10,
|
|
143
144
|
prompt_len=mel_len1,
|
|
144
|
-
|
|
145
|
+
cache=flow_cache
|
|
145
146
|
)
|
|
146
147
|
feat = feat[:, :, mel_len1:]
|
|
147
148
|
assert feat.shape[2] == mel_len2
|
|
148
|
-
return feat, flow_cache
|
|
149
|
+
return feat.float(), flow_cache
|
|
149
150
|
|
|
150
151
|
|
|
151
152
|
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
@@ -186,6 +187,53 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
|
186
187
|
self.token_mel_ratio = token_mel_ratio
|
|
187
188
|
self.pre_lookahead_len = pre_lookahead_len
|
|
188
189
|
|
|
190
|
+
def forward(
|
|
191
|
+
self,
|
|
192
|
+
batch: dict,
|
|
193
|
+
device: torch.device,
|
|
194
|
+
) -> Dict[str, Optional[torch.Tensor]]:
|
|
195
|
+
token = batch['speech_token'].to(device)
|
|
196
|
+
token_len = batch['speech_token_len'].to(device)
|
|
197
|
+
feat = batch['speech_feat'].to(device)
|
|
198
|
+
feat_len = batch['speech_feat_len'].to(device)
|
|
199
|
+
embedding = batch['embedding'].to(device)
|
|
200
|
+
|
|
201
|
+
# NOTE unified training, static_chunk_size > 0 or = 0
|
|
202
|
+
streaming = True if random.random() < 0.5 else False
|
|
203
|
+
|
|
204
|
+
# xvec projection
|
|
205
|
+
embedding = F.normalize(embedding, dim=1)
|
|
206
|
+
embedding = self.spk_embed_affine_layer(embedding)
|
|
207
|
+
|
|
208
|
+
# concat text and prompt_text
|
|
209
|
+
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
|
210
|
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
|
211
|
+
|
|
212
|
+
# text encode
|
|
213
|
+
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
|
214
|
+
h = self.encoder_proj(h)
|
|
215
|
+
|
|
216
|
+
# get conditions
|
|
217
|
+
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
|
218
|
+
conds = torch.zeros(feat.shape, device=token.device)
|
|
219
|
+
for i, j in enumerate(feat_len):
|
|
220
|
+
if random.random() < 0.5:
|
|
221
|
+
continue
|
|
222
|
+
index = random.randint(0, int(0.3 * j))
|
|
223
|
+
conds[i, :index] = feat[i, :index]
|
|
224
|
+
conds = conds.transpose(1, 2)
|
|
225
|
+
|
|
226
|
+
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
|
|
227
|
+
loss, _ = self.decoder.compute_loss(
|
|
228
|
+
feat.transpose(1, 2).contiguous(),
|
|
229
|
+
mask.unsqueeze(1),
|
|
230
|
+
h.transpose(1, 2).contiguous(),
|
|
231
|
+
embedding,
|
|
232
|
+
cond=conds,
|
|
233
|
+
streaming=streaming,
|
|
234
|
+
)
|
|
235
|
+
return {'loss': loss}
|
|
236
|
+
|
|
189
237
|
@torch.inference_mode()
|
|
190
238
|
def inference(self,
|
|
191
239
|
token,
|
|
@@ -195,6 +243,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
|
195
243
|
prompt_feat,
|
|
196
244
|
prompt_feat_len,
|
|
197
245
|
embedding,
|
|
246
|
+
cache,
|
|
198
247
|
finalize):
|
|
199
248
|
assert token.shape[0] == 1
|
|
200
249
|
# xvec projection
|
|
@@ -207,25 +256,34 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
|
207
256
|
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
|
208
257
|
|
|
209
258
|
# text encode
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
259
|
+
if finalize is True:
|
|
260
|
+
h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache'])
|
|
261
|
+
else:
|
|
262
|
+
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
|
|
263
|
+
h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache'])
|
|
264
|
+
cache['encoder_cache']['offset'] = encoder_cache[0]
|
|
265
|
+
cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1]
|
|
266
|
+
cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2]
|
|
267
|
+
cache['encoder_cache']['upsample_offset'] = encoder_cache[3]
|
|
268
|
+
cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4]
|
|
269
|
+
cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5]
|
|
213
270
|
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
|
214
271
|
h = self.encoder_proj(h)
|
|
215
272
|
|
|
216
273
|
# get conditions
|
|
217
|
-
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
|
274
|
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
|
218
275
|
conds[:, :mel_len1] = prompt_feat
|
|
219
276
|
conds = conds.transpose(1, 2)
|
|
220
277
|
|
|
221
278
|
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
|
222
|
-
feat,
|
|
279
|
+
feat, cache['decoder_cache'] = self.decoder(
|
|
223
280
|
mu=h.transpose(1, 2).contiguous(),
|
|
224
281
|
mask=mask.unsqueeze(1),
|
|
225
282
|
spks=embedding,
|
|
226
283
|
cond=conds,
|
|
227
|
-
n_timesteps=10
|
|
284
|
+
n_timesteps=10,
|
|
285
|
+
cache=cache['decoder_cache']
|
|
228
286
|
)
|
|
229
287
|
feat = feat[:, :, mel_len1:]
|
|
230
288
|
assert feat.shape[2] == mel_len2
|
|
231
|
-
return feat,
|
|
289
|
+
return feat.float(), cache
|
|
@@ -11,7 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import
|
|
14
|
+
import threading
|
|
15
15
|
import torch
|
|
16
16
|
import torch.nn.functional as F
|
|
17
17
|
from matcha.models.components.flow_matching import BASECFM
|
|
@@ -31,9 +31,10 @@ class ConditionalCFM(BASECFM):
|
|
|
31
31
|
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
|
32
32
|
# Just change the architecture of the estimator here
|
|
33
33
|
self.estimator = estimator
|
|
34
|
+
self.lock = threading.Lock()
|
|
34
35
|
|
|
35
36
|
@torch.inference_mode()
|
|
36
|
-
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0,
|
|
37
|
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
|
|
37
38
|
"""Forward diffusion
|
|
38
39
|
|
|
39
40
|
Args:
|
|
@@ -52,20 +53,20 @@ class ConditionalCFM(BASECFM):
|
|
|
52
53
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
53
54
|
"""
|
|
54
55
|
|
|
55
|
-
z = torch.randn_like(mu) * temperature
|
|
56
|
-
cache_size =
|
|
56
|
+
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
|
57
|
+
cache_size = cache.shape[2]
|
|
57
58
|
# fix prompt and overlap part mu and z
|
|
58
59
|
if cache_size != 0:
|
|
59
|
-
z[:, :, :cache_size] =
|
|
60
|
-
mu[:, :, :cache_size] =
|
|
60
|
+
z[:, :, :cache_size] = cache[:, :, :, 0]
|
|
61
|
+
mu[:, :, :cache_size] = cache[:, :, :, 1]
|
|
61
62
|
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
|
|
62
63
|
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
|
|
63
|
-
|
|
64
|
+
cache = torch.stack([z_cache, mu_cache], dim=-1)
|
|
64
65
|
|
|
65
66
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
|
66
67
|
if self.t_scheduler == 'cosine':
|
|
67
68
|
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
68
|
-
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond),
|
|
69
|
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
|
|
69
70
|
|
|
70
71
|
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
|
71
72
|
"""
|
|
@@ -89,36 +90,29 @@ class ConditionalCFM(BASECFM):
|
|
|
89
90
|
# Or in future might add like a return_all_steps flag
|
|
90
91
|
sol = []
|
|
91
92
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
100
|
-
else:
|
|
101
|
-
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
|
|
93
|
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
|
94
|
+
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
95
|
+
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
|
96
|
+
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
97
|
+
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
|
98
|
+
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
|
99
|
+
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
102
100
|
for step in range(1, len(t_span)):
|
|
103
101
|
# Classifier-Free Guidance inference introduced in VoiceBox
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
cond_in[0] = cond
|
|
111
|
-
else:
|
|
112
|
-
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
|
|
102
|
+
x_in[:] = x
|
|
103
|
+
mask_in[:] = mask
|
|
104
|
+
mu_in[0] = mu
|
|
105
|
+
t_in[:] = t.unsqueeze(0)
|
|
106
|
+
spks_in[0] = spks
|
|
107
|
+
cond_in[0] = cond
|
|
113
108
|
dphi_dt = self.forward_estimator(
|
|
114
109
|
x_in, mask_in,
|
|
115
110
|
mu_in, t_in,
|
|
116
111
|
spks_in,
|
|
117
112
|
cond_in
|
|
118
113
|
)
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
|
114
|
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
|
115
|
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
|
122
116
|
x = x + dt * dphi_dt
|
|
123
117
|
t = t + dt
|
|
124
118
|
sol.append(x)
|
|
@@ -129,36 +123,26 @@ class ConditionalCFM(BASECFM):
|
|
|
129
123
|
|
|
130
124
|
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
|
131
125
|
if isinstance(self.estimator, torch.nn.Module):
|
|
132
|
-
return self.estimator
|
|
133
|
-
elif isinstance(self.estimator, onnxruntime.InferenceSession):
|
|
134
|
-
ort_inputs = {
|
|
135
|
-
'x': x.cpu().numpy(),
|
|
136
|
-
'mask': mask.cpu().numpy(),
|
|
137
|
-
'mu': mu.cpu().numpy(),
|
|
138
|
-
't': t.cpu().numpy(),
|
|
139
|
-
'spks': spks.cpu().numpy(),
|
|
140
|
-
'cond': cond.cpu().numpy()
|
|
141
|
-
}
|
|
142
|
-
output = self.estimator.run(None, ort_inputs)[0]
|
|
143
|
-
return torch.tensor(output, dtype=x.dtype, device=x.device)
|
|
126
|
+
return self.estimator(x, mask, mu, t, spks, cond)
|
|
144
127
|
else:
|
|
145
|
-
self.
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
128
|
+
with self.lock:
|
|
129
|
+
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
|
130
|
+
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
|
131
|
+
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
|
132
|
+
self.estimator.set_input_shape('t', (2,))
|
|
133
|
+
self.estimator.set_input_shape('spks', (2, 80))
|
|
134
|
+
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
|
135
|
+
# run trt engine
|
|
136
|
+
assert self.estimator.execute_v2([x.contiguous().data_ptr(),
|
|
137
|
+
mask.contiguous().data_ptr(),
|
|
138
|
+
mu.contiguous().data_ptr(),
|
|
139
|
+
t.contiguous().data_ptr(),
|
|
140
|
+
spks.contiguous().data_ptr(),
|
|
141
|
+
cond.contiguous().data_ptr(),
|
|
142
|
+
x.data_ptr()]) is True
|
|
159
143
|
return x
|
|
160
144
|
|
|
161
|
-
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
|
145
|
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
|
162
146
|
"""Computes diffusion loss
|
|
163
147
|
|
|
164
148
|
Args:
|
|
@@ -195,7 +179,7 @@ class ConditionalCFM(BASECFM):
|
|
|
195
179
|
spks = spks * cfg_mask.view(-1, 1)
|
|
196
180
|
cond = cond * cfg_mask.view(-1, 1, 1)
|
|
197
181
|
|
|
198
|
-
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
|
182
|
+
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
|
|
199
183
|
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
|
200
184
|
return loss, y
|
|
201
185
|
|
|
@@ -206,7 +190,7 @@ class CausalConditionalCFM(ConditionalCFM):
|
|
|
206
190
|
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
|
207
191
|
|
|
208
192
|
@torch.inference_mode()
|
|
209
|
-
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
|
193
|
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}):
|
|
210
194
|
"""Forward diffusion
|
|
211
195
|
|
|
212
196
|
Args:
|
|
@@ -225,11 +209,131 @@ class CausalConditionalCFM(ConditionalCFM):
|
|
|
225
209
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
226
210
|
"""
|
|
227
211
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
212
|
+
offset = cache.pop('offset')
|
|
213
|
+
z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature
|
|
214
|
+
z = z[:, :, offset:]
|
|
215
|
+
offset += mu.size(2)
|
|
231
216
|
# fix prompt and overlap part mu and z
|
|
232
217
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
|
233
218
|
if self.t_scheduler == 'cosine':
|
|
234
219
|
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
235
|
-
|
|
220
|
+
mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache)
|
|
221
|
+
cache['offset'] = offset
|
|
222
|
+
return mel, cache
|
|
223
|
+
|
|
224
|
+
def solve_euler(self, x, t_span, mu, mask, spks, cond, cache):
|
|
225
|
+
"""
|
|
226
|
+
Fixed euler solver for ODEs.
|
|
227
|
+
Args:
|
|
228
|
+
x (torch.Tensor): random noise
|
|
229
|
+
t_span (torch.Tensor): n_timesteps interpolated
|
|
230
|
+
shape: (n_timesteps + 1,)
|
|
231
|
+
mu (torch.Tensor): output of encoder
|
|
232
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
233
|
+
mask (torch.Tensor): output_mask
|
|
234
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
235
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
236
|
+
shape: (batch_size, spk_emb_dim)
|
|
237
|
+
cond: Not used but kept for future purposes
|
|
238
|
+
"""
|
|
239
|
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
240
|
+
t = t.unsqueeze(dim=0)
|
|
241
|
+
|
|
242
|
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
|
243
|
+
# Or in future might add like a return_all_steps flag
|
|
244
|
+
sol = []
|
|
245
|
+
|
|
246
|
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
|
247
|
+
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
248
|
+
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
|
249
|
+
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
250
|
+
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
|
251
|
+
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
|
252
|
+
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
253
|
+
flow_cache_size = cache['down_blocks_kv_cache'].shape[4]
|
|
254
|
+
for step in range(1, len(t_span)):
|
|
255
|
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
|
256
|
+
x_in[:] = x
|
|
257
|
+
mask_in[:] = mask
|
|
258
|
+
mu_in[0] = mu
|
|
259
|
+
t_in[:] = t.unsqueeze(0)
|
|
260
|
+
spks_in[0] = spks
|
|
261
|
+
cond_in[0] = cond
|
|
262
|
+
cache_step = {k: v[step - 1] for k, v in cache.items()}
|
|
263
|
+
dphi_dt, cache_step = self.forward_estimator(
|
|
264
|
+
x_in, mask_in,
|
|
265
|
+
mu_in, t_in,
|
|
266
|
+
spks_in,
|
|
267
|
+
cond_in,
|
|
268
|
+
cache_step
|
|
269
|
+
)
|
|
270
|
+
# NOTE if smaller than flow_cache_size, means last chunk, no need to cache
|
|
271
|
+
if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size:
|
|
272
|
+
cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
|
|
273
|
+
cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:, :, :, -flow_cache_size:]
|
|
274
|
+
cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
|
|
275
|
+
cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:, :, :, -flow_cache_size:]
|
|
276
|
+
cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
|
|
277
|
+
cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:, :, :, -flow_cache_size:]
|
|
278
|
+
cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
|
|
279
|
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
|
280
|
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
|
281
|
+
x = x + dt * dphi_dt
|
|
282
|
+
t = t + dt
|
|
283
|
+
sol.append(x)
|
|
284
|
+
if step < len(t_span) - 1:
|
|
285
|
+
dt = t_span[step + 1] - t
|
|
286
|
+
return sol[-1].float(), cache
|
|
287
|
+
|
|
288
|
+
def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
|
|
289
|
+
if isinstance(self.estimator, torch.nn.Module):
|
|
290
|
+
x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
|
|
291
|
+
cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
|
|
292
|
+
else:
|
|
293
|
+
with self.lock:
|
|
294
|
+
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
|
295
|
+
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
|
296
|
+
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
|
297
|
+
self.estimator.set_input_shape('t', (2,))
|
|
298
|
+
self.estimator.set_input_shape('spks', (2, 80))
|
|
299
|
+
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
|
300
|
+
self.estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
|
|
301
|
+
self.estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
|
|
302
|
+
self.estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
|
|
303
|
+
self.estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
|
|
304
|
+
self.estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
|
|
305
|
+
self.estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
|
|
306
|
+
self.estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
|
|
307
|
+
# run trt engine
|
|
308
|
+
down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
|
|
309
|
+
mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
|
|
310
|
+
up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
|
|
311
|
+
assert self.estimator.execute_v2([x.contiguous().data_ptr(),
|
|
312
|
+
mask.contiguous().data_ptr(),
|
|
313
|
+
mu.contiguous().data_ptr(),
|
|
314
|
+
t.contiguous().data_ptr(),
|
|
315
|
+
spks.contiguous().data_ptr(),
|
|
316
|
+
cond.contiguous().data_ptr(),
|
|
317
|
+
cache['down_blocks_conv_cache'].contiguous().data_ptr(),
|
|
318
|
+
cache['down_blocks_kv_cache'].contiguous().data_ptr(),
|
|
319
|
+
cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
|
|
320
|
+
cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
|
|
321
|
+
cache['up_blocks_conv_cache'].contiguous().data_ptr(),
|
|
322
|
+
cache['up_blocks_kv_cache'].contiguous().data_ptr(),
|
|
323
|
+
cache['final_blocks_conv_cache'].contiguous().data_ptr(),
|
|
324
|
+
x.data_ptr(),
|
|
325
|
+
cache['down_blocks_conv_cache'].data_ptr(),
|
|
326
|
+
down_blocks_kv_cache_out.data_ptr(),
|
|
327
|
+
cache['mid_blocks_conv_cache'].data_ptr(),
|
|
328
|
+
mid_blocks_kv_cache_out.data_ptr(),
|
|
329
|
+
cache['up_blocks_conv_cache'].data_ptr(),
|
|
330
|
+
up_blocks_kv_cache_out.data_ptr(),
|
|
331
|
+
cache['final_blocks_conv_cache'].data_ptr()]) is True
|
|
332
|
+
cache = (cache['down_blocks_conv_cache'],
|
|
333
|
+
down_blocks_kv_cache_out,
|
|
334
|
+
cache['mid_blocks_conv_cache'],
|
|
335
|
+
mid_blocks_kv_cache_out,
|
|
336
|
+
cache['up_blocks_conv_cache'],
|
|
337
|
+
up_blocks_kv_cache_out,
|
|
338
|
+
cache['final_blocks_conv_cache'])
|
|
339
|
+
return x, cache
|
|
@@ -51,6 +51,7 @@ class InterpolateRegulator(nn.Module):
|
|
|
51
51
|
|
|
52
52
|
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
|
|
53
53
|
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
|
54
|
+
# NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
|
|
54
55
|
# x in (B, T, D)
|
|
55
56
|
if x2.shape[1] > 40:
|
|
56
57
|
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
|
@@ -1,10 +1,16 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
|
-
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
try:
|
|
5
|
+
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
|
|
6
|
+
except ImportError:
|
|
7
|
+
from torch.nn.utils import weight_norm, spectral_norm
|
|
4
8
|
from typing import List, Optional, Tuple
|
|
5
9
|
from einops import rearrange
|
|
6
10
|
from torchaudio.transforms import Spectrogram
|
|
7
11
|
|
|
12
|
+
LRELU_SLOPE = 0.1
|
|
13
|
+
|
|
8
14
|
|
|
9
15
|
class MultipleDiscriminator(nn.Module):
|
|
10
16
|
def __init__(
|
|
@@ -138,3 +144,87 @@ class DiscriminatorR(nn.Module):
|
|
|
138
144
|
x += h
|
|
139
145
|
|
|
140
146
|
return x, fmap
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class MultiResSpecDiscriminator(torch.nn.Module):
|
|
150
|
+
|
|
151
|
+
def __init__(self,
|
|
152
|
+
fft_sizes=[1024, 2048, 512],
|
|
153
|
+
hop_sizes=[120, 240, 50],
|
|
154
|
+
win_lengths=[600, 1200, 240],
|
|
155
|
+
window="hann_window"):
|
|
156
|
+
|
|
157
|
+
super(MultiResSpecDiscriminator, self).__init__()
|
|
158
|
+
self.discriminators = nn.ModuleList([
|
|
159
|
+
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
|
|
160
|
+
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
|
|
161
|
+
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)])
|
|
162
|
+
|
|
163
|
+
def forward(self, y, y_hat):
|
|
164
|
+
y_d_rs = []
|
|
165
|
+
y_d_gs = []
|
|
166
|
+
fmap_rs = []
|
|
167
|
+
fmap_gs = []
|
|
168
|
+
for _, d in enumerate(self.discriminators):
|
|
169
|
+
y_d_r, fmap_r = d(y)
|
|
170
|
+
y_d_g, fmap_g = d(y_hat)
|
|
171
|
+
y_d_rs.append(y_d_r)
|
|
172
|
+
fmap_rs.append(fmap_r)
|
|
173
|
+
y_d_gs.append(y_d_g)
|
|
174
|
+
fmap_gs.append(fmap_g)
|
|
175
|
+
|
|
176
|
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def stft(x, fft_size, hop_size, win_length, window):
|
|
180
|
+
"""Perform STFT and convert to magnitude spectrogram.
|
|
181
|
+
Args:
|
|
182
|
+
x (Tensor): Input signal tensor (B, T).
|
|
183
|
+
fft_size (int): FFT size.
|
|
184
|
+
hop_size (int): Hop size.
|
|
185
|
+
win_length (int): Window length.
|
|
186
|
+
window (str): Window function type.
|
|
187
|
+
Returns:
|
|
188
|
+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
|
189
|
+
"""
|
|
190
|
+
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
|
|
191
|
+
|
|
192
|
+
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
|
193
|
+
return torch.abs(x_stft).transpose(2, 1)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class SpecDiscriminator(nn.Module):
|
|
197
|
+
"""docstring for Discriminator."""
|
|
198
|
+
|
|
199
|
+
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
|
|
200
|
+
super(SpecDiscriminator, self).__init__()
|
|
201
|
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
|
202
|
+
self.fft_size = fft_size
|
|
203
|
+
self.shift_size = shift_size
|
|
204
|
+
self.win_length = win_length
|
|
205
|
+
self.window = getattr(torch, window)(win_length)
|
|
206
|
+
self.discriminators = nn.ModuleList([
|
|
207
|
+
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
|
208
|
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
|
209
|
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
|
210
|
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
|
211
|
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
|
|
212
|
+
])
|
|
213
|
+
|
|
214
|
+
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
|
|
215
|
+
|
|
216
|
+
def forward(self, y):
|
|
217
|
+
|
|
218
|
+
fmap = []
|
|
219
|
+
y = y.squeeze(1)
|
|
220
|
+
y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.device))
|
|
221
|
+
y = y.unsqueeze(1)
|
|
222
|
+
for _, d in enumerate(self.discriminators):
|
|
223
|
+
y = d(y)
|
|
224
|
+
y = F.leaky_relu(y, LRELU_SLOPE)
|
|
225
|
+
fmap.append(y)
|
|
226
|
+
|
|
227
|
+
y = self.out(y)
|
|
228
|
+
fmap.append(y)
|
|
229
|
+
|
|
230
|
+
return torch.flatten(y, 1, -1), fmap
|
|
@@ -13,7 +13,10 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import torch
|
|
15
15
|
import torch.nn as nn
|
|
16
|
-
|
|
16
|
+
try:
|
|
17
|
+
from torch.nn.utils.parametrizations import weight_norm
|
|
18
|
+
except ImportError:
|
|
19
|
+
from torch.nn.utils import weight_norm
|
|
17
20
|
|
|
18
21
|
|
|
19
22
|
class ConvRNNF0Predictor(nn.Module):
|
|
@@ -23,7 +23,10 @@ import torch.nn.functional as F
|
|
|
23
23
|
from torch.nn import Conv1d
|
|
24
24
|
from torch.nn import ConvTranspose1d
|
|
25
25
|
from torch.nn.utils import remove_weight_norm
|
|
26
|
-
|
|
26
|
+
try:
|
|
27
|
+
from torch.nn.utils.parametrizations import weight_norm
|
|
28
|
+
except ImportError:
|
|
29
|
+
from torch.nn.utils import weight_norm
|
|
27
30
|
from torch.distributions.uniform import Uniform
|
|
28
31
|
|
|
29
32
|
from cosyvoice.transformer.activation import Snake
|
|
@@ -41,7 +41,7 @@ class HiFiGan(nn.Module):
|
|
|
41
41
|
loss_fm = feature_loss(fmap_rs, fmap_gs)
|
|
42
42
|
loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
|
|
43
43
|
if self.tpr_loss_weight != 0:
|
|
44
|
-
loss_tpr = tpr_loss(
|
|
44
|
+
loss_tpr = tpr_loss(y_d_gs, y_d_rs, self.tpr_loss_tau)
|
|
45
45
|
else:
|
|
46
46
|
loss_tpr = torch.zeros(1).to(device)
|
|
47
47
|
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
|
|
@@ -56,7 +56,7 @@ class HiFiGan(nn.Module):
|
|
|
56
56
|
with torch.no_grad():
|
|
57
57
|
generated_speech, generated_f0 = self.generator(batch, device)
|
|
58
58
|
# 2. calculate discriminator outputs
|
|
59
|
-
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
|
|
59
|
+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach())
|
|
60
60
|
# 3. calculate discriminator losses, tpr losses [Optional]
|
|
61
61
|
loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
|
|
62
62
|
if self.tpr_loss_weight != 0:
|