xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/chat_interface.py +1 -1
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +4 -1
- xinference/core/worker.py +60 -44
- xinference/model/audio/chattts.py +25 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/cosyvoice.py +4 -3
- xinference/model/audio/custom.py +4 -5
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +25 -1
- xinference/model/embedding/custom.py +4 -5
- xinference/model/flexible/core.py +5 -1
- xinference/model/image/custom.py +4 -5
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +66 -3
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +54 -9
- xinference/model/llm/llm_family.py +7 -6
- xinference/model/llm/llm_family_modelscope.json +56 -10
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/sglang/core.py +7 -1
- xinference/model/llm/transformers/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/transformers/core.py +3 -0
- xinference/model/llm/transformers/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +94 -11
- xinference/model/llm/transformers/minicpmv25.py +2 -23
- xinference/model/llm/transformers/minicpmv26.py +2 -22
- xinference/model/llm/transformers/yi_vl.py +2 -24
- xinference/model/llm/utils.py +13 -1
- xinference/model/llm/vllm/core.py +1 -34
- xinference/model/rerank/custom.py +4 -5
- xinference/model/utils.py +41 -1
- xinference/model/video/core.py +3 -1
- xinference/model/video/diffusers.py +41 -38
- xinference/model/video/model_spec.json +24 -1
- xinference/model/video/model_spec_modelscope.json +25 -1
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/matcha/__init__.py +0 -0
- xinference/thirdparty/matcha/app.py +357 -0
- xinference/thirdparty/matcha/cli.py +419 -0
- xinference/thirdparty/matcha/data/__init__.py +0 -0
- xinference/thirdparty/matcha/data/components/__init__.py +0 -0
- xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
- xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
- xinference/thirdparty/matcha/hifigan/config.py +28 -0
- xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
- xinference/thirdparty/matcha/hifigan/env.py +17 -0
- xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
- xinference/thirdparty/matcha/hifigan/models.py +368 -0
- xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
- xinference/thirdparty/matcha/models/__init__.py +0 -0
- xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
- xinference/thirdparty/matcha/models/components/__init__.py +0 -0
- xinference/thirdparty/matcha/models/components/decoder.py +443 -0
- xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
- xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
- xinference/thirdparty/matcha/models/components/transformer.py +316 -0
- xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
- xinference/thirdparty/matcha/onnx/__init__.py +0 -0
- xinference/thirdparty/matcha/onnx/export.py +181 -0
- xinference/thirdparty/matcha/onnx/infer.py +168 -0
- xinference/thirdparty/matcha/text/__init__.py +53 -0
- xinference/thirdparty/matcha/text/cleaners.py +121 -0
- xinference/thirdparty/matcha/text/numbers.py +71 -0
- xinference/thirdparty/matcha/text/symbols.py +17 -0
- xinference/thirdparty/matcha/train.py +122 -0
- xinference/thirdparty/matcha/utils/__init__.py +5 -0
- xinference/thirdparty/matcha/utils/audio.py +82 -0
- xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
- xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
- xinference/thirdparty/matcha/utils/instantiators.py +56 -0
- xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
- xinference/thirdparty/matcha/utils/model.py +90 -0
- xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
- xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
- xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
- xinference/thirdparty/matcha/utils/pylogger.py +21 -0
- xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
- xinference/thirdparty/matcha/utils/utils.py +259 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
- xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
- xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from diffusers.models.attention import (
|
|
6
|
+
GEGLU,
|
|
7
|
+
GELU,
|
|
8
|
+
AdaLayerNorm,
|
|
9
|
+
AdaLayerNormZero,
|
|
10
|
+
ApproximateGELU,
|
|
11
|
+
)
|
|
12
|
+
from diffusers.models.attention_processor import Attention
|
|
13
|
+
from diffusers.models.lora import LoRACompatibleLinear
|
|
14
|
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SnakeBeta(nn.Module):
|
|
18
|
+
"""
|
|
19
|
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
|
20
|
+
Shape:
|
|
21
|
+
- Input: (B, C, T)
|
|
22
|
+
- Output: (B, C, T), same shape as the input
|
|
23
|
+
Parameters:
|
|
24
|
+
- alpha - trainable parameter that controls frequency
|
|
25
|
+
- beta - trainable parameter that controls magnitude
|
|
26
|
+
References:
|
|
27
|
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
|
28
|
+
https://arxiv.org/abs/2006.08195
|
|
29
|
+
Examples:
|
|
30
|
+
>>> a1 = snakebeta(256)
|
|
31
|
+
>>> x = torch.randn(256)
|
|
32
|
+
>>> x = a1(x)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
|
36
|
+
"""
|
|
37
|
+
Initialization.
|
|
38
|
+
INPUT:
|
|
39
|
+
- in_features: shape of the input
|
|
40
|
+
- alpha - trainable parameter that controls frequency
|
|
41
|
+
- beta - trainable parameter that controls magnitude
|
|
42
|
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
|
43
|
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
|
44
|
+
alpha will be trained along with the rest of your model.
|
|
45
|
+
"""
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.in_features = out_features if isinstance(out_features, list) else [out_features]
|
|
48
|
+
self.proj = LoRACompatibleLinear(in_features, out_features)
|
|
49
|
+
|
|
50
|
+
# initialize alpha
|
|
51
|
+
self.alpha_logscale = alpha_logscale
|
|
52
|
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
|
53
|
+
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
|
54
|
+
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
|
55
|
+
else: # linear scale alphas initialized to ones
|
|
56
|
+
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
|
|
57
|
+
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
|
|
58
|
+
|
|
59
|
+
self.alpha.requires_grad = alpha_trainable
|
|
60
|
+
self.beta.requires_grad = alpha_trainable
|
|
61
|
+
|
|
62
|
+
self.no_div_by_zero = 0.000000001
|
|
63
|
+
|
|
64
|
+
def forward(self, x):
|
|
65
|
+
"""
|
|
66
|
+
Forward pass of the function.
|
|
67
|
+
Applies the function to the input elementwise.
|
|
68
|
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
|
69
|
+
"""
|
|
70
|
+
x = self.proj(x)
|
|
71
|
+
if self.alpha_logscale:
|
|
72
|
+
alpha = torch.exp(self.alpha)
|
|
73
|
+
beta = torch.exp(self.beta)
|
|
74
|
+
else:
|
|
75
|
+
alpha = self.alpha
|
|
76
|
+
beta = self.beta
|
|
77
|
+
|
|
78
|
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
|
|
79
|
+
|
|
80
|
+
return x
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class FeedForward(nn.Module):
|
|
84
|
+
r"""
|
|
85
|
+
A feed-forward layer.
|
|
86
|
+
|
|
87
|
+
Parameters:
|
|
88
|
+
dim (`int`): The number of channels in the input.
|
|
89
|
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
|
90
|
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
|
91
|
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
|
92
|
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
|
93
|
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
dim: int,
|
|
99
|
+
dim_out: Optional[int] = None,
|
|
100
|
+
mult: int = 4,
|
|
101
|
+
dropout: float = 0.0,
|
|
102
|
+
activation_fn: str = "geglu",
|
|
103
|
+
final_dropout: bool = False,
|
|
104
|
+
):
|
|
105
|
+
super().__init__()
|
|
106
|
+
inner_dim = int(dim * mult)
|
|
107
|
+
dim_out = dim_out if dim_out is not None else dim
|
|
108
|
+
|
|
109
|
+
if activation_fn == "gelu":
|
|
110
|
+
act_fn = GELU(dim, inner_dim)
|
|
111
|
+
if activation_fn == "gelu-approximate":
|
|
112
|
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
|
113
|
+
elif activation_fn == "geglu":
|
|
114
|
+
act_fn = GEGLU(dim, inner_dim)
|
|
115
|
+
elif activation_fn == "geglu-approximate":
|
|
116
|
+
act_fn = ApproximateGELU(dim, inner_dim)
|
|
117
|
+
elif activation_fn == "snakebeta":
|
|
118
|
+
act_fn = SnakeBeta(dim, inner_dim)
|
|
119
|
+
|
|
120
|
+
self.net = nn.ModuleList([])
|
|
121
|
+
# project in
|
|
122
|
+
self.net.append(act_fn)
|
|
123
|
+
# project dropout
|
|
124
|
+
self.net.append(nn.Dropout(dropout))
|
|
125
|
+
# project out
|
|
126
|
+
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
|
|
127
|
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
|
128
|
+
if final_dropout:
|
|
129
|
+
self.net.append(nn.Dropout(dropout))
|
|
130
|
+
|
|
131
|
+
def forward(self, hidden_states):
|
|
132
|
+
for module in self.net:
|
|
133
|
+
hidden_states = module(hidden_states)
|
|
134
|
+
return hidden_states
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@maybe_allow_in_graph
|
|
138
|
+
class BasicTransformerBlock(nn.Module):
|
|
139
|
+
r"""
|
|
140
|
+
A basic Transformer block.
|
|
141
|
+
|
|
142
|
+
Parameters:
|
|
143
|
+
dim (`int`): The number of channels in the input and output.
|
|
144
|
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
|
145
|
+
attention_head_dim (`int`): The number of channels in each head.
|
|
146
|
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
|
147
|
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
|
148
|
+
only_cross_attention (`bool`, *optional*):
|
|
149
|
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
|
150
|
+
double_self_attention (`bool`, *optional*):
|
|
151
|
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
|
152
|
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
|
153
|
+
num_embeds_ada_norm (:
|
|
154
|
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
|
155
|
+
attention_bias (:
|
|
156
|
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
def __init__(
|
|
160
|
+
self,
|
|
161
|
+
dim: int,
|
|
162
|
+
num_attention_heads: int,
|
|
163
|
+
attention_head_dim: int,
|
|
164
|
+
dropout=0.0,
|
|
165
|
+
cross_attention_dim: Optional[int] = None,
|
|
166
|
+
activation_fn: str = "geglu",
|
|
167
|
+
num_embeds_ada_norm: Optional[int] = None,
|
|
168
|
+
attention_bias: bool = False,
|
|
169
|
+
only_cross_attention: bool = False,
|
|
170
|
+
double_self_attention: bool = False,
|
|
171
|
+
upcast_attention: bool = False,
|
|
172
|
+
norm_elementwise_affine: bool = True,
|
|
173
|
+
norm_type: str = "layer_norm",
|
|
174
|
+
final_dropout: bool = False,
|
|
175
|
+
):
|
|
176
|
+
super().__init__()
|
|
177
|
+
self.only_cross_attention = only_cross_attention
|
|
178
|
+
|
|
179
|
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
|
180
|
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
|
181
|
+
|
|
182
|
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
|
183
|
+
raise ValueError(
|
|
184
|
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
|
185
|
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Define 3 blocks. Each block has its own normalization layer.
|
|
189
|
+
# 1. Self-Attn
|
|
190
|
+
if self.use_ada_layer_norm:
|
|
191
|
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
|
192
|
+
elif self.use_ada_layer_norm_zero:
|
|
193
|
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
|
194
|
+
else:
|
|
195
|
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
|
196
|
+
self.attn1 = Attention(
|
|
197
|
+
query_dim=dim,
|
|
198
|
+
heads=num_attention_heads,
|
|
199
|
+
dim_head=attention_head_dim,
|
|
200
|
+
dropout=dropout,
|
|
201
|
+
bias=attention_bias,
|
|
202
|
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
|
203
|
+
upcast_attention=upcast_attention,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# 2. Cross-Attn
|
|
207
|
+
if cross_attention_dim is not None or double_self_attention:
|
|
208
|
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
|
209
|
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
|
210
|
+
# the second cross attention block.
|
|
211
|
+
self.norm2 = (
|
|
212
|
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
|
213
|
+
if self.use_ada_layer_norm
|
|
214
|
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
|
215
|
+
)
|
|
216
|
+
self.attn2 = Attention(
|
|
217
|
+
query_dim=dim,
|
|
218
|
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
|
219
|
+
heads=num_attention_heads,
|
|
220
|
+
dim_head=attention_head_dim,
|
|
221
|
+
dropout=dropout,
|
|
222
|
+
bias=attention_bias,
|
|
223
|
+
upcast_attention=upcast_attention,
|
|
224
|
+
# scale_qk=False, # uncomment this to not to use flash attention
|
|
225
|
+
) # is self-attn if encoder_hidden_states is none
|
|
226
|
+
else:
|
|
227
|
+
self.norm2 = None
|
|
228
|
+
self.attn2 = None
|
|
229
|
+
|
|
230
|
+
# 3. Feed-forward
|
|
231
|
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
|
232
|
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
|
233
|
+
|
|
234
|
+
# let chunk size default to None
|
|
235
|
+
self._chunk_size = None
|
|
236
|
+
self._chunk_dim = 0
|
|
237
|
+
|
|
238
|
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
|
239
|
+
# Sets chunk feed-forward
|
|
240
|
+
self._chunk_size = chunk_size
|
|
241
|
+
self._chunk_dim = dim
|
|
242
|
+
|
|
243
|
+
def forward(
|
|
244
|
+
self,
|
|
245
|
+
hidden_states: torch.FloatTensor,
|
|
246
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
|
247
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
248
|
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
249
|
+
timestep: Optional[torch.LongTensor] = None,
|
|
250
|
+
cross_attention_kwargs: Dict[str, Any] = None,
|
|
251
|
+
class_labels: Optional[torch.LongTensor] = None,
|
|
252
|
+
):
|
|
253
|
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
|
254
|
+
# 1. Self-Attention
|
|
255
|
+
if self.use_ada_layer_norm:
|
|
256
|
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
|
257
|
+
elif self.use_ada_layer_norm_zero:
|
|
258
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
|
259
|
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
|
260
|
+
)
|
|
261
|
+
else:
|
|
262
|
+
norm_hidden_states = self.norm1(hidden_states)
|
|
263
|
+
|
|
264
|
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
|
265
|
+
|
|
266
|
+
attn_output = self.attn1(
|
|
267
|
+
norm_hidden_states,
|
|
268
|
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
|
269
|
+
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
|
|
270
|
+
**cross_attention_kwargs,
|
|
271
|
+
)
|
|
272
|
+
if self.use_ada_layer_norm_zero:
|
|
273
|
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
|
274
|
+
hidden_states = attn_output + hidden_states
|
|
275
|
+
|
|
276
|
+
# 2. Cross-Attention
|
|
277
|
+
if self.attn2 is not None:
|
|
278
|
+
norm_hidden_states = (
|
|
279
|
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
attn_output = self.attn2(
|
|
283
|
+
norm_hidden_states,
|
|
284
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
285
|
+
attention_mask=encoder_attention_mask,
|
|
286
|
+
**cross_attention_kwargs,
|
|
287
|
+
)
|
|
288
|
+
hidden_states = attn_output + hidden_states
|
|
289
|
+
|
|
290
|
+
# 3. Feed-forward
|
|
291
|
+
norm_hidden_states = self.norm3(hidden_states)
|
|
292
|
+
|
|
293
|
+
if self.use_ada_layer_norm_zero:
|
|
294
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
|
295
|
+
|
|
296
|
+
if self._chunk_size is not None:
|
|
297
|
+
# "feed_forward_chunk_size" can be used to save memory
|
|
298
|
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
|
299
|
+
raise ValueError(
|
|
300
|
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
|
304
|
+
ff_output = torch.cat(
|
|
305
|
+
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
|
306
|
+
dim=self._chunk_dim,
|
|
307
|
+
)
|
|
308
|
+
else:
|
|
309
|
+
ff_output = self.ff(norm_hidden_states)
|
|
310
|
+
|
|
311
|
+
if self.use_ada_layer_norm_zero:
|
|
312
|
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
|
313
|
+
|
|
314
|
+
hidden_states = ff_output + hidden_states
|
|
315
|
+
|
|
316
|
+
return hidden_states
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
import datetime as dt
|
|
2
|
+
import math
|
|
3
|
+
import random
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
import matcha.utils.monotonic_align as monotonic_align
|
|
8
|
+
from matcha import utils
|
|
9
|
+
from matcha.models.baselightningmodule import BaseLightningClass
|
|
10
|
+
from matcha.models.components.flow_matching import CFM
|
|
11
|
+
from matcha.models.components.text_encoder import TextEncoder
|
|
12
|
+
from matcha.utils.model import (
|
|
13
|
+
denormalize,
|
|
14
|
+
duration_loss,
|
|
15
|
+
fix_len_compatibility,
|
|
16
|
+
generate_path,
|
|
17
|
+
sequence_mask,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
log = utils.get_pylogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MatchaTTS(BaseLightningClass): # 🍵
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
n_vocab,
|
|
27
|
+
n_spks,
|
|
28
|
+
spk_emb_dim,
|
|
29
|
+
n_feats,
|
|
30
|
+
encoder,
|
|
31
|
+
decoder,
|
|
32
|
+
cfm,
|
|
33
|
+
data_statistics,
|
|
34
|
+
out_size,
|
|
35
|
+
optimizer=None,
|
|
36
|
+
scheduler=None,
|
|
37
|
+
prior_loss=True,
|
|
38
|
+
use_precomputed_durations=False,
|
|
39
|
+
):
|
|
40
|
+
super().__init__()
|
|
41
|
+
|
|
42
|
+
self.save_hyperparameters(logger=False)
|
|
43
|
+
|
|
44
|
+
self.n_vocab = n_vocab
|
|
45
|
+
self.n_spks = n_spks
|
|
46
|
+
self.spk_emb_dim = spk_emb_dim
|
|
47
|
+
self.n_feats = n_feats
|
|
48
|
+
self.out_size = out_size
|
|
49
|
+
self.prior_loss = prior_loss
|
|
50
|
+
self.use_precomputed_durations = use_precomputed_durations
|
|
51
|
+
|
|
52
|
+
if n_spks > 1:
|
|
53
|
+
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
|
54
|
+
|
|
55
|
+
self.encoder = TextEncoder(
|
|
56
|
+
encoder.encoder_type,
|
|
57
|
+
encoder.encoder_params,
|
|
58
|
+
encoder.duration_predictor_params,
|
|
59
|
+
n_vocab,
|
|
60
|
+
n_spks,
|
|
61
|
+
spk_emb_dim,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
self.decoder = CFM(
|
|
65
|
+
in_channels=2 * encoder.encoder_params.n_feats,
|
|
66
|
+
out_channel=encoder.encoder_params.n_feats,
|
|
67
|
+
cfm_params=cfm,
|
|
68
|
+
decoder_params=decoder,
|
|
69
|
+
n_spks=n_spks,
|
|
70
|
+
spk_emb_dim=spk_emb_dim,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
self.update_data_statistics(data_statistics)
|
|
74
|
+
|
|
75
|
+
@torch.inference_mode()
|
|
76
|
+
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0):
|
|
77
|
+
"""
|
|
78
|
+
Generates mel-spectrogram from text. Returns:
|
|
79
|
+
1. encoder outputs
|
|
80
|
+
2. decoder outputs
|
|
81
|
+
3. generated alignment
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
|
85
|
+
shape: (batch_size, max_text_length)
|
|
86
|
+
x_lengths (torch.Tensor): lengths of texts in batch.
|
|
87
|
+
shape: (batch_size,)
|
|
88
|
+
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
|
|
89
|
+
temperature (float, optional): controls variance of terminal distribution.
|
|
90
|
+
spks (bool, optional): speaker ids.
|
|
91
|
+
shape: (batch_size,)
|
|
92
|
+
length_scale (float, optional): controls speech pace.
|
|
93
|
+
Increase value to slow down generated speech and vice versa.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
dict: {
|
|
97
|
+
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
|
98
|
+
# Average mel spectrogram generated by the encoder
|
|
99
|
+
"decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
|
100
|
+
# Refined mel spectrogram improved by the CFM
|
|
101
|
+
"attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
|
|
102
|
+
# Alignment map between text and mel spectrogram
|
|
103
|
+
"mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
|
104
|
+
# Denormalized mel spectrogram
|
|
105
|
+
"mel_lengths": torch.Tensor, shape: (batch_size,),
|
|
106
|
+
# Lengths of mel spectrograms
|
|
107
|
+
"rtf": float,
|
|
108
|
+
# Real-time factor
|
|
109
|
+
"""
|
|
110
|
+
# For RTF computation
|
|
111
|
+
t = dt.datetime.now()
|
|
112
|
+
|
|
113
|
+
if self.n_spks > 1:
|
|
114
|
+
# Get speaker embedding
|
|
115
|
+
spks = self.spk_emb(spks.long())
|
|
116
|
+
|
|
117
|
+
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
|
118
|
+
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
|
|
119
|
+
|
|
120
|
+
w = torch.exp(logw) * x_mask
|
|
121
|
+
w_ceil = torch.ceil(w) * length_scale
|
|
122
|
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
|
123
|
+
y_max_length = y_lengths.max()
|
|
124
|
+
y_max_length_ = fix_len_compatibility(y_max_length)
|
|
125
|
+
|
|
126
|
+
# Using obtained durations `w` construct alignment map `attn`
|
|
127
|
+
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
|
|
128
|
+
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
|
129
|
+
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
|
130
|
+
|
|
131
|
+
# Align encoded text and get mu_y
|
|
132
|
+
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
|
133
|
+
mu_y = mu_y.transpose(1, 2)
|
|
134
|
+
encoder_outputs = mu_y[:, :, :y_max_length]
|
|
135
|
+
|
|
136
|
+
# Generate sample tracing the probability flow
|
|
137
|
+
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks)
|
|
138
|
+
decoder_outputs = decoder_outputs[:, :, :y_max_length]
|
|
139
|
+
|
|
140
|
+
t = (dt.datetime.now() - t).total_seconds()
|
|
141
|
+
rtf = t * 22050 / (decoder_outputs.shape[-1] * 256)
|
|
142
|
+
|
|
143
|
+
return {
|
|
144
|
+
"encoder_outputs": encoder_outputs,
|
|
145
|
+
"decoder_outputs": decoder_outputs,
|
|
146
|
+
"attn": attn[:, :, :y_max_length],
|
|
147
|
+
"mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std),
|
|
148
|
+
"mel_lengths": y_lengths,
|
|
149
|
+
"rtf": rtf,
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None):
|
|
153
|
+
"""
|
|
154
|
+
Computes 3 losses:
|
|
155
|
+
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
|
|
156
|
+
2. prior loss: loss between mel-spectrogram and encoder outputs.
|
|
157
|
+
3. flow matching loss: loss between mel-spectrogram and decoder outputs.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
|
161
|
+
shape: (batch_size, max_text_length)
|
|
162
|
+
x_lengths (torch.Tensor): lengths of texts in batch.
|
|
163
|
+
shape: (batch_size,)
|
|
164
|
+
y (torch.Tensor): batch of corresponding mel-spectrograms.
|
|
165
|
+
shape: (batch_size, n_feats, max_mel_length)
|
|
166
|
+
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
|
|
167
|
+
shape: (batch_size,)
|
|
168
|
+
out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
|
|
169
|
+
Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
|
|
170
|
+
spks (torch.Tensor, optional): speaker ids.
|
|
171
|
+
shape: (batch_size,)
|
|
172
|
+
"""
|
|
173
|
+
if self.n_spks > 1:
|
|
174
|
+
# Get speaker embedding
|
|
175
|
+
spks = self.spk_emb(spks)
|
|
176
|
+
|
|
177
|
+
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
|
178
|
+
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
|
|
179
|
+
y_max_length = y.shape[-1]
|
|
180
|
+
|
|
181
|
+
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
|
|
182
|
+
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
|
183
|
+
|
|
184
|
+
if self.use_precomputed_durations:
|
|
185
|
+
attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1))
|
|
186
|
+
else:
|
|
187
|
+
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
|
|
188
|
+
with torch.no_grad():
|
|
189
|
+
const = -0.5 * math.log(2 * math.pi) * self.n_feats
|
|
190
|
+
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
|
|
191
|
+
y_square = torch.matmul(factor.transpose(1, 2), y**2)
|
|
192
|
+
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
|
|
193
|
+
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
|
|
194
|
+
log_prior = y_square - y_mu_double + mu_square + const
|
|
195
|
+
|
|
196
|
+
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
|
|
197
|
+
attn = attn.detach() # b, t_text, T_mel
|
|
198
|
+
|
|
199
|
+
# Compute loss between predicted log-scaled durations and those obtained from MAS
|
|
200
|
+
# refered to as prior loss in the paper
|
|
201
|
+
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
|
|
202
|
+
dur_loss = duration_loss(logw, logw_, x_lengths)
|
|
203
|
+
|
|
204
|
+
# Cut a small segment of mel-spectrogram in order to increase batch size
|
|
205
|
+
# - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it
|
|
206
|
+
# - Do not need this hack for Matcha-TTS, but it works with it as well
|
|
207
|
+
if not isinstance(out_size, type(None)):
|
|
208
|
+
max_offset = (y_lengths - out_size).clamp(0)
|
|
209
|
+
offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy()))
|
|
210
|
+
out_offset = torch.LongTensor(
|
|
211
|
+
[torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges]
|
|
212
|
+
).to(y_lengths)
|
|
213
|
+
attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device)
|
|
214
|
+
y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device)
|
|
215
|
+
|
|
216
|
+
y_cut_lengths = []
|
|
217
|
+
for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
|
|
218
|
+
y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0)
|
|
219
|
+
y_cut_lengths.append(y_cut_length)
|
|
220
|
+
cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length
|
|
221
|
+
y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
|
|
222
|
+
attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
|
|
223
|
+
|
|
224
|
+
y_cut_lengths = torch.LongTensor(y_cut_lengths)
|
|
225
|
+
y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask)
|
|
226
|
+
|
|
227
|
+
attn = attn_cut
|
|
228
|
+
y = y_cut
|
|
229
|
+
y_mask = y_cut_mask
|
|
230
|
+
|
|
231
|
+
# Align encoded text with mel-spectrogram and get mu_y segment
|
|
232
|
+
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
|
233
|
+
mu_y = mu_y.transpose(1, 2)
|
|
234
|
+
|
|
235
|
+
# Compute loss of the decoder
|
|
236
|
+
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
|
|
237
|
+
|
|
238
|
+
if self.prior_loss:
|
|
239
|
+
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
|
|
240
|
+
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
|
|
241
|
+
else:
|
|
242
|
+
prior_loss = 0
|
|
243
|
+
|
|
244
|
+
return dur_loss, prior_loss, diff_loss, attn
|
|
File without changes
|