minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import os
|
|
15
|
+
import time
|
|
16
|
+
from typing import Generator
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
from hyperpyyaml import load_hyperpyyaml
|
|
19
|
+
from modelscope import snapshot_download
|
|
20
|
+
import torch
|
|
21
|
+
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
|
22
|
+
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
|
23
|
+
from cosyvoice.utils.file_utils import logging
|
|
24
|
+
from cosyvoice.utils.class_utils import get_model_type
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class CosyVoice:
|
|
28
|
+
|
|
29
|
+
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
|
30
|
+
self.instruct = True if '-Instruct' in model_dir else False
|
|
31
|
+
self.model_dir = model_dir
|
|
32
|
+
self.fp16 = fp16
|
|
33
|
+
if not os.path.exists(model_dir):
|
|
34
|
+
model_dir = snapshot_download(model_dir)
|
|
35
|
+
hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
|
|
36
|
+
if not os.path.exists(hyper_yaml_path):
|
|
37
|
+
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
|
38
|
+
with open(hyper_yaml_path, 'r') as f:
|
|
39
|
+
configs = load_hyperpyyaml(f)
|
|
40
|
+
assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
|
|
41
|
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
|
42
|
+
configs['feat_extractor'],
|
|
43
|
+
'{}/campplus.onnx'.format(model_dir),
|
|
44
|
+
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
|
|
45
|
+
'{}/spk2info.pt'.format(model_dir),
|
|
46
|
+
configs['allowed_special'])
|
|
47
|
+
self.sample_rate = configs['sample_rate']
|
|
48
|
+
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
|
49
|
+
load_jit, load_trt, fp16 = False, False, False
|
|
50
|
+
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
|
51
|
+
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
|
52
|
+
self.model.load('{}/llm.pt'.format(model_dir),
|
|
53
|
+
'{}/flow.pt'.format(model_dir),
|
|
54
|
+
'{}/hift.pt'.format(model_dir))
|
|
55
|
+
if load_jit:
|
|
56
|
+
self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
|
57
|
+
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
|
58
|
+
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
|
59
|
+
if load_trt:
|
|
60
|
+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
|
61
|
+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
|
62
|
+
trt_concurrent,
|
|
63
|
+
self.fp16)
|
|
64
|
+
del configs
|
|
65
|
+
|
|
66
|
+
def list_available_spks(self):
|
|
67
|
+
spks = list(self.frontend.spk2info.keys())
|
|
68
|
+
return spks
|
|
69
|
+
|
|
70
|
+
def add_zero_shot_spk(self, prompt_text, prompt_speech_16k, zero_shot_spk_id):
|
|
71
|
+
assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
|
|
72
|
+
model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_speech_16k, self.sample_rate, '')
|
|
73
|
+
del model_input['text']
|
|
74
|
+
del model_input['text_len']
|
|
75
|
+
self.frontend.spk2info[zero_shot_spk_id] = model_input
|
|
76
|
+
return True
|
|
77
|
+
|
|
78
|
+
def save_spkinfo(self):
|
|
79
|
+
torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))
|
|
80
|
+
|
|
81
|
+
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
|
82
|
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
83
|
+
model_input = self.frontend.frontend_sft(i, spk_id)
|
|
84
|
+
start_time = time.time()
|
|
85
|
+
logging.info('synthesis text {}'.format(i))
|
|
86
|
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
87
|
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
88
|
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
89
|
+
yield model_output
|
|
90
|
+
start_time = time.time()
|
|
91
|
+
|
|
92
|
+
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
|
93
|
+
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
|
94
|
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
95
|
+
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
|
|
96
|
+
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
|
97
|
+
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
|
|
98
|
+
start_time = time.time()
|
|
99
|
+
logging.info('synthesis text {}'.format(i))
|
|
100
|
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
101
|
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
102
|
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
103
|
+
yield model_output
|
|
104
|
+
start_time = time.time()
|
|
105
|
+
|
|
106
|
+
def inference_cross_lingual(self, tts_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
|
107
|
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
108
|
+
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
|
|
109
|
+
start_time = time.time()
|
|
110
|
+
logging.info('synthesis text {}'.format(i))
|
|
111
|
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
112
|
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
113
|
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
114
|
+
yield model_output
|
|
115
|
+
start_time = time.time()
|
|
116
|
+
|
|
117
|
+
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
|
|
118
|
+
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
|
|
119
|
+
if self.instruct is False:
|
|
120
|
+
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
|
121
|
+
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
|
|
122
|
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
123
|
+
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
|
124
|
+
start_time = time.time()
|
|
125
|
+
logging.info('synthesis text {}'.format(i))
|
|
126
|
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
127
|
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
128
|
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
129
|
+
yield model_output
|
|
130
|
+
start_time = time.time()
|
|
131
|
+
|
|
132
|
+
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
|
133
|
+
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
|
|
134
|
+
start_time = time.time()
|
|
135
|
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
136
|
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
137
|
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
138
|
+
yield model_output
|
|
139
|
+
start_time = time.time()
|
|
140
|
+
|
|
141
|
+
def token2wav(self, speech_token, speech_token_len, prompt_speech_16k, stream=False, speed=1.0):
|
|
142
|
+
model_input = self.frontend.frontend_token2wav(speech_token, speech_token_len, prompt_speech_16k, self.sample_rate)
|
|
143
|
+
start_time = time.time()
|
|
144
|
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
145
|
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
146
|
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
147
|
+
yield model_output
|
|
148
|
+
start_time = time.time()
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class CosyVoice2(CosyVoice):
|
|
152
|
+
|
|
153
|
+
def __init__(self, model_dir, config_path=None, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
|
|
154
|
+
self.instruct = True if '-Instruct' in model_dir else False
|
|
155
|
+
self.model_dir = model_dir
|
|
156
|
+
self.fp16 = fp16
|
|
157
|
+
if not os.path.exists(model_dir):
|
|
158
|
+
model_dir = snapshot_download(model_dir)
|
|
159
|
+
|
|
160
|
+
if config_path is None:
|
|
161
|
+
config_path = f'{model_dir}/cosyvoice2.yaml'
|
|
162
|
+
hyper_yaml_path = config_path
|
|
163
|
+
|
|
164
|
+
print(f"config_path={config_path}")
|
|
165
|
+
|
|
166
|
+
if not os.path.exists(hyper_yaml_path):
|
|
167
|
+
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
|
168
|
+
with open(hyper_yaml_path, 'r') as f:
|
|
169
|
+
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
|
170
|
+
# assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
|
|
171
|
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
|
172
|
+
configs['feat_extractor'],
|
|
173
|
+
'{}/campplus.onnx'.format(model_dir),
|
|
174
|
+
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
|
|
175
|
+
'{}/spk2info.pt'.format(model_dir),
|
|
176
|
+
configs['allowed_special'])
|
|
177
|
+
self.sample_rate = configs['sample_rate']
|
|
178
|
+
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
|
179
|
+
load_jit, load_trt, fp16 = False, False, False
|
|
180
|
+
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
|
181
|
+
self.model = CosyVoice2Model(None, configs['flow'], configs['hift'], fp16)
|
|
182
|
+
self.model.load('{}/llm.pt'.format(model_dir),
|
|
183
|
+
'{}/flow.pt'.format(model_dir),
|
|
184
|
+
'{}/hift.pt'.format(model_dir))
|
|
185
|
+
if load_vllm:
|
|
186
|
+
self.model.load_vllm('{}/vllm'.format(model_dir))
|
|
187
|
+
if load_jit:
|
|
188
|
+
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
|
189
|
+
if load_trt:
|
|
190
|
+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
|
191
|
+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
|
192
|
+
trt_concurrent,
|
|
193
|
+
self.fp16)
|
|
194
|
+
del configs
|
|
195
|
+
|
|
196
|
+
def inference_instruct(self, *args, **kwargs):
|
|
197
|
+
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
|
|
198
|
+
|
|
199
|
+
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
|
200
|
+
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
|
|
201
|
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
202
|
+
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
|
|
203
|
+
start_time = time.time()
|
|
204
|
+
logging.info('synthesis text {}'.format(i))
|
|
205
|
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
206
|
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
207
|
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
208
|
+
yield model_output
|
|
209
|
+
start_time = time.time()
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from functools import partial
|
|
15
|
+
from typing import Generator
|
|
16
|
+
import json
|
|
17
|
+
import onnxruntime
|
|
18
|
+
import torch
|
|
19
|
+
import numpy as np
|
|
20
|
+
import whisper
|
|
21
|
+
from typing import Callable
|
|
22
|
+
import torchaudio.compliance.kaldi as kaldi
|
|
23
|
+
import torchaudio
|
|
24
|
+
import os
|
|
25
|
+
import re
|
|
26
|
+
import inflect
|
|
27
|
+
# try:
|
|
28
|
+
# import ttsfrd
|
|
29
|
+
# use_ttsfrd = True
|
|
30
|
+
# except ImportError:
|
|
31
|
+
# print("failed to import ttsfrd, use WeTextProcessing instead")
|
|
32
|
+
# from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
|
33
|
+
# from tn.english.normalizer import Normalizer as EnNormalizer
|
|
34
|
+
# use_ttsfrd = False
|
|
35
|
+
from cosyvoice.utils.file_utils import logging
|
|
36
|
+
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class CosyVoiceFrontEnd:
|
|
40
|
+
|
|
41
|
+
def __init__(self,
|
|
42
|
+
get_tokenizer: Callable,
|
|
43
|
+
feat_extractor: Callable,
|
|
44
|
+
campplus_model: str,
|
|
45
|
+
speech_tokenizer_model: str,
|
|
46
|
+
spk2info: str = '',
|
|
47
|
+
allowed_special: str = 'all'):
|
|
48
|
+
self.tokenizer = get_tokenizer()
|
|
49
|
+
self.feat_extractor = feat_extractor
|
|
50
|
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
51
|
+
option = onnxruntime.SessionOptions()
|
|
52
|
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
53
|
+
option.intra_op_num_threads = 1
|
|
54
|
+
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
|
55
|
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
|
56
|
+
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
|
57
|
+
"CPUExecutionProvider"])
|
|
58
|
+
if os.path.exists(spk2info):
|
|
59
|
+
self.spk2info = torch.load(spk2info, map_location=self.device)
|
|
60
|
+
else:
|
|
61
|
+
self.spk2info = {}
|
|
62
|
+
self.allowed_special = allowed_special
|
|
63
|
+
# self.use_ttsfrd = use_ttsfrd
|
|
64
|
+
# if self.use_ttsfrd:
|
|
65
|
+
# self.frd = ttsfrd.TtsFrontendEngine()
|
|
66
|
+
# ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
67
|
+
# assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
|
68
|
+
# 'failed to initialize ttsfrd resource'
|
|
69
|
+
# self.frd.set_lang_type('pinyinvg')
|
|
70
|
+
# else:
|
|
71
|
+
# self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
|
|
72
|
+
# self.en_tn_model = EnNormalizer()
|
|
73
|
+
# self.inflect_parser = inflect.engine()
|
|
74
|
+
|
|
75
|
+
def _extract_text_token(self, text):
|
|
76
|
+
if isinstance(text, Generator):
|
|
77
|
+
logging.info('get tts_text generator, will return _extract_text_token_generator!')
|
|
78
|
+
# NOTE add a dummy text_token_len for compatibility
|
|
79
|
+
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
|
|
80
|
+
else:
|
|
81
|
+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
|
82
|
+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
|
83
|
+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
|
84
|
+
return text_token, text_token_len
|
|
85
|
+
|
|
86
|
+
def _extract_text_token_generator(self, text_generator):
|
|
87
|
+
for text in text_generator:
|
|
88
|
+
text_token, _ = self._extract_text_token(text)
|
|
89
|
+
for i in range(text_token.shape[1]):
|
|
90
|
+
yield text_token[:, i: i + 1]
|
|
91
|
+
|
|
92
|
+
def _extract_speech_token(self, speech):
|
|
93
|
+
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
|
94
|
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
|
95
|
+
speech_token = self.speech_tokenizer_session.run(None,
|
|
96
|
+
{self.speech_tokenizer_session.get_inputs()[0].name:
|
|
97
|
+
feat.detach().cpu().numpy(),
|
|
98
|
+
self.speech_tokenizer_session.get_inputs()[1].name:
|
|
99
|
+
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
|
100
|
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
|
101
|
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
|
102
|
+
return speech_token, speech_token_len
|
|
103
|
+
|
|
104
|
+
def _extract_spk_embedding(self, speech):
|
|
105
|
+
feat = kaldi.fbank(speech,
|
|
106
|
+
num_mel_bins=80,
|
|
107
|
+
dither=0,
|
|
108
|
+
sample_frequency=16000)
|
|
109
|
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
|
110
|
+
embedding = self.campplus_session.run(None,
|
|
111
|
+
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
|
112
|
+
embedding = torch.tensor([embedding]).to(self.device)
|
|
113
|
+
return embedding
|
|
114
|
+
|
|
115
|
+
def _extract_speech_feat(self, speech):
|
|
116
|
+
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
|
117
|
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
|
118
|
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
|
119
|
+
return speech_feat, speech_feat_len
|
|
120
|
+
|
|
121
|
+
def text_normalize(self, text, split=True, text_frontend=True):
|
|
122
|
+
if isinstance(text, Generator):
|
|
123
|
+
logging.info('get tts_text generator, will skip text_normalize!')
|
|
124
|
+
return [text]
|
|
125
|
+
if text_frontend is False or text == '':
|
|
126
|
+
return [text] if split is True else text
|
|
127
|
+
text = text.strip()
|
|
128
|
+
if self.use_ttsfrd:
|
|
129
|
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
|
130
|
+
text = ''.join(texts)
|
|
131
|
+
else:
|
|
132
|
+
if contains_chinese(text):
|
|
133
|
+
text = self.zh_tn_model.normalize(text)
|
|
134
|
+
text = text.replace("\n", "")
|
|
135
|
+
text = replace_blank(text)
|
|
136
|
+
text = replace_corner_mark(text)
|
|
137
|
+
text = text.replace(".", "。")
|
|
138
|
+
text = text.replace(" - ", ",")
|
|
139
|
+
text = remove_bracket(text)
|
|
140
|
+
text = re.sub(r'[,,、]+$', '。', text)
|
|
141
|
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
|
142
|
+
token_min_n=60, merge_len=20, comma_split=False))
|
|
143
|
+
else:
|
|
144
|
+
text = self.en_tn_model.normalize(text)
|
|
145
|
+
text = spell_out_number(text, self.inflect_parser)
|
|
146
|
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
|
147
|
+
token_min_n=60, merge_len=20, comma_split=False))
|
|
148
|
+
texts = [i for i in texts if not is_only_punctuation(i)]
|
|
149
|
+
return texts if split is True else text
|
|
150
|
+
|
|
151
|
+
def frontend_sft(self, tts_text, spk_id):
|
|
152
|
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
|
153
|
+
embedding = self.spk2info[spk_id]['embedding']
|
|
154
|
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
|
155
|
+
return model_input
|
|
156
|
+
|
|
157
|
+
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
|
158
|
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
|
159
|
+
if zero_shot_spk_id == '':
|
|
160
|
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
|
161
|
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
|
162
|
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
|
163
|
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
|
164
|
+
if resample_rate == 24000:
|
|
165
|
+
# cosyvoice2, force speech_feat % speech_token = 2
|
|
166
|
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
|
167
|
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
|
168
|
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
|
169
|
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
|
170
|
+
model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
|
171
|
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
|
172
|
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
|
173
|
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
|
174
|
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
|
175
|
+
else:
|
|
176
|
+
model_input = self.spk2info[zero_shot_spk_id]
|
|
177
|
+
model_input['text'] = tts_text_token
|
|
178
|
+
model_input['text_len'] = tts_text_token_len
|
|
179
|
+
return model_input
|
|
180
|
+
|
|
181
|
+
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
|
182
|
+
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
|
|
183
|
+
# in cross lingual mode, we remove prompt in llm
|
|
184
|
+
del model_input['prompt_text']
|
|
185
|
+
del model_input['prompt_text_len']
|
|
186
|
+
del model_input['llm_prompt_speech_token']
|
|
187
|
+
del model_input['llm_prompt_speech_token_len']
|
|
188
|
+
return model_input
|
|
189
|
+
|
|
190
|
+
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
|
191
|
+
model_input = self.frontend_sft(tts_text, spk_id)
|
|
192
|
+
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
|
193
|
+
del model_input['llm_embedding']
|
|
194
|
+
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
|
195
|
+
model_input['prompt_text'] = instruct_text_token
|
|
196
|
+
model_input['prompt_text_len'] = instruct_text_token_len
|
|
197
|
+
return model_input
|
|
198
|
+
|
|
199
|
+
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
|
200
|
+
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id)
|
|
201
|
+
del model_input['llm_prompt_speech_token']
|
|
202
|
+
del model_input['llm_prompt_speech_token_len']
|
|
203
|
+
return model_input
|
|
204
|
+
|
|
205
|
+
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
|
|
206
|
+
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
|
207
|
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
|
208
|
+
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
|
209
|
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
|
210
|
+
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
|
211
|
+
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
|
212
|
+
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
|
213
|
+
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
|
214
|
+
'flow_embedding': embedding}
|
|
215
|
+
return model_input
|
|
216
|
+
|
|
217
|
+
def frontend_token2wav(self, speech_token, speech_token_len, prompt_speech_16k, resample_rate):
|
|
218
|
+
# hacked by xbk May 20th, 2025
|
|
219
|
+
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
|
220
|
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
|
221
|
+
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
|
222
|
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
|
223
|
+
|
|
224
|
+
# source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
|
225
|
+
# from IPython import embed; embed()
|
|
226
|
+
# In [2]: source_speech_token.shape
|
|
227
|
+
# Out[2]: torch.Size([1, 122])
|
|
228
|
+
# In [4]: source_speech_token_len
|
|
229
|
+
# Out[4]: tensor([122], device='cuda:0', dtype=torch.int32)
|
|
230
|
+
|
|
231
|
+
source_speech_token = speech_token
|
|
232
|
+
source_speech_token_len = speech_token_len
|
|
233
|
+
|
|
234
|
+
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
|
235
|
+
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
|
236
|
+
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
|
237
|
+
'flow_embedding': embedding}
|
|
238
|
+
return model_input
|