xinference 0.11.3__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (75) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +143 -6
  3. xinference/client/restful/restful_client.py +144 -5
  4. xinference/constants.py +5 -0
  5. xinference/core/cache_tracker.py +48 -28
  6. xinference/core/model.py +160 -19
  7. xinference/core/scheduler.py +446 -0
  8. xinference/core/supervisor.py +99 -24
  9. xinference/core/worker.py +68 -2
  10. xinference/deploy/cmdline.py +86 -2
  11. xinference/deploy/test/test_cmdline.py +19 -10
  12. xinference/isolation.py +9 -2
  13. xinference/model/audio/__init__.py +14 -1
  14. xinference/model/audio/chattts.py +84 -0
  15. xinference/model/audio/core.py +22 -4
  16. xinference/model/audio/custom.py +6 -4
  17. xinference/model/audio/model_spec.json +20 -0
  18. xinference/model/audio/model_spec_modelscope.json +20 -0
  19. xinference/model/llm/__init__.py +38 -2
  20. xinference/model/llm/llm_family.json +509 -1
  21. xinference/model/llm/llm_family.py +86 -1
  22. xinference/model/llm/llm_family_csghub.json +66 -0
  23. xinference/model/llm/llm_family_modelscope.json +411 -2
  24. xinference/model/llm/pytorch/chatglm.py +20 -13
  25. xinference/model/llm/pytorch/cogvlm2.py +76 -17
  26. xinference/model/llm/pytorch/core.py +141 -6
  27. xinference/model/llm/pytorch/glm4v.py +268 -0
  28. xinference/model/llm/pytorch/minicpmv25.py +232 -0
  29. xinference/model/llm/pytorch/qwen_vl.py +1 -1
  30. xinference/model/llm/pytorch/utils.py +405 -8
  31. xinference/model/llm/utils.py +14 -13
  32. xinference/model/llm/vllm/core.py +16 -4
  33. xinference/model/utils.py +8 -2
  34. xinference/thirdparty/ChatTTS/__init__.py +1 -0
  35. xinference/thirdparty/ChatTTS/core.py +200 -0
  36. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  37. xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
  38. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  39. xinference/thirdparty/ChatTTS/infer/api.py +125 -0
  40. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  41. xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
  42. xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
  43. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  44. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
  45. xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
  46. xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
  47. xinference/types.py +3 -0
  48. xinference/web/ui/build/asset-manifest.json +6 -6
  49. xinference/web/ui/build/index.html +1 -1
  50. xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
  51. xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
  52. xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
  53. xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
  59. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/METADATA +26 -9
  60. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/RECORD +65 -47
  61. xinference/web/ui/build/static/css/main.54bca460.css +0 -2
  62. xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
  63. xinference/web/ui/build/static/js/main.551aa479.js +0 -3
  64. xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
  65. xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
  66. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
  67. xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
  68. xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
  69. xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
  70. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
  71. /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
  72. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
  73. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
  74. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
  75. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,200 @@
