minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
cosyvoice/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CosyVoice: Text-to-Speech with Large Language Model
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
__version__ = "0.1.0"
|
|
6
|
+
|
|
7
|
+
# Lazy import to avoid requiring all dependencies at package import time
|
|
8
|
+
def __getattr__(name):
|
|
9
|
+
if name in ('CosyVoice', 'CosyVoice2'):
|
|
10
|
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
|
11
|
+
if name == 'CosyVoice':
|
|
12
|
+
return CosyVoice
|
|
13
|
+
elif name == 'CosyVoice2':
|
|
14
|
+
return CosyVoice2
|
|
15
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
16
|
+
|
|
17
|
+
__all__ = ['CosyVoice', 'CosyVoice2']
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
|
|
2
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import argparse
|
|
18
|
+
import glob
|
|
19
|
+
|
|
20
|
+
import yaml
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_args():
|
|
25
|
+
parser = argparse.ArgumentParser(description='average model')
|
|
26
|
+
parser.add_argument('--dst_model', required=True, help='averaged model')
|
|
27
|
+
parser.add_argument('--src_path',
|
|
28
|
+
required=True,
|
|
29
|
+
help='src model path for average')
|
|
30
|
+
parser.add_argument('--val_best',
|
|
31
|
+
action="store_true",
|
|
32
|
+
help='averaged model')
|
|
33
|
+
parser.add_argument('--num',
|
|
34
|
+
default=5,
|
|
35
|
+
type=int,
|
|
36
|
+
help='nums for averaged model')
|
|
37
|
+
|
|
38
|
+
args = parser.parse_args()
|
|
39
|
+
print(args)
|
|
40
|
+
return args
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def main():
|
|
44
|
+
args = get_args()
|
|
45
|
+
val_scores = []
|
|
46
|
+
if args.val_best:
|
|
47
|
+
yamls = glob.glob('{}/*.yaml'.format(args.src_path))
|
|
48
|
+
yamls = [
|
|
49
|
+
f for f in yamls
|
|
50
|
+
if not (os.path.basename(f).startswith('train')
|
|
51
|
+
or os.path.basename(f).startswith('init'))
|
|
52
|
+
]
|
|
53
|
+
for y in yamls:
|
|
54
|
+
with open(y, 'r') as f:
|
|
55
|
+
dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
|
|
56
|
+
loss = float(dic_yaml['loss_dict']['loss'])
|
|
57
|
+
epoch = int(dic_yaml['epoch'])
|
|
58
|
+
step = int(dic_yaml['step'])
|
|
59
|
+
tag = dic_yaml['tag']
|
|
60
|
+
val_scores += [[epoch, step, loss, tag]]
|
|
61
|
+
sorted_val_scores = sorted(val_scores,
|
|
62
|
+
key=lambda x: x[2],
|
|
63
|
+
reverse=False)
|
|
64
|
+
print("best val (epoch, step, loss, tag) = " +
|
|
65
|
+
str(sorted_val_scores[:args.num]))
|
|
66
|
+
path_list = [
|
|
67
|
+
args.src_path + '/epoch_{}_whole.pt'.format(score[0])
|
|
68
|
+
for score in sorted_val_scores[:args.num]
|
|
69
|
+
]
|
|
70
|
+
print(path_list)
|
|
71
|
+
avg = {}
|
|
72
|
+
num = args.num
|
|
73
|
+
assert num == len(path_list)
|
|
74
|
+
for path in path_list:
|
|
75
|
+
print('Processing {}'.format(path))
|
|
76
|
+
states = torch.load(path, map_location=torch.device('cpu'))
|
|
77
|
+
for k in states.keys():
|
|
78
|
+
if k not in ['step', 'epoch']:
|
|
79
|
+
if k not in avg.keys():
|
|
80
|
+
avg[k] = states[k].clone()
|
|
81
|
+
else:
|
|
82
|
+
avg[k] += states[k]
|
|
83
|
+
# average
|
|
84
|
+
for k in avg.keys():
|
|
85
|
+
if avg[k] is not None:
|
|
86
|
+
# pytorch 1.6 use true_divide instead of /=
|
|
87
|
+
avg[k] = torch.true_divide(avg[k], num)
|
|
88
|
+
print('Saving to {}'.format(args.dst_model))
|
|
89
|
+
torch.save(avg, args.dst_model)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
if __name__ == '__main__':
|
|
93
|
+
main()
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import print_function
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import logging
|
|
19
|
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
20
|
+
import os
|
|
21
|
+
import sys
|
|
22
|
+
import torch
|
|
23
|
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
24
|
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
|
25
|
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
|
26
|
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
|
27
|
+
from cosyvoice.utils.file_utils import logging
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_args():
|
|
31
|
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
|
32
|
+
parser.add_argument('--model_dir',
|
|
33
|
+
type=str,
|
|
34
|
+
default='pretrained_models/CosyVoice-300M',
|
|
35
|
+
help='local path')
|
|
36
|
+
args = parser.parse_args()
|
|
37
|
+
print(args)
|
|
38
|
+
return args
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_optimized_script(model, preserved_attrs=[]):
|
|
42
|
+
script = torch.jit.script(model)
|
|
43
|
+
if preserved_attrs != []:
|
|
44
|
+
script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
|
|
45
|
+
else:
|
|
46
|
+
script = torch.jit.freeze(script)
|
|
47
|
+
script = torch.jit.optimize_for_inference(script)
|
|
48
|
+
return script
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def main():
|
|
52
|
+
args = get_args()
|
|
53
|
+
logging.basicConfig(level=logging.DEBUG,
|
|
54
|
+
format='%(asctime)s %(levelname)s %(message)s')
|
|
55
|
+
|
|
56
|
+
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
|
|
57
|
+
torch._C._jit_set_profiling_mode(False)
|
|
58
|
+
torch._C._jit_set_profiling_executor(False)
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
model = CosyVoice(args.model_dir)
|
|
62
|
+
except Exception:
|
|
63
|
+
try:
|
|
64
|
+
model = CosyVoice2(args.model_dir)
|
|
65
|
+
except Exception:
|
|
66
|
+
raise TypeError('no valid model_type!')
|
|
67
|
+
|
|
68
|
+
if not isinstance(model, CosyVoice2):
|
|
69
|
+
# 1. export llm text_encoder
|
|
70
|
+
llm_text_encoder = model.model.llm.text_encoder
|
|
71
|
+
script = get_optimized_script(llm_text_encoder)
|
|
72
|
+
script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
|
|
73
|
+
script = get_optimized_script(llm_text_encoder.half())
|
|
74
|
+
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
|
75
|
+
logging.info('successfully export llm_text_encoder')
|
|
76
|
+
|
|
77
|
+
# 2. export llm llm
|
|
78
|
+
llm_llm = model.model.llm.llm
|
|
79
|
+
script = get_optimized_script(llm_llm, ['forward_chunk'])
|
|
80
|
+
script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
|
|
81
|
+
script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
|
|
82
|
+
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
|
83
|
+
logging.info('successfully export llm_llm')
|
|
84
|
+
|
|
85
|
+
# 3. export flow encoder
|
|
86
|
+
flow_encoder = model.model.flow.encoder
|
|
87
|
+
script = get_optimized_script(flow_encoder)
|
|
88
|
+
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
|
89
|
+
script = get_optimized_script(flow_encoder.half())
|
|
90
|
+
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
|
91
|
+
logging.info('successfully export flow_encoder')
|
|
92
|
+
else:
|
|
93
|
+
# 3. export flow encoder
|
|
94
|
+
flow_encoder = model.model.flow.encoder
|
|
95
|
+
script = get_optimized_script(flow_encoder)
|
|
96
|
+
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
|
97
|
+
script = get_optimized_script(flow_encoder.half())
|
|
98
|
+
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
|
99
|
+
logging.info('successfully export flow_encoder')
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
if __name__ == '__main__':
|
|
103
|
+
main()
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
|
|
2
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import print_function
|
|
17
|
+
|
|
18
|
+
import argparse
|
|
19
|
+
import logging
|
|
20
|
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
21
|
+
import os
|
|
22
|
+
import sys
|
|
23
|
+
import onnxruntime
|
|
24
|
+
import random
|
|
25
|
+
import torch
|
|
26
|
+
from tqdm import tqdm
|
|
27
|
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
28
|
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
|
29
|
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
|
30
|
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
|
31
|
+
from cosyvoice.utils.file_utils import logging
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
|
35
|
+
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
|
36
|
+
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
|
|
37
|
+
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
|
38
|
+
t = torch.rand((batch_size), dtype=torch.float32, device=device)
|
|
39
|
+
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
|
|
40
|
+
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
|
41
|
+
return x, mask, mu, t, spks, cond
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_args():
|
|
45
|
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
|
46
|
+
parser.add_argument('--model_dir',
|
|
47
|
+
type=str,
|
|
48
|
+
default='pretrained_models/CosyVoice-300M',
|
|
49
|
+
help='local path')
|
|
50
|
+
args = parser.parse_args()
|
|
51
|
+
print(args)
|
|
52
|
+
return args
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@torch.no_grad()
|
|
56
|
+
def main():
|
|
57
|
+
args = get_args()
|
|
58
|
+
logging.basicConfig(level=logging.DEBUG,
|
|
59
|
+
format='%(asctime)s %(levelname)s %(message)s')
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
model = CosyVoice(args.model_dir)
|
|
63
|
+
except Exception:
|
|
64
|
+
try:
|
|
65
|
+
model = CosyVoice2(args.model_dir)
|
|
66
|
+
except Exception:
|
|
67
|
+
raise TypeError('no valid model_type!')
|
|
68
|
+
|
|
69
|
+
# 1. export flow decoder estimator
|
|
70
|
+
estimator = model.model.flow.decoder.estimator
|
|
71
|
+
estimator.eval()
|
|
72
|
+
|
|
73
|
+
device = model.model.device
|
|
74
|
+
batch_size, seq_len = 2, 256
|
|
75
|
+
out_channels = model.model.flow.decoder.estimator.out_channels
|
|
76
|
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
|
77
|
+
torch.onnx.export(
|
|
78
|
+
estimator,
|
|
79
|
+
(x, mask, mu, t, spks, cond),
|
|
80
|
+
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
81
|
+
export_params=True,
|
|
82
|
+
opset_version=18,
|
|
83
|
+
do_constant_folding=True,
|
|
84
|
+
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
|
85
|
+
output_names=['estimator_out'],
|
|
86
|
+
dynamic_axes={
|
|
87
|
+
'x': {2: 'seq_len'},
|
|
88
|
+
'mask': {2: 'seq_len'},
|
|
89
|
+
'mu': {2: 'seq_len'},
|
|
90
|
+
'cond': {2: 'seq_len'},
|
|
91
|
+
'estimator_out': {2: 'seq_len'},
|
|
92
|
+
}
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# 2. test computation consistency
|
|
96
|
+
option = onnxruntime.SessionOptions()
|
|
97
|
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
98
|
+
option.intra_op_num_threads = 1
|
|
99
|
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
|
100
|
+
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
101
|
+
sess_options=option, providers=providers)
|
|
102
|
+
|
|
103
|
+
for _ in tqdm(range(10)):
|
|
104
|
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
|
105
|
+
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
|
106
|
+
ort_inputs = {
|
|
107
|
+
'x': x.cpu().numpy(),
|
|
108
|
+
'mask': mask.cpu().numpy(),
|
|
109
|
+
'mu': mu.cpu().numpy(),
|
|
110
|
+
't': t.cpu().numpy(),
|
|
111
|
+
'spks': spks.cpu().numpy(),
|
|
112
|
+
'cond': cond.cpu().numpy()
|
|
113
|
+
}
|
|
114
|
+
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
|
115
|
+
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
|
116
|
+
logging.info('successfully export estimator')
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
if __name__ == "__main__":
|
|
120
|
+
main()
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import print_function
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import logging
|
|
19
|
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
20
|
+
import os
|
|
21
|
+
import torch
|
|
22
|
+
from torch.utils.data import DataLoader
|
|
23
|
+
import torchaudio
|
|
24
|
+
from hyperpyyaml import load_hyperpyyaml
|
|
25
|
+
from tqdm import tqdm
|
|
26
|
+
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
|
27
|
+
from cosyvoice.dataset.dataset import Dataset
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_args():
|
|
31
|
+
parser = argparse.ArgumentParser(description='inference with your model')
|
|
32
|
+
parser.add_argument('--config', required=True, help='config file')
|
|
33
|
+
parser.add_argument('--prompt_data', required=True, help='prompt data file')
|
|
34
|
+
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
|
|
35
|
+
parser.add_argument('--tts_text', required=True, help='tts input file')
|
|
36
|
+
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
|
|
37
|
+
parser.add_argument('--llm_model', required=True, help='llm model file')
|
|
38
|
+
parser.add_argument('--flow_model', required=True, help='flow model file')
|
|
39
|
+
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
|
|
40
|
+
parser.add_argument('--gpu',
|
|
41
|
+
type=int,
|
|
42
|
+
default=-1,
|
|
43
|
+
help='gpu id for this rank, -1 for cpu')
|
|
44
|
+
parser.add_argument('--mode',
|
|
45
|
+
default='sft',
|
|
46
|
+
choices=['sft', 'zero_shot'],
|
|
47
|
+
help='inference mode')
|
|
48
|
+
parser.add_argument('--result_dir', required=True, help='asr result file')
|
|
49
|
+
args = parser.parse_args()
|
|
50
|
+
print(args)
|
|
51
|
+
return args
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def main():
|
|
55
|
+
args = get_args()
|
|
56
|
+
logging.basicConfig(level=logging.DEBUG,
|
|
57
|
+
format='%(asctime)s %(levelname)s %(message)s')
|
|
58
|
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
|
59
|
+
|
|
60
|
+
# Init cosyvoice models from configs
|
|
61
|
+
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
|
62
|
+
device = torch.device('cuda' if use_cuda else 'cpu')
|
|
63
|
+
try:
|
|
64
|
+
with open(args.config, 'r') as f:
|
|
65
|
+
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
|
|
66
|
+
model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
|
|
67
|
+
except Exception:
|
|
68
|
+
try:
|
|
69
|
+
with open(args.config, 'r') as f:
|
|
70
|
+
configs = load_hyperpyyaml(f)
|
|
71
|
+
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
|
72
|
+
except Exception:
|
|
73
|
+
raise TypeError('no valid model_type!')
|
|
74
|
+
|
|
75
|
+
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
|
76
|
+
|
|
77
|
+
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
|
78
|
+
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
|
79
|
+
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
|
80
|
+
|
|
81
|
+
sample_rate = configs['sample_rate']
|
|
82
|
+
del configs
|
|
83
|
+
os.makedirs(args.result_dir, exist_ok=True)
|
|
84
|
+
fn = os.path.join(args.result_dir, 'wav.scp')
|
|
85
|
+
f = open(fn, 'w')
|
|
86
|
+
with torch.no_grad():
|
|
87
|
+
for _, batch in tqdm(enumerate(test_data_loader)):
|
|
88
|
+
utts = batch["utts"]
|
|
89
|
+
assert len(utts) == 1, "inference mode only support batchsize 1"
|
|
90
|
+
text_token = batch["text_token"].to(device)
|
|
91
|
+
text_token_len = batch["text_token_len"].to(device)
|
|
92
|
+
tts_index = batch["tts_index"]
|
|
93
|
+
tts_text_token = batch["tts_text_token"].to(device)
|
|
94
|
+
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
|
95
|
+
speech_token = batch["speech_token"].to(device)
|
|
96
|
+
speech_token_len = batch["speech_token_len"].to(device)
|
|
97
|
+
speech_feat = batch["speech_feat"].to(device)
|
|
98
|
+
speech_feat_len = batch["speech_feat_len"].to(device)
|
|
99
|
+
utt_embedding = batch["utt_embedding"].to(device)
|
|
100
|
+
spk_embedding = batch["spk_embedding"].to(device)
|
|
101
|
+
if args.mode == 'sft':
|
|
102
|
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
|
103
|
+
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
|
|
104
|
+
else:
|
|
105
|
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
|
106
|
+
'prompt_text': text_token, 'prompt_text_len': text_token_len,
|
|
107
|
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
|
108
|
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
|
109
|
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
|
110
|
+
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
|
111
|
+
tts_speeches = []
|
|
112
|
+
for model_output in model.tts(**model_input):
|
|
113
|
+
tts_speeches.append(model_output['tts_speech'])
|
|
114
|
+
tts_speeches = torch.concat(tts_speeches, dim=1)
|
|
115
|
+
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
|
116
|
+
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
|
117
|
+
torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
|
|
118
|
+
f.write('{} {}\n'.format(tts_key, tts_fn))
|
|
119
|
+
f.flush()
|
|
120
|
+
f.close()
|
|
121
|
+
logging.info('Result wav.scp saved in {}'.format(fn))
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
if __name__ == '__main__':
|
|
125
|
+
logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
|
|
126
|
+
main()
|
cosyvoice/bin/train.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import print_function
|
|
16
|
+
import argparse
|
|
17
|
+
import datetime
|
|
18
|
+
import logging
|
|
19
|
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
20
|
+
from copy import deepcopy
|
|
21
|
+
import os
|
|
22
|
+
import torch
|
|
23
|
+
import torch.distributed as dist
|
|
24
|
+
import deepspeed
|
|
25
|
+
|
|
26
|
+
from hyperpyyaml import load_hyperpyyaml
|
|
27
|
+
|
|
28
|
+
from torch.distributed.elastic.multiprocessing.errors import record
|
|
29
|
+
|
|
30
|
+
from cosyvoice.utils.losses import DPOLoss
|
|
31
|
+
from cosyvoice.utils.executor import Executor
|
|
32
|
+
from cosyvoice.utils.train_utils import (
|
|
33
|
+
init_distributed,
|
|
34
|
+
init_dataset_and_dataloader,
|
|
35
|
+
init_optimizer_and_scheduler,
|
|
36
|
+
init_summarywriter, save_model,
|
|
37
|
+
wrap_cuda_model, check_modify_and_save_config)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_args():
|
|
41
|
+
parser = argparse.ArgumentParser(description='training your network')
|
|
42
|
+
parser.add_argument('--train_engine',
|
|
43
|
+
default='torch_ddp',
|
|
44
|
+
choices=['torch_ddp', 'deepspeed'],
|
|
45
|
+
help='Engine for paralleled training')
|
|
46
|
+
parser.add_argument('--model', required=True, help='model which will be trained')
|
|
47
|
+
parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
|
|
48
|
+
parser.add_argument('--config', required=True, help='config file')
|
|
49
|
+
parser.add_argument('--train_data', required=True, help='train data file')
|
|
50
|
+
parser.add_argument('--cv_data', required=True, help='cv data file')
|
|
51
|
+
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
|
|
52
|
+
parser.add_argument('--checkpoint', help='checkpoint model')
|
|
53
|
+
parser.add_argument('--model_dir', required=True, help='save model dir')
|
|
54
|
+
parser.add_argument('--tensorboard_dir',
|
|
55
|
+
default='tensorboard',
|
|
56
|
+
help='tensorboard log dir')
|
|
57
|
+
parser.add_argument('--ddp.dist_backend',
|
|
58
|
+
dest='dist_backend',
|
|
59
|
+
default='nccl',
|
|
60
|
+
choices=['nccl', 'gloo'],
|
|
61
|
+
help='distributed backend')
|
|
62
|
+
parser.add_argument('--num_workers',
|
|
63
|
+
default=0,
|
|
64
|
+
type=int,
|
|
65
|
+
help='num of subprocess workers for reading')
|
|
66
|
+
parser.add_argument('--prefetch',
|
|
67
|
+
default=100,
|
|
68
|
+
type=int,
|
|
69
|
+
help='prefetch number')
|
|
70
|
+
parser.add_argument('--pin_memory',
|
|
71
|
+
action='store_true',
|
|
72
|
+
default=False,
|
|
73
|
+
help='Use pinned memory buffers used for reading')
|
|
74
|
+
parser.add_argument('--use_amp',
|
|
75
|
+
action='store_true',
|
|
76
|
+
default=False,
|
|
77
|
+
help='Use automatic mixed precision training')
|
|
78
|
+
parser.add_argument('--dpo',
|
|
79
|
+
action='store_true',
|
|
80
|
+
default=False,
|
|
81
|
+
help='Use Direct Preference Optimization')
|
|
82
|
+
parser.add_argument('--deepspeed.save_states',
|
|
83
|
+
dest='save_states',
|
|
84
|
+
default='model_only',
|
|
85
|
+
choices=['model_only', 'model+optimizer'],
|
|
86
|
+
help='save model/optimizer states')
|
|
87
|
+
parser.add_argument('--timeout',
|
|
88
|
+
default=60,
|
|
89
|
+
type=int,
|
|
90
|
+
help='timeout (in seconds) of cosyvoice_join.')
|
|
91
|
+
parser = deepspeed.add_config_arguments(parser)
|
|
92
|
+
args = parser.parse_args()
|
|
93
|
+
return args
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@record
|
|
97
|
+
def main():
|
|
98
|
+
args = get_args()
|
|
99
|
+
logging.basicConfig(level=logging.DEBUG,
|
|
100
|
+
format='%(asctime)s %(levelname)s %(message)s')
|
|
101
|
+
# gan train has some special initialization logic
|
|
102
|
+
gan = True if args.model == 'hifigan' else False
|
|
103
|
+
|
|
104
|
+
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
|
|
105
|
+
if gan is True:
|
|
106
|
+
override_dict.pop('hift')
|
|
107
|
+
try:
|
|
108
|
+
with open(args.config, 'r') as f:
|
|
109
|
+
configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
|
|
110
|
+
except Exception:
|
|
111
|
+
with open(args.config, 'r') as f:
|
|
112
|
+
configs = load_hyperpyyaml(f, overrides=override_dict)
|
|
113
|
+
if gan is True:
|
|
114
|
+
configs['train_conf'] = configs['train_conf_gan']
|
|
115
|
+
configs['train_conf'].update(vars(args))
|
|
116
|
+
|
|
117
|
+
# Init env for ddp
|
|
118
|
+
init_distributed(args)
|
|
119
|
+
|
|
120
|
+
# Get dataset & dataloader
|
|
121
|
+
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
|
122
|
+
init_dataset_and_dataloader(args, configs, gan, args.dpo)
|
|
123
|
+
|
|
124
|
+
# Do some sanity checks and save config to arsg.model_dir
|
|
125
|
+
configs = check_modify_and_save_config(args, configs)
|
|
126
|
+
|
|
127
|
+
# Tensorboard summary
|
|
128
|
+
writer = init_summarywriter(args)
|
|
129
|
+
|
|
130
|
+
# load checkpoint
|
|
131
|
+
if args.dpo is True:
|
|
132
|
+
configs[args.model].forward = configs[args.model].forward_dpo
|
|
133
|
+
model = configs[args.model]
|
|
134
|
+
start_step, start_epoch = 0, -1
|
|
135
|
+
if args.checkpoint is not None:
|
|
136
|
+
if os.path.exists(args.checkpoint):
|
|
137
|
+
state_dict = torch.load(args.checkpoint, map_location='cpu')
|
|
138
|
+
model.load_state_dict(state_dict, strict=False)
|
|
139
|
+
if 'step' in state_dict:
|
|
140
|
+
start_step = state_dict['step']
|
|
141
|
+
if 'epoch' in state_dict:
|
|
142
|
+
start_epoch = state_dict['epoch']
|
|
143
|
+
else:
|
|
144
|
+
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
|
|
145
|
+
|
|
146
|
+
# Dispatch model from cpu to gpu
|
|
147
|
+
model = wrap_cuda_model(args, model)
|
|
148
|
+
|
|
149
|
+
# Get optimizer & scheduler
|
|
150
|
+
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
|
|
151
|
+
scheduler.set_step(start_step)
|
|
152
|
+
if scheduler_d is not None:
|
|
153
|
+
scheduler_d.set_step(start_step)
|
|
154
|
+
|
|
155
|
+
# Save init checkpoints
|
|
156
|
+
info_dict = deepcopy(configs['train_conf'])
|
|
157
|
+
info_dict['step'] = start_step
|
|
158
|
+
info_dict['epoch'] = start_epoch
|
|
159
|
+
save_model(model, 'init', info_dict)
|
|
160
|
+
|
|
161
|
+
# DPO related
|
|
162
|
+
if args.dpo is True:
|
|
163
|
+
ref_model = deepcopy(configs[args.model])
|
|
164
|
+
state_dict = torch.load(args.ref_model, map_location='cpu')
|
|
165
|
+
ref_model.load_state_dict(state_dict, strict=False)
|
|
166
|
+
dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
|
|
167
|
+
# NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
|
|
168
|
+
ref_model = wrap_cuda_model(args, ref_model)
|
|
169
|
+
else:
|
|
170
|
+
ref_model, dpo_loss = None, None
|
|
171
|
+
|
|
172
|
+
# Get executor
|
|
173
|
+
executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
|
|
174
|
+
executor.step = start_step
|
|
175
|
+
|
|
176
|
+
# Init scaler, used for pytorch amp mixed precision training
|
|
177
|
+
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
|
|
178
|
+
print('start step {} start epoch {}'.format(start_step, start_epoch))
|
|
179
|
+
|
|
180
|
+
# Start training loop
|
|
181
|
+
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
|
|
182
|
+
executor.epoch = epoch
|
|
183
|
+
train_dataset.set_epoch(epoch)
|
|
184
|
+
dist.barrier()
|
|
185
|
+
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
|
|
186
|
+
if gan is True:
|
|
187
|
+
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
|
188
|
+
writer, info_dict, scaler, group_join)
|
|
189
|
+
else:
|
|
190
|
+
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
|
|
191
|
+
dist.destroy_process_group(group_join)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
if __name__ == '__main__':
|
|
195
|
+
main()
|
|
File without changes
|