xinference 0.13.1__py3-none-any.whl → 0.13.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (82) hide show
  1. xinference/__init__.py +0 -1
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +99 -5
  4. xinference/client/restful/restful_client.py +98 -1
  5. xinference/core/chat_interface.py +2 -2
  6. xinference/core/model.py +85 -26
  7. xinference/core/scheduler.py +4 -4
  8. xinference/model/audio/chattts.py +40 -8
  9. xinference/model/audio/core.py +5 -2
  10. xinference/model/audio/cosyvoice.py +136 -0
  11. xinference/model/audio/model_spec.json +24 -0
  12. xinference/model/audio/model_spec_modelscope.json +27 -0
  13. xinference/model/flexible/launchers/__init__.py +1 -0
  14. xinference/model/flexible/launchers/image_process_launcher.py +70 -0
  15. xinference/model/image/core.py +3 -0
  16. xinference/model/image/model_spec.json +21 -0
  17. xinference/model/image/stable_diffusion/core.py +49 -7
  18. xinference/model/llm/llm_family.json +1065 -106
  19. xinference/model/llm/llm_family.py +26 -6
  20. xinference/model/llm/llm_family_csghub.json +39 -0
  21. xinference/model/llm/llm_family_modelscope.json +460 -47
  22. xinference/model/llm/pytorch/chatglm.py +243 -5
  23. xinference/model/llm/pytorch/cogvlm2.py +1 -1
  24. xinference/model/llm/sglang/core.py +7 -2
  25. xinference/model/llm/utils.py +78 -1
  26. xinference/model/llm/vllm/core.py +11 -0
  27. xinference/thirdparty/cosyvoice/__init__.py +0 -0
  28. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  29. xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
  30. xinference/thirdparty/cosyvoice/bin/train.py +136 -0
  31. xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
  32. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
  33. xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
  34. xinference/thirdparty/cosyvoice/cli/model.py +60 -0
  35. xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
  36. xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
  37. xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
  38. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  39. xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
  40. xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
  41. xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
  42. xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
  43. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  44. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
  45. xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
  46. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  47. xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
  48. xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
  49. xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
  50. xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
  51. xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
  52. xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
  53. xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
  54. xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
  55. xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
  56. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
  57. xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
  58. xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  59. xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
  60. xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
  61. xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
  62. xinference/thirdparty/cosyvoice/utils/common.py +103 -0
  63. xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
  64. xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
  65. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
  66. xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
  67. xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
  68. xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
  69. xinference/web/ui/build/asset-manifest.json +3 -3
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
  72. xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
  74. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/METADATA +18 -8
  75. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/RECORD +80 -36
  76. xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
  78. /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
  79. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
  80. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
  81. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
  82. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,136 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+ import argparse
