minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
from matcha.models.components.decoder import Decoder
|
|
7
|
+
from matcha.utils.pylogger import get_pylogger
|
|
8
|
+
|
|
9
|
+
log = get_pylogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BASECFM(torch.nn.Module, ABC):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
n_feats,
|
|
16
|
+
cfm_params,
|
|
17
|
+
n_spks=1,
|
|
18
|
+
spk_emb_dim=128,
|
|
19
|
+
):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.n_feats = n_feats
|
|
22
|
+
self.n_spks = n_spks
|
|
23
|
+
self.spk_emb_dim = spk_emb_dim
|
|
24
|
+
self.solver = cfm_params.solver
|
|
25
|
+
if hasattr(cfm_params, "sigma_min"):
|
|
26
|
+
self.sigma_min = cfm_params.sigma_min
|
|
27
|
+
else:
|
|
28
|
+
self.sigma_min = 1e-4
|
|
29
|
+
|
|
30
|
+
self.estimator = None
|
|
31
|
+
|
|
32
|
+
@torch.inference_mode()
|
|
33
|
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
|
34
|
+
"""Forward diffusion
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
mu (torch.Tensor): output of encoder
|
|
38
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
39
|
+
mask (torch.Tensor): output_mask
|
|
40
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
41
|
+
n_timesteps (int): number of diffusion steps
|
|
42
|
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
|
43
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
44
|
+
shape: (batch_size, spk_emb_dim)
|
|
45
|
+
cond: Not used but kept for future purposes
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
sample: generated mel-spectrogram
|
|
49
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
50
|
+
"""
|
|
51
|
+
z = torch.randn_like(mu) * temperature
|
|
52
|
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
|
53
|
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
|
54
|
+
|
|
55
|
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
|
56
|
+
"""
|
|
57
|
+
Fixed euler solver for ODEs.
|
|
58
|
+
Args:
|
|
59
|
+
x (torch.Tensor): random noise
|
|
60
|
+
t_span (torch.Tensor): n_timesteps interpolated
|
|
61
|
+
shape: (n_timesteps + 1,)
|
|
62
|
+
mu (torch.Tensor): output of encoder
|
|
63
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
64
|
+
mask (torch.Tensor): output_mask
|
|
65
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
66
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
67
|
+
shape: (batch_size, spk_emb_dim)
|
|
68
|
+
cond: Not used but kept for future purposes
|
|
69
|
+
"""
|
|
70
|
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
71
|
+
|
|
72
|
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
|
73
|
+
# Or in future might add like a return_all_steps flag
|
|
74
|
+
sol = []
|
|
75
|
+
|
|
76
|
+
for step in range(1, len(t_span)):
|
|
77
|
+
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
|
78
|
+
|
|
79
|
+
x = x + dt * dphi_dt
|
|
80
|
+
t = t + dt
|
|
81
|
+
sol.append(x)
|
|
82
|
+
if step < len(t_span) - 1:
|
|
83
|
+
dt = t_span[step + 1] - t
|
|
84
|
+
|
|
85
|
+
return sol[-1]
|
|
86
|
+
|
|
87
|
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
|
88
|
+
"""Computes diffusion loss
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
x1 (torch.Tensor): Target
|
|
92
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
93
|
+
mask (torch.Tensor): target mask
|
|
94
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
95
|
+
mu (torch.Tensor): output of encoder
|
|
96
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
97
|
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
|
98
|
+
shape: (batch_size, spk_emb_dim)
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
loss: conditional flow matching loss
|
|
102
|
+
y: conditional flow
|
|
103
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
104
|
+
"""
|
|
105
|
+
b, _, t = mu.shape
|
|
106
|
+
|
|
107
|
+
# random timestep
|
|
108
|
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
|
109
|
+
# sample noise p(x_0)
|
|
110
|
+
z = torch.randn_like(x1)
|
|
111
|
+
|
|
112
|
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
|
113
|
+
u = x1 - (1 - self.sigma_min) * z
|
|
114
|
+
|
|
115
|
+
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
|
|
116
|
+
torch.sum(mask) * u.shape[1]
|
|
117
|
+
)
|
|
118
|
+
return loss, y
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class CFM(BASECFM):
|
|
122
|
+
def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
|
|
123
|
+
super().__init__(
|
|
124
|
+
n_feats=in_channels,
|
|
125
|
+
cfm_params=cfm_params,
|
|
126
|
+
n_spks=n_spks,
|
|
127
|
+
spk_emb_dim=spk_emb_dim,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
|
|
131
|
+
# Just change the architecture of the estimator here
|
|
132
|
+
self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
|
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
""" from https://github.com/jaywalnut310/glow-tts """
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
from einops import rearrange
|
|
8
|
+
|
|
9
|
+
import matcha.utils as utils
|
|
10
|
+
from matcha.utils.model import sequence_mask
|
|
11
|
+
|
|
12
|
+
log = utils.get_pylogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LayerNorm(nn.Module):
|
|
16
|
+
def __init__(self, channels, eps=1e-4):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.channels = channels
|
|
19
|
+
self.eps = eps
|
|
20
|
+
|
|
21
|
+
self.gamma = torch.nn.Parameter(torch.ones(channels))
|
|
22
|
+
self.beta = torch.nn.Parameter(torch.zeros(channels))
|
|
23
|
+
|
|
24
|
+
def forward(self, x):
|
|
25
|
+
n_dims = len(x.shape)
|
|
26
|
+
mean = torch.mean(x, 1, keepdim=True)
|
|
27
|
+
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
|
28
|
+
|
|
29
|
+
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
|
30
|
+
|
|
31
|
+
shape = [1, -1] + [1] * (n_dims - 2)
|
|
32
|
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
|
33
|
+
return x
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ConvReluNorm(nn.Module):
|
|
37
|
+
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.in_channels = in_channels
|
|
40
|
+
self.hidden_channels = hidden_channels
|
|
41
|
+
self.out_channels = out_channels
|
|
42
|
+
self.kernel_size = kernel_size
|
|
43
|
+
self.n_layers = n_layers
|
|
44
|
+
self.p_dropout = p_dropout
|
|
45
|
+
|
|
46
|
+
self.conv_layers = torch.nn.ModuleList()
|
|
47
|
+
self.norm_layers = torch.nn.ModuleList()
|
|
48
|
+
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
|
49
|
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
|
50
|
+
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
|
|
51
|
+
for _ in range(n_layers - 1):
|
|
52
|
+
self.conv_layers.append(
|
|
53
|
+
torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
|
|
54
|
+
)
|
|
55
|
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
|
56
|
+
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
|
|
57
|
+
self.proj.weight.data.zero_()
|
|
58
|
+
self.proj.bias.data.zero_()
|
|
59
|
+
|
|
60
|
+
def forward(self, x, x_mask):
|
|
61
|
+
x_org = x
|
|
62
|
+
for i in range(self.n_layers):
|
|
63
|
+
x = self.conv_layers[i](x * x_mask)
|
|
64
|
+
x = self.norm_layers[i](x)
|
|
65
|
+
x = self.relu_drop(x)
|
|
66
|
+
x = x_org + self.proj(x)
|
|
67
|
+
return x * x_mask
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class DurationPredictor(nn.Module):
|
|
71
|
+
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
|
|
72
|
+
super().__init__()
|
|
73
|
+
self.in_channels = in_channels
|
|
74
|
+
self.filter_channels = filter_channels
|
|
75
|
+
self.p_dropout = p_dropout
|
|
76
|
+
|
|
77
|
+
self.drop = torch.nn.Dropout(p_dropout)
|
|
78
|
+
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
|
79
|
+
self.norm_1 = LayerNorm(filter_channels)
|
|
80
|
+
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
|
81
|
+
self.norm_2 = LayerNorm(filter_channels)
|
|
82
|
+
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
|
|
83
|
+
|
|
84
|
+
def forward(self, x, x_mask):
|
|
85
|
+
x = self.conv_1(x * x_mask)
|
|
86
|
+
x = torch.relu(x)
|
|
87
|
+
x = self.norm_1(x)
|
|
88
|
+
x = self.drop(x)
|
|
89
|
+
x = self.conv_2(x * x_mask)
|
|
90
|
+
x = torch.relu(x)
|
|
91
|
+
x = self.norm_2(x)
|
|
92
|
+
x = self.drop(x)
|
|
93
|
+
x = self.proj(x * x_mask)
|
|
94
|
+
return x * x_mask
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class RotaryPositionalEmbeddings(nn.Module):
|
|
98
|
+
"""
|
|
99
|
+
## RoPE module
|
|
100
|
+
|
|
101
|
+
Rotary encoding transforms pairs of features by rotating in the 2D plane.
|
|
102
|
+
That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
|
|
103
|
+
Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
|
|
104
|
+
by an angle depending on the position of the token.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(self, d: int, base: int = 10_000):
|
|
108
|
+
r"""
|
|
109
|
+
* `d` is the number of features $d$
|
|
110
|
+
* `base` is the constant used for calculating $\Theta$
|
|
111
|
+
"""
|
|
112
|
+
super().__init__()
|
|
113
|
+
|
|
114
|
+
self.base = base
|
|
115
|
+
self.d = int(d)
|
|
116
|
+
self.cos_cached = None
|
|
117
|
+
self.sin_cached = None
|
|
118
|
+
|
|
119
|
+
def _build_cache(self, x: torch.Tensor):
|
|
120
|
+
r"""
|
|
121
|
+
Cache $\cos$ and $\sin$ values
|
|
122
|
+
"""
|
|
123
|
+
# Return if cache is already built
|
|
124
|
+
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
# Get sequence length
|
|
128
|
+
seq_len = x.shape[0]
|
|
129
|
+
|
|
130
|
+
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
|
131
|
+
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
|
|
132
|
+
|
|
133
|
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
|
134
|
+
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
|
|
135
|
+
|
|
136
|
+
# Calculate the product of position index and $\theta_i$
|
|
137
|
+
idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
|
|
138
|
+
|
|
139
|
+
# Concatenate so that for row $m$ we have
|
|
140
|
+
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
|
|
141
|
+
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
|
|
142
|
+
|
|
143
|
+
# Cache them
|
|
144
|
+
self.cos_cached = idx_theta2.cos()[:, None, None, :]
|
|
145
|
+
self.sin_cached = idx_theta2.sin()[:, None, None, :]
|
|
146
|
+
|
|
147
|
+
def _neg_half(self, x: torch.Tensor):
|
|
148
|
+
# $\frac{d}{2}$
|
|
149
|
+
d_2 = self.d // 2
|
|
150
|
+
|
|
151
|
+
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
|
152
|
+
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
|
|
153
|
+
|
|
154
|
+
def forward(self, x: torch.Tensor):
|
|
155
|
+
"""
|
|
156
|
+
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
|
|
157
|
+
"""
|
|
158
|
+
# Cache $\cos$ and $\sin$ values
|
|
159
|
+
x = rearrange(x, "b h t d -> t b h d")
|
|
160
|
+
|
|
161
|
+
self._build_cache(x)
|
|
162
|
+
|
|
163
|
+
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
|
|
164
|
+
x_rope, x_pass = x[..., : self.d], x[..., self.d :]
|
|
165
|
+
|
|
166
|
+
# Calculate
|
|
167
|
+
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
|
168
|
+
neg_half_x = self._neg_half(x_rope)
|
|
169
|
+
|
|
170
|
+
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
|
|
171
|
+
|
|
172
|
+
return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class MultiHeadAttention(nn.Module):
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
channels,
|
|
179
|
+
out_channels,
|
|
180
|
+
n_heads,
|
|
181
|
+
heads_share=True,
|
|
182
|
+
p_dropout=0.0,
|
|
183
|
+
proximal_bias=False,
|
|
184
|
+
proximal_init=False,
|
|
185
|
+
):
|
|
186
|
+
super().__init__()
|
|
187
|
+
assert channels % n_heads == 0
|
|
188
|
+
|
|
189
|
+
self.channels = channels
|
|
190
|
+
self.out_channels = out_channels
|
|
191
|
+
self.n_heads = n_heads
|
|
192
|
+
self.heads_share = heads_share
|
|
193
|
+
self.proximal_bias = proximal_bias
|
|
194
|
+
self.p_dropout = p_dropout
|
|
195
|
+
self.attn = None
|
|
196
|
+
|
|
197
|
+
self.k_channels = channels // n_heads
|
|
198
|
+
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
|
199
|
+
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
|
200
|
+
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
|
201
|
+
|
|
202
|
+
# from https://nn.labml.ai/transformers/rope/index.html
|
|
203
|
+
self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
|
|
204
|
+
self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
|
|
205
|
+
|
|
206
|
+
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
|
207
|
+
self.drop = torch.nn.Dropout(p_dropout)
|
|
208
|
+
|
|
209
|
+
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
|
210
|
+
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
|
211
|
+
if proximal_init:
|
|
212
|
+
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
|
213
|
+
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
|
214
|
+
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
|
215
|
+
|
|
216
|
+
def forward(self, x, c, attn_mask=None):
|
|
217
|
+
q = self.conv_q(x)
|
|
218
|
+
k = self.conv_k(c)
|
|
219
|
+
v = self.conv_v(c)
|
|
220
|
+
|
|
221
|
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
|
222
|
+
|
|
223
|
+
x = self.conv_o(x)
|
|
224
|
+
return x
|
|
225
|
+
|
|
226
|
+
def attention(self, query, key, value, mask=None):
|
|
227
|
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
|
228
|
+
query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
|
|
229
|
+
key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
|
|
230
|
+
value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
|
|
231
|
+
|
|
232
|
+
query = self.query_rotary_pe(query)
|
|
233
|
+
key = self.key_rotary_pe(key)
|
|
234
|
+
|
|
235
|
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
|
236
|
+
|
|
237
|
+
if self.proximal_bias:
|
|
238
|
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
|
239
|
+
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
|
240
|
+
if mask is not None:
|
|
241
|
+
scores = scores.masked_fill(mask == 0, -1e4)
|
|
242
|
+
p_attn = torch.nn.functional.softmax(scores, dim=-1)
|
|
243
|
+
p_attn = self.drop(p_attn)
|
|
244
|
+
output = torch.matmul(p_attn, value)
|
|
245
|
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
|
|
246
|
+
return output, p_attn
|
|
247
|
+
|
|
248
|
+
@staticmethod
|
|
249
|
+
def _attention_bias_proximal(length):
|
|
250
|
+
r = torch.arange(length, dtype=torch.float32)
|
|
251
|
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
|
252
|
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class FFN(nn.Module):
|
|
256
|
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
|
|
257
|
+
super().__init__()
|
|
258
|
+
self.in_channels = in_channels
|
|
259
|
+
self.out_channels = out_channels
|
|
260
|
+
self.filter_channels = filter_channels
|
|
261
|
+
self.kernel_size = kernel_size
|
|
262
|
+
self.p_dropout = p_dropout
|
|
263
|
+
|
|
264
|
+
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
|
265
|
+
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
|
266
|
+
self.drop = torch.nn.Dropout(p_dropout)
|
|
267
|
+
|
|
268
|
+
def forward(self, x, x_mask):
|
|
269
|
+
x = self.conv_1(x * x_mask)
|
|
270
|
+
x = torch.relu(x)
|
|
271
|
+
x = self.drop(x)
|
|
272
|
+
x = self.conv_2(x * x_mask)
|
|
273
|
+
return x * x_mask
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class Encoder(nn.Module):
|
|
277
|
+
def __init__(
|
|
278
|
+
self,
|
|
279
|
+
hidden_channels,
|
|
280
|
+
filter_channels,
|
|
281
|
+
n_heads,
|
|
282
|
+
n_layers,
|
|
283
|
+
kernel_size=1,
|
|
284
|
+
p_dropout=0.0,
|
|
285
|
+
**kwargs,
|
|
286
|
+
):
|
|
287
|
+
super().__init__()
|
|
288
|
+
self.hidden_channels = hidden_channels
|
|
289
|
+
self.filter_channels = filter_channels
|
|
290
|
+
self.n_heads = n_heads
|
|
291
|
+
self.n_layers = n_layers
|
|
292
|
+
self.kernel_size = kernel_size
|
|
293
|
+
self.p_dropout = p_dropout
|
|
294
|
+
|
|
295
|
+
self.drop = torch.nn.Dropout(p_dropout)
|
|
296
|
+
self.attn_layers = torch.nn.ModuleList()
|
|
297
|
+
self.norm_layers_1 = torch.nn.ModuleList()
|
|
298
|
+
self.ffn_layers = torch.nn.ModuleList()
|
|
299
|
+
self.norm_layers_2 = torch.nn.ModuleList()
|
|
300
|
+
for _ in range(self.n_layers):
|
|
301
|
+
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
|
|
302
|
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
|
303
|
+
self.ffn_layers.append(
|
|
304
|
+
FFN(
|
|
305
|
+
hidden_channels,
|
|
306
|
+
hidden_channels,
|
|
307
|
+
filter_channels,
|
|
308
|
+
kernel_size,
|
|
309
|
+
p_dropout=p_dropout,
|
|
310
|
+
)
|
|
311
|
+
)
|
|
312
|
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
|
313
|
+
|
|
314
|
+
def forward(self, x, x_mask):
|
|
315
|
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
|
316
|
+
for i in range(self.n_layers):
|
|
317
|
+
x = x * x_mask
|
|
318
|
+
y = self.attn_layers[i](x, x, attn_mask)
|
|
319
|
+
y = self.drop(y)
|
|
320
|
+
x = self.norm_layers_1[i](x + y)
|
|
321
|
+
y = self.ffn_layers[i](x, x_mask)
|
|
322
|
+
y = self.drop(y)
|
|
323
|
+
x = self.norm_layers_2[i](x + y)
|
|
324
|
+
x = x * x_mask
|
|
325
|
+
return x
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
class TextEncoder(nn.Module):
|
|
329
|
+
def __init__(
|
|
330
|
+
self,
|
|
331
|
+
encoder_type,
|
|
332
|
+
encoder_params,
|
|
333
|
+
duration_predictor_params,
|
|
334
|
+
n_vocab,
|
|
335
|
+
n_spks=1,
|
|
336
|
+
spk_emb_dim=128,
|
|
337
|
+
):
|
|
338
|
+
super().__init__()
|
|
339
|
+
self.encoder_type = encoder_type
|
|
340
|
+
self.n_vocab = n_vocab
|
|
341
|
+
self.n_feats = encoder_params.n_feats
|
|
342
|
+
self.n_channels = encoder_params.n_channels
|
|
343
|
+
self.spk_emb_dim = spk_emb_dim
|
|
344
|
+
self.n_spks = n_spks
|
|
345
|
+
|
|
346
|
+
self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
|
|
347
|
+
torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
|
|
348
|
+
|
|
349
|
+
if encoder_params.prenet:
|
|
350
|
+
self.prenet = ConvReluNorm(
|
|
351
|
+
self.n_channels,
|
|
352
|
+
self.n_channels,
|
|
353
|
+
self.n_channels,
|
|
354
|
+
kernel_size=5,
|
|
355
|
+
n_layers=3,
|
|
356
|
+
p_dropout=0.5,
|
|
357
|
+
)
|
|
358
|
+
else:
|
|
359
|
+
self.prenet = lambda x, x_mask: x
|
|
360
|
+
|
|
361
|
+
self.encoder = Encoder(
|
|
362
|
+
encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
|
|
363
|
+
encoder_params.filter_channels,
|
|
364
|
+
encoder_params.n_heads,
|
|
365
|
+
encoder_params.n_layers,
|
|
366
|
+
encoder_params.kernel_size,
|
|
367
|
+
encoder_params.p_dropout,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
|
|
371
|
+
self.proj_w = DurationPredictor(
|
|
372
|
+
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
|
|
373
|
+
duration_predictor_params.filter_channels_dp,
|
|
374
|
+
duration_predictor_params.kernel_size,
|
|
375
|
+
duration_predictor_params.p_dropout,
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
def forward(self, x, x_lengths, spks=None):
|
|
379
|
+
"""Run forward pass to the transformer based encoder and duration predictor
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
x (torch.Tensor): text input
|
|
383
|
+
shape: (batch_size, max_text_length)
|
|
384
|
+
x_lengths (torch.Tensor): text input lengths
|
|
385
|
+
shape: (batch_size,)
|
|
386
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
387
|
+
shape: (batch_size,)
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
mu (torch.Tensor): average output of the encoder
|
|
391
|
+
shape: (batch_size, n_feats, max_text_length)
|
|
392
|
+
logw (torch.Tensor): log duration predicted by the duration predictor
|
|
393
|
+
shape: (batch_size, 1, max_text_length)
|
|
394
|
+
x_mask (torch.Tensor): mask for the text input
|
|
395
|
+
shape: (batch_size, 1, max_text_length)
|
|
396
|
+
"""
|
|
397
|
+
x = self.emb(x) * math.sqrt(self.n_channels)
|
|
398
|
+
x = torch.transpose(x, 1, -1)
|
|
399
|
+
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
|
400
|
+
|
|
401
|
+
x = self.prenet(x, x_mask)
|
|
402
|
+
if self.n_spks > 1:
|
|
403
|
+
x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
|
|
404
|
+
x = self.encoder(x, x_mask)
|
|
405
|
+
mu = self.proj_m(x) * x_mask
|
|
406
|
+
|
|
407
|
+
x_dp = torch.detach(x)
|
|
408
|
+
logw = self.proj_w(x_dp, x_mask)
|
|
409
|
+
|
|
410
|
+
return mu, logw, x_mask
|