minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
from torch.nn import functional as F
|
|
17
|
+
|
|
18
|
+
from stepaudio2.cosyvoice2.utils.mask import make_pad_mask
|
|
19
|
+
from stepaudio2.cosyvoice2.flow.flow_matching import CausalConditionalCFM
|
|
20
|
+
from stepaudio2.cosyvoice2.transformer.upsample_encoder_v2 import UpsampleConformerEncoderV2
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
24
|
+
def __init__(self,
|
|
25
|
+
input_size: int = 512,
|
|
26
|
+
output_size: int = 80,
|
|
27
|
+
spk_embed_dim: int = 192,
|
|
28
|
+
output_type: str = "mel",
|
|
29
|
+
vocab_size: int = 5121,
|
|
30
|
+
encoder: UpsampleConformerEncoderV2 = None,
|
|
31
|
+
decoder: CausalConditionalCFM = None,
|
|
32
|
+
input_embedding: torch.nn.Module = None,
|
|
33
|
+
):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.input_size = input_size
|
|
36
|
+
self.output_size = output_size
|
|
37
|
+
self.vocab_size = vocab_size
|
|
38
|
+
self.output_type = output_type
|
|
39
|
+
self.pre_lookahead_len = int(encoder.pre_lookahead_layer.pre_lookahead_len)
|
|
40
|
+
self.up_rate = int(encoder.up_layer.stride)
|
|
41
|
+
if input_embedding is None:
|
|
42
|
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
|
43
|
+
else:
|
|
44
|
+
self.input_embedding = input_embedding
|
|
45
|
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
|
46
|
+
self.encoder = encoder
|
|
47
|
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
|
48
|
+
self.decoder = decoder
|
|
49
|
+
|
|
50
|
+
# xvec projection with CUDA Graph optimization
|
|
51
|
+
# 初始化 CUDA Graph 相关变量
|
|
52
|
+
self.enable_cuda_graph = False
|
|
53
|
+
self.static_embedding = None
|
|
54
|
+
self.static_output = None
|
|
55
|
+
self.graph = None
|
|
56
|
+
self.embedding_shape = None
|
|
57
|
+
|
|
58
|
+
def scatter_cuda_graph(self, enable_cuda_graph: bool):
|
|
59
|
+
self.enable_cuda_graph = enable_cuda_graph
|
|
60
|
+
if self.enable_cuda_graph:
|
|
61
|
+
# self.encoder.scatter_cuda_graph(enable_cuda_graph)
|
|
62
|
+
self.decoder.scatter_cuda_graph(enable_cuda_graph)
|
|
63
|
+
|
|
64
|
+
@torch.inference_mode()
|
|
65
|
+
def inference(self,
|
|
66
|
+
token,
|
|
67
|
+
token_len,
|
|
68
|
+
prompt_token,
|
|
69
|
+
prompt_token_len,
|
|
70
|
+
prompt_feat,
|
|
71
|
+
prompt_feat_len,
|
|
72
|
+
embedding,
|
|
73
|
+
n_timesteps: int = 10,
|
|
74
|
+
):
|
|
75
|
+
assert token.shape[0] == 1
|
|
76
|
+
|
|
77
|
+
# xvec projection
|
|
78
|
+
embedding = F.normalize(embedding, dim=1)
|
|
79
|
+
embedding = self.spk_embed_affine_layer(embedding)
|
|
80
|
+
|
|
81
|
+
# concat text and prompt_text
|
|
82
|
+
token_len = prompt_token_len + token_len
|
|
83
|
+
token = torch.concat([prompt_token, token], dim=1)
|
|
84
|
+
|
|
85
|
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
|
86
|
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
|
87
|
+
|
|
88
|
+
# token encode
|
|
89
|
+
h, _ = self.encoder.forward(token, token_len)
|
|
90
|
+
h = self.encoder_proj(h)
|
|
91
|
+
|
|
92
|
+
# condition
|
|
93
|
+
mel_len1 = prompt_feat.shape[1]
|
|
94
|
+
mel_len2 = h.shape[1] - prompt_feat.shape[1]
|
|
95
|
+
|
|
96
|
+
conds = torch.zeros_like(h)
|
|
97
|
+
conds[:, :mel_len1] = prompt_feat
|
|
98
|
+
conds = conds.transpose(1, 2).contiguous()
|
|
99
|
+
|
|
100
|
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
|
101
|
+
|
|
102
|
+
feat = self.decoder.forward(
|
|
103
|
+
mu=h.transpose(1, 2).contiguous(),
|
|
104
|
+
mask=mask.unsqueeze(1),
|
|
105
|
+
spks=embedding,
|
|
106
|
+
cond=conds,
|
|
107
|
+
n_timesteps=n_timesteps,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
feat = feat[:, :, mel_len1:]
|
|
111
|
+
assert feat.shape[2] == mel_len2
|
|
112
|
+
return feat
|
|
113
|
+
|
|
114
|
+
@torch.inference_mode()
|
|
115
|
+
def setup_cache(self,
|
|
116
|
+
token: torch.Tensor,
|
|
117
|
+
mel: torch.Tensor,
|
|
118
|
+
spk: torch.Tensor,
|
|
119
|
+
n_timesteps: int = 10,
|
|
120
|
+
):
|
|
121
|
+
"""
|
|
122
|
+
Args:
|
|
123
|
+
token: shape (b, t), with look ahead tokens
|
|
124
|
+
mel: shape (b, t, c), groundtruth mel
|
|
125
|
+
spk: shape (b, 192), speaker embedding
|
|
126
|
+
Returns:
|
|
127
|
+
cache: dict {
|
|
128
|
+
'conformer': {'cnn_cache': xxx, 'att_cache': xxx},
|
|
129
|
+
'estimator': {'cnn_cache': xxx, 'att_cache': xxx}
|
|
130
|
+
}
|
|
131
|
+
"""
|
|
132
|
+
# check if look ahead token included
|
|
133
|
+
assert (token.shape[1] - self.pre_lookahead_len) * self.up_rate == mel.shape[1], (token.shape, mel.shape)
|
|
134
|
+
|
|
135
|
+
# xvec projection
|
|
136
|
+
spk = F.normalize(spk, dim=1)
|
|
137
|
+
spk = self.spk_embed_affine_layer(spk)
|
|
138
|
+
|
|
139
|
+
token = self.input_embedding(token)
|
|
140
|
+
# NOTE encoder.forward_chunk will strip the look ahead part
|
|
141
|
+
h, conformer_cnn_cache, conformer_att_cache = self.encoder.forward_chunk(
|
|
142
|
+
xs = token,
|
|
143
|
+
last_chunk = False,
|
|
144
|
+
cnn_cache = None,
|
|
145
|
+
att_cache = None,
|
|
146
|
+
)
|
|
147
|
+
h = self.encoder_proj(h)
|
|
148
|
+
|
|
149
|
+
feat, estimator_cnn_cache, estimator_att_cache = self.decoder.forward_chunk(
|
|
150
|
+
mu = h.transpose(1, 2).contiguous(),
|
|
151
|
+
spks = spk,
|
|
152
|
+
cond = mel.transpose(1, 2).contiguous(),
|
|
153
|
+
n_timesteps = n_timesteps,
|
|
154
|
+
temperature = 1.0,
|
|
155
|
+
cnn_cache = None,
|
|
156
|
+
att_cache = None,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
cache = {
|
|
160
|
+
'conformer_cnn_cache': conformer_cnn_cache,
|
|
161
|
+
'conformer_att_cache': conformer_att_cache,
|
|
162
|
+
'estimator_cnn_cache': estimator_cnn_cache,
|
|
163
|
+
'estimator_att_cache': estimator_att_cache,
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
# print("examining flow cache")
|
|
167
|
+
# from IPython import embed; embed()
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
return cache
|
|
171
|
+
|
|
172
|
+
@torch.inference_mode()
|
|
173
|
+
def inference_chunk(self,
|
|
174
|
+
token: torch.Tensor,
|
|
175
|
+
spk: torch.Tensor,
|
|
176
|
+
cache: dict,
|
|
177
|
+
last_chunk: bool = False,
|
|
178
|
+
n_timesteps: int = 10,
|
|
179
|
+
):
|
|
180
|
+
"""
|
|
181
|
+
Args:
|
|
182
|
+
token: shape (b, t), with look ahead tokens
|
|
183
|
+
spk: shape (b, 192), speaker embedding
|
|
184
|
+
cache: dict {
|
|
185
|
+
'conformer_cnn_cache': xxx,
|
|
186
|
+
...
|
|
187
|
+
}
|
|
188
|
+
"""
|
|
189
|
+
# unpack cache
|
|
190
|
+
conformer_cnn_cache = cache['conformer_cnn_cache']
|
|
191
|
+
conformer_att_cache = cache['conformer_att_cache']
|
|
192
|
+
estimator_cnn_cache = cache['estimator_cnn_cache']
|
|
193
|
+
estimator_att_cache = cache['estimator_att_cache']
|
|
194
|
+
|
|
195
|
+
# xvec projection
|
|
196
|
+
spk = F.normalize(spk, dim=1)
|
|
197
|
+
spk = self.spk_embed_affine_layer(spk)
|
|
198
|
+
|
|
199
|
+
token = self.input_embedding(token)
|
|
200
|
+
# if not the last chunk, h is shorter than xs for a length of lookahead_length * stride (6)
|
|
201
|
+
h, conformer_cnn_cache, conformer_att_cache = self.encoder.forward_chunk(
|
|
202
|
+
xs = token,
|
|
203
|
+
last_chunk = last_chunk,
|
|
204
|
+
cnn_cache = conformer_cnn_cache,
|
|
205
|
+
att_cache = conformer_att_cache,
|
|
206
|
+
)
|
|
207
|
+
h = self.encoder_proj(h)
|
|
208
|
+
|
|
209
|
+
cond = torch.zeros_like(h)
|
|
210
|
+
# forward estimator
|
|
211
|
+
feat, estimator_cnn_cache, estimator_att_cache = self.decoder.forward_chunk(
|
|
212
|
+
mu = h.transpose(1, 2).contiguous(),
|
|
213
|
+
spks = spk,
|
|
214
|
+
cond = cond.transpose(1, 2).contiguous(),
|
|
215
|
+
n_timesteps = n_timesteps,
|
|
216
|
+
temperature = 1.0,
|
|
217
|
+
cnn_cache = estimator_cnn_cache,
|
|
218
|
+
att_cache = estimator_att_cache,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
new_cache = {
|
|
223
|
+
'conformer_cnn_cache': conformer_cnn_cache,
|
|
224
|
+
'conformer_att_cache': conformer_att_cache,
|
|
225
|
+
'estimator_cnn_cache': estimator_cnn_cache,
|
|
226
|
+
'estimator_att_cache': estimator_att_cache,
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
return feat, new_cache
|
|
230
|
+
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from typing import List
|
|
15
|
+
import onnxruntime
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn.functional as F
|
|
18
|
+
|
|
19
|
+
from stepaudio2.cosyvoice2.flow.decoder_dit import DiT
|
|
20
|
+
from stepaudio2.cosyvoice2.utils.mask import make_pad_mask
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
Inference wrapper
|
|
25
|
+
"""
|
|
26
|
+
class CausalConditionalCFM(torch.nn.Module):
|
|
27
|
+
def __init__(self, estimator: DiT, inference_cfg_rate:float=0.7):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.estimator = estimator
|
|
30
|
+
self.inference_cfg_rate = inference_cfg_rate
|
|
31
|
+
self.out_channels = estimator.out_channels
|
|
32
|
+
# a maximum of 600s
|
|
33
|
+
self.register_buffer('rand_noise', torch.randn([1, self.out_channels, 50 * 600]), persistent=False)
|
|
34
|
+
|
|
35
|
+
self.register_buffer('cnn_cache_buffer', torch.zeros(16, 16, 2, 1024, 2), persistent=False)
|
|
36
|
+
self.register_buffer('att_cache_buffer', torch.zeros(16, 16, 2, 8, 1000, 128), persistent=False)
|
|
37
|
+
|
|
38
|
+
def scatter_cuda_graph(self, enable_cuda_graph: bool):
|
|
39
|
+
if enable_cuda_graph:
|
|
40
|
+
self.estimator._init_cuda_graph_all()
|
|
41
|
+
|
|
42
|
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
|
43
|
+
"""
|
|
44
|
+
Fixed euler solver for ODEs.
|
|
45
|
+
Args:
|
|
46
|
+
x (torch.Tensor): random noise
|
|
47
|
+
t_span (torch.Tensor): n_timesteps interpolated
|
|
48
|
+
shape: (n_timesteps + 1,)
|
|
49
|
+
mu (torch.Tensor): output of encoder
|
|
50
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
51
|
+
mask (torch.Tensor): output_mask
|
|
52
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
53
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
54
|
+
shape: (batch_size, spk_emb_dim)
|
|
55
|
+
cond: Not used but kept for future purposes
|
|
56
|
+
"""
|
|
57
|
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
58
|
+
t = t.unsqueeze(dim=0)
|
|
59
|
+
assert self.inference_cfg_rate > 0, 'inference_cfg_rate better > 0'
|
|
60
|
+
|
|
61
|
+
# constant during denoising
|
|
62
|
+
mask_in = torch.cat([mask, mask], dim=0)
|
|
63
|
+
mu_in = torch.cat([mu, torch.zeros_like(mu)], dim=0)
|
|
64
|
+
spks_in = torch.cat([spks, torch.zeros_like(spks)], dim=0)
|
|
65
|
+
cond_in = torch.cat([cond, torch.zeros_like(cond)], dim=0)
|
|
66
|
+
|
|
67
|
+
for step in range(1, len(t_span)):
|
|
68
|
+
|
|
69
|
+
x_in = torch.cat([x, x], dim=0)
|
|
70
|
+
t_in = torch.cat([t, t], dim=0)
|
|
71
|
+
|
|
72
|
+
dphi_dt = self.estimator.forward(
|
|
73
|
+
x_in,
|
|
74
|
+
mask_in,
|
|
75
|
+
mu_in,
|
|
76
|
+
t_in,
|
|
77
|
+
spks_in,
|
|
78
|
+
cond_in,
|
|
79
|
+
)
|
|
80
|
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
|
81
|
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
|
82
|
+
x = x + dt * dphi_dt
|
|
83
|
+
t = t + dt
|
|
84
|
+
if step < len(t_span) - 1:
|
|
85
|
+
dt = t_span[step + 1] - t
|
|
86
|
+
|
|
87
|
+
return x
|
|
88
|
+
|
|
89
|
+
@torch.inference_mode()
|
|
90
|
+
def forward(self, mu, mask, spks, cond, n_timesteps=10, temperature=1.0):
|
|
91
|
+
z = self.rand_noise[:, :, :mu.size(2)] * temperature
|
|
92
|
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
|
93
|
+
# cosine scheduling
|
|
94
|
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
95
|
+
return self.solve_euler(z, t_span, mu, mask, spks, cond)
|
|
96
|
+
|
|
97
|
+
def solve_euler_chunk(self,
|
|
98
|
+
x:torch.Tensor,
|
|
99
|
+
t_span:torch.Tensor,
|
|
100
|
+
mu:torch.Tensor,
|
|
101
|
+
spks:torch.Tensor,
|
|
102
|
+
cond:torch.Tensor,
|
|
103
|
+
cnn_cache:torch.Tensor=None,
|
|
104
|
+
att_cache:torch.Tensor=None,
|
|
105
|
+
):
|
|
106
|
+
"""
|
|
107
|
+
Fixed euler solver for ODEs.
|
|
108
|
+
Args:
|
|
109
|
+
x (torch.Tensor): random noise
|
|
110
|
+
t_span (torch.Tensor): n_timesteps interpolated
|
|
111
|
+
shape: (n_timesteps + 1,)
|
|
112
|
+
mu (torch.Tensor): output of encoder
|
|
113
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
114
|
+
mask (torch.Tensor): output_mask
|
|
115
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
116
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
117
|
+
shape: (batch_size, spk_emb_dim)
|
|
118
|
+
cond: Not used but kept for future purposes
|
|
119
|
+
cnn_cache: shape (n_time, depth, b, c1+c2, 2)
|
|
120
|
+
att_cache: shape (n_time, depth, b, nh, t, c * 2)
|
|
121
|
+
"""
|
|
122
|
+
assert self.inference_cfg_rate > 0, 'cfg rate should be > 0'
|
|
123
|
+
|
|
124
|
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
125
|
+
t = t.unsqueeze(dim=0) # (b,)
|
|
126
|
+
|
|
127
|
+
# setup initial cache
|
|
128
|
+
if cnn_cache is None:
|
|
129
|
+
cnn_cache = [None for _ in range(len(t_span)-1)]
|
|
130
|
+
if att_cache is None:
|
|
131
|
+
att_cache = [None for _ in range(len(t_span)-1)]
|
|
132
|
+
# next chunk's cache at each timestep
|
|
133
|
+
|
|
134
|
+
if att_cache[0] is not None:
|
|
135
|
+
last_att_len = att_cache.shape[4]
|
|
136
|
+
else:
|
|
137
|
+
last_att_len = 0
|
|
138
|
+
|
|
139
|
+
# constant during denoising
|
|
140
|
+
mu_in = torch.cat([mu, torch.zeros_like(mu)], dim=0)
|
|
141
|
+
spks_in = torch.cat([spks, torch.zeros_like(spks)], dim=0)
|
|
142
|
+
cond_in = torch.cat([cond, torch.zeros_like(cond)], dim=0)
|
|
143
|
+
for step in range(1, len(t_span)):
|
|
144
|
+
# torch.cuda.memory._record_memory_history(max_entries=100000)
|
|
145
|
+
# torch.cuda.memory._record_memory_history(max_entries=100000)
|
|
146
|
+
this_att_cache = att_cache[step-1]
|
|
147
|
+
this_cnn_cache = cnn_cache[step-1]
|
|
148
|
+
|
|
149
|
+
dphi_dt, this_new_cnn_cache, this_new_att_cache = self.estimator.forward_chunk(
|
|
150
|
+
x = x.repeat(2, 1, 1),
|
|
151
|
+
mu = mu_in,
|
|
152
|
+
t = t.repeat(2),
|
|
153
|
+
spks = spks_in,
|
|
154
|
+
cond = cond_in,
|
|
155
|
+
cnn_cache = this_cnn_cache,
|
|
156
|
+
att_cache = this_att_cache,
|
|
157
|
+
)
|
|
158
|
+
dphi_dt, cfg_dphi_dt = dphi_dt.chunk(2, dim=0)
|
|
159
|
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
|
160
|
+
x = x + dt * dphi_dt
|
|
161
|
+
t = t + dt
|
|
162
|
+
if step < len(t_span) - 1:
|
|
163
|
+
dt = t_span[step + 1] - t
|
|
164
|
+
|
|
165
|
+
self.cnn_cache_buffer[step-1] = this_new_cnn_cache
|
|
166
|
+
self.att_cache_buffer[step-1][:, :, :, :x.shape[2]+last_att_len, :] = this_new_att_cache
|
|
167
|
+
|
|
168
|
+
cnn_cache = self.cnn_cache_buffer
|
|
169
|
+
att_cache = self.att_cache_buffer[:, :, :, :, :x.shape[2]+last_att_len, :]
|
|
170
|
+
return x, cnn_cache, att_cache
|
|
171
|
+
|
|
172
|
+
@torch.inference_mode()
|
|
173
|
+
def forward_chunk(self,
|
|
174
|
+
mu:torch.Tensor,
|
|
175
|
+
spks:torch.Tensor,
|
|
176
|
+
cond:torch.Tensor,
|
|
177
|
+
n_timesteps:int=10,
|
|
178
|
+
temperature:float=1.0,
|
|
179
|
+
cnn_cache:torch.Tensor=None,
|
|
180
|
+
att_cache:torch.Tensor=None,
|
|
181
|
+
):
|
|
182
|
+
"""
|
|
183
|
+
Args:
|
|
184
|
+
mu(torch.Tensor): shape (b, c, t)
|
|
185
|
+
spks(torch.Tensor): shape (b, 192)
|
|
186
|
+
cond(torch.Tensor): shape (b, c, t)
|
|
187
|
+
cnn_cache: shape (n_time, depth, b, c1+c2, 2)
|
|
188
|
+
att_cache: shape (n_time, depth, b, nh, t, c * 2)
|
|
189
|
+
"""
|
|
190
|
+
# get offset from att_cache
|
|
191
|
+
offset = att_cache.shape[4] if att_cache is not None else 0
|
|
192
|
+
z = self.rand_noise[:, :, offset:offset+mu.size(2)] * temperature
|
|
193
|
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
|
194
|
+
# cosine scheduling
|
|
195
|
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
196
|
+
x, new_cnn_cache, new_att_cache = self.solve_euler_chunk(
|
|
197
|
+
x=z,
|
|
198
|
+
t_span=t_span,
|
|
199
|
+
mu=mu,
|
|
200
|
+
spks=spks,
|
|
201
|
+
cond=cond,
|
|
202
|
+
att_cache=att_cache,
|
|
203
|
+
cnn_cache=cnn_cache,
|
|
204
|
+
)
|
|
205
|
+
return x, new_cnn_cache, new_att_cache
|
|
File without changes
|