xinference 0.11.3__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +143 -6
- xinference/client/restful/restful_client.py +144 -5
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +48 -28
- xinference/core/model.py +160 -19
- xinference/core/scheduler.py +446 -0
- xinference/core/supervisor.py +99 -24
- xinference/core/worker.py +68 -2
- xinference/deploy/cmdline.py +86 -2
- xinference/deploy/test/test_cmdline.py +19 -10
- xinference/isolation.py +9 -2
- xinference/model/audio/__init__.py +14 -1
- xinference/model/audio/chattts.py +84 -0
- xinference/model/audio/core.py +22 -4
- xinference/model/audio/custom.py +6 -4
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/audio/model_spec_modelscope.json +20 -0
- xinference/model/llm/__init__.py +38 -2
- xinference/model/llm/llm_family.json +509 -1
- xinference/model/llm/llm_family.py +86 -1
- xinference/model/llm/llm_family_csghub.json +66 -0
- xinference/model/llm/llm_family_modelscope.json +411 -2
- xinference/model/llm/pytorch/chatglm.py +20 -13
- xinference/model/llm/pytorch/cogvlm2.py +76 -17
- xinference/model/llm/pytorch/core.py +141 -6
- xinference/model/llm/pytorch/glm4v.py +268 -0
- xinference/model/llm/pytorch/minicpmv25.py +232 -0
- xinference/model/llm/pytorch/qwen_vl.py +1 -1
- xinference/model/llm/pytorch/utils.py +405 -8
- xinference/model/llm/utils.py +14 -13
- xinference/model/llm/vllm/core.py +16 -4
- xinference/model/utils.py +8 -2
- xinference/thirdparty/ChatTTS/__init__.py +1 -0
- xinference/thirdparty/ChatTTS/core.py +200 -0
- xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
- xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/infer/api.py +125 -0
- xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
- xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
- xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
- xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
- xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
- xinference/types.py +3 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
- xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
- xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
- xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/METADATA +26 -9
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/RECORD +65 -47
- xinference/web/ui/build/static/css/main.54bca460.css +0 -2
- xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
- xinference/web/ui/build/static/js/main.551aa479.js +0 -3
- xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
- /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
|
|
2
|
+
import os
|
|
3
|
+
import logging
|
|
4
|
+
from functools import partial
|
|
5
|
+
from omegaconf import OmegaConf
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from vocos import Vocos
|
|
9
|
+
from .model.dvae import DVAE
|
|
10
|
+
from .model.gpt import GPT_warpper
|
|
11
|
+
from .utils.gpu_utils import select_device
|
|
12
|
+
from .utils.infer_utils import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map
|
|
13
|
+
from .utils.io_utils import get_latest_modified_file
|
|
14
|
+
from .infer.api import refine_text, infer_code
|
|
15
|
+
|
|
16
|
+
from huggingface_hub import snapshot_download
|
|
17
|
+
|
|
18
|
+
logging.basicConfig(level = logging.INFO)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Chat:
|
|
22
|
+
def __init__(self, ):
|
|
23
|
+
self.pretrain_models = {}
|
|
24
|
+
self.normalizer = {}
|
|
25
|
+
self.logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
def check_model(self, level = logging.INFO, use_decoder = False):
|
|
28
|
+
not_finish = False
|
|
29
|
+
check_list = ['vocos', 'gpt', 'tokenizer']
|
|
30
|
+
|
|
31
|
+
if use_decoder:
|
|
32
|
+
check_list.append('decoder')
|
|
33
|
+
else:
|
|
34
|
+
check_list.append('dvae')
|
|
35
|
+
|
|
36
|
+
for module in check_list:
|
|
37
|
+
if module not in self.pretrain_models:
|
|
38
|
+
self.logger.log(logging.WARNING, f'{module} not initialized.')
|
|
39
|
+
not_finish = True
|
|
40
|
+
|
|
41
|
+
if not not_finish:
|
|
42
|
+
self.logger.log(level, f'All initialized.')
|
|
43
|
+
|
|
44
|
+
return not not_finish
|
|
45
|
+
|
|
46
|
+
def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>', **kwargs):
|
|
47
|
+
if source == 'huggingface':
|
|
48
|
+
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
|
|
49
|
+
try:
|
|
50
|
+
download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
|
|
51
|
+
except:
|
|
52
|
+
download_path = None
|
|
53
|
+
if download_path is None or force_redownload:
|
|
54
|
+
self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
|
|
55
|
+
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
|
|
56
|
+
else:
|
|
57
|
+
self.logger.log(logging.INFO, f'Load from cache: {download_path}')
|
|
58
|
+
elif source == 'local':
|
|
59
|
+
self.logger.log(logging.INFO, f'Load from local: {local_path}')
|
|
60
|
+
download_path = local_path
|
|
61
|
+
|
|
62
|
+
self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs)
|
|
63
|
+
|
|
64
|
+
def _load(
|
|
65
|
+
self,
|
|
66
|
+
vocos_config_path: str = None,
|
|
67
|
+
vocos_ckpt_path: str = None,
|
|
68
|
+
dvae_config_path: str = None,
|
|
69
|
+
dvae_ckpt_path: str = None,
|
|
70
|
+
gpt_config_path: str = None,
|
|
71
|
+
gpt_ckpt_path: str = None,
|
|
72
|
+
decoder_config_path: str = None,
|
|
73
|
+
decoder_ckpt_path: str = None,
|
|
74
|
+
tokenizer_path: str = None,
|
|
75
|
+
device: str = None,
|
|
76
|
+
compile: bool = True,
|
|
77
|
+
):
|
|
78
|
+
if not device:
|
|
79
|
+
device = select_device(4096)
|
|
80
|
+
self.logger.log(logging.INFO, f'use {device}')
|
|
81
|
+
|
|
82
|
+
if vocos_config_path:
|
|
83
|
+
vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
|
|
84
|
+
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
|
|
85
|
+
vocos.load_state_dict(torch.load(vocos_ckpt_path))
|
|
86
|
+
self.pretrain_models['vocos'] = vocos
|
|
87
|
+
self.logger.log(logging.INFO, 'vocos loaded.')
|
|
88
|
+
|
|
89
|
+
if dvae_config_path:
|
|
90
|
+
cfg = OmegaConf.load(dvae_config_path)
|
|
91
|
+
dvae = DVAE(**cfg).to(device).eval()
|
|
92
|
+
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
|
|
93
|
+
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
|
|
94
|
+
self.pretrain_models['dvae'] = dvae
|
|
95
|
+
self.logger.log(logging.INFO, 'dvae loaded.')
|
|
96
|
+
|
|
97
|
+
if gpt_config_path:
|
|
98
|
+
cfg = OmegaConf.load(gpt_config_path)
|
|
99
|
+
gpt = GPT_warpper(**cfg).to(device).eval()
|
|
100
|
+
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
|
|
101
|
+
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
|
|
102
|
+
if compile and 'cuda' in str(device):
|
|
103
|
+
gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True)
|
|
104
|
+
self.pretrain_models['gpt'] = gpt
|
|
105
|
+
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt')
|
|
106
|
+
assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'
|
|
107
|
+
self.pretrain_models['spk_stat'] = torch.load(spk_stat_path).to(device)
|
|
108
|
+
self.logger.log(logging.INFO, 'gpt loaded.')
|
|
109
|
+
|
|
110
|
+
if decoder_config_path:
|
|
111
|
+
cfg = OmegaConf.load(decoder_config_path)
|
|
112
|
+
decoder = DVAE(**cfg).to(device).eval()
|
|
113
|
+
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
|
|
114
|
+
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
|
|
115
|
+
self.pretrain_models['decoder'] = decoder
|
|
116
|
+
self.logger.log(logging.INFO, 'decoder loaded.')
|
|
117
|
+
|
|
118
|
+
if tokenizer_path:
|
|
119
|
+
tokenizer = torch.load(tokenizer_path, map_location='cpu')
|
|
120
|
+
tokenizer.padding_side = 'left'
|
|
121
|
+
self.pretrain_models['tokenizer'] = tokenizer
|
|
122
|
+
self.logger.log(logging.INFO, 'tokenizer loaded.')
|
|
123
|
+
|
|
124
|
+
self.check_model()
|
|
125
|
+
|
|
126
|
+
def infer(
|
|
127
|
+
self,
|
|
128
|
+
text,
|
|
129
|
+
skip_refine_text=False,
|
|
130
|
+
refine_text_only=False,
|
|
131
|
+
params_refine_text={},
|
|
132
|
+
params_infer_code={'prompt':'[speed_5]'},
|
|
133
|
+
use_decoder=True,
|
|
134
|
+
do_text_normalization=True,
|
|
135
|
+
lang=None,
|
|
136
|
+
):
|
|
137
|
+
|
|
138
|
+
assert self.check_model(use_decoder=use_decoder)
|
|
139
|
+
|
|
140
|
+
if not isinstance(text, list):
|
|
141
|
+
text = [text]
|
|
142
|
+
|
|
143
|
+
if do_text_normalization:
|
|
144
|
+
for i, t in enumerate(text):
|
|
145
|
+
_lang = detect_language(t) if lang is None else lang
|
|
146
|
+
self.init_normalizer(_lang)
|
|
147
|
+
text[i] = self.normalizer[_lang](t)
|
|
148
|
+
if _lang == 'zh':
|
|
149
|
+
text[i] = apply_half2full_map(text[i])
|
|
150
|
+
|
|
151
|
+
for i, t in enumerate(text):
|
|
152
|
+
invalid_characters = count_invalid_characters(t)
|
|
153
|
+
if len(invalid_characters):
|
|
154
|
+
self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
|
|
155
|
+
text[i] = apply_character_map(t)
|
|
156
|
+
|
|
157
|
+
if not skip_refine_text:
|
|
158
|
+
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
|
|
159
|
+
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
|
|
160
|
+
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
|
|
161
|
+
if refine_text_only:
|
|
162
|
+
return text
|
|
163
|
+
|
|
164
|
+
text = [params_infer_code.get('prompt', '') + i for i in text]
|
|
165
|
+
params_infer_code.pop('prompt', '')
|
|
166
|
+
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
|
|
167
|
+
|
|
168
|
+
if use_decoder:
|
|
169
|
+
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
|
|
170
|
+
else:
|
|
171
|
+
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
|
|
172
|
+
|
|
173
|
+
wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
|
|
174
|
+
|
|
175
|
+
return wav
|
|
176
|
+
|
|
177
|
+
def sample_random_speaker(self, ):
|
|
178
|
+
|
|
179
|
+
dim = self.pretrain_models['gpt'].gpt.layers[0].mlp.gate_proj.in_features
|
|
180
|
+
std, mean = self.pretrain_models['spk_stat'].chunk(2)
|
|
181
|
+
return torch.randn(dim, device=std.device) * std + mean
|
|
182
|
+
|
|
183
|
+
def init_normalizer(self, lang):
|
|
184
|
+
|
|
185
|
+
if lang not in self.normalizer:
|
|
186
|
+
if lang == 'zh':
|
|
187
|
+
try:
|
|
188
|
+
from tn.chinese.normalizer import Normalizer
|
|
189
|
+
except:
|
|
190
|
+
self.logger.log(logging.WARNING, f'Package WeTextProcessing not found! \
|
|
191
|
+
Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing')
|
|
192
|
+
self.normalizer[lang] = Normalizer().normalize
|
|
193
|
+
else:
|
|
194
|
+
try:
|
|
195
|
+
from nemo_text_processing.text_normalization.normalize import Normalizer
|
|
196
|
+
except:
|
|
197
|
+
self.logger.log(logging.WARNING, f'Package nemo_text_processing not found! \
|
|
198
|
+
Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing')
|
|
199
|
+
self.normalizer[lang] = partial(Normalizer(input_case='cased', lang=lang).normalize, verbose=False, punct_post_process=True)
|
|
200
|
+
|
|
File without changes
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
|
|
2
|
+
from openai import OpenAI
|
|
3
|
+
|
|
4
|
+
prompt_dict = {
|
|
5
|
+
'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"},
|
|
6
|
+
{"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
|
|
7
|
+
{"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
|
|
8
|
+
'deepseek': [
|
|
9
|
+
{"role": "system", "content": "You are a helpful assistant"},
|
|
10
|
+
{"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
|
|
11
|
+
{"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
|
|
12
|
+
'deepseek_TN': [
|
|
13
|
+
{"role": "system", "content": "You are a helpful assistant"},
|
|
14
|
+
{"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"},
|
|
15
|
+
{"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"},
|
|
16
|
+
{"role": "user", "content": "We paid $123 for this desk."},
|
|
17
|
+
{"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."},
|
|
18
|
+
{"role": "user", "content": "详询请拨打010-724654"},
|
|
19
|
+
{"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"},
|
|
20
|
+
{"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"},
|
|
21
|
+
{"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"},
|
|
22
|
+
],
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
class llm_api:
|
|
26
|
+
def __init__(self, api_key, base_url, model):
|
|
27
|
+
self.client = OpenAI(
|
|
28
|
+
api_key = api_key,
|
|
29
|
+
base_url = base_url,
|
|
30
|
+
)
|
|
31
|
+
self.model = model
|
|
32
|
+
def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs):
|
|
33
|
+
|
|
34
|
+
completion = self.client.chat.completions.create(
|
|
35
|
+
model = self.model,
|
|
36
|
+
messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},],
|
|
37
|
+
temperature = temperature,
|
|
38
|
+
**kwargs
|
|
39
|
+
)
|
|
40
|
+
return completion.choices[0].message.content
|
|
File without changes
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
|
|
5
|
+
from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
|
|
6
|
+
|
|
7
|
+
def infer_code(
|
|
8
|
+
models,
|
|
9
|
+
text,
|
|
10
|
+
spk_emb = None,
|
|
11
|
+
top_P = 0.7,
|
|
12
|
+
top_K = 20,
|
|
13
|
+
temperature = 0.3,
|
|
14
|
+
repetition_penalty = 1.05,
|
|
15
|
+
max_new_token = 2048,
|
|
16
|
+
**kwargs
|
|
17
|
+
):
|
|
18
|
+
|
|
19
|
+
device = next(models['gpt'].parameters()).device
|
|
20
|
+
|
|
21
|
+
if not isinstance(text, list):
|
|
22
|
+
text = [text]
|
|
23
|
+
|
|
24
|
+
if not isinstance(temperature, list):
|
|
25
|
+
temperature = [temperature] * models['gpt'].num_vq
|
|
26
|
+
|
|
27
|
+
if spk_emb is not None:
|
|
28
|
+
text = [f'[Stts][spk_emb]{i}[Ptts]' for i in text]
|
|
29
|
+
else:
|
|
30
|
+
text = [f'[Stts][empty_spk]{i}[Ptts]' for i in text]
|
|
31
|
+
|
|
32
|
+
text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
|
|
33
|
+
input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq)
|
|
34
|
+
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
|
|
35
|
+
|
|
36
|
+
inputs = {
|
|
37
|
+
'input_ids': input_ids,
|
|
38
|
+
'text_mask': text_mask,
|
|
39
|
+
'attention_mask': text_token['attention_mask'],
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
emb = models['gpt'].get_emb(**inputs)
|
|
43
|
+
if spk_emb is not None:
|
|
44
|
+
emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \
|
|
45
|
+
F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12)
|
|
46
|
+
|
|
47
|
+
num_code = models['gpt'].emb_code[0].num_embeddings - 1
|
|
48
|
+
|
|
49
|
+
LogitsWarpers = []
|
|
50
|
+
if top_P is not None:
|
|
51
|
+
LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
|
|
52
|
+
if top_K is not None:
|
|
53
|
+
LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
|
|
54
|
+
|
|
55
|
+
LogitsProcessors = []
|
|
56
|
+
if repetition_penalty is not None and repetition_penalty != 1:
|
|
57
|
+
LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\
|
|
58
|
+
repetition_penalty, num_code, 16))
|
|
59
|
+
|
|
60
|
+
result = models['gpt'].generate(
|
|
61
|
+
emb, inputs['input_ids'],
|
|
62
|
+
temperature = torch.tensor(temperature, device=device),
|
|
63
|
+
attention_mask = inputs['attention_mask'],
|
|
64
|
+
LogitsWarpers = LogitsWarpers,
|
|
65
|
+
LogitsProcessors = LogitsProcessors,
|
|
66
|
+
eos_token = num_code,
|
|
67
|
+
max_new_token = max_new_token,
|
|
68
|
+
infer_text = False,
|
|
69
|
+
**kwargs
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return result
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def refine_text(
|
|
76
|
+
models,
|
|
77
|
+
text,
|
|
78
|
+
top_P = 0.7,
|
|
79
|
+
top_K = 20,
|
|
80
|
+
temperature = 0.7,
|
|
81
|
+
repetition_penalty = 1.0,
|
|
82
|
+
max_new_token = 384,
|
|
83
|
+
prompt = '',
|
|
84
|
+
**kwargs
|
|
85
|
+
):
|
|
86
|
+
|
|
87
|
+
device = next(models['gpt'].parameters()).device
|
|
88
|
+
|
|
89
|
+
if not isinstance(text, list):
|
|
90
|
+
text = [text]
|
|
91
|
+
|
|
92
|
+
assert len(text), 'text should not be empty'
|
|
93
|
+
|
|
94
|
+
text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
|
|
95
|
+
text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
|
|
96
|
+
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
|
|
97
|
+
|
|
98
|
+
inputs = {
|
|
99
|
+
'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq),
|
|
100
|
+
'text_mask': text_mask,
|
|
101
|
+
'attention_mask': text_token['attention_mask'],
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
LogitsWarpers = []
|
|
105
|
+
if top_P is not None:
|
|
106
|
+
LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
|
|
107
|
+
if top_K is not None:
|
|
108
|
+
LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
|
|
109
|
+
|
|
110
|
+
LogitsProcessors = []
|
|
111
|
+
if repetition_penalty is not None and repetition_penalty != 1:
|
|
112
|
+
LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16))
|
|
113
|
+
|
|
114
|
+
result = models['gpt'].generate(
|
|
115
|
+
models['gpt'].get_emb(**inputs), inputs['input_ids'],
|
|
116
|
+
temperature = torch.tensor([temperature,], device=device),
|
|
117
|
+
attention_mask = inputs['attention_mask'],
|
|
118
|
+
LogitsWarpers = LogitsWarpers,
|
|
119
|
+
LogitsProcessors = LogitsProcessors,
|
|
120
|
+
eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
|
|
121
|
+
max_new_token = max_new_token,
|
|
122
|
+
infer_text = True,
|
|
123
|
+
**kwargs
|
|
124
|
+
)
|
|
125
|
+
return result
|
|
File without changes
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from einops import rearrange
|
|
3
|
+
from vector_quantize_pytorch import GroupedResidualFSQ
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
class ConvNeXtBlock(nn.Module):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
dim: int,
|
|
13
|
+
intermediate_dim: int,
|
|
14
|
+
kernel, dilation,
|
|
15
|
+
layer_scale_init_value: float = 1e-6,
|
|
16
|
+
):
|
|
17
|
+
# ConvNeXt Block copied from Vocos.
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.dwconv = nn.Conv1d(dim, dim,
|
|
20
|
+
kernel_size=kernel, padding=dilation*(kernel//2),
|
|
21
|
+
dilation=dilation, groups=dim
|
|
22
|
+
) # depthwise conv
|
|
23
|
+
|
|
24
|
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
25
|
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
|
26
|
+
self.act = nn.GELU()
|
|
27
|
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
|
28
|
+
self.gamma = (
|
|
29
|
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
|
30
|
+
if layer_scale_init_value > 0
|
|
31
|
+
else None
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def forward(self, x: torch.Tensor, cond = None) -> torch.Tensor:
|
|
35
|
+
residual = x
|
|
36
|
+
x = self.dwconv(x)
|
|
37
|
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
|
38
|
+
x = self.norm(x)
|
|
39
|
+
x = self.pwconv1(x)
|
|
40
|
+
x = self.act(x)
|
|
41
|
+
x = self.pwconv2(x)
|
|
42
|
+
if self.gamma is not None:
|
|
43
|
+
x = self.gamma * x
|
|
44
|
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
|
45
|
+
|
|
46
|
+
x = residual + x
|
|
47
|
+
return x
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class GFSQ(nn.Module):
|
|
52
|
+
|
|
53
|
+
def __init__(self,
|
|
54
|
+
dim, levels, G, R, eps=1e-5, transpose = True
|
|
55
|
+
):
|
|
56
|
+
super(GFSQ, self).__init__()
|
|
57
|
+
self.quantizer = GroupedResidualFSQ(
|
|
58
|
+
dim=dim,
|
|
59
|
+
levels=levels,
|
|
60
|
+
num_quantizers=R,
|
|
61
|
+
groups=G,
|
|
62
|
+
)
|
|
63
|
+
self.n_ind = math.prod(levels)
|
|
64
|
+
self.eps = eps
|
|
65
|
+
self.transpose = transpose
|
|
66
|
+
self.G = G
|
|
67
|
+
self.R = R
|
|
68
|
+
|
|
69
|
+
def _embed(self, x):
|
|
70
|
+
if self.transpose:
|
|
71
|
+
x = x.transpose(1,2)
|
|
72
|
+
x = rearrange(
|
|
73
|
+
x, "b t (g r) -> g b t r", g = self.G, r = self.R,
|
|
74
|
+
)
|
|
75
|
+
feat = self.quantizer.get_output_from_indices(x)
|
|
76
|
+
return feat.transpose(1,2) if self.transpose else feat
|
|
77
|
+
|
|
78
|
+
def forward(self, x,):
|
|
79
|
+
if self.transpose:
|
|
80
|
+
x = x.transpose(1,2)
|
|
81
|
+
feat, ind = self.quantizer(x)
|
|
82
|
+
ind = rearrange(
|
|
83
|
+
ind, "g b t r ->b t (g r)",
|
|
84
|
+
)
|
|
85
|
+
embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
|
|
86
|
+
e_mean = torch.mean(embed_onehot, dim=[0,1])
|
|
87
|
+
e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
|
|
88
|
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
|
|
89
|
+
|
|
90
|
+
return (
|
|
91
|
+
torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
|
|
92
|
+
feat.transpose(1,2) if self.transpose else feat,
|
|
93
|
+
perplexity,
|
|
94
|
+
None,
|
|
95
|
+
ind.transpose(1,2) if self.transpose else ind,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
class DVAEDecoder(nn.Module):
|
|
99
|
+
def __init__(self, idim, odim,
|
|
100
|
+
n_layer = 12, bn_dim = 64, hidden = 256,
|
|
101
|
+
kernel = 7, dilation = 2, up = False
|
|
102
|
+
):
|
|
103
|
+
super().__init__()
|
|
104
|
+
self.up = up
|
|
105
|
+
self.conv_in = nn.Sequential(
|
|
106
|
+
nn.Conv1d(idim, bn_dim, 3, 1, 1), nn.GELU(),
|
|
107
|
+
nn.Conv1d(bn_dim, hidden, 3, 1, 1)
|
|
108
|
+
)
|
|
109
|
+
self.decoder_block = nn.ModuleList([
|
|
110
|
+
ConvNeXtBlock(hidden, hidden* 4, kernel, dilation,)
|
|
111
|
+
for _ in range(n_layer)])
|
|
112
|
+
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
|
|
113
|
+
|
|
114
|
+
def forward(self, input, conditioning=None):
|
|
115
|
+
# B, T, C
|
|
116
|
+
x = input.transpose(1, 2)
|
|
117
|
+
x = self.conv_in(x)
|
|
118
|
+
for f in self.decoder_block:
|
|
119
|
+
x = f(x, conditioning)
|
|
120
|
+
|
|
121
|
+
x = self.conv_out(x)
|
|
122
|
+
return x.transpose(1, 2)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class DVAE(nn.Module):
|
|
126
|
+
def __init__(
|
|
127
|
+
self, decoder_config, vq_config, dim=512
|
|
128
|
+
):
|
|
129
|
+
super().__init__()
|
|
130
|
+
self.register_buffer('coef', torch.randn(1, 100, 1))
|
|
131
|
+
|
|
132
|
+
self.decoder = DVAEDecoder(**decoder_config)
|
|
133
|
+
self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
|
|
134
|
+
if vq_config is not None:
|
|
135
|
+
self.vq_layer = GFSQ(**vq_config)
|
|
136
|
+
else:
|
|
137
|
+
self.vq_layer = None
|
|
138
|
+
|
|
139
|
+
def forward(self, inp):
|
|
140
|
+
|
|
141
|
+
if self.vq_layer is not None:
|
|
142
|
+
vq_feats = self.vq_layer._embed(inp)
|
|
143
|
+
else:
|
|
144
|
+
vq_feats = inp.detach().clone()
|
|
145
|
+
|
|
146
|
+
temp = torch.chunk(vq_feats, 2, dim=1) # flatten trick :)
|
|
147
|
+
temp = torch.stack(temp, -1)
|
|
148
|
+
vq_feats = temp.reshape(*temp.shape[:2], -1)
|
|
149
|
+
|
|
150
|
+
vq_feats = vq_feats.transpose(1, 2)
|
|
151
|
+
dec_out = self.decoder(input=vq_feats)
|
|
152
|
+
dec_out = self.out_conv(dec_out.transpose(1, 2))
|
|
153
|
+
mel = dec_out * self.coef
|
|
154
|
+
|
|
155
|
+
return mel
|