xinference 0.13.2__py3-none-any.whl → 0.13.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -1
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +26 -4
- xinference/client/restful/restful_client.py +16 -1
- xinference/core/chat_interface.py +2 -2
- xinference/core/model.py +8 -3
- xinference/core/scheduler.py +4 -4
- xinference/model/audio/core.py +5 -2
- xinference/model/audio/cosyvoice.py +136 -0
- xinference/model/audio/model_spec.json +24 -0
- xinference/model/audio/model_spec_modelscope.json +27 -0
- xinference/model/flexible/launchers/__init__.py +1 -0
- xinference/model/flexible/launchers/image_process_launcher.py +70 -0
- xinference/model/image/model_spec.json +7 -0
- xinference/model/image/stable_diffusion/core.py +6 -1
- xinference/model/llm/llm_family.json +802 -82
- xinference/model/llm/llm_family_csghub.json +39 -0
- xinference/model/llm/llm_family_modelscope.json +295 -47
- xinference/model/llm/pytorch/chatglm.py +243 -5
- xinference/model/llm/pytorch/cogvlm2.py +1 -1
- xinference/model/llm/utils.py +78 -1
- xinference/model/llm/vllm/core.py +8 -0
- xinference/thirdparty/cosyvoice/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
- xinference/thirdparty/cosyvoice/bin/train.py +136 -0
- xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
- xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
- xinference/thirdparty/cosyvoice/cli/model.py +60 -0
- xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
- xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
- xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
- xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
- xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
- xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
- xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
- xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
- xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
- xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
- xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
- xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
- xinference/thirdparty/cosyvoice/utils/common.py +103 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
- xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/METADATA +16 -8
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/RECORD +76 -32
- xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
- /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import print_function
|
|
16
|
+
import argparse
|
|
17
|
+
import datetime
|
|
18
|
+
import logging
|
|
19
|
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
20
|
+
from copy import deepcopy
|
|
21
|
+
import torch
|
|
22
|
+
import torch.distributed as dist
|
|
23
|
+
import deepspeed
|
|
24
|
+
|
|
25
|
+
from hyperpyyaml import load_hyperpyyaml
|
|
26
|
+
|
|
27
|
+
from torch.distributed.elastic.multiprocessing.errors import record
|
|
28
|
+
|
|
29
|
+
from cosyvoice.utils.executor import Executor
|
|
30
|
+
from cosyvoice.utils.train_utils import (
|
|
31
|
+
init_distributed,
|
|
32
|
+
init_dataset_and_dataloader,
|
|
33
|
+
init_optimizer_and_scheduler,
|
|
34
|
+
init_summarywriter, save_model,
|
|
35
|
+
wrap_cuda_model, check_modify_and_save_config)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_args():
|
|
39
|
+
parser = argparse.ArgumentParser(description='training your network')
|
|
40
|
+
parser.add_argument('--train_engine',
|
|
41
|
+
default='torch_ddp',
|
|
42
|
+
choices=['torch_ddp', 'deepspeed'],
|
|
43
|
+
help='Engine for paralleled training')
|
|
44
|
+
parser.add_argument('--model', required=True, help='model which will be trained')
|
|
45
|
+
parser.add_argument('--config', required=True, help='config file')
|
|
46
|
+
parser.add_argument('--train_data', required=True, help='train data file')
|
|
47
|
+
parser.add_argument('--cv_data', required=True, help='cv data file')
|
|
48
|
+
parser.add_argument('--checkpoint', help='checkpoint model')
|
|
49
|
+
parser.add_argument('--model_dir', required=True, help='save model dir')
|
|
50
|
+
parser.add_argument('--tensorboard_dir',
|
|
51
|
+
default='tensorboard',
|
|
52
|
+
help='tensorboard log dir')
|
|
53
|
+
parser.add_argument('--ddp.dist_backend',
|
|
54
|
+
dest='dist_backend',
|
|
55
|
+
default='nccl',
|
|
56
|
+
choices=['nccl', 'gloo'],
|
|
57
|
+
help='distributed backend')
|
|
58
|
+
parser.add_argument('--num_workers',
|
|
59
|
+
default=0,
|
|
60
|
+
type=int,
|
|
61
|
+
help='num of subprocess workers for reading')
|
|
62
|
+
parser.add_argument('--prefetch',
|
|
63
|
+
default=100,
|
|
64
|
+
type=int,
|
|
65
|
+
help='prefetch number')
|
|
66
|
+
parser.add_argument('--pin_memory',
|
|
67
|
+
action='store_true',
|
|
68
|
+
default=False,
|
|
69
|
+
help='Use pinned memory buffers used for reading')
|
|
70
|
+
parser.add_argument('--deepspeed.save_states',
|
|
71
|
+
dest='save_states',
|
|
72
|
+
default='model_only',
|
|
73
|
+
choices=['model_only', 'model+optimizer'],
|
|
74
|
+
help='save model/optimizer states')
|
|
75
|
+
parser.add_argument('--timeout',
|
|
76
|
+
default=30,
|
|
77
|
+
type=int,
|
|
78
|
+
help='timeout (in seconds) of cosyvoice_join.')
|
|
79
|
+
parser = deepspeed.add_config_arguments(parser)
|
|
80
|
+
args = parser.parse_args()
|
|
81
|
+
return args
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@record
|
|
85
|
+
def main():
|
|
86
|
+
args = get_args()
|
|
87
|
+
logging.basicConfig(level=logging.DEBUG,
|
|
88
|
+
format='%(asctime)s %(levelname)s %(message)s')
|
|
89
|
+
|
|
90
|
+
override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
|
|
91
|
+
with open(args.config, 'r') as f:
|
|
92
|
+
configs = load_hyperpyyaml(f, overrides=override_dict)
|
|
93
|
+
configs['train_conf'].update(vars(args))
|
|
94
|
+
|
|
95
|
+
# Init env for ddp
|
|
96
|
+
init_distributed(args)
|
|
97
|
+
|
|
98
|
+
# Get dataset & dataloader
|
|
99
|
+
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
|
100
|
+
init_dataset_and_dataloader(args, configs)
|
|
101
|
+
|
|
102
|
+
# Do some sanity checks and save config to arsg.model_dir
|
|
103
|
+
configs = check_modify_and_save_config(args, configs)
|
|
104
|
+
|
|
105
|
+
# Tensorboard summary
|
|
106
|
+
writer = init_summarywriter(args)
|
|
107
|
+
|
|
108
|
+
# load checkpoint
|
|
109
|
+
model = configs[args.model]
|
|
110
|
+
if args.checkpoint is not None:
|
|
111
|
+
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
|
|
112
|
+
|
|
113
|
+
# Dispatch model from cpu to gpu
|
|
114
|
+
model = wrap_cuda_model(args, model)
|
|
115
|
+
|
|
116
|
+
# Get optimizer & scheduler
|
|
117
|
+
model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
|
|
118
|
+
|
|
119
|
+
# Save init checkpoints
|
|
120
|
+
info_dict = deepcopy(configs['train_conf'])
|
|
121
|
+
save_model(model, 'init', info_dict)
|
|
122
|
+
|
|
123
|
+
# Get executor
|
|
124
|
+
executor = Executor()
|
|
125
|
+
|
|
126
|
+
# Start training loop
|
|
127
|
+
for epoch in range(info_dict['max_epoch']):
|
|
128
|
+
executor.epoch = epoch
|
|
129
|
+
train_dataset.set_epoch(epoch)
|
|
130
|
+
dist.barrier()
|
|
131
|
+
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
|
|
132
|
+
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
|
|
133
|
+
dist.destroy_process_group(group_join)
|
|
134
|
+
|
|
135
|
+
if __name__ == '__main__':
|
|
136
|
+
main()
|
|
File without changes
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import os
|
|
15
|
+
import torch
|
|
16
|
+
from hyperpyyaml import load_hyperpyyaml
|
|
17
|
+
from modelscope import snapshot_download
|
|
18
|
+
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
|
19
|
+
from cosyvoice.cli.model import CosyVoiceModel
|
|
20
|
+
|
|
21
|
+
class CosyVoice:
|
|
22
|
+
|
|
23
|
+
def __init__(self, model_dir):
|
|
24
|
+
instruct = True if '-Instruct' in model_dir else False
|
|
25
|
+
self.model_dir = model_dir
|
|
26
|
+
if not os.path.exists(model_dir):
|
|
27
|
+
model_dir = snapshot_download(model_dir)
|
|
28
|
+
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
|
29
|
+
configs = load_hyperpyyaml(f)
|
|
30
|
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
|
31
|
+
configs['feat_extractor'],
|
|
32
|
+
'{}/campplus.onnx'.format(model_dir),
|
|
33
|
+
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
|
|
34
|
+
'{}/spk2info.pt'.format(model_dir),
|
|
35
|
+
instruct,
|
|
36
|
+
configs['allowed_special'])
|
|
37
|
+
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
|
38
|
+
self.model.load('{}/llm.pt'.format(model_dir),
|
|
39
|
+
'{}/flow.pt'.format(model_dir),
|
|
40
|
+
'{}/hift.pt'.format(model_dir))
|
|
41
|
+
del configs
|
|
42
|
+
|
|
43
|
+
def list_avaliable_spks(self):
|
|
44
|
+
spks = list(self.frontend.spk2info.keys())
|
|
45
|
+
return spks
|
|
46
|
+
|
|
47
|
+
def inference_sft(self, tts_text, spk_id):
|
|
48
|
+
tts_speeches = []
|
|
49
|
+
for i in self.frontend.text_normalize(tts_text, split=True):
|
|
50
|
+
model_input = self.frontend.frontend_sft(i, spk_id)
|
|
51
|
+
model_output = self.model.inference(**model_input)
|
|
52
|
+
tts_speeches.append(model_output['tts_speech'])
|
|
53
|
+
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
|
54
|
+
|
|
55
|
+
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
|
|
56
|
+
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
|
57
|
+
tts_speeches = []
|
|
58
|
+
for i in self.frontend.text_normalize(tts_text, split=True):
|
|
59
|
+
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
|
|
60
|
+
model_output = self.model.inference(**model_input)
|
|
61
|
+
tts_speeches.append(model_output['tts_speech'])
|
|
62
|
+
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
|
63
|
+
|
|
64
|
+
def inference_cross_lingual(self, tts_text, prompt_speech_16k):
|
|
65
|
+
if self.frontend.instruct is True:
|
|
66
|
+
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
|
67
|
+
tts_speeches = []
|
|
68
|
+
for i in self.frontend.text_normalize(tts_text, split=True):
|
|
69
|
+
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
|
|
70
|
+
model_output = self.model.inference(**model_input)
|
|
71
|
+
tts_speeches.append(model_output['tts_speech'])
|
|
72
|
+
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
|
73
|
+
|
|
74
|
+
def inference_instruct(self, tts_text, spk_id, instruct_text):
|
|
75
|
+
if self.frontend.instruct is False:
|
|
76
|
+
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
|
77
|
+
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
|
|
78
|
+
tts_speeches = []
|
|
79
|
+
for i in self.frontend.text_normalize(tts_text, split=True):
|
|
80
|
+
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
|
81
|
+
model_output = self.model.inference(**model_input)
|
|
82
|
+
tts_speeches.append(model_output['tts_speech'])
|
|
83
|
+
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from functools import partial
|
|
15
|
+
import onnxruntime
|
|
16
|
+
import torch
|
|
17
|
+
import numpy as np
|
|
18
|
+
import whisper
|
|
19
|
+
from typing import Callable
|
|
20
|
+
import torchaudio.compliance.kaldi as kaldi
|
|
21
|
+
import torchaudio
|
|
22
|
+
import os
|
|
23
|
+
import re
|
|
24
|
+
import inflect
|
|
25
|
+
try:
|
|
26
|
+
import ttsfrd
|
|
27
|
+
use_ttsfrd = True
|
|
28
|
+
except ImportError:
|
|
29
|
+
print("failed to import ttsfrd, use WeTextProcessing instead")
|
|
30
|
+
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
|
31
|
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
|
32
|
+
use_ttsfrd = False
|
|
33
|
+
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class CosyVoiceFrontEnd:
|
|
37
|
+
|
|
38
|
+
def __init__(self,
|
|
39
|
+
get_tokenizer: Callable,
|
|
40
|
+
feat_extractor: Callable,
|
|
41
|
+
campplus_model: str,
|
|
42
|
+
speech_tokenizer_model: str,
|
|
43
|
+
spk2info: str = '',
|
|
44
|
+
instruct: bool = False,
|
|
45
|
+
allowed_special: str = 'all'):
|
|
46
|
+
self.tokenizer = get_tokenizer()
|
|
47
|
+
self.feat_extractor = feat_extractor
|
|
48
|
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
49
|
+
option = onnxruntime.SessionOptions()
|
|
50
|
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
51
|
+
option.intra_op_num_threads = 1
|
|
52
|
+
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
|
53
|
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if torch.cuda.is_available() else "CPUExecutionProvider"])
|
|
54
|
+
if os.path.exists(spk2info):
|
|
55
|
+
self.spk2info = torch.load(spk2info, map_location=self.device)
|
|
56
|
+
self.instruct = instruct
|
|
57
|
+
self.allowed_special = allowed_special
|
|
58
|
+
self.inflect_parser = inflect.engine()
|
|
59
|
+
self.use_ttsfrd = use_ttsfrd
|
|
60
|
+
if self.use_ttsfrd:
|
|
61
|
+
self.frd = ttsfrd.TtsFrontendEngine()
|
|
62
|
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
63
|
+
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, 'failed to initialize ttsfrd resource'
|
|
64
|
+
self.frd.set_lang_type('pinyin')
|
|
65
|
+
self.frd.enable_pinyin_mix(True)
|
|
66
|
+
self.frd.set_breakmodel_index(1)
|
|
67
|
+
else:
|
|
68
|
+
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
|
|
69
|
+
self.en_tn_model = EnNormalizer()
|
|
70
|
+
|
|
71
|
+
def _extract_text_token(self, text):
|
|
72
|
+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
|
73
|
+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
|
74
|
+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
|
75
|
+
return text_token, text_token_len
|
|
76
|
+
|
|
77
|
+
def _extract_speech_token(self, speech):
|
|
78
|
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
|
79
|
+
speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
|
|
80
|
+
self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
|
81
|
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
|
82
|
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
|
83
|
+
return speech_token, speech_token_len
|
|
84
|
+
|
|
85
|
+
def _extract_spk_embedding(self, speech):
|
|
86
|
+
feat = kaldi.fbank(speech,
|
|
87
|
+
num_mel_bins=80,
|
|
88
|
+
dither=0,
|
|
89
|
+
sample_frequency=16000)
|
|
90
|
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
|
91
|
+
embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
|
92
|
+
embedding = torch.tensor([embedding]).to(self.device)
|
|
93
|
+
return embedding
|
|
94
|
+
|
|
95
|
+
def _extract_speech_feat(self, speech):
|
|
96
|
+
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
|
97
|
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
|
98
|
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
|
99
|
+
return speech_feat, speech_feat_len
|
|
100
|
+
|
|
101
|
+
def text_normalize(self, text, split=True):
|
|
102
|
+
text = text.strip()
|
|
103
|
+
if contains_chinese(text):
|
|
104
|
+
if self.use_ttsfrd:
|
|
105
|
+
text = self.frd.get_frd_extra_info(text, 'input')
|
|
106
|
+
else:
|
|
107
|
+
text = self.zh_tn_model.normalize(text)
|
|
108
|
+
text = text.replace("\n", "")
|
|
109
|
+
text = replace_blank(text)
|
|
110
|
+
text = replace_corner_mark(text)
|
|
111
|
+
text = text.replace(".", "、")
|
|
112
|
+
text = text.replace(" - ", ",")
|
|
113
|
+
text = remove_bracket(text)
|
|
114
|
+
text = re.sub(r'[,,]+$', '。', text)
|
|
115
|
+
texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
|
116
|
+
token_min_n=60, merge_len=20,
|
|
117
|
+
comma_split=False)]
|
|
118
|
+
else:
|
|
119
|
+
if self.use_ttsfrd:
|
|
120
|
+
text = self.frd.get_frd_extra_info(text, 'input')
|
|
121
|
+
else:
|
|
122
|
+
text = self.en_tn_model.normalize(text)
|
|
123
|
+
text = spell_out_number(text, self.inflect_parser)
|
|
124
|
+
texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
|
125
|
+
token_min_n=60, merge_len=20,
|
|
126
|
+
comma_split=False)]
|
|
127
|
+
if split is False:
|
|
128
|
+
return text
|
|
129
|
+
return texts
|
|
130
|
+
|
|
131
|
+
def frontend_sft(self, tts_text, spk_id):
|
|
132
|
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
|
133
|
+
embedding = self.spk2info[spk_id]['embedding']
|
|
134
|
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
|
135
|
+
return model_input
|
|
136
|
+
|
|
137
|
+
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
|
|
138
|
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
|
139
|
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
|
140
|
+
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
|
|
141
|
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
|
|
142
|
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
|
143
|
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
|
144
|
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
|
145
|
+
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
|
146
|
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
|
147
|
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
|
148
|
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
|
149
|
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
|
150
|
+
return model_input
|
|
151
|
+
|
|
152
|
+
def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
|
|
153
|
+
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
|
|
154
|
+
# in cross lingual mode, we remove prompt in llm
|
|
155
|
+
del model_input['prompt_text']
|
|
156
|
+
del model_input['prompt_text_len']
|
|
157
|
+
del model_input['llm_prompt_speech_token']
|
|
158
|
+
del model_input['llm_prompt_speech_token_len']
|
|
159
|
+
return model_input
|
|
160
|
+
|
|
161
|
+
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
|
162
|
+
model_input = self.frontend_sft(tts_text, spk_id)
|
|
163
|
+
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
|
164
|
+
del model_input['llm_embedding']
|
|
165
|
+
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
|
166
|
+
model_input['prompt_text'] = instruct_text_token
|
|
167
|
+
model_input['prompt_text_len'] = instruct_text_token_len
|
|
168
|
+
return model_input
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
class CosyVoiceModel:
|
|
17
|
+
|
|
18
|
+
def __init__(self,
|
|
19
|
+
llm: torch.nn.Module,
|
|
20
|
+
flow: torch.nn.Module,
|
|
21
|
+
hift: torch.nn.Module):
|
|
22
|
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
23
|
+
self.llm = llm
|
|
24
|
+
self.flow = flow
|
|
25
|
+
self.hift = hift
|
|
26
|
+
|
|
27
|
+
def load(self, llm_model, flow_model, hift_model):
|
|
28
|
+
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
|
29
|
+
self.llm.to(self.device).eval()
|
|
30
|
+
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
|
|
31
|
+
self.flow.to(self.device).eval()
|
|
32
|
+
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
|
|
33
|
+
self.hift.to(self.device).eval()
|
|
34
|
+
|
|
35
|
+
def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
|
36
|
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
|
|
37
|
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
|
38
|
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
|
39
|
+
prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
|
|
40
|
+
tts_speech_token = self.llm.inference(text=text.to(self.device),
|
|
41
|
+
text_len=text_len.to(self.device),
|
|
42
|
+
prompt_text=prompt_text.to(self.device),
|
|
43
|
+
prompt_text_len=prompt_text_len.to(self.device),
|
|
44
|
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
|
45
|
+
prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
|
|
46
|
+
embedding=llm_embedding.to(self.device),
|
|
47
|
+
beam_size=1,
|
|
48
|
+
sampling=25,
|
|
49
|
+
max_token_text_ratio=30,
|
|
50
|
+
min_token_text_ratio=3)
|
|
51
|
+
tts_mel = self.flow.inference(token=tts_speech_token,
|
|
52
|
+
token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
|
|
53
|
+
prompt_token=flow_prompt_speech_token.to(self.device),
|
|
54
|
+
prompt_token_len=flow_prompt_speech_token_len.to(self.device),
|
|
55
|
+
prompt_feat=prompt_speech_feat.to(self.device),
|
|
56
|
+
prompt_feat_len=prompt_speech_feat_len.to(self.device),
|
|
57
|
+
embedding=flow_embedding.to(self.device))
|
|
58
|
+
tts_speech = self.hift.inference(mel=tts_mel).cpu()
|
|
59
|
+
torch.cuda.empty_cache()
|
|
60
|
+
return {'tts_speech': tts_speech}
|
|
File without changes
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
|
2
|
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import random
|
|
17
|
+
import json
|
|
18
|
+
import math
|
|
19
|
+
from functools import partial
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
import torch.distributed as dist
|
|
23
|
+
from torch.utils.data import IterableDataset
|
|
24
|
+
from cosyvoice.utils.file_utils import read_lists, read_json_lists
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Processor(IterableDataset):
|
|
28
|
+
|
|
29
|
+
def __init__(self, source, f, *args, **kw):
|
|
30
|
+
assert callable(f)
|
|
31
|
+
self.source = source
|
|
32
|
+
self.f = f
|
|
33
|
+
self.args = args
|
|
34
|
+
self.kw = kw
|
|
35
|
+
|
|
36
|
+
def set_epoch(self, epoch):
|
|
37
|
+
self.source.set_epoch(epoch)
|
|
38
|
+
|
|
39
|
+
def __iter__(self):
|
|
40
|
+
""" Return an iterator over the source dataset processed by the
|
|
41
|
+
given processor.
|
|
42
|
+
"""
|
|
43
|
+
assert self.source is not None
|
|
44
|
+
assert callable(self.f)
|
|
45
|
+
return self.f(iter(self.source), *self.args, **self.kw)
|
|
46
|
+
|
|
47
|
+
def apply(self, f):
|
|
48
|
+
assert callable(f)
|
|
49
|
+
return Processor(self, f, *self.args, **self.kw)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class DistributedSampler:
|
|
53
|
+
|
|
54
|
+
def __init__(self, shuffle=True, partition=True):
|
|
55
|
+
self.epoch = -1
|
|
56
|
+
self.update()
|
|
57
|
+
self.shuffle = shuffle
|
|
58
|
+
self.partition = partition
|
|
59
|
+
|
|
60
|
+
def update(self):
|
|
61
|
+
assert dist.is_available()
|
|
62
|
+
if dist.is_initialized():
|
|
63
|
+
self.rank = dist.get_rank()
|
|
64
|
+
self.world_size = dist.get_world_size()
|
|
65
|
+
else:
|
|
66
|
+
self.rank = 0
|
|
67
|
+
self.world_size = 1
|
|
68
|
+
worker_info = torch.utils.data.get_worker_info()
|
|
69
|
+
if worker_info is None:
|
|
70
|
+
self.worker_id = 0
|
|
71
|
+
self.num_workers = 1
|
|
72
|
+
else:
|
|
73
|
+
self.worker_id = worker_info.id
|
|
74
|
+
self.num_workers = worker_info.num_workers
|
|
75
|
+
return dict(rank=self.rank,
|
|
76
|
+
world_size=self.world_size,
|
|
77
|
+
worker_id=self.worker_id,
|
|
78
|
+
num_workers=self.num_workers)
|
|
79
|
+
|
|
80
|
+
def set_epoch(self, epoch):
|
|
81
|
+
self.epoch = epoch
|
|
82
|
+
|
|
83
|
+
def sample(self, data):
|
|
84
|
+
""" Sample data according to rank/world_size/num_workers
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
data(List): input data list
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
List: data list after sample
|
|
91
|
+
"""
|
|
92
|
+
data = list(range(len(data)))
|
|
93
|
+
# force datalist even
|
|
94
|
+
if self.partition:
|
|
95
|
+
if self.shuffle:
|
|
96
|
+
random.Random(self.epoch).shuffle(data)
|
|
97
|
+
if len(data) < self.world_size:
|
|
98
|
+
data = data * math.ceil(self.world_size / len(data))
|
|
99
|
+
data = data[:self.world_size]
|
|
100
|
+
data = data[self.rank::self.world_size]
|
|
101
|
+
if len(data) < self.num_workers:
|
|
102
|
+
data = data * math.ceil(self.num_workers / len(data))
|
|
103
|
+
data = data[:self.num_workers]
|
|
104
|
+
data = data[self.worker_id::self.num_workers]
|
|
105
|
+
return data
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class DataList(IterableDataset):
|
|
109
|
+
|
|
110
|
+
def __init__(self, lists, shuffle=True, partition=True):
|
|
111
|
+
self.lists = lists
|
|
112
|
+
self.sampler = DistributedSampler(shuffle, partition)
|
|
113
|
+
|
|
114
|
+
def set_epoch(self, epoch):
|
|
115
|
+
self.sampler.set_epoch(epoch)
|
|
116
|
+
|
|
117
|
+
def __iter__(self):
|
|
118
|
+
sampler_info = self.sampler.update()
|
|
119
|
+
indexes = self.sampler.sample(self.lists)
|
|
120
|
+
for index in indexes:
|
|
121
|
+
data = dict(src=self.lists[index])
|
|
122
|
+
data.update(sampler_info)
|
|
123
|
+
yield data
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def Dataset(data_list_file,
|
|
127
|
+
data_pipeline,
|
|
128
|
+
mode='train',
|
|
129
|
+
shuffle=True,
|
|
130
|
+
partition=True,
|
|
131
|
+
tts_file='',
|
|
132
|
+
prompt_utt2data=''):
|
|
133
|
+
""" Construct dataset from arguments
|
|
134
|
+
|
|
135
|
+
We have two shuffle stage in the Dataset. The first is global
|
|
136
|
+
shuffle at shards tar/raw file level. The second is global shuffle
|
|
137
|
+
at training samples level.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
data_type(str): raw/shard
|
|
141
|
+
tokenizer (BaseTokenizer): tokenizer to tokenize
|
|
142
|
+
partition(bool): whether to do data partition in terms of rank
|
|
143
|
+
"""
|
|
144
|
+
assert mode in ['train', 'inference']
|
|
145
|
+
lists = read_lists(data_list_file)
|
|
146
|
+
if mode == 'inference':
|
|
147
|
+
with open(tts_file) as f:
|
|
148
|
+
tts_data = json.load(f)
|
|
149
|
+
utt2lists = read_json_lists(prompt_utt2data)
|
|
150
|
+
# filter unnecessary file in inference mode
|
|
151
|
+
lists = list(set([utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists]))
|
|
152
|
+
dataset = DataList(lists,
|
|
153
|
+
shuffle=shuffle,
|
|
154
|
+
partition=partition)
|
|
155
|
+
if mode == 'inference':
|
|
156
|
+
# map partial arg tts_data in inference mode
|
|
157
|
+
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
|
|
158
|
+
for func in data_pipeline:
|
|
159
|
+
dataset = Processor(dataset, func, mode=mode)
|
|
160
|
+
return dataset
|