xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +51 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +209 -40
- xinference/client/restful/restful_client.py +7 -26
- xinference/conftest.py +1 -1
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +1 -1
- xinference/core/chat_interface.py +8 -14
- xinference/core/event.py +1 -1
- xinference/core/image_interface.py +28 -0
- xinference/core/model.py +110 -31
- xinference/core/scheduler.py +37 -37
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +17 -10
- xinference/core/utils.py +80 -22
- xinference/core/worker.py +17 -16
- xinference/deploy/cmdline.py +8 -16
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/deploy/utils.py +1 -1
- xinference/deploy/worker.py +1 -1
- xinference/model/audio/cosyvoice.py +86 -41
- xinference/model/audio/fish_speech.py +9 -9
- xinference/model/audio/model_spec.json +9 -9
- xinference/model/audio/whisper.py +4 -1
- xinference/model/embedding/core.py +52 -31
- xinference/model/image/core.py +2 -1
- xinference/model/image/model_spec.json +16 -4
- xinference/model/image/model_spec_modelscope.json +16 -4
- xinference/model/image/sdapi.py +136 -0
- xinference/model/image/stable_diffusion/core.py +164 -19
- xinference/model/llm/__init__.py +29 -11
- xinference/model/llm/llama_cpp/core.py +16 -33
- xinference/model/llm/llm_family.json +1011 -1296
- xinference/model/llm/llm_family.py +34 -53
- xinference/model/llm/llm_family_csghub.json +18 -35
- xinference/model/llm/llm_family_modelscope.json +981 -1122
- xinference/model/llm/lmdeploy/core.py +56 -88
- xinference/model/llm/mlx/core.py +46 -69
- xinference/model/llm/sglang/core.py +36 -18
- xinference/model/llm/transformers/chatglm.py +168 -306
- xinference/model/llm/transformers/cogvlm2.py +36 -63
- xinference/model/llm/transformers/cogvlm2_video.py +33 -223
- xinference/model/llm/transformers/core.py +55 -50
- xinference/model/llm/transformers/deepseek_v2.py +340 -0
- xinference/model/llm/transformers/deepseek_vl.py +53 -96
- xinference/model/llm/transformers/glm4v.py +55 -111
- xinference/model/llm/transformers/intern_vl.py +39 -70
- xinference/model/llm/transformers/internlm2.py +32 -54
- xinference/model/llm/transformers/minicpmv25.py +22 -55
- xinference/model/llm/transformers/minicpmv26.py +158 -68
- xinference/model/llm/transformers/omnilmm.py +5 -28
- xinference/model/llm/transformers/qwen2_audio.py +168 -0
- xinference/model/llm/transformers/qwen2_vl.py +234 -0
- xinference/model/llm/transformers/qwen_vl.py +34 -86
- xinference/model/llm/transformers/utils.py +32 -38
- xinference/model/llm/transformers/yi_vl.py +32 -72
- xinference/model/llm/utils.py +280 -554
- xinference/model/llm/vllm/core.py +161 -100
- xinference/model/rerank/core.py +41 -8
- xinference/model/rerank/model_spec.json +7 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -1
- xinference/model/utils.py +1 -31
- xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
- xinference/thirdparty/cosyvoice/cli/model.py +139 -26
- xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
- xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
- xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
- xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
- xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
- xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
- xinference/thirdparty/cosyvoice/utils/common.py +36 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
- xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
- xinference/thirdparty/fish_speech/tools/api.py +79 -134
- xinference/thirdparty/fish_speech/tools/commons.py +35 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
- xinference/thirdparty/fish_speech/tools/file.py +17 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
- xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
- xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
- xinference/thirdparty/fish_speech/tools/webui.py +12 -146
- xinference/thirdparty/matcha/VERSION +1 -0
- xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
- xinference/thirdparty/matcha/hifigan/README.md +101 -0
- xinference/thirdparty/omnilmm/LICENSE +201 -0
- xinference/thirdparty/whisper/__init__.py +156 -0
- xinference/thirdparty/whisper/__main__.py +3 -0
- xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
- xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
- xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
- xinference/thirdparty/whisper/audio.py +157 -0
- xinference/thirdparty/whisper/decoding.py +826 -0
- xinference/thirdparty/whisper/model.py +314 -0
- xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
- xinference/thirdparty/whisper/normalizers/basic.py +76 -0
- xinference/thirdparty/whisper/normalizers/english.json +1741 -0
- xinference/thirdparty/whisper/normalizers/english.py +550 -0
- xinference/thirdparty/whisper/timing.py +386 -0
- xinference/thirdparty/whisper/tokenizer.py +395 -0
- xinference/thirdparty/whisper/transcribe.py +605 -0
- xinference/thirdparty/whisper/triton_ops.py +109 -0
- xinference/thirdparty/whisper/utils.py +316 -0
- xinference/thirdparty/whisper/version.py +1 -0
- xinference/types.py +14 -53
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
- xinference/web/ui/build/static/js/main.754740c0.js +3 -0
- xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
- xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +37 -0
- xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
- xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
- xinference/web/ui/node_modules/nunjucks/package.json +112 -0
- xinference/web/ui/package-lock.json +38 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
- xinference/model/llm/transformers/llama_2.py +0 -108
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
- xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
- xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
- xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
- xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import print_function
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import logging
|
|
19
|
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
20
|
+
import os
|
|
21
|
+
import sys
|
|
22
|
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
23
|
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
|
24
|
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
|
25
|
+
import torch
|
|
26
|
+
from cosyvoice.cli.cosyvoice import CosyVoice
|
|
27
|
+
|
|
28
|
+
def get_args():
|
|
29
|
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
|
30
|
+
parser.add_argument('--model_dir',
|
|
31
|
+
type=str,
|
|
32
|
+
default='pretrained_models/CosyVoice-300M',
|
|
33
|
+
help='local path')
|
|
34
|
+
args = parser.parse_args()
|
|
35
|
+
print(args)
|
|
36
|
+
return args
|
|
37
|
+
|
|
38
|
+
def main():
|
|
39
|
+
args = get_args()
|
|
40
|
+
logging.basicConfig(level=logging.DEBUG,
|
|
41
|
+
format='%(asctime)s %(levelname)s %(message)s')
|
|
42
|
+
|
|
43
|
+
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
|
|
44
|
+
torch._C._jit_set_profiling_mode(False)
|
|
45
|
+
torch._C._jit_set_profiling_executor(False)
|
|
46
|
+
|
|
47
|
+
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
|
|
48
|
+
|
|
49
|
+
# 1. export llm text_encoder
|
|
50
|
+
llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
|
|
51
|
+
script = torch.jit.script(llm_text_encoder)
|
|
52
|
+
script = torch.jit.freeze(script)
|
|
53
|
+
script = torch.jit.optimize_for_inference(script)
|
|
54
|
+
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
|
55
|
+
|
|
56
|
+
# 2. export llm llm
|
|
57
|
+
llm_llm = cosyvoice.model.llm.llm.half()
|
|
58
|
+
script = torch.jit.script(llm_llm)
|
|
59
|
+
script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
|
|
60
|
+
script = torch.jit.optimize_for_inference(script)
|
|
61
|
+
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
|
62
|
+
|
|
63
|
+
if __name__ == '__main__':
|
|
64
|
+
main()
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
# TODO 跟export_jit一样的逻辑,完成flow部分的estimator的onnx导出。
|
|
2
|
+
# tensorrt的安装方式,再这里写一下步骤提示如下,如果没有安装,那么不要执行这个脚本,提示用户先安装,不给选择
|
|
3
|
+
try:
|
|
4
|
+
import tensorrt
|
|
5
|
+
except ImportError:
|
|
6
|
+
print('step1, 下载\n step2. 解压,安装whl,')
|
|
7
|
+
# 安装命令里tensosrt的根目录用环境变量导入,比如os.environ['tensorrt_root_dir']/bin/exetrace,然后python里subprocess里执行导出命令
|
|
8
|
+
# 后面我会在run.sh里写好执行命令 tensorrt_root_dir=xxxx python cosyvoice/bin/export_trt.py --model_dir xxx
|
|
@@ -100,10 +100,13 @@ def main():
|
|
|
100
100
|
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
|
101
101
|
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
|
102
102
|
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
|
103
|
-
|
|
103
|
+
tts_speeches = []
|
|
104
|
+
for model_output in model.inference(**model_input):
|
|
105
|
+
tts_speeches.append(model_output['tts_speech'])
|
|
106
|
+
tts_speeches = torch.concat(tts_speeches, dim=1)
|
|
104
107
|
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
|
105
108
|
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
|
106
|
-
torchaudio.save(tts_fn,
|
|
109
|
+
torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
|
|
107
110
|
f.write('{} {}\n'.format(tts_key, tts_fn))
|
|
108
111
|
f.flush()
|
|
109
112
|
f.close()
|
|
@@ -12,15 +12,16 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import os
|
|
15
|
-
import
|
|
15
|
+
import time
|
|
16
16
|
from hyperpyyaml import load_hyperpyyaml
|
|
17
17
|
from modelscope import snapshot_download
|
|
18
18
|
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
|
19
19
|
from cosyvoice.cli.model import CosyVoiceModel
|
|
20
|
+
from cosyvoice.utils.file_utils import logging
|
|
20
21
|
|
|
21
22
|
class CosyVoice:
|
|
22
23
|
|
|
23
|
-
def __init__(self, model_dir):
|
|
24
|
+
def __init__(self, model_dir, load_jit=True):
|
|
24
25
|
instruct = True if '-Instruct' in model_dir else False
|
|
25
26
|
self.model_dir = model_dir
|
|
26
27
|
if not os.path.exists(model_dir):
|
|
@@ -38,46 +39,61 @@ class CosyVoice:
|
|
|
38
39
|
self.model.load('{}/llm.pt'.format(model_dir),
|
|
39
40
|
'{}/flow.pt'.format(model_dir),
|
|
40
41
|
'{}/hift.pt'.format(model_dir))
|
|
42
|
+
if load_jit:
|
|
43
|
+
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
|
|
44
|
+
'{}/llm.llm.fp16.zip'.format(model_dir))
|
|
41
45
|
del configs
|
|
42
46
|
|
|
43
47
|
def list_avaliable_spks(self):
|
|
44
48
|
spks = list(self.frontend.spk2info.keys())
|
|
45
49
|
return spks
|
|
46
50
|
|
|
47
|
-
def inference_sft(self, tts_text, spk_id):
|
|
48
|
-
tts_speeches = []
|
|
51
|
+
def inference_sft(self, tts_text, spk_id, stream=False):
|
|
49
52
|
for i in self.frontend.text_normalize(tts_text, split=True):
|
|
50
53
|
model_input = self.frontend.frontend_sft(i, spk_id)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
+
start_time = time.time()
|
|
55
|
+
logging.info('synthesis text {}'.format(i))
|
|
56
|
+
for model_output in self.model.inference(**model_input, stream=stream):
|
|
57
|
+
speech_len = model_output['tts_speech'].shape[1] / 22050
|
|
58
|
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
59
|
+
yield model_output
|
|
60
|
+
start_time = time.time()
|
|
54
61
|
|
|
55
|
-
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
|
|
62
|
+
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
|
|
56
63
|
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
|
57
|
-
tts_speeches = []
|
|
58
64
|
for i in self.frontend.text_normalize(tts_text, split=True):
|
|
59
65
|
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
66
|
+
start_time = time.time()
|
|
67
|
+
logging.info('synthesis text {}'.format(i))
|
|
68
|
+
for model_output in self.model.inference(**model_input, stream=stream):
|
|
69
|
+
speech_len = model_output['tts_speech'].shape[1] / 22050
|
|
70
|
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
71
|
+
yield model_output
|
|
72
|
+
start_time = time.time()
|
|
63
73
|
|
|
64
|
-
def inference_cross_lingual(self, tts_text, prompt_speech_16k):
|
|
74
|
+
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
|
|
65
75
|
if self.frontend.instruct is True:
|
|
66
76
|
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
|
67
|
-
tts_speeches = []
|
|
68
77
|
for i in self.frontend.text_normalize(tts_text, split=True):
|
|
69
78
|
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
79
|
+
start_time = time.time()
|
|
80
|
+
logging.info('synthesis text {}'.format(i))
|
|
81
|
+
for model_output in self.model.inference(**model_input, stream=stream):
|
|
82
|
+
speech_len = model_output['tts_speech'].shape[1] / 22050
|
|
83
|
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
84
|
+
yield model_output
|
|
85
|
+
start_time = time.time()
|
|
73
86
|
|
|
74
|
-
def inference_instruct(self, tts_text, spk_id, instruct_text):
|
|
87
|
+
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False):
|
|
75
88
|
if self.frontend.instruct is False:
|
|
76
89
|
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
|
77
90
|
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
|
|
78
|
-
tts_speeches = []
|
|
79
91
|
for i in self.frontend.text_normalize(tts_text, split=True):
|
|
80
92
|
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
93
|
+
start_time = time.time()
|
|
94
|
+
logging.info('synthesis text {}'.format(i))
|
|
95
|
+
for model_output in self.model.inference(**model_input, stream=stream):
|
|
96
|
+
speech_len = model_output['tts_speech'].shape[1] / 22050
|
|
97
|
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
98
|
+
yield model_output
|
|
99
|
+
start_time = time.time()
|
|
@@ -12,6 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import torch
|
|
15
|
+
import numpy as np
|
|
16
|
+
import threading
|
|
17
|
+
import time
|
|
18
|
+
from contextlib import nullcontext
|
|
19
|
+
import uuid
|
|
20
|
+
from cosyvoice.utils.common import fade_in_out
|
|
21
|
+
|
|
15
22
|
|
|
16
23
|
class CosyVoiceModel:
|
|
17
24
|
|
|
@@ -23,38 +30,144 @@ class CosyVoiceModel:
|
|
|
23
30
|
self.llm = llm
|
|
24
31
|
self.flow = flow
|
|
25
32
|
self.hift = hift
|
|
33
|
+
self.token_min_hop_len = 100
|
|
34
|
+
self.token_max_hop_len = 200
|
|
35
|
+
self.token_overlap_len = 20
|
|
36
|
+
# mel fade in out
|
|
37
|
+
self.mel_overlap_len = 34
|
|
38
|
+
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
|
39
|
+
# hift cache
|
|
40
|
+
self.mel_cache_len = 20
|
|
41
|
+
self.source_cache_len = int(self.mel_cache_len * 256)
|
|
42
|
+
# rtf and decoding related
|
|
43
|
+
self.stream_scale_factor = 1
|
|
44
|
+
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
|
45
|
+
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
|
46
|
+
self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
|
47
|
+
self.lock = threading.Lock()
|
|
48
|
+
# dict used to store session related variable
|
|
49
|
+
self.tts_speech_token_dict = {}
|
|
50
|
+
self.llm_end_dict = {}
|
|
51
|
+
self.mel_overlap_dict = {}
|
|
52
|
+
self.hift_cache_dict = {}
|
|
26
53
|
|
|
27
54
|
def load(self, llm_model, flow_model, hift_model):
|
|
28
55
|
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
|
29
56
|
self.llm.to(self.device).eval()
|
|
57
|
+
self.llm.half()
|
|
30
58
|
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
|
|
31
59
|
self.flow.to(self.device).eval()
|
|
32
60
|
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
|
|
33
61
|
self.hift.to(self.device).eval()
|
|
34
62
|
|
|
35
|
-
def
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
63
|
+
def load_jit(self, llm_text_encoder_model, llm_llm_model):
|
|
64
|
+
llm_text_encoder = torch.jit.load(llm_text_encoder_model)
|
|
65
|
+
self.llm.text_encoder = llm_text_encoder
|
|
66
|
+
llm_llm = torch.jit.load(llm_llm_model)
|
|
67
|
+
self.llm.llm = llm_llm
|
|
68
|
+
|
|
69
|
+
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
|
70
|
+
with self.llm_context:
|
|
71
|
+
for i in self.llm.inference(text=text.to(self.device),
|
|
72
|
+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
|
73
|
+
prompt_text=prompt_text.to(self.device),
|
|
74
|
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
|
75
|
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
|
76
|
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
|
77
|
+
embedding=llm_embedding.to(self.device).half(),
|
|
78
|
+
sampling=25,
|
|
79
|
+
max_token_text_ratio=30,
|
|
80
|
+
min_token_text_ratio=3):
|
|
81
|
+
self.tts_speech_token_dict[uuid].append(i)
|
|
82
|
+
self.llm_end_dict[uuid] = True
|
|
83
|
+
|
|
84
|
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
|
|
85
|
+
with self.flow_hift_context:
|
|
86
|
+
tts_mel = self.flow.inference(token=token.to(self.device),
|
|
87
|
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
|
88
|
+
prompt_token=prompt_token.to(self.device),
|
|
89
|
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
|
90
|
+
prompt_feat=prompt_feat.to(self.device),
|
|
91
|
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
|
92
|
+
embedding=embedding.to(self.device))
|
|
93
|
+
# mel overlap fade in out
|
|
94
|
+
# if self.mel_overlap_dict[uuid] is not None:
|
|
95
|
+
# tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
|
96
|
+
# append hift cache
|
|
97
|
+
if self.hift_cache_dict[uuid] is not None:
|
|
98
|
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
|
99
|
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
|
100
|
+
else:
|
|
101
|
+
hift_cache_source = torch.zeros(1, 1, 0)
|
|
102
|
+
# keep overlap mel and hift cache
|
|
103
|
+
if finalize is False:
|
|
104
|
+
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
|
|
105
|
+
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
|
|
106
|
+
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
|
|
107
|
+
self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
|
|
108
|
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
|
109
|
+
else:
|
|
110
|
+
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
|
|
111
|
+
return tts_speech
|
|
112
|
+
|
|
113
|
+
def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
|
114
|
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
|
115
|
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
|
116
|
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
|
117
|
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
|
|
118
|
+
# this_uuid is used to track variables related to this inference thread
|
|
119
|
+
this_uuid = str(uuid.uuid1())
|
|
120
|
+
with self.lock:
|
|
121
|
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
|
|
122
|
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
|
123
|
+
p.start()
|
|
124
|
+
if stream is True:
|
|
125
|
+
token_hop_len = self.token_min_hop_len
|
|
126
|
+
while True:
|
|
127
|
+
time.sleep(0.1)
|
|
128
|
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
|
129
|
+
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
|
|
130
|
+
with self.flow_hift_context:
|
|
131
|
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
132
|
+
prompt_token=flow_prompt_speech_token,
|
|
133
|
+
prompt_feat=prompt_speech_feat,
|
|
134
|
+
embedding=flow_embedding,
|
|
135
|
+
uuid=this_uuid,
|
|
136
|
+
finalize=False)
|
|
137
|
+
yield {'tts_speech': this_tts_speech.cpu()}
|
|
138
|
+
with self.lock:
|
|
139
|
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
|
140
|
+
# increase token_hop_len for better speech quality
|
|
141
|
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
|
142
|
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
|
143
|
+
break
|
|
144
|
+
p.join()
|
|
145
|
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
|
146
|
+
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
|
147
|
+
with self.flow_hift_context:
|
|
148
|
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
149
|
+
prompt_token=flow_prompt_speech_token,
|
|
150
|
+
prompt_feat=prompt_speech_feat,
|
|
151
|
+
embedding=flow_embedding,
|
|
152
|
+
uuid=this_uuid,
|
|
153
|
+
finalize=True)
|
|
154
|
+
yield {'tts_speech': this_tts_speech.cpu()}
|
|
155
|
+
else:
|
|
156
|
+
# deal with all tokens
|
|
157
|
+
p.join()
|
|
158
|
+
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
|
159
|
+
with self.flow_hift_context:
|
|
160
|
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
161
|
+
prompt_token=flow_prompt_speech_token,
|
|
162
|
+
prompt_feat=prompt_speech_feat,
|
|
163
|
+
embedding=flow_embedding,
|
|
164
|
+
uuid=this_uuid,
|
|
165
|
+
finalize=True)
|
|
166
|
+
yield {'tts_speech': this_tts_speech.cpu()}
|
|
167
|
+
with self.lock:
|
|
168
|
+
self.tts_speech_token_dict.pop(this_uuid)
|
|
169
|
+
self.llm_end_dict.pop(this_uuid)
|
|
170
|
+
self.mel_overlap_dict.pop(this_uuid)
|
|
171
|
+
self.hift_cache_dict.pop(this_uuid)
|
|
172
|
+
if torch.cuda.is_initialized():
|
|
173
|
+
torch.cuda.synchronize()
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import logging
|
|
15
|
+
import random
|
|
15
16
|
from typing import Dict, Optional
|
|
16
17
|
import torch
|
|
17
18
|
import torch.nn as nn
|
|
@@ -77,6 +78,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
|
77
78
|
|
|
78
79
|
# get conditions
|
|
79
80
|
conds = torch.zeros(feat.shape, device=token.device)
|
|
81
|
+
for i, j in enumerate(feat_len):
|
|
82
|
+
if random.random() < 0.5:
|
|
83
|
+
continue
|
|
84
|
+
index = random.randint(0, int(0.3 * j))
|
|
85
|
+
conds[i, :index] = feat[i, :index]
|
|
80
86
|
conds = conds.transpose(1, 2)
|
|
81
87
|
|
|
82
88
|
mask = (~make_pad_mask(feat_len)).to(h)
|
|
@@ -105,6 +111,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
|
105
111
|
embedding = self.spk_embed_affine_layer(embedding)
|
|
106
112
|
|
|
107
113
|
# concat text and prompt_text
|
|
114
|
+
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
|
108
115
|
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
|
109
116
|
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
|
|
110
117
|
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
|
@@ -112,17 +119,16 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
|
112
119
|
# text encode
|
|
113
120
|
h, h_lengths = self.encoder(token, token_len)
|
|
114
121
|
h = self.encoder_proj(h)
|
|
115
|
-
|
|
116
|
-
h, h_lengths = self.length_regulator(h,
|
|
122
|
+
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / 50 * 22050 / 256)
|
|
123
|
+
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
|
|
117
124
|
|
|
118
125
|
# get conditions
|
|
119
|
-
conds = torch.zeros([1,
|
|
120
|
-
|
|
121
|
-
for i, j in enumerate(prompt_feat_len):
|
|
122
|
-
conds[i, :j] = prompt_feat[i]
|
|
126
|
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
|
127
|
+
conds[:, :mel_len1] = prompt_feat
|
|
123
128
|
conds = conds.transpose(1, 2)
|
|
124
129
|
|
|
125
|
-
mask = (~make_pad_mask(feat_len)).to(h)
|
|
130
|
+
# mask = (~make_pad_mask(feat_len)).to(h)
|
|
131
|
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
|
126
132
|
feat = self.decoder(
|
|
127
133
|
mu=h.transpose(1, 2).contiguous(),
|
|
128
134
|
mask=mask.unsqueeze(1),
|
|
@@ -130,6 +136,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
|
130
136
|
cond=conds,
|
|
131
137
|
n_timesteps=10
|
|
132
138
|
)
|
|
133
|
-
|
|
134
|
-
|
|
139
|
+
feat = feat[:, :, mel_len1:]
|
|
140
|
+
assert feat.shape[2] == mel_len2
|
|
135
141
|
return feat
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
from typing import Tuple
|
|
15
15
|
import torch.nn as nn
|
|
16
|
+
import torch
|
|
16
17
|
from torch.nn import functional as F
|
|
17
18
|
from cosyvoice.utils.mask import make_pad_mask
|
|
18
19
|
|
|
@@ -43,7 +44,25 @@ class InterpolateRegulator(nn.Module):
|
|
|
43
44
|
def forward(self, x, ylens=None):
|
|
44
45
|
# x in (B, T, D)
|
|
45
46
|
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
|
46
|
-
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='
|
|
47
|
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
|
|
47
48
|
out = self.model(x).transpose(1, 2).contiguous()
|
|
48
49
|
olens = ylens
|
|
49
50
|
return out * mask, olens
|
|
51
|
+
|
|
52
|
+
def inference(self, x1, x2, mel_len1, mel_len2):
|
|
53
|
+
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
|
54
|
+
# x in (B, T, D)
|
|
55
|
+
if x2.shape[1] > 40:
|
|
56
|
+
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
|
|
57
|
+
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
|
|
58
|
+
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
|
|
59
|
+
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
|
60
|
+
else:
|
|
61
|
+
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
|
62
|
+
if x1.shape[1] != 0:
|
|
63
|
+
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
|
|
64
|
+
x = torch.concat([x1, x2], dim=2)
|
|
65
|
+
else:
|
|
66
|
+
x = x2
|
|
67
|
+
out = self.model(x).transpose(1, 2).contiguous()
|
|
68
|
+
return out, mel_len1 + mel_len2
|
|
@@ -335,10 +335,14 @@ class HiFTGenerator(nn.Module):
|
|
|
335
335
|
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
|
336
336
|
return inverse_transform
|
|
337
337
|
|
|
338
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
338
|
+
def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
|
339
339
|
f0 = self.f0_predictor(x)
|
|
340
340
|
s = self._f02source(f0)
|
|
341
341
|
|
|
342
|
+
# use cache_source to avoid glitch
|
|
343
|
+
if cache_source.shape[2] == 0:
|
|
344
|
+
s[:, :, :cache_source.shape[2]] = cache_source
|
|
345
|
+
|
|
342
346
|
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
|
343
347
|
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
|
344
348
|
|
|
@@ -370,7 +374,7 @@ class HiFTGenerator(nn.Module):
|
|
|
370
374
|
|
|
371
375
|
x = self._istft(magnitude, phase)
|
|
372
376
|
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
|
373
|
-
return x
|
|
377
|
+
return x, s
|
|
374
378
|
|
|
375
379
|
def remove_weight_norm(self):
|
|
376
380
|
print('Removing weight norm...')
|
|
@@ -387,5 +391,5 @@ class HiFTGenerator(nn.Module):
|
|
|
387
391
|
l.remove_weight_norm()
|
|
388
392
|
|
|
389
393
|
@torch.inference_mode()
|
|
390
|
-
def inference(self, mel: torch.Tensor) -> torch.Tensor:
|
|
391
|
-
return self.forward(x=mel)
|
|
394
|
+
def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
|
395
|
+
return self.forward(x=mel, cache_source=cache_source)
|
|
@@ -11,7 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
from typing import Dict, Optional,
|
|
14
|
+
from typing import Dict, Optional, Callable, List, Generator
|
|
15
15
|
import torch
|
|
16
16
|
from torch import nn
|
|
17
17
|
import torch.nn.functional as F
|
|
@@ -31,6 +31,7 @@ class TransformerLM(torch.nn.Module):
|
|
|
31
31
|
speech_token_size: int,
|
|
32
32
|
text_encoder: torch.nn.Module,
|
|
33
33
|
llm: torch.nn.Module,
|
|
34
|
+
sampling: Callable,
|
|
34
35
|
length_normalized_loss: bool = True,
|
|
35
36
|
lsm_weight: float = 0.0,
|
|
36
37
|
spk_embed_dim: int = 192,
|
|
@@ -63,6 +64,9 @@ class TransformerLM(torch.nn.Module):
|
|
|
63
64
|
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
|
|
64
65
|
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
|
|
65
66
|
|
|
67
|
+
# 4. sampling method
|
|
68
|
+
self.sampling = sampling
|
|
69
|
+
|
|
66
70
|
def encode(
|
|
67
71
|
self,
|
|
68
72
|
text: torch.Tensor,
|
|
@@ -132,14 +136,12 @@ class TransformerLM(torch.nn.Module):
|
|
|
132
136
|
def sampling_ids(
|
|
133
137
|
self,
|
|
134
138
|
weighted_scores: torch.Tensor,
|
|
135
|
-
|
|
136
|
-
|
|
139
|
+
decoded_tokens: List,
|
|
140
|
+
sampling: int,
|
|
137
141
|
ignore_eos: bool = True,
|
|
138
142
|
):
|
|
139
143
|
while True:
|
|
140
|
-
|
|
141
|
-
top_ids = prob.multinomial(beam_size, replacement=True)
|
|
142
|
-
top_ids = indices[top_ids]
|
|
144
|
+
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
|
143
145
|
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
|
144
146
|
break
|
|
145
147
|
return top_ids
|
|
@@ -154,11 +156,10 @@ class TransformerLM(torch.nn.Module):
|
|
|
154
156
|
prompt_speech_token: torch.Tensor,
|
|
155
157
|
prompt_speech_token_len: torch.Tensor,
|
|
156
158
|
embedding: torch.Tensor,
|
|
157
|
-
beam_size: int = 1,
|
|
158
159
|
sampling: int = 25,
|
|
159
160
|
max_token_text_ratio: float = 20,
|
|
160
161
|
min_token_text_ratio: float = 2,
|
|
161
|
-
) -> torch.Tensor:
|
|
162
|
+
) -> Generator[torch.Tensor, None, None]:
|
|
162
163
|
device = text.device
|
|
163
164
|
text = torch.concat([prompt_text, text], dim=1)
|
|
164
165
|
text_len += prompt_text_len
|
|
@@ -173,7 +174,7 @@ class TransformerLM(torch.nn.Module):
|
|
|
173
174
|
embedding = self.spk_embed_affine_layer(embedding)
|
|
174
175
|
embedding = embedding.unsqueeze(dim=1)
|
|
175
176
|
else:
|
|
176
|
-
embedding = torch.zeros(1, 0, self.llm_input_size).to(device)
|
|
177
|
+
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
|
177
178
|
|
|
178
179
|
# 3. concat llm_input
|
|
179
180
|
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
|
@@ -181,7 +182,7 @@ class TransformerLM(torch.nn.Module):
|
|
|
181
182
|
if prompt_speech_token_len != 0:
|
|
182
183
|
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
|
183
184
|
else:
|
|
184
|
-
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device)
|
|
185
|
+
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
|
185
186
|
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
|
186
187
|
|
|
187
188
|
# 4. cal min/max_length
|
|
@@ -196,11 +197,11 @@ class TransformerLM(torch.nn.Module):
|
|
|
196
197
|
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
|
|
197
198
|
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
|
|
198
199
|
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
|
199
|
-
top_ids = self.sampling_ids(logp.squeeze(dim=0),
|
|
200
|
+
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
|
200
201
|
if top_ids == self.speech_token_size:
|
|
201
202
|
break
|
|
203
|
+
# in stream mode, yield token one by one
|
|
204
|
+
yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
|
|
202
205
|
out_tokens.append(top_ids)
|
|
203
206
|
offset += lm_input.size(1)
|
|
204
207
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
|
205
|
-
|
|
206
|
-
return torch.tensor([out_tokens], dtype=torch.int64, device=device)
|
|
@@ -222,7 +222,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
222
222
|
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
|
223
223
|
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
|
224
224
|
|
|
225
|
-
def rel_shift(self, x):
|
|
225
|
+
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
|
226
226
|
"""Compute relative positional encoding.
|
|
227
227
|
|
|
228
228
|
Args:
|
|
@@ -233,10 +233,14 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
233
233
|
torch.Tensor: Output tensor.
|
|
234
234
|
|
|
235
235
|
"""
|
|
236
|
-
zero_pad = torch.zeros((
|
|
236
|
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
|
237
|
+
device=x.device,
|
|
238
|
+
dtype=x.dtype)
|
|
237
239
|
x_padded = torch.cat([zero_pad, x], dim=-1)
|
|
238
240
|
|
|
239
|
-
x_padded = x_padded.view(
|
|
241
|
+
x_padded = x_padded.view(x.size()[0],
|
|
242
|
+
x.size()[1],
|
|
243
|
+
x.size(3) + 1, x.size(2))
|
|
240
244
|
x = x_padded[:, :, 1:].view_as(x)[
|
|
241
245
|
:, :, :, : x.size(-1) // 2 + 1
|
|
242
246
|
] # only keep the positions from 0 to time2
|