1
+
2
+ import os
3
+ import logging
4
+ from functools import partial
5
+ from omegaconf import OmegaConf
6
+
7
+ import torch
8
+ from vocos import Vocos
9
+ from .model.dvae import DVAE
10
+ from .model.gpt import GPT_warpper
11
+ from .utils.gpu_utils import select_device
12
+ from .utils.infer_utils import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map
13
+ from .utils.io_utils import get_latest_modified_file
14
+ from .infer.api import refine_text, infer_code
15
+
16
+ from huggingface_hub import snapshot_download
17
+
18
+ logging.basicConfig(level = logging.INFO)
19
+
20
+
21
+ class Chat:
22
+ def __init__(self, ):
23
+ self.pretrain_models = {}
24
+ self.normalizer = {}
25
+ self.logger = logging.getLogger(__name__)
26
+
27
+ def check_model(self, level = logging.INFO, use_decoder = False):
28
+ not_finish = False
29
+ check_list = ['vocos', 'gpt', 'tokenizer']
30
+
31
+ if use_decoder:
32
+ check_list.append('decoder')
33
+ else:
34
+ check_list.append('dvae')
35
+
36
+ for module in check_list:
37
+ if module not in self.pretrain_models:
38
+ self.logger.log(logging.WARNING, f'{module} not initialized.')
39
+ not_finish = True
40
+
41
+ if not not_finish:
42
+ self.logger.log(level, f'All initialized.')
43
+
44
+ return not not_finish
45
+
46
+ def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>', **kwargs):
47
+ if source == 'huggingface':
48
+ hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
49
+ try:
50
+ download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
51
+ except:
52
+ download_path = None
53
+ if download_path is None or force_redownload:
54
+ self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
55
+ download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
56
+ else:
57
+ self.logger.log(logging.INFO, f'Load from cache: {download_path}')
58
+ elif source == 'local':
59
+ self.logger.log(logging.INFO, f'Load from local: {local_path}')
60
+ download_path = local_path
61
+
62
+ self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs)
63
+
64
+ def _load(
65
+ self,
66
+ vocos_config_path: str = None,
67
+ vocos_ckpt_path: str = None,
68
+ dvae_config_path: str = None,
69
+ dvae_ckpt_path: str = None,
70
+ gpt_config_path: str = None,
71
+ gpt_ckpt_path: str = None,
72
+ decoder_config_path: str = None,
73
+ decoder_ckpt_path: str = None,
74
+ tokenizer_path: str = None,
75
+ device: str = None,
76
+ compile: bool = True,
77
+ ):
78
+ if not device:
79
+ device = select_device(4096)
80
+ self.logger.log(logging.INFO, f'use {device}')
81
+
82
+ if vocos_config_path:
83
+ vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
84
+ assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
85
+ vocos.load_state_dict(torch.load(vocos_ckpt_path))
86
+ self.pretrain_models['vocos'] = vocos
87
+ self.logger.log(logging.INFO, 'vocos loaded.')
88
+
89
+ if dvae_config_path:
90
+ cfg = OmegaConf.load(dvae_config_path)
91
+ dvae = DVAE(**cfg).to(device).eval()
92
+ assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
93
+ dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
94
+ self.pretrain_models['dvae'] = dvae
95
+ self.logger.log(logging.INFO, 'dvae loaded.')
96
+
97
+ if gpt_config_path:
98
+ cfg = OmegaConf.load(gpt_config_path)
99
+ gpt = GPT_warpper(**cfg).to(device).eval()
100
+ assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
101
+ gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
102
+ if compile and 'cuda' in str(device):
103
+ gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True)
104
+ self.pretrain_models['gpt'] = gpt
105
+ spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt')
106
+ assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'
107
+ self.pretrain_models['spk_stat'] = torch.load(spk_stat_path).to(device)
108
+ self.logger.log(logging.INFO, 'gpt loaded.')
109
+
110
+ if decoder_config_path:
111
+ cfg = OmegaConf.load(decoder_config_path)
112
+ decoder = DVAE(**cfg).to(device).eval()
113
+ assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
114
+ decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
115
+ self.pretrain_models['decoder'] = decoder
116
+ self.logger.log(logging.INFO, 'decoder loaded.')
117
+
118
+ if tokenizer_path:
119
+ tokenizer = torch.load(tokenizer_path, map_location='cpu')
120
+ tokenizer.padding_side = 'left'
121
+ self.pretrain_models['tokenizer'] = tokenizer
122
+ self.logger.log(logging.INFO, 'tokenizer loaded.')
123
+
124
+ self.check_model()
125
+
126
+ def infer(
127
+ self,
128
+ text,
129
+ skip_refine_text=False,
130
+ refine_text_only=False,
131
+ params_refine_text={},
132
+ params_infer_code={'prompt':'[speed_5]'},
133
+ use_decoder=True,
134
+ do_text_normalization=True,
135
+ lang=None,
136
+ ):
137
+
138
+ assert self.check_model(use_decoder=use_decoder)
139
+
140
+ if not isinstance(text, list):
141
+ text = [text]
142
+
143
+ if do_text_normalization:
144
+ for i, t in enumerate(text):
145
+ _lang = detect_language(t) if lang is None else lang
146
+ self.init_normalizer(_lang)
147
+ text[i] = self.normalizer[_lang](t)
148
+ if _lang == 'zh':
149
+ text[i] = apply_half2full_map(text[i])
150
+
151
+ for i, t in enumerate(text):
152
+ invalid_characters = count_invalid_characters(t)
153
+ if len(invalid_characters):
154
+ self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
155
+ text[i] = apply_character_map(t)
156
+
157
+ if not skip_refine_text:
158
+ text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
159
+ text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
160
+ text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
161
+ if refine_text_only:
162
+ return text
163
+
164
+ text = [params_infer_code.get('prompt', '') + i for i in text]
165
+ params_infer_code.pop('prompt', '')
166
+ result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
167
+
168
+ if use_decoder:
169
+ mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
170
+ else:
171
+ mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
172
+
173
+ wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
174
+
175
+ return wav
176
+
177
+ def sample_random_speaker(self, ):
178
+
179
+ dim = self.pretrain_models['gpt'].gpt.layers[0].mlp.gate_proj.in_features
180
+ std, mean = self.pretrain_models['spk_stat'].chunk(2)
181
+ return torch.randn(dim, device=std.device) * std + mean
182
+
183
+ def init_normalizer(self, lang):
184
+
185
+ if lang not in self.normalizer:
186
+ if lang == 'zh':
187
+ try:
188
+ from tn.chinese.normalizer import Normalizer
189
+ except:
190
+ self.logger.log(logging.WARNING, f'Package WeTextProcessing not found! \
191
+ Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing')
192
+ self.normalizer[lang] = Normalizer().normalize
193
+ else:
194
+ try:
195
+ from nemo_text_processing.text_normalization.normalize import Normalizer
196
+ except:
197
+ self.logger.log(logging.WARNING, f'Package nemo_text_processing not found! \
198
+ Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing')
199
+ self.normalizer[lang] = partial(Normalizer(input_case='cased', lang=lang).normalize, verbose=False, punct_post_process=True)
200
+
File without changes
@@ -0,0 +1,40 @@
1
+
2
+ from openai import OpenAI
3
+
4
+ prompt_dict = {
5
+ 'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"},
6
+ {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
7
+ {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
8
+ 'deepseek': [
9
+ {"role": "system", "content": "You are a helpful assistant"},
10
+ {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
11
+ {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
12
+ 'deepseek_TN': [
13
+ {"role": "system", "content": "You are a helpful assistant"},
14
+ {"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"},
15
+ {"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"},
16
+ {"role": "user", "content": "We paid $123 for this desk."},
17
+ {"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."},
18
+ {"role": "user", "content": "详询请拨打010-724654"},
19
+ {"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"},
20
+ {"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"},
21
+ {"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"},
22
+ ],
23
+ }
24
+
25
+ class llm_api:
26
+ def __init__(self, api_key, base_url, model):
27
+ self.client = OpenAI(
28
+ api_key = api_key,
29
+ base_url = base_url,
30
+ )
31
+ self.model = model
32
+ def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs):
33
+
34
+ completion = self.client.chat.completions.create(
35
+ model = self.model,
36
+ messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},],
37
+ temperature = temperature,
38
+ **kwargs
39
+ )
40
+ return completion.choices[0].message.content
File without changes
@@ -0,0 +1,125 @@
1
+
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
5
+ from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
6
+
7
+ def infer_code(
8
+ models,
9
+ text,
10
+ spk_emb = None,
11
+ top_P = 0.7,
12
+ top_K = 20,
13
+ temperature = 0.3,
14
+ repetition_penalty = 1.05,
15
+ max_new_token = 2048,
16
+ **kwargs
17
+ ):
18
+
19
+ device = next(models['gpt'].parameters()).device
20
+
21
+ if not isinstance(text, list):
22
+ text = [text]
23
+
24
+ if not isinstance(temperature, list):
25
+ temperature = [temperature] * models['gpt'].num_vq
26
+
27
+ if spk_emb is not None:
28
+ text = [f'[Stts][spk_emb]{i}[Ptts]' for i in text]
29
+ else:
30
+ text = [f'[Stts][empty_spk]{i}[Ptts]' for i in text]
31
+
32
+ text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
33
+ input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq)
34
+ text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
35
+
36
+ inputs = {
37
+ 'input_ids': input_ids,
38
+ 'text_mask': text_mask,
39
+ 'attention_mask': text_token['attention_mask'],
40
+ }
41
+
42
+ emb = models['gpt'].get_emb(**inputs)
43
+ if spk_emb is not None:
44
+ emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \
45
+ F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12)
46
+
47
+ num_code = models['gpt'].emb_code[0].num_embeddings - 1
48
+
49
+ LogitsWarpers = []
50
+ if top_P is not None:
51
+ LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
52
+ if top_K is not None:
53
+ LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
54
+
55
+ LogitsProcessors = []
56
+ if repetition_penalty is not None and repetition_penalty != 1:
57
+ LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\
58
+ repetition_penalty, num_code, 16))
59
+
60
+ result = models['gpt'].generate(
61
+ emb, inputs['input_ids'],
62
+ temperature = torch.tensor(temperature, device=device),
63
+ attention_mask = inputs['attention_mask'],
64
+ LogitsWarpers = LogitsWarpers,
65
+ LogitsProcessors = LogitsProcessors,
66
+ eos_token = num_code,
67
+ max_new_token = max_new_token,
68
+ infer_text = False,
69
+ **kwargs
70
+ )
71
+
72
+ return result
73
+
74
+
75
+ def refine_text(
76
+ models,
77
+ text,
78
+ top_P = 0.7,
79
+ top_K = 20,
80
+ temperature = 0.7,
81
+ repetition_penalty = 1.0,
82
+ max_new_token = 384,
83
+ prompt = '',
84
+ **kwargs
85
+ ):
86
+
87
+ device = next(models['gpt'].parameters()).device
88
+
89
+ if not isinstance(text, list):
90
+ text = [text]
91
+
92
+ assert len(text), 'text should not be empty'
93
+
94
+ text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
95
+ text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
96
+ text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
97
+
98
+ inputs = {
99
+ 'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq),
100
+ 'text_mask': text_mask,
101
+ 'attention_mask': text_token['attention_mask'],
102
+ }
103
+
104
+ LogitsWarpers = []
105
+ if top_P is not None:
106
+ LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
107
+ if top_K is not None:
108
+ LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
109
+
110
+ LogitsProcessors = []
111
+ if repetition_penalty is not None and repetition_penalty != 1:
112
+ LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16))
113
+
114
+ result = models['gpt'].generate(
115
+ models['gpt'].get_emb(**inputs), inputs['input_ids'],
116
+ temperature = torch.tensor([temperature,], device=device),
117
+ attention_mask = inputs['attention_mask'],
118
+ LogitsWarpers = LogitsWarpers,
119
+ LogitsProcessors = LogitsProcessors,
120
+ eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
121
+ max_new_token = max_new_token,
122
+ infer_text = True,
123
+ **kwargs
124
+ )
125
+ return result
File without changes
@@ -0,0 +1,155 @@
1
+ import math
2
+ from einops import rearrange
3
+ from vector_quantize_pytorch import GroupedResidualFSQ
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ class ConvNeXtBlock(nn.Module):
10
+ def __init__(
11
+ self,
12
+ dim: int,
13
+ intermediate_dim: int,
14
+ kernel, dilation,
15
+ layer_scale_init_value: float = 1e-6,
16
+ ):
17
+ # ConvNeXt Block copied from Vocos.
18
+ super().__init__()
19
+ self.dwconv = nn.Conv1d(dim, dim,
20
+ kernel_size=kernel, padding=dilation*(kernel//2),
21
+ dilation=dilation, groups=dim
22
+ ) # depthwise conv
23
+
24
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
25
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
26
+ self.act = nn.GELU()
27
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
28
+ self.gamma = (
29
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
30
+ if layer_scale_init_value > 0
31
+ else None
32
+ )
33
+
34
+ def forward(self, x: torch.Tensor, cond = None) -> torch.Tensor:
35
+ residual = x
36
+ x = self.dwconv(x)
37
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
38
+ x = self.norm(x)
39
+ x = self.pwconv1(x)
40
+ x = self.act(x)
41
+ x = self.pwconv2(x)
42
+ if self.gamma is not None:
43
+ x = self.gamma * x
44
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
45
+
46
+ x = residual + x
47
+ return x
48
+
49
+
50
+
51
+ class GFSQ(nn.Module):
52
+
53
+ def __init__(self,
54
+ dim, levels, G, R, eps=1e-5, transpose = True
55
+ ):
56
+ super(GFSQ, self).__init__()
57
+ self.quantizer = GroupedResidualFSQ(
58
+ dim=dim,
59
+ levels=levels,
60
+ num_quantizers=R,
61
+ groups=G,
62
+ )
63
+ self.n_ind = math.prod(levels)
64
+ self.eps = eps
65
+ self.transpose = transpose
66
+ self.G = G
67
+ self.R = R
68
+
69
+ def _embed(self, x):
70
+ if self.transpose:
71
+ x = x.transpose(1,2)
72
+ x = rearrange(
73
+ x, "b t (g r) -> g b t r", g = self.G, r = self.R,
74
+ )
75
+ feat = self.quantizer.get_output_from_indices(x)
76
+ return feat.transpose(1,2) if self.transpose else feat
77
+
78
+ def forward(self, x,):
79
+ if self.transpose:
80
+ x = x.transpose(1,2)
81
+ feat, ind = self.quantizer(x)
82
+ ind = rearrange(
83
+ ind, "g b t r ->b t (g r)",
84
+ )
85
+ embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
86
+ e_mean = torch.mean(embed_onehot, dim=[0,1])
87
+ e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
88
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
89
+
90
+ return (
91
+ torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
92
+ feat.transpose(1,2) if self.transpose else feat,
93
+ perplexity,
94
+ None,
95
+ ind.transpose(1,2) if self.transpose else ind,
96
+ )
97
+
98
+ class DVAEDecoder(nn.Module):
99
+ def __init__(self, idim, odim,
100
+ n_layer = 12, bn_dim = 64, hidden = 256,
101
+ kernel = 7, dilation = 2, up = False
102
+ ):
103
+ super().__init__()
104
+ self.up = up
105
+ self.conv_in = nn.Sequential(
106
+ nn.Conv1d(idim, bn_dim, 3, 1, 1), nn.GELU(),
107
+ nn.Conv1d(bn_dim, hidden, 3, 1, 1)
108
+ )
109
+ self.decoder_block = nn.ModuleList([
110
+ ConvNeXtBlock(hidden, hidden* 4, kernel, dilation,)
111
+ for _ in range(n_layer)])
112
+ self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
113
+
114
+ def forward(self, input, conditioning=None):
115
+ # B, T, C
116
+ x = input.transpose(1, 2)
117
+ x = self.conv_in(x)
118
+ for f in self.decoder_block:
119
+ x = f(x, conditioning)
120
+
121
+ x = self.conv_out(x)
122
+ return x.transpose(1, 2)
123
+
124
+
125
+ class DVAE(nn.Module):
126
+ def __init__(
127
+ self, decoder_config, vq_config, dim=512
128
+ ):
129
+ super().__init__()
130
+ self.register_buffer('coef', torch.randn(1, 100, 1))
131
+
132
+ self.decoder = DVAEDecoder(**decoder_config)
133
+ self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
134
+ if vq_config is not None:
135
+ self.vq_layer = GFSQ(**vq_config)
136
+ else:
137
+ self.vq_layer = None
138
+
139
+ def forward(self, inp):
140
+
141
+ if self.vq_layer is not None:
142
+ vq_feats = self.vq_layer._embed(inp)
143
+ else:
144
+ vq_feats = inp.detach().clone()
145
+
146
+ temp = torch.chunk(vq_feats, 2, dim=1) # flatten trick :)
147
+ temp = torch.stack(temp, -1)
148
+ vq_feats = temp.reshape(*temp.shape[:2], -1)
149
+
150
+ vq_feats = vq_feats.transpose(1, 2)
151
+ dec_out = self.decoder(input=vq_feats)
152
+ dec_out = self.out_conv(dec_out.transpose(1, 2))
153
+ mel = dec_out * self.coef
154
+
155
+ return mel