minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
cosyvoice/flow/flow.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
import random
|
|
16
|
+
from typing import Dict, Optional
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
from torch.nn import functional as F
|
|
20
|
+
from omegaconf import DictConfig
|
|
21
|
+
from cosyvoice.utils.mask import make_pad_mask
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MaskedDiffWithXvec(torch.nn.Module):
|
|
25
|
+
def __init__(self,
|
|
26
|
+
input_size: int = 512,
|
|
27
|
+
output_size: int = 80,
|
|
28
|
+
spk_embed_dim: int = 192,
|
|
29
|
+
output_type: str = "mel",
|
|
30
|
+
vocab_size: int = 4096,
|
|
31
|
+
input_frame_rate: int = 50,
|
|
32
|
+
only_mask_loss: bool = True,
|
|
33
|
+
encoder: torch.nn.Module = None,
|
|
34
|
+
length_regulator: torch.nn.Module = None,
|
|
35
|
+
decoder: torch.nn.Module = None,
|
|
36
|
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
|
37
|
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
|
38
|
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
|
39
|
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
|
40
|
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
|
41
|
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
|
42
|
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.input_size = input_size
|
|
45
|
+
self.output_size = output_size
|
|
46
|
+
self.decoder_conf = decoder_conf
|
|
47
|
+
self.mel_feat_conf = mel_feat_conf
|
|
48
|
+
self.vocab_size = vocab_size
|
|
49
|
+
self.output_type = output_type
|
|
50
|
+
self.input_frame_rate = input_frame_rate
|
|
51
|
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
|
52
|
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
|
53
|
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
|
54
|
+
self.encoder = encoder
|
|
55
|
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
|
56
|
+
self.decoder = decoder
|
|
57
|
+
self.length_regulator = length_regulator
|
|
58
|
+
self.only_mask_loss = only_mask_loss
|
|
59
|
+
|
|
60
|
+
def forward(
|
|
61
|
+
self,
|
|
62
|
+
batch: dict,
|
|
63
|
+
device: torch.device,
|
|
64
|
+
) -> Dict[str, Optional[torch.Tensor]]:
|
|
65
|
+
token = batch['speech_token'].to(device)
|
|
66
|
+
token_len = batch['speech_token_len'].to(device)
|
|
67
|
+
feat = batch['speech_feat'].to(device)
|
|
68
|
+
feat_len = batch['speech_feat_len'].to(device)
|
|
69
|
+
embedding = batch['embedding'].to(device)
|
|
70
|
+
|
|
71
|
+
# xvec projection
|
|
72
|
+
embedding = F.normalize(embedding, dim=1)
|
|
73
|
+
embedding = self.spk_embed_affine_layer(embedding)
|
|
74
|
+
|
|
75
|
+
# concat text and prompt_text
|
|
76
|
+
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
|
77
|
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
|
78
|
+
|
|
79
|
+
# text encode
|
|
80
|
+
h, h_lengths = self.encoder(token, token_len)
|
|
81
|
+
h = self.encoder_proj(h)
|
|
82
|
+
h, h_lengths = self.length_regulator(h, feat_len)
|
|
83
|
+
|
|
84
|
+
# get conditions
|
|
85
|
+
conds = torch.zeros(feat.shape, device=token.device)
|
|
86
|
+
for i, j in enumerate(feat_len):
|
|
87
|
+
if random.random() < 0.5:
|
|
88
|
+
continue
|
|
89
|
+
index = random.randint(0, int(0.3 * j))
|
|
90
|
+
conds[i, :index] = feat[i, :index]
|
|
91
|
+
conds = conds.transpose(1, 2)
|
|
92
|
+
|
|
93
|
+
mask = (~make_pad_mask(feat_len)).to(h)
|
|
94
|
+
# NOTE this is unnecessary, feat/h already same shape
|
|
95
|
+
loss, _ = self.decoder.compute_loss(
|
|
96
|
+
feat.transpose(1, 2).contiguous(),
|
|
97
|
+
mask.unsqueeze(1),
|
|
98
|
+
h.transpose(1, 2).contiguous(),
|
|
99
|
+
embedding,
|
|
100
|
+
cond=conds
|
|
101
|
+
)
|
|
102
|
+
return {'loss': loss}
|
|
103
|
+
|
|
104
|
+
@torch.inference_mode()
|
|
105
|
+
def inference(self,
|
|
106
|
+
token,
|
|
107
|
+
token_len,
|
|
108
|
+
prompt_token,
|
|
109
|
+
prompt_token_len,
|
|
110
|
+
prompt_feat,
|
|
111
|
+
prompt_feat_len,
|
|
112
|
+
embedding,
|
|
113
|
+
flow_cache):
|
|
114
|
+
assert token.shape[0] == 1
|
|
115
|
+
# xvec projection
|
|
116
|
+
embedding = F.normalize(embedding, dim=1)
|
|
117
|
+
embedding = self.spk_embed_affine_layer(embedding)
|
|
118
|
+
|
|
119
|
+
# concat speech token and prompt speech token
|
|
120
|
+
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
|
121
|
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
|
122
|
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
|
123
|
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
|
124
|
+
|
|
125
|
+
# text encode
|
|
126
|
+
h, h_lengths = self.encoder(token, token_len)
|
|
127
|
+
h = self.encoder_proj(h)
|
|
128
|
+
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
|
129
|
+
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
|
130
|
+
|
|
131
|
+
# get conditions
|
|
132
|
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
|
133
|
+
conds[:, :mel_len1] = prompt_feat
|
|
134
|
+
conds = conds.transpose(1, 2)
|
|
135
|
+
|
|
136
|
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
|
137
|
+
feat, flow_cache = self.decoder(
|
|
138
|
+
mu=h.transpose(1, 2).contiguous(),
|
|
139
|
+
mask=mask.unsqueeze(1),
|
|
140
|
+
spks=embedding,
|
|
141
|
+
cond=conds,
|
|
142
|
+
n_timesteps=10,
|
|
143
|
+
prompt_len=mel_len1,
|
|
144
|
+
cache=flow_cache
|
|
145
|
+
)
|
|
146
|
+
feat = feat[:, :, mel_len1:]
|
|
147
|
+
assert feat.shape[2] == mel_len2
|
|
148
|
+
return feat.float(), flow_cache
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
152
|
+
def __init__(self,
|
|
153
|
+
input_size: int = 512,
|
|
154
|
+
output_size: int = 80,
|
|
155
|
+
spk_embed_dim: int = 192,
|
|
156
|
+
output_type: str = "mel",
|
|
157
|
+
vocab_size: int = 4096,
|
|
158
|
+
input_frame_rate: int = 50,
|
|
159
|
+
only_mask_loss: bool = True,
|
|
160
|
+
token_mel_ratio: int = 2,
|
|
161
|
+
pre_lookahead_len: int = 3,
|
|
162
|
+
encoder: torch.nn.Module = None,
|
|
163
|
+
decoder: torch.nn.Module = None,
|
|
164
|
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
|
165
|
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
|
166
|
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
|
167
|
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
|
168
|
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
|
169
|
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
|
170
|
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
|
171
|
+
super().__init__()
|
|
172
|
+
self.input_size = input_size
|
|
173
|
+
self.output_size = output_size
|
|
174
|
+
self.decoder_conf = decoder_conf
|
|
175
|
+
self.mel_feat_conf = mel_feat_conf
|
|
176
|
+
self.vocab_size = vocab_size
|
|
177
|
+
self.output_type = output_type
|
|
178
|
+
self.input_frame_rate = input_frame_rate
|
|
179
|
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
|
180
|
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
|
181
|
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
|
182
|
+
self.encoder = encoder
|
|
183
|
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
|
184
|
+
self.decoder = decoder
|
|
185
|
+
self.only_mask_loss = only_mask_loss
|
|
186
|
+
self.token_mel_ratio = token_mel_ratio
|
|
187
|
+
self.pre_lookahead_len = pre_lookahead_len
|
|
188
|
+
|
|
189
|
+
def forward(
|
|
190
|
+
self,
|
|
191
|
+
batch: dict,
|
|
192
|
+
device: torch.device,
|
|
193
|
+
) -> Dict[str, Optional[torch.Tensor]]:
|
|
194
|
+
token = batch['speech_token'].to(device)
|
|
195
|
+
token_len = batch['speech_token_len'].to(device)
|
|
196
|
+
feat = batch['speech_feat'].to(device)
|
|
197
|
+
feat_len = batch['speech_feat_len'].to(device)
|
|
198
|
+
embedding = batch['embedding'].to(device)
|
|
199
|
+
|
|
200
|
+
# NOTE unified training, static_chunk_size > 0 or = 0
|
|
201
|
+
streaming = True if random.random() < 0.5 else False
|
|
202
|
+
|
|
203
|
+
# xvec projection
|
|
204
|
+
embedding = F.normalize(embedding, dim=1)
|
|
205
|
+
embedding = self.spk_embed_affine_layer(embedding)
|
|
206
|
+
|
|
207
|
+
# concat text and prompt_text
|
|
208
|
+
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
|
209
|
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
|
210
|
+
|
|
211
|
+
# text encode
|
|
212
|
+
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
|
213
|
+
h = self.encoder_proj(h)
|
|
214
|
+
|
|
215
|
+
# get conditions
|
|
216
|
+
conds = torch.zeros(feat.shape, device=token.device)
|
|
217
|
+
for i, j in enumerate(feat_len):
|
|
218
|
+
if random.random() < 0.5:
|
|
219
|
+
continue
|
|
220
|
+
index = random.randint(0, int(0.3 * j))
|
|
221
|
+
conds[i, :index] = feat[i, :index]
|
|
222
|
+
conds = conds.transpose(1, 2)
|
|
223
|
+
|
|
224
|
+
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
|
|
225
|
+
loss, _ = self.decoder.compute_loss(
|
|
226
|
+
feat.transpose(1, 2).contiguous(),
|
|
227
|
+
mask.unsqueeze(1),
|
|
228
|
+
h.transpose(1, 2).contiguous(),
|
|
229
|
+
embedding,
|
|
230
|
+
cond=conds,
|
|
231
|
+
streaming=streaming,
|
|
232
|
+
)
|
|
233
|
+
return {'loss': loss}
|
|
234
|
+
|
|
235
|
+
@torch.inference_mode()
|
|
236
|
+
def inference(self,
|
|
237
|
+
token,
|
|
238
|
+
token_len,
|
|
239
|
+
prompt_token,
|
|
240
|
+
prompt_token_len,
|
|
241
|
+
prompt_feat,
|
|
242
|
+
prompt_feat_len,
|
|
243
|
+
embedding,
|
|
244
|
+
streaming,
|
|
245
|
+
finalize):
|
|
246
|
+
assert token.shape[0] == 1
|
|
247
|
+
# xvec projection
|
|
248
|
+
embedding = F.normalize(embedding, dim=1)
|
|
249
|
+
embedding = self.spk_embed_affine_layer(embedding)
|
|
250
|
+
|
|
251
|
+
# concat text and prompt_text
|
|
252
|
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
|
253
|
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
|
254
|
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
|
255
|
+
|
|
256
|
+
# text encode
|
|
257
|
+
if finalize is True:
|
|
258
|
+
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
|
259
|
+
else:
|
|
260
|
+
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
|
|
261
|
+
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
|
|
262
|
+
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
|
263
|
+
h = self.encoder_proj(h)
|
|
264
|
+
|
|
265
|
+
# get conditions
|
|
266
|
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
|
267
|
+
conds[:, :mel_len1] = prompt_feat
|
|
268
|
+
conds = conds.transpose(1, 2)
|
|
269
|
+
|
|
270
|
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
|
271
|
+
feat, _ = self.decoder(
|
|
272
|
+
mu=h.transpose(1, 2).contiguous(),
|
|
273
|
+
mask=mask.unsqueeze(1),
|
|
274
|
+
spks=embedding,
|
|
275
|
+
cond=conds,
|
|
276
|
+
n_timesteps=10,
|
|
277
|
+
streaming=streaming
|
|
278
|
+
)
|
|
279
|
+
feat = feat[:, :, mel_len1:]
|
|
280
|
+
assert feat.shape[2] == mel_len2
|
|
281
|
+
return feat.float(), None
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
|
2
|
+
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn.functional as F
|
|
17
|
+
from matcha.models.components.flow_matching import BASECFM
|
|
18
|
+
from cosyvoice.utils.common import set_all_random_seed
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ConditionalCFM(BASECFM):
|
|
22
|
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
|
23
|
+
super().__init__(
|
|
24
|
+
n_feats=in_channels,
|
|
25
|
+
cfm_params=cfm_params,
|
|
26
|
+
n_spks=n_spks,
|
|
27
|
+
spk_emb_dim=spk_emb_dim,
|
|
28
|
+
)
|
|
29
|
+
self.t_scheduler = cfm_params.t_scheduler
|
|
30
|
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
|
31
|
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
|
32
|
+
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
|
33
|
+
# Just change the architecture of the estimator here
|
|
34
|
+
self.estimator = estimator
|
|
35
|
+
|
|
36
|
+
@torch.inference_mode()
|
|
37
|
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
|
|
38
|
+
"""Forward diffusion
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
mu (torch.Tensor): output of encoder
|
|
42
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
43
|
+
mask (torch.Tensor): output_mask
|
|
44
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
45
|
+
n_timesteps (int): number of diffusion steps
|
|
46
|
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
|
47
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
48
|
+
shape: (batch_size, spk_emb_dim)
|
|
49
|
+
cond: Not used but kept for future purposes
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
sample: generated mel-spectrogram
|
|
53
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
|
57
|
+
cache_size = cache.shape[2]
|
|
58
|
+
# fix prompt and overlap part mu and z
|
|
59
|
+
if cache_size != 0:
|
|
60
|
+
z[:, :, :cache_size] = cache[:, :, :, 0]
|
|
61
|
+
mu[:, :, :cache_size] = cache[:, :, :, 1]
|
|
62
|
+
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
|
|
63
|
+
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
|
|
64
|
+
cache = torch.stack([z_cache, mu_cache], dim=-1)
|
|
65
|
+
|
|
66
|
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
|
67
|
+
if self.t_scheduler == 'cosine':
|
|
68
|
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
69
|
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
|
|
70
|
+
|
|
71
|
+
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
|
|
72
|
+
"""
|
|
73
|
+
Fixed euler solver for ODEs.
|
|
74
|
+
Args:
|
|
75
|
+
x (torch.Tensor): random noise
|
|
76
|
+
t_span (torch.Tensor): n_timesteps interpolated
|
|
77
|
+
shape: (n_timesteps + 1,)
|
|
78
|
+
mu (torch.Tensor): output of encoder
|
|
79
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
80
|
+
mask (torch.Tensor): output_mask
|
|
81
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
82
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
83
|
+
shape: (batch_size, spk_emb_dim)
|
|
84
|
+
cond: Not used but kept for future purposes
|
|
85
|
+
"""
|
|
86
|
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
87
|
+
t = t.unsqueeze(dim=0)
|
|
88
|
+
|
|
89
|
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
|
90
|
+
# Or in future might add like a return_all_steps flag
|
|
91
|
+
sol = []
|
|
92
|
+
|
|
93
|
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
|
94
|
+
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
95
|
+
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
|
96
|
+
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
97
|
+
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
|
98
|
+
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
|
99
|
+
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
100
|
+
for step in range(1, len(t_span)):
|
|
101
|
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
|
102
|
+
x_in[:] = x
|
|
103
|
+
mask_in[:] = mask
|
|
104
|
+
mu_in[0] = mu
|
|
105
|
+
t_in[:] = t.unsqueeze(0)
|
|
106
|
+
spks_in[0] = spks
|
|
107
|
+
cond_in[0] = cond
|
|
108
|
+
dphi_dt = self.forward_estimator(
|
|
109
|
+
x_in, mask_in,
|
|
110
|
+
mu_in, t_in,
|
|
111
|
+
spks_in,
|
|
112
|
+
cond_in,
|
|
113
|
+
streaming
|
|
114
|
+
)
|
|
115
|
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
|
116
|
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
|
117
|
+
x = x + dt * dphi_dt
|
|
118
|
+
t = t + dt
|
|
119
|
+
sol.append(x)
|
|
120
|
+
if step < len(t_span) - 1:
|
|
121
|
+
dt = t_span[step + 1] - t
|
|
122
|
+
|
|
123
|
+
return sol[-1].float()
|
|
124
|
+
|
|
125
|
+
def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
|
|
126
|
+
if isinstance(self.estimator, torch.nn.Module):
|
|
127
|
+
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
|
|
128
|
+
else:
|
|
129
|
+
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
|
|
130
|
+
# NOTE need to synchronize when switching stream
|
|
131
|
+
torch.cuda.current_stream().synchronize()
|
|
132
|
+
with stream:
|
|
133
|
+
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
|
134
|
+
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
|
135
|
+
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
|
136
|
+
estimator.set_input_shape('t', (2,))
|
|
137
|
+
estimator.set_input_shape('spks', (2, 80))
|
|
138
|
+
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
|
139
|
+
data_ptrs = [x.contiguous().data_ptr(),
|
|
140
|
+
mask.contiguous().data_ptr(),
|
|
141
|
+
mu.contiguous().data_ptr(),
|
|
142
|
+
t.contiguous().data_ptr(),
|
|
143
|
+
spks.contiguous().data_ptr(),
|
|
144
|
+
cond.contiguous().data_ptr(),
|
|
145
|
+
x.data_ptr()]
|
|
146
|
+
for i, j in enumerate(data_ptrs):
|
|
147
|
+
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
|
148
|
+
# run trt engine
|
|
149
|
+
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
|
150
|
+
torch.cuda.current_stream().synchronize()
|
|
151
|
+
self.estimator.release_estimator(estimator, stream)
|
|
152
|
+
return x
|
|
153
|
+
|
|
154
|
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
|
155
|
+
"""Computes diffusion loss
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
x1 (torch.Tensor): Target
|
|
159
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
160
|
+
mask (torch.Tensor): target mask
|
|
161
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
162
|
+
mu (torch.Tensor): output of encoder
|
|
163
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
164
|
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
|
165
|
+
shape: (batch_size, spk_emb_dim)
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
loss: conditional flow matching loss
|
|
169
|
+
y: conditional flow
|
|
170
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
171
|
+
"""
|
|
172
|
+
b, _, t = mu.shape
|
|
173
|
+
|
|
174
|
+
# random timestep
|
|
175
|
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
|
176
|
+
if self.t_scheduler == 'cosine':
|
|
177
|
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
|
178
|
+
# sample noise p(x_0)
|
|
179
|
+
z = torch.randn_like(x1)
|
|
180
|
+
|
|
181
|
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
|
182
|
+
u = x1 - (1 - self.sigma_min) * z
|
|
183
|
+
|
|
184
|
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
|
185
|
+
if self.training_cfg_rate > 0:
|
|
186
|
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
|
187
|
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
|
188
|
+
spks = spks * cfg_mask.view(-1, 1)
|
|
189
|
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
|
190
|
+
|
|
191
|
+
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
|
|
192
|
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
|
193
|
+
return loss, y
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class CausalConditionalCFM(ConditionalCFM):
|
|
197
|
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
|
198
|
+
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
|
199
|
+
set_all_random_seed(0)
|
|
200
|
+
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
|
201
|
+
|
|
202
|
+
@torch.inference_mode()
|
|
203
|
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
|
|
204
|
+
"""Forward diffusion
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
mu (torch.Tensor): output of encoder
|
|
208
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
209
|
+
mask (torch.Tensor): output_mask
|
|
210
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
211
|
+
n_timesteps (int): number of diffusion steps
|
|
212
|
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
|
213
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
214
|
+
shape: (batch_size, spk_emb_dim)
|
|
215
|
+
cond: Not used but kept for future purposes
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
sample: generated mel-spectrogram
|
|
219
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
|
223
|
+
# fix prompt and overlap part mu and z
|
|
224
|
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
|
225
|
+
if self.t_scheduler == 'cosine':
|
|
226
|
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
227
|
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from typing import Tuple
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
import torch
|
|
17
|
+
from torch.nn import functional as F
|
|
18
|
+
from cosyvoice.utils.mask import make_pad_mask
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class InterpolateRegulator(nn.Module):
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
channels: int,
|
|
25
|
+
sampling_ratios: Tuple,
|
|
26
|
+
out_channels: int = None,
|
|
27
|
+
groups: int = 1,
|
|
28
|
+
):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.sampling_ratios = sampling_ratios
|
|
31
|
+
out_channels = out_channels or channels
|
|
32
|
+
model = nn.ModuleList([])
|
|
33
|
+
if len(sampling_ratios) > 0:
|
|
34
|
+
for _ in sampling_ratios:
|
|
35
|
+
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
|
36
|
+
norm = nn.GroupNorm(groups, channels)
|
|
37
|
+
act = nn.Mish()
|
|
38
|
+
model.extend([module, norm, act])
|
|
39
|
+
model.append(
|
|
40
|
+
nn.Conv1d(channels, out_channels, 1, 1)
|
|
41
|
+
)
|
|
42
|
+
self.model = nn.Sequential(*model)
|
|
43
|
+
|
|
44
|
+
def forward(self, x, ylens=None):
|
|
45
|
+
# x in (B, T, D)
|
|
46
|
+
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
|
47
|
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
|
|
48
|
+
out = self.model(x).transpose(1, 2).contiguous()
|
|
49
|
+
olens = ylens
|
|
50
|
+
return out * mask, olens
|
|
51
|
+
|
|
52
|
+
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
|
|
53
|
+
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
|
54
|
+
# NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
|
|
55
|
+
# x in (B, T, D)
|
|
56
|
+
if x2.shape[1] > 40:
|
|
57
|
+
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
|
58
|
+
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
|
|
59
|
+
mode='linear')
|
|
60
|
+
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
|
61
|
+
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
|
62
|
+
else:
|
|
63
|
+
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
|
64
|
+
if x1.shape[1] != 0:
|
|
65
|
+
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
|
|
66
|
+
x = torch.concat([x1, x2], dim=2)
|
|
67
|
+
else:
|
|
68
|
+
x = x2
|
|
69
|
+
out = self.model(x).transpose(1, 2).contiguous()
|
|
70
|
+
return out, mel_len1 + mel_len2
|