minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
stepaudio2/token2wav.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import sys
|
|
3
|
+
import types
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torchaudio
|
|
7
|
+
import s3tokenizer
|
|
8
|
+
import onnxruntime
|
|
9
|
+
import numpy as np
|
|
10
|
+
from copy import deepcopy
|
|
11
|
+
|
|
12
|
+
import torchaudio.compliance.kaldi as kaldi
|
|
13
|
+
from stepaudio2.flashcosyvoice.modules.hifigan import HiFTGenerator
|
|
14
|
+
from stepaudio2.flashcosyvoice.utils.audio import mel_spectrogram
|
|
15
|
+
from hyperpyyaml import load_hyperpyyaml
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _setup_cosyvoice2_alias():
|
|
19
|
+
"""给 hyperpyyaml 提供 cosyvoice2.* 的兼容别名,不改 flow.yaml。"""
|
|
20
|
+
if 'cosyvoice2.flow.flow' in sys.modules:
|
|
21
|
+
# 已经设置过,直接复用
|
|
22
|
+
return
|
|
23
|
+
|
|
24
|
+
# 导入 stepaudio2 里真实的实现
|
|
25
|
+
import stepaudio2.cosyvoice2.flow.flow as _step_flow
|
|
26
|
+
import stepaudio2.cosyvoice2.flow.flow_matching as _step_flow_matching
|
|
27
|
+
import stepaudio2.cosyvoice2.flow.decoder_dit as _step_decoder_dit
|
|
28
|
+
import stepaudio2.cosyvoice2.transformer.upsample_encoder_v2 as _step_upsample
|
|
29
|
+
|
|
30
|
+
# 创建顶层 cosyvoice2 包和子包
|
|
31
|
+
cosyvoice2_pkg = types.ModuleType('cosyvoice2')
|
|
32
|
+
cosyvoice2_flow_pkg = types.ModuleType('cosyvoice2.flow')
|
|
33
|
+
cosyvoice2_transformer_pkg = types.ModuleType('cosyvoice2.transformer')
|
|
34
|
+
|
|
35
|
+
# 挂载子模块
|
|
36
|
+
cosyvoice2_flow_pkg.flow = _step_flow
|
|
37
|
+
cosyvoice2_flow_pkg.flow_matching = _step_flow_matching
|
|
38
|
+
cosyvoice2_flow_pkg.decoder_dit = _step_decoder_dit
|
|
39
|
+
|
|
40
|
+
cosyvoice2_transformer_pkg.upsample_encoder_v2 = _step_upsample
|
|
41
|
+
|
|
42
|
+
cosyvoice2_pkg.flow = cosyvoice2_flow_pkg
|
|
43
|
+
cosyvoice2_pkg.transformer = cosyvoice2_transformer_pkg
|
|
44
|
+
|
|
45
|
+
# 注册到 sys.modules,让 `cosyvoice2.flow.flow.*` 这类路径可被 import
|
|
46
|
+
sys.modules['cosyvoice2'] = cosyvoice2_pkg
|
|
47
|
+
sys.modules['cosyvoice2.flow'] = cosyvoice2_flow_pkg
|
|
48
|
+
sys.modules['cosyvoice2.flow.flow'] = _step_flow
|
|
49
|
+
sys.modules['cosyvoice2.flow.flow_matching'] = _step_flow_matching
|
|
50
|
+
sys.modules['cosyvoice2.flow.decoder_dit'] = _step_decoder_dit
|
|
51
|
+
sys.modules['cosyvoice2.transformer'] = cosyvoice2_transformer_pkg
|
|
52
|
+
sys.modules['cosyvoice2.transformer.upsample_encoder_v2'] = _step_upsample
|
|
53
|
+
|
|
54
|
+
def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor):
|
|
55
|
+
"""perform fade_in_out in tensor style
|
|
56
|
+
"""
|
|
57
|
+
mel_overlap_len = int(window.shape[0] / 2)
|
|
58
|
+
fade_in_mel = fade_in_mel.clone()
|
|
59
|
+
fade_in_mel[..., :mel_overlap_len] = \
|
|
60
|
+
fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
|
|
61
|
+
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
|
62
|
+
return fade_in_mel
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class Token2wav():
|
|
66
|
+
def __init__(self, model_path, float16=False, n_timesteps=10):
|
|
67
|
+
self.float16 = float16
|
|
68
|
+
self.n_timesteps = n_timesteps
|
|
69
|
+
|
|
70
|
+
# 在加载 flow.yaml 之前,先把 cosyvoice2 的别名注册好
|
|
71
|
+
_setup_cosyvoice2_alias()
|
|
72
|
+
|
|
73
|
+
self.audio_tokenizer = s3tokenizer.load_model(f"{model_path}/speech_tokenizer_v2_25hz.onnx").cuda().eval()
|
|
74
|
+
|
|
75
|
+
option = onnxruntime.SessionOptions()
|
|
76
|
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
77
|
+
option.intra_op_num_threads = 1
|
|
78
|
+
self.spk_model = onnxruntime.InferenceSession(f"{model_path}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"])
|
|
79
|
+
|
|
80
|
+
with open(f"{model_path}/flow.yaml", "r") as f:
|
|
81
|
+
configs = load_hyperpyyaml(f)
|
|
82
|
+
self.flow = configs['flow']
|
|
83
|
+
if float16:
|
|
84
|
+
self.flow.half()
|
|
85
|
+
self.flow.load_state_dict(torch.load(f"{model_path}/flow.pt", map_location="cpu", weights_only=True), strict=True)
|
|
86
|
+
self.flow.cuda().eval()
|
|
87
|
+
|
|
88
|
+
self.hift = HiFTGenerator()
|
|
89
|
+
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_path}/hift.pt", map_location="cpu", weights_only=True).items()}
|
|
90
|
+
self.hift.load_state_dict(hift_state_dict, strict=True)
|
|
91
|
+
self.hift.cuda().eval()
|
|
92
|
+
|
|
93
|
+
self.cache = None
|
|
94
|
+
|
|
95
|
+
# stream conf
|
|
96
|
+
self.mel_cache_len = 8 # hard-coded, 160ms
|
|
97
|
+
self.source_cache_len = int(self.mel_cache_len * 480) # 50hz mel -> 24kHz wave
|
|
98
|
+
self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda()
|
|
99
|
+
|
|
100
|
+
# hifigan cache
|
|
101
|
+
self.hift_cache_dict = {}
|
|
102
|
+
|
|
103
|
+
def _prepare_prompt(self, prompt_wav):
|
|
104
|
+
audio = s3tokenizer.load_audio(prompt_wav, sr=16000) # [T]
|
|
105
|
+
# TODO 在audio 后面 pad 3个token长度 = 0.12s的音频
|
|
106
|
+
mels = s3tokenizer.log_mel_spectrogram(audio)
|
|
107
|
+
mels, mels_lens = s3tokenizer.padding([mels])
|
|
108
|
+
prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(mels.cuda(), mels_lens.cuda())
|
|
109
|
+
# TODO 手动赋值最后3个token为 4218 静音token
|
|
110
|
+
|
|
111
|
+
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
|
|
112
|
+
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
|
|
113
|
+
spk_emb = torch.tensor(self.spk_model.run(
|
|
114
|
+
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
|
|
115
|
+
)[0], device='cuda')
|
|
116
|
+
|
|
117
|
+
audio, sample_rate = torchaudio.load(prompt_wav, backend='soundfile')
|
|
118
|
+
audio = audio.mean(dim=0, keepdim=True) # [1, T]
|
|
119
|
+
if sample_rate != 24000:
|
|
120
|
+
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
|
|
121
|
+
prompt_mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
|
|
122
|
+
prompt_mels = prompt_mel.unsqueeze(0).cuda()
|
|
123
|
+
prompt_mels_lens = torch.tensor([prompt_mels.shape[1]], dtype=torch.int32, device='cuda')
|
|
124
|
+
prompt_mels = torch.nn.functional.pad(prompt_mels, (0, 0, 0, prompt_speech_tokens.shape[1] * self.flow.up_rate - prompt_mels.shape[1]), mode='replicate')
|
|
125
|
+
return prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens
|
|
126
|
+
|
|
127
|
+
def __call__(self, generated_speech_tokens, prompt_wav):
|
|
128
|
+
if self.cache is None:
|
|
129
|
+
self.cache = self._prepare_prompt(prompt_wav)
|
|
130
|
+
prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens = self.cache
|
|
131
|
+
|
|
132
|
+
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
|
|
133
|
+
generated_speech_tokens_lens = torch.tensor([generated_speech_tokens.shape[1]], dtype=torch.int32, device='cuda')
|
|
134
|
+
|
|
135
|
+
with torch.amp.autocast("cuda", dtype=torch.float16 if self.float16 else torch.float32):
|
|
136
|
+
mel = self.flow.inference(generated_speech_tokens, generated_speech_tokens_lens,
|
|
137
|
+
prompt_speech_tokens, prompt_speech_tokens_lens,
|
|
138
|
+
prompt_mels, prompt_mels_lens, spk_emb, 10)
|
|
139
|
+
|
|
140
|
+
wav, _ = self.hift(speech_feat=mel)
|
|
141
|
+
output = io.BytesIO()
|
|
142
|
+
torchaudio.save(output, wav.cpu(), sample_rate=24000, format='wav')
|
|
143
|
+
|
|
144
|
+
return output.getvalue()
|
|
145
|
+
|
|
146
|
+
def set_stream_cache(self, prompt_wav):
|
|
147
|
+
self.cache = self._prepare_prompt(prompt_wav)
|
|
148
|
+
prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens = self.cache
|
|
149
|
+
|
|
150
|
+
print("prompt_speech_tokens_lens", prompt_speech_tokens_lens)
|
|
151
|
+
print("prompt_mels.shape", prompt_mels.shape)
|
|
152
|
+
print("prompt_mels_lens", prompt_mels_lens)
|
|
153
|
+
|
|
154
|
+
right_pad_speech_tokens = torch.ones(1, 3, device=prompt_speech_tokens.device, dtype=prompt_speech_tokens.dtype) * 4218
|
|
155
|
+
|
|
156
|
+
# self.stream_cache = self.flow.setup_cache(
|
|
157
|
+
# torch.cat([prompt_speech_tokens, prompt_speech_tokens[:, :3]], dim=1),
|
|
158
|
+
# prompt_mels, spk_emb, n_timesteps=self.n_timesteps)
|
|
159
|
+
|
|
160
|
+
stream_cache = self.flow.setup_cache(
|
|
161
|
+
torch.cat([prompt_speech_tokens, right_pad_speech_tokens], dim=1),
|
|
162
|
+
prompt_mels, spk_emb, n_timesteps=self.n_timesteps)
|
|
163
|
+
|
|
164
|
+
# hift cache
|
|
165
|
+
hift_cache_dict = dict(
|
|
166
|
+
mel = torch.zeros(1, prompt_mels.shape[2], 0, device='cuda'),
|
|
167
|
+
source = torch.zeros(1, 1, 0, device='cuda'),
|
|
168
|
+
speech = torch.zeros(1, 0, device='cuda'),
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
return stream_cache, hift_cache_dict
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def stream(self, generated_speech_tokens, prompt_wav, last_chunk=False, return_waveform=False):
|
|
175
|
+
if self.cache is None:
|
|
176
|
+
self.cache = self._prepare_prompt(prompt_wav)
|
|
177
|
+
prompt_speech_tokens, prompt_speech_tokens_lens, spk_emb, prompt_mels, prompt_mels_lens = self.cache
|
|
178
|
+
|
|
179
|
+
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
|
|
180
|
+
generated_speech_tokens_lens = torch.tensor([generated_speech_tokens.shape[1]], dtype=torch.int32, device='cuda')
|
|
181
|
+
|
|
182
|
+
if self.stream_cache is None:
|
|
183
|
+
raise ValueError("stream_cache is not set")
|
|
184
|
+
|
|
185
|
+
with torch.amp.autocast("cuda", dtype=torch.float16 if self.float16 else torch.float32):
|
|
186
|
+
chunk_mel, self.stream_cache = self.flow.inference_chunk(
|
|
187
|
+
token=generated_speech_tokens,
|
|
188
|
+
spk=spk_emb,
|
|
189
|
+
cache=self.stream_cache,
|
|
190
|
+
last_chunk=last_chunk,
|
|
191
|
+
n_timesteps=self.n_timesteps,
|
|
192
|
+
)
|
|
193
|
+
if self.stream_cache['estimator_att_cache'].shape[4] > (prompt_mels.shape[1] + 100):
|
|
194
|
+
self.stream_cache['estimator_att_cache'] = torch.cat([
|
|
195
|
+
self.stream_cache['estimator_att_cache'][:, :, :, :, :prompt_mels.shape[1]],
|
|
196
|
+
self.stream_cache['estimator_att_cache'][:, :, :, :, -100:],
|
|
197
|
+
], dim=4)
|
|
198
|
+
|
|
199
|
+
# bug fix - 20260107
|
|
200
|
+
# 同样截断 conformer_att_cache,防止无限增长导致 position embedding 不匹配
|
|
201
|
+
if self.stream_cache['conformer_att_cache'].shape[3] > (prompt_mels.shape[1] + 100):
|
|
202
|
+
self.stream_cache['conformer_att_cache'] = torch.cat([
|
|
203
|
+
self.stream_cache['conformer_att_cache'][:, :, :, :prompt_mels.shape[1], :],
|
|
204
|
+
self.stream_cache['conformer_att_cache'][:, :, :, -100:, :],
|
|
205
|
+
], dim=3)
|
|
206
|
+
|
|
207
|
+
# vocoder cache
|
|
208
|
+
hift_cache_mel = self.hift_cache_dict['mel']
|
|
209
|
+
hift_cache_source = self.hift_cache_dict['source']
|
|
210
|
+
hift_cache_speech = self.hift_cache_dict['speech']
|
|
211
|
+
mel = torch.concat([hift_cache_mel, chunk_mel], dim=2)
|
|
212
|
+
|
|
213
|
+
speech, source = self.hift(mel, hift_cache_source)
|
|
214
|
+
|
|
215
|
+
# overlap speech smooth
|
|
216
|
+
if hift_cache_speech.shape[-1] > 0:
|
|
217
|
+
speech = fade_in_out(speech, hift_cache_speech, self.speech_window)
|
|
218
|
+
|
|
219
|
+
# 检测是否是第一个 chunk(没有有效的 speech cache)
|
|
220
|
+
is_first_chunk = hift_cache_speech.shape[-1] == 0
|
|
221
|
+
|
|
222
|
+
# update vocoder cache
|
|
223
|
+
self.hift_cache_dict = dict(
|
|
224
|
+
mel = mel[..., -self.mel_cache_len:].clone().detach(),
|
|
225
|
+
source = source[:, :, -self.source_cache_len:].clone().detach(),
|
|
226
|
+
speech = speech[:, -self.source_cache_len:].clone().detach(),
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if not last_chunk:
|
|
230
|
+
if is_first_chunk:
|
|
231
|
+
# 第一个 chunk:截断尾部,在开头添加静音来补偿长度
|
|
232
|
+
# 不做淡入,保持原始音频内容不变
|
|
233
|
+
silence_padding = torch.zeros(1, self.source_cache_len, device=speech.device)
|
|
234
|
+
speech = torch.cat([silence_padding, speech[:, :-self.source_cache_len]], dim=1)
|
|
235
|
+
else:
|
|
236
|
+
# 后续 chunk:正常截断尾部(由下一个 chunk 的 fade_in 补偿)
|
|
237
|
+
speech = speech[:, :-self.source_cache_len]
|
|
238
|
+
|
|
239
|
+
wav_np = speech.cpu().numpy()
|
|
240
|
+
if return_waveform:
|
|
241
|
+
return wav_np
|
|
242
|
+
|
|
243
|
+
# Clip to [-1, 1] to avoid overflow, then scale to int16
|
|
244
|
+
wav_np = np.clip(wav_np, -1.0, 1.0)
|
|
245
|
+
wav_int16 = (wav_np * 32767.0).astype('<i2') # 16-bit little-endian PCM
|
|
246
|
+
pcm_bytes = wav_int16.tobytes()
|
|
247
|
+
return pcm_bytes
|
|
248
|
+
|
stepaudio2/utils.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import librosa
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
5
|
+
import torchaudio
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _mel_filters(n_mels: int) -> torch.Tensor:
|
|
10
|
+
"""Load the mel filterbank matrix for projecting STFT into a Mel spectrogram."""
|
|
11
|
+
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
|
12
|
+
if n_mels == 128:
|
|
13
|
+
return torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=128))
|
|
14
|
+
else:
|
|
15
|
+
return torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=80))
|
|
16
|
+
|
|
17
|
+
def load_audio(file_path, target_rate=16000, max_length=None):
|
|
18
|
+
"""
|
|
19
|
+
Open an audio file and read as mono waveform, resampling as necessary
|
|
20
|
+
If max_length is provided, truncate the audio to that length
|
|
21
|
+
"""
|
|
22
|
+
waveform, sample_rate = torchaudio.load(file_path)
|
|
23
|
+
if sample_rate != target_rate:
|
|
24
|
+
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_rate)(waveform)
|
|
25
|
+
audio = waveform[0] # get the first channel
|
|
26
|
+
|
|
27
|
+
# Truncate audio if it exceeds max_length
|
|
28
|
+
if max_length is not None and audio.shape[0] > max_length:
|
|
29
|
+
audio = audio[:max_length]
|
|
30
|
+
|
|
31
|
+
return audio
|
|
32
|
+
|
|
33
|
+
def log_mel_spectrogram(audio, n_mels=128, padding=479, device=None):
|
|
34
|
+
"""
|
|
35
|
+
Compute the log-Mel spectrogram with specific padding for StepAudio
|
|
36
|
+
"""
|
|
37
|
+
if not torch.is_tensor(audio):
|
|
38
|
+
if isinstance(audio, str):
|
|
39
|
+
audio = load_audio(audio)
|
|
40
|
+
audio = torch.from_numpy(audio)
|
|
41
|
+
if device is not None:
|
|
42
|
+
audio = audio.to(device)
|
|
43
|
+
if padding > 0:
|
|
44
|
+
audio = F.pad(audio, (0, padding))
|
|
45
|
+
window = torch.hann_window(400).to(audio.device)
|
|
46
|
+
stft = torch.stft(audio, 400, 160, window=window, return_complex=True)
|
|
47
|
+
magnitudes = stft[..., :-1].abs() ** 2
|
|
48
|
+
filters = _mel_filters(n_mels)
|
|
49
|
+
mel_spec = filters @ magnitudes
|
|
50
|
+
|
|
51
|
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
|
52
|
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
|
53
|
+
log_spec = (log_spec + 4.0) / 4.0
|
|
54
|
+
return log_spec
|
|
55
|
+
|
|
56
|
+
def compute_token_num(max_feature_len):
|
|
57
|
+
# First, audio goes through encoder:
|
|
58
|
+
# 1. conv1: kernel=3, stride=1, padding=1 -> size unchanged
|
|
59
|
+
# 2. conv2: kernel=3, stride=2, padding=1 -> size/2
|
|
60
|
+
# 3. avg_pooler: kernel=2, stride=2 -> size/2
|
|
61
|
+
max_feature_len = max_feature_len - 2 # remove padding
|
|
62
|
+
encoder_output_dim = (max_feature_len + 1) // 2 // 2 # after conv2 and avg_pooler
|
|
63
|
+
|
|
64
|
+
# Then through adaptor (parameters from config file):
|
|
65
|
+
padding = 1
|
|
66
|
+
kernel_size = 3 # from config: audio_encoder_config.kernel_size
|
|
67
|
+
stride = 2 # from config: audio_encoder_config.adapter_stride
|
|
68
|
+
adapter_output_dim = (encoder_output_dim + 2 * padding - kernel_size) // stride + 1
|
|
69
|
+
return adapter_output_dim
|
|
70
|
+
|
|
71
|
+
def padding_mels(data: List[torch.Tensor]):
|
|
72
|
+
""" Padding the data into batch data
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
data: List[Tensor], shape of Tensor (128, T)
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
-------
|
|
80
|
+
feats, feats lengths
|
|
81
|
+
"""
|
|
82
|
+
sample = data
|
|
83
|
+
assert isinstance(sample, list)
|
|
84
|
+
feats_lengths = torch.tensor([s.size(1)-2 for s in sample],
|
|
85
|
+
dtype=torch.int32)
|
|
86
|
+
feats = [s.t() for s in sample]
|
|
87
|
+
padded_feats = pad_sequence(feats,
|
|
88
|
+
batch_first=True,
|
|
89
|
+
padding_value=0)
|
|
90
|
+
|
|
91
|
+
return padded_feats.transpose(1, 2), feats_lengths
|