17
+ import datetime
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ from copy import deepcopy
21
+ import torch
22
+ import torch.distributed as dist
23
+ import deepspeed
24
+
25
+ from hyperpyyaml import load_hyperpyyaml
26
+
27
+ from torch.distributed.elastic.multiprocessing.errors import record
28
+
29
+ from cosyvoice.utils.executor import Executor
30
+ from cosyvoice.utils.train_utils import (
31
+ init_distributed,
32
+ init_dataset_and_dataloader,
33
+ init_optimizer_and_scheduler,
34
+ init_summarywriter, save_model,
35
+ wrap_cuda_model, check_modify_and_save_config)
36
+
37
+
38
+ def get_args():
39
+ parser = argparse.ArgumentParser(description='training your network')
40
+ parser.add_argument('--train_engine',
41
+ default='torch_ddp',
42
+ choices=['torch_ddp', 'deepspeed'],
43
+ help='Engine for paralleled training')
44
+ parser.add_argument('--model', required=True, help='model which will be trained')
45
+ parser.add_argument('--config', required=True, help='config file')
46
+ parser.add_argument('--train_data', required=True, help='train data file')
47
+ parser.add_argument('--cv_data', required=True, help='cv data file')
48
+ parser.add_argument('--checkpoint', help='checkpoint model')
49
+ parser.add_argument('--model_dir', required=True, help='save model dir')
50
+ parser.add_argument('--tensorboard_dir',
51
+ default='tensorboard',
52
+ help='tensorboard log dir')
53
+ parser.add_argument('--ddp.dist_backend',
54
+ dest='dist_backend',
55
+ default='nccl',
56
+ choices=['nccl', 'gloo'],
57
+ help='distributed backend')
58
+ parser.add_argument('--num_workers',
59
+ default=0,
60
+ type=int,
61
+ help='num of subprocess workers for reading')
62
+ parser.add_argument('--prefetch',
63
+ default=100,
64
+ type=int,
65
+ help='prefetch number')
66
+ parser.add_argument('--pin_memory',
67
+ action='store_true',
68
+ default=False,
69
+ help='Use pinned memory buffers used for reading')
70
+ parser.add_argument('--deepspeed.save_states',
71
+ dest='save_states',
72
+ default='model_only',
73
+ choices=['model_only', 'model+optimizer'],
74
+ help='save model/optimizer states')
75
+ parser.add_argument('--timeout',
76
+ default=30,
77
+ type=int,
78
+ help='timeout (in seconds) of cosyvoice_join.')
79
+ parser = deepspeed.add_config_arguments(parser)
80
+ args = parser.parse_args()
81
+ return args
82
+
83
+
84
+ @record
85
+ def main():
86
+ args = get_args()
87
+ logging.basicConfig(level=logging.DEBUG,
88
+ format='%(asctime)s %(levelname)s %(message)s')
89
+
90
+ override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
91
+ with open(args.config, 'r') as f:
92
+ configs = load_hyperpyyaml(f, overrides=override_dict)
93
+ configs['train_conf'].update(vars(args))
94
+
95
+ # Init env for ddp
96
+ init_distributed(args)
97
+
98
+ # Get dataset & dataloader
99
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
100
+ init_dataset_and_dataloader(args, configs)
101
+
102
+ # Do some sanity checks and save config to arsg.model_dir
103
+ configs = check_modify_and_save_config(args, configs)
104
+
105
+ # Tensorboard summary
106
+ writer = init_summarywriter(args)
107
+
108
+ # load checkpoint
109
+ model = configs[args.model]
110
+ if args.checkpoint is not None:
111
+ model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
112
+
113
+ # Dispatch model from cpu to gpu
114
+ model = wrap_cuda_model(args, model)
115
+
116
+ # Get optimizer & scheduler
117
+ model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
118
+
119
+ # Save init checkpoints
120
+ info_dict = deepcopy(configs['train_conf'])
121
+ save_model(model, 'init', info_dict)
122
+
123
+ # Get executor
124
+ executor = Executor()
125
+
126
+ # Start training loop
127
+ for epoch in range(info_dict['max_epoch']):
128
+ executor.epoch = epoch
129
+ train_dataset.set_epoch(epoch)
130
+ dist.barrier()
131
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
132
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
133
+ dist.destroy_process_group(group_join)
134
+
135
+ if __name__ == '__main__':
136
+ main()
File without changes
@@ -0,0 +1,83 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import torch
16
+ from hyperpyyaml import load_hyperpyyaml
17
+ from modelscope import snapshot_download
18
+ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
19
+ from cosyvoice.cli.model import CosyVoiceModel
20
+
21
+ class CosyVoice:
22
+
23
+ def __init__(self, model_dir):
24
+ instruct = True if '-Instruct' in model_dir else False
25
+ self.model_dir = model_dir
26
+ if not os.path.exists(model_dir):
27
+ model_dir = snapshot_download(model_dir)
28
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
29
+ configs = load_hyperpyyaml(f)
30
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
31
+ configs['feat_extractor'],
32
+ '{}/campplus.onnx'.format(model_dir),
33
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
34
+ '{}/spk2info.pt'.format(model_dir),
35
+ instruct,
36
+ configs['allowed_special'])
37
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
38
+ self.model.load('{}/llm.pt'.format(model_dir),
39
+ '{}/flow.pt'.format(model_dir),
40
+ '{}/hift.pt'.format(model_dir))
41
+ del configs
42
+
43
+ def list_avaliable_spks(self):
44
+ spks = list(self.frontend.spk2info.keys())
45
+ return spks
46
+
47
+ def inference_sft(self, tts_text, spk_id):
48
+ tts_speeches = []
49
+ for i in self.frontend.text_normalize(tts_text, split=True):
50
+ model_input = self.frontend.frontend_sft(i, spk_id)
51
+ model_output = self.model.inference(**model_input)
52
+ tts_speeches.append(model_output['tts_speech'])
53
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
54
+
55
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
56
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False)
57
+ tts_speeches = []
58
+ for i in self.frontend.text_normalize(tts_text, split=True):
59
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
60
+ model_output = self.model.inference(**model_input)
61
+ tts_speeches.append(model_output['tts_speech'])
62
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
63
+
64
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k):
65
+ if self.frontend.instruct is True:
66
+ raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
67
+ tts_speeches = []
68
+ for i in self.frontend.text_normalize(tts_text, split=True):
69
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
70
+ model_output = self.model.inference(**model_input)
71
+ tts_speeches.append(model_output['tts_speech'])
72
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
73
+
74
+ def inference_instruct(self, tts_text, spk_id, instruct_text):
75
+ if self.frontend.instruct is False:
76
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
77
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False)
78
+ tts_speeches = []
79
+ for i in self.frontend.text_normalize(tts_text, split=True):
80
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
81
+ model_output = self.model.inference(**model_input)
82
+ tts_speeches.append(model_output['tts_speech'])
83
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
@@ -0,0 +1,168 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ import onnxruntime
16
+ import torch
17
+ import numpy as np
18
+ import whisper
19
+ from typing import Callable
20
+ import torchaudio.compliance.kaldi as kaldi
21
+ import torchaudio
22
+ import os
23
+ import re
24
+ import inflect
25
+ try:
26
+ import ttsfrd
27
+ use_ttsfrd = True
28
+ except ImportError:
29
+ print("failed to import ttsfrd, use WeTextProcessing instead")
30
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
31
+ from tn.english.normalizer import Normalizer as EnNormalizer
32
+ use_ttsfrd = False
33
+ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
34
+
35
+
36
+ class CosyVoiceFrontEnd:
37
+
38
+ def __init__(self,
39
+ get_tokenizer: Callable,
40
+ feat_extractor: Callable,
41
+ campplus_model: str,
42
+ speech_tokenizer_model: str,
43
+ spk2info: str = '',
44
+ instruct: bool = False,
45
+ allowed_special: str = 'all'):
46
+ self.tokenizer = get_tokenizer()
47
+ self.feat_extractor = feat_extractor
48
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49
+ option = onnxruntime.SessionOptions()
50
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
51
+ option.intra_op_num_threads = 1
52
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
53
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if torch.cuda.is_available() else "CPUExecutionProvider"])
54
+ if os.path.exists(spk2info):
55
+ self.spk2info = torch.load(spk2info, map_location=self.device)
56
+ self.instruct = instruct
57
+ self.allowed_special = allowed_special
58
+ self.inflect_parser = inflect.engine()
59
+ self.use_ttsfrd = use_ttsfrd
60
+ if self.use_ttsfrd:
61
+ self.frd = ttsfrd.TtsFrontendEngine()
62
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
63
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, 'failed to initialize ttsfrd resource'
64
+ self.frd.set_lang_type('pinyin')
65
+ self.frd.enable_pinyin_mix(True)
66
+ self.frd.set_breakmodel_index(1)
67
+ else:
68
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
69
+ self.en_tn_model = EnNormalizer()
70
+
71
+ def _extract_text_token(self, text):
72
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
73
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
74
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
75
+ return text_token, text_token_len
76
+
77
+ def _extract_speech_token(self, speech):
78
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
79
+ speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
80
+ self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
81
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
82
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
83
+ return speech_token, speech_token_len
84
+
85
+ def _extract_spk_embedding(self, speech):
86
+ feat = kaldi.fbank(speech,
87
+ num_mel_bins=80,
88
+ dither=0,
89
+ sample_frequency=16000)
90
+ feat = feat - feat.mean(dim=0, keepdim=True)
91
+ embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
92
+ embedding = torch.tensor([embedding]).to(self.device)
93
+ return embedding
94
+
95
+ def _extract_speech_feat(self, speech):
96
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
97
+ speech_feat = speech_feat.unsqueeze(dim=0)
98
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
99
+ return speech_feat, speech_feat_len
100
+
101
+ def text_normalize(self, text, split=True):
102
+ text = text.strip()
103
+ if contains_chinese(text):
104
+ if self.use_ttsfrd:
105
+ text = self.frd.get_frd_extra_info(text, 'input')
106
+ else:
107
+ text = self.zh_tn_model.normalize(text)
108
+ text = text.replace("\n", "")
109
+ text = replace_blank(text)
110
+ text = replace_corner_mark(text)
111
+ text = text.replace(".", "、")
112
+ text = text.replace(" - ", ",")
113
+ text = remove_bracket(text)
114
+ text = re.sub(r'[,,]+$', '。', text)
115
+ texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
116
+ token_min_n=60, merge_len=20,
117
+ comma_split=False)]
118
+ else:
119
+ if self.use_ttsfrd:
120
+ text = self.frd.get_frd_extra_info(text, 'input')
121
+ else:
122
+ text = self.en_tn_model.normalize(text)
123
+ text = spell_out_number(text, self.inflect_parser)
124
+ texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
125
+ token_min_n=60, merge_len=20,
126
+ comma_split=False)]
127
+ if split is False:
128
+ return text
129
+ return texts
130
+
131
+ def frontend_sft(self, tts_text, spk_id):
132
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
133
+ embedding = self.spk2info[spk_id]['embedding']
134
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
135
+ return model_input
136
+
137
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
138
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
139
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
140
+ prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
141
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
142
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
143
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
144
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
145
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
146
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
147
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
148
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
149
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
150
+ return model_input
151
+
152
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
153
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
154
+ # in cross lingual mode, we remove prompt in llm
155
+ del model_input['prompt_text']
156
+ del model_input['prompt_text_len']
157
+ del model_input['llm_prompt_speech_token']
158
+ del model_input['llm_prompt_speech_token_len']
159
+ return model_input
160
+
161
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
162
+ model_input = self.frontend_sft(tts_text, spk_id)
163
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
164
+ del model_input['llm_embedding']
165
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
166
+ model_input['prompt_text'] = instruct_text_token
167
+ model_input['prompt_text_len'] = instruct_text_token_len
168
+ return model_input
@@ -0,0 +1,60 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+
16
+ class CosyVoiceModel:
17
+
18
+ def __init__(self,
19
+ llm: torch.nn.Module,
20
+ flow: torch.nn.Module,
21
+ hift: torch.nn.Module):
22
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+ self.llm = llm
24
+ self.flow = flow
25
+ self.hift = hift
26
+
27
+ def load(self, llm_model, flow_model, hift_model):
28
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
29
+ self.llm.to(self.device).eval()
30
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
31
+ self.flow.to(self.device).eval()
32
+ self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
33
+ self.hift.to(self.device).eval()
34
+
35
+ def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
36
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
37
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
38
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
39
+ prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
40
+ tts_speech_token = self.llm.inference(text=text.to(self.device),
41
+ text_len=text_len.to(self.device),
42
+ prompt_text=prompt_text.to(self.device),
43
+ prompt_text_len=prompt_text_len.to(self.device),
44
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
45
+ prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
46
+ embedding=llm_embedding.to(self.device),
47
+ beam_size=1,
48
+ sampling=25,
49
+ max_token_text_ratio=30,
50
+ min_token_text_ratio=3)
51
+ tts_mel = self.flow.inference(token=tts_speech_token,
52
+ token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
53
+ prompt_token=flow_prompt_speech_token.to(self.device),
54
+ prompt_token_len=flow_prompt_speech_token_len.to(self.device),
55
+ prompt_feat=prompt_speech_feat.to(self.device),
56
+ prompt_feat_len=prompt_speech_feat_len.to(self.device),
57
+ embedding=flow_embedding.to(self.device))
58
+ tts_speech = self.hift.inference(mel=tts_mel).cpu()
59
+ torch.cuda.empty_cache()
60
+ return {'tts_speech': tts_speech}
File without changes
@@ -0,0 +1,160 @@
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import random
17
+ import json
18
+ import math
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ from torch.utils.data import IterableDataset
24
+ from cosyvoice.utils.file_utils import read_lists, read_json_lists
25
+
26
+
27
+ class Processor(IterableDataset):
28
+
29
+ def __init__(self, source, f, *args, **kw):
30
+ assert callable(f)
31
+ self.source = source
32
+ self.f = f
33
+ self.args = args
34
+ self.kw = kw
35
+
36
+ def set_epoch(self, epoch):
37
+ self.source.set_epoch(epoch)
38
+
39
+ def __iter__(self):
40
+ """ Return an iterator over the source dataset processed by the
41
+ given processor.
42
+ """
43
+ assert self.source is not None
44
+ assert callable(self.f)
45
+ return self.f(iter(self.source), *self.args, **self.kw)
46
+
47
+ def apply(self, f):
48
+ assert callable(f)
49
+ return Processor(self, f, *self.args, **self.kw)
50
+
51
+
52
+ class DistributedSampler:
53
+
54
+ def __init__(self, shuffle=True, partition=True):
55
+ self.epoch = -1
56
+ self.update()
57
+ self.shuffle = shuffle
58
+ self.partition = partition
59
+
60
+ def update(self):
61
+ assert dist.is_available()
62
+ if dist.is_initialized():
63
+ self.rank = dist.get_rank()
64
+ self.world_size = dist.get_world_size()
65
+ else:
66
+ self.rank = 0
67
+ self.world_size = 1
68
+ worker_info = torch.utils.data.get_worker_info()
69
+ if worker_info is None:
70
+ self.worker_id = 0
71
+ self.num_workers = 1
72
+ else:
73
+ self.worker_id = worker_info.id
74
+ self.num_workers = worker_info.num_workers
75
+ return dict(rank=self.rank,
76
+ world_size=self.world_size,
77
+ worker_id=self.worker_id,
78
+ num_workers=self.num_workers)
79
+
80
+ def set_epoch(self, epoch):
81
+ self.epoch = epoch
82
+
83
+ def sample(self, data):
84
+ """ Sample data according to rank/world_size/num_workers
85
+
86
+ Args:
87
+ data(List): input data list
88
+
89
+ Returns:
90
+ List: data list after sample
91
+ """
92
+ data = list(range(len(data)))
93
+ # force datalist even
94
+ if self.partition:
95
+ if self.shuffle:
96
+ random.Random(self.epoch).shuffle(data)
97
+ if len(data) < self.world_size:
98
+ data = data * math.ceil(self.world_size / len(data))
99
+ data = data[:self.world_size]
100
+ data = data[self.rank::self.world_size]
101
+ if len(data) < self.num_workers:
102
+ data = data * math.ceil(self.num_workers / len(data))
103
+ data = data[:self.num_workers]
104
+ data = data[self.worker_id::self.num_workers]
105
+ return data
106
+
107
+
108
+ class DataList(IterableDataset):
109
+
110
+ def __init__(self, lists, shuffle=True, partition=True):
111
+ self.lists = lists
112
+ self.sampler = DistributedSampler(shuffle, partition)
113
+
114
+ def set_epoch(self, epoch):
115
+ self.sampler.set_epoch(epoch)
116
+
117
+ def __iter__(self):
118
+ sampler_info = self.sampler.update()
119
+ indexes = self.sampler.sample(self.lists)
120
+ for index in indexes:
121
+ data = dict(src=self.lists[index])
122
+ data.update(sampler_info)
123
+ yield data
124
+
125
+
126
+ def Dataset(data_list_file,
127
+ data_pipeline,
128
+ mode='train',
129
+ shuffle=True,
130
+ partition=True,
131
+ tts_file='',
132
+ prompt_utt2data=''):
133
+ """ Construct dataset from arguments
134
+
135
+ We have two shuffle stage in the Dataset. The first is global
136
+ shuffle at shards tar/raw file level. The second is global shuffle
137
+ at training samples level.
138
+
139
+ Args:
140
+ data_type(str): raw/shard
141
+ tokenizer (BaseTokenizer): tokenizer to tokenize
142
+ partition(bool): whether to do data partition in terms of rank
143
+ """
144
+ assert mode in ['train', 'inference']
145
+ lists = read_lists(data_list_file)
146
+ if mode == 'inference':
147
+ with open(tts_file) as f:
148
+ tts_data = json.load(f)
149
+ utt2lists = read_json_lists(prompt_utt2data)
150
+ # filter unnecessary file in inference mode
151
+ lists = list(set([utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists]))
152
+ dataset = DataList(lists,
153
+ shuffle=shuffle,
154
+ partition=partition)
155
+ if mode == 'inference':
156
+ # map partial arg tts_data in inference mode
157
+ data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
158
+ for func in data_pipeline:
159
+ dataset = Processor(dataset, func, mode=mode)
160
+ return dataset