xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +209 -40
  4. xinference/client/restful/restful_client.py +7 -26
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/image_interface.py +28 -0
  11. xinference/core/model.py +110 -31
  12. xinference/core/scheduler.py +37 -37
  13. xinference/core/status_guard.py +1 -1
  14. xinference/core/supervisor.py +17 -10
  15. xinference/core/utils.py +80 -22
  16. xinference/core/worker.py +17 -16
  17. xinference/deploy/cmdline.py +8 -16
  18. xinference/deploy/local.py +1 -1
  19. xinference/deploy/supervisor.py +1 -1
  20. xinference/deploy/utils.py +1 -1
  21. xinference/deploy/worker.py +1 -1
  22. xinference/model/audio/cosyvoice.py +86 -41
  23. xinference/model/audio/fish_speech.py +9 -9
  24. xinference/model/audio/model_spec.json +9 -9
  25. xinference/model/audio/whisper.py +4 -1
  26. xinference/model/embedding/core.py +52 -31
  27. xinference/model/image/core.py +2 -1
  28. xinference/model/image/model_spec.json +16 -4
  29. xinference/model/image/model_spec_modelscope.json +16 -4
  30. xinference/model/image/sdapi.py +136 -0
  31. xinference/model/image/stable_diffusion/core.py +164 -19
  32. xinference/model/llm/__init__.py +29 -11
  33. xinference/model/llm/llama_cpp/core.py +16 -33
  34. xinference/model/llm/llm_family.json +1011 -1296
  35. xinference/model/llm/llm_family.py +34 -53
  36. xinference/model/llm/llm_family_csghub.json +18 -35
  37. xinference/model/llm/llm_family_modelscope.json +981 -1122
  38. xinference/model/llm/lmdeploy/core.py +56 -88
  39. xinference/model/llm/mlx/core.py +46 -69
  40. xinference/model/llm/sglang/core.py +36 -18
  41. xinference/model/llm/transformers/chatglm.py +168 -306
  42. xinference/model/llm/transformers/cogvlm2.py +36 -63
  43. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  44. xinference/model/llm/transformers/core.py +55 -50
  45. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  46. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  47. xinference/model/llm/transformers/glm4v.py +55 -111
  48. xinference/model/llm/transformers/intern_vl.py +39 -70
  49. xinference/model/llm/transformers/internlm2.py +32 -54
  50. xinference/model/llm/transformers/minicpmv25.py +22 -55
  51. xinference/model/llm/transformers/minicpmv26.py +158 -68
  52. xinference/model/llm/transformers/omnilmm.py +5 -28
  53. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  54. xinference/model/llm/transformers/qwen2_vl.py +234 -0
  55. xinference/model/llm/transformers/qwen_vl.py +34 -86
  56. xinference/model/llm/transformers/utils.py +32 -38
  57. xinference/model/llm/transformers/yi_vl.py +32 -72
  58. xinference/model/llm/utils.py +280 -554
  59. xinference/model/llm/vllm/core.py +161 -100
  60. xinference/model/rerank/core.py +41 -8
  61. xinference/model/rerank/model_spec.json +7 -0
  62. xinference/model/rerank/model_spec_modelscope.json +7 -1
  63. xinference/model/utils.py +1 -31
  64. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  65. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  66. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  67. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  68. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  69. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  70. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  71. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  72. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  73. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  74. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  77. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  78. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  79. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  80. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  81. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  82. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  83. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  84. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  85. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  86. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  87. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  88. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  89. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  90. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  91. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
  92. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  93. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  94. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  95. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  96. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  97. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  98. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  99. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  100. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  101. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  102. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  103. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  104. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  105. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  107. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  108. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  109. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  110. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  111. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  112. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  113. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  114. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  115. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  116. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  117. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  118. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  122. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  123. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  124. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  126. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  127. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  128. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  129. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  130. xinference/thirdparty/matcha/VERSION +1 -0
  131. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  132. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  133. xinference/thirdparty/omnilmm/LICENSE +201 -0
  134. xinference/thirdparty/whisper/__init__.py +156 -0
  135. xinference/thirdparty/whisper/__main__.py +3 -0
  136. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  137. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  138. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  139. xinference/thirdparty/whisper/audio.py +157 -0
  140. xinference/thirdparty/whisper/decoding.py +826 -0
  141. xinference/thirdparty/whisper/model.py +314 -0
  142. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  143. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  144. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  145. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  146. xinference/thirdparty/whisper/timing.py +386 -0
  147. xinference/thirdparty/whisper/tokenizer.py +395 -0
  148. xinference/thirdparty/whisper/transcribe.py +605 -0
  149. xinference/thirdparty/whisper/triton_ops.py +109 -0
  150. xinference/thirdparty/whisper/utils.py +316 -0
  151. xinference/thirdparty/whisper/version.py +1 -0
  152. xinference/types.py +14 -53
  153. xinference/web/ui/build/asset-manifest.json +6 -6
  154. xinference/web/ui/build/index.html +1 -1
  155. xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
  156. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  157. xinference/web/ui/build/static/js/main.754740c0.js +3 -0
  158. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
  159. xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
  160. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  161. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  162. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  163. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  164. xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
  165. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  166. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  167. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  168. xinference/web/ui/node_modules/.package-lock.json +37 -0
  169. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  170. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  171. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  172. xinference/web/ui/package-lock.json +38 -0
  173. xinference/web/ui/package.json +1 -0
  174. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
  175. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
  176. xinference/model/llm/transformers/llama_2.py +0 -108
  177. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  178. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  179. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  180. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  181. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  182. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  183. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  184. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  185. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  186. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  187. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  188. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  189. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  190. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  191. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,64 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import sys
22
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
23
+ sys.path.append('{}/../..'.format(ROOT_DIR))
24
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
25
+ import torch
26
+ from cosyvoice.cli.cosyvoice import CosyVoice
27
+
28
+ def get_args():
29
+ parser = argparse.ArgumentParser(description='export your model for deployment')
30
+ parser.add_argument('--model_dir',
31
+ type=str,
32
+ default='pretrained_models/CosyVoice-300M',
33
+ help='local path')
34
+ args = parser.parse_args()
35
+ print(args)
36
+ return args
37
+
38
+ def main():
39
+ args = get_args()
40
+ logging.basicConfig(level=logging.DEBUG,
41
+ format='%(asctime)s %(levelname)s %(message)s')
42
+
43
+ torch._C._jit_set_fusion_strategy([('STATIC', 1)])
44
+ torch._C._jit_set_profiling_mode(False)
45
+ torch._C._jit_set_profiling_executor(False)
46
+
47
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
48
+
49
+ # 1. export llm text_encoder
50
+ llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
51
+ script = torch.jit.script(llm_text_encoder)
52
+ script = torch.jit.freeze(script)
53
+ script = torch.jit.optimize_for_inference(script)
54
+ script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
55
+
56
+ # 2. export llm llm
57
+ llm_llm = cosyvoice.model.llm.llm.half()
58
+ script = torch.jit.script(llm_llm)
59
+ script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
60
+ script = torch.jit.optimize_for_inference(script)
61
+ script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
62
+
63
+ if __name__ == '__main__':
64
+ main()
@@ -0,0 +1,8 @@
1
+ # TODO 跟export_jit一样的逻辑,完成flow部分的estimator的onnx导出。
2
+ # tensorrt的安装方式,再这里写一下步骤提示如下,如果没有安装,那么不要执行这个脚本,提示用户先安装,不给选择
3
+ try:
4
+ import tensorrt
5
+ except ImportError:
6
+ print('step1, 下载\n step2. 解压,安装whl,')
7
+ # 安装命令里tensosrt的根目录用环境变量导入,比如os.environ['tensorrt_root_dir']/bin/exetrace,然后python里subprocess里执行导出命令
8
+ # 后面我会在run.sh里写好执行命令 tensorrt_root_dir=xxxx python cosyvoice/bin/export_trt.py --model_dir xxx
@@ -100,10 +100,13 @@ def main():
100
100
  'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
101
101
  'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
102
102
  'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
103
- model_output = model.inference(**model_input)
103
+ tts_speeches = []
104
+ for model_output in model.inference(**model_input):
105
+ tts_speeches.append(model_output['tts_speech'])
106
+ tts_speeches = torch.concat(tts_speeches, dim=1)
104
107
  tts_key = '{}_{}'.format(utts[0], tts_index[0])
105
108
  tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
106
- torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
109
+ torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
107
110
  f.write('{} {}\n'.format(tts_key, tts_fn))
108
111
  f.flush()
109
112
  f.close()
@@ -12,15 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import os
15
- import torch
15
+ import time
16
16
  from hyperpyyaml import load_hyperpyyaml
17
17
  from modelscope import snapshot_download
18
18
  from cosyvoice.cli.frontend import CosyVoiceFrontEnd
19
19
  from cosyvoice.cli.model import CosyVoiceModel
20
+ from cosyvoice.utils.file_utils import logging
20
21
 
21
22
  class CosyVoice:
22
23
 
23
- def __init__(self, model_dir):
24
+ def __init__(self, model_dir, load_jit=True):
24
25
  instruct = True if '-Instruct' in model_dir else False
25
26
  self.model_dir = model_dir
26
27
  if not os.path.exists(model_dir):
@@ -38,46 +39,61 @@ class CosyVoice:
38
39
  self.model.load('{}/llm.pt'.format(model_dir),
39
40
  '{}/flow.pt'.format(model_dir),
40
41
  '{}/hift.pt'.format(model_dir))
42
+ if load_jit:
43
+ self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
44
+ '{}/llm.llm.fp16.zip'.format(model_dir))
41
45
  del configs
42
46
 
43
47
  def list_avaliable_spks(self):
44
48
  spks = list(self.frontend.spk2info.keys())
45
49
  return spks
46
50
 
47
- def inference_sft(self, tts_text, spk_id):
48
- tts_speeches = []
51
+ def inference_sft(self, tts_text, spk_id, stream=False):
49
52
  for i in self.frontend.text_normalize(tts_text, split=True):
50
53
  model_input = self.frontend.frontend_sft(i, spk_id)
51
- model_output = self.model.inference(**model_input)
52
- tts_speeches.append(model_output['tts_speech'])
53
- return {'tts_speech': torch.concat(tts_speeches, dim=1)}
54
+ start_time = time.time()
55
+ logging.info('synthesis text {}'.format(i))
56
+ for model_output in self.model.inference(**model_input, stream=stream):
57
+ speech_len = model_output['tts_speech'].shape[1] / 22050
58
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
59
+ yield model_output
60
+ start_time = time.time()
54
61
 
55
- def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
62
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
56
63
  prompt_text = self.frontend.text_normalize(prompt_text, split=False)
57
- tts_speeches = []
58
64
  for i in self.frontend.text_normalize(tts_text, split=True):
59
65
  model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
60
- model_output = self.model.inference(**model_input)
61
- tts_speeches.append(model_output['tts_speech'])
62
- return {'tts_speech': torch.concat(tts_speeches, dim=1)}
66
+ start_time = time.time()
67
+ logging.info('synthesis text {}'.format(i))
68
+ for model_output in self.model.inference(**model_input, stream=stream):
69
+ speech_len = model_output['tts_speech'].shape[1] / 22050
70
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
71
+ yield model_output
72
+ start_time = time.time()
63
73
 
64
- def inference_cross_lingual(self, tts_text, prompt_speech_16k):
74
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
65
75
  if self.frontend.instruct is True:
66
76
  raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
67
- tts_speeches = []
68
77
  for i in self.frontend.text_normalize(tts_text, split=True):
69
78
  model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
70
- model_output = self.model.inference(**model_input)
71
- tts_speeches.append(model_output['tts_speech'])
72
- return {'tts_speech': torch.concat(tts_speeches, dim=1)}
79
+ start_time = time.time()
80
+ logging.info('synthesis text {}'.format(i))
81
+ for model_output in self.model.inference(**model_input, stream=stream):
82
+ speech_len = model_output['tts_speech'].shape[1] / 22050
83
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
84
+ yield model_output
85
+ start_time = time.time()
73
86
 
74
- def inference_instruct(self, tts_text, spk_id, instruct_text):
87
+ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False):
75
88
  if self.frontend.instruct is False:
76
89
  raise ValueError('{} do not support instruct inference'.format(self.model_dir))
77
90
  instruct_text = self.frontend.text_normalize(instruct_text, split=False)
78
- tts_speeches = []
79
91
  for i in self.frontend.text_normalize(tts_text, split=True):
80
92
  model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
81
- model_output = self.model.inference(**model_input)
82
- tts_speeches.append(model_output['tts_speech'])
83
- return {'tts_speech': torch.concat(tts_speeches, dim=1)}
93
+ start_time = time.time()
94
+ logging.info('synthesis text {}'.format(i))
95
+ for model_output in self.model.inference(**model_input, stream=stream):
96
+ speech_len = model_output['tts_speech'].shape[1] / 22050
97
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
98
+ yield model_output
99
+ start_time = time.time()
@@ -12,6 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import torch
15
+ import numpy as np
16
+ import threading
17
+ import time
18
+ from contextlib import nullcontext
19
+ import uuid
20
+ from cosyvoice.utils.common import fade_in_out
21
+
15
22
 
16
23
  class CosyVoiceModel:
17
24
 
@@ -23,38 +30,144 @@ class CosyVoiceModel:
23
30
  self.llm = llm
24
31
  self.flow = flow
25
32
  self.hift = hift
33
+ self.token_min_hop_len = 100
34
+ self.token_max_hop_len = 200
35
+ self.token_overlap_len = 20
36
+ # mel fade in out
37
+ self.mel_overlap_len = 34
38
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
39
+ # hift cache
40
+ self.mel_cache_len = 20
41
+ self.source_cache_len = int(self.mel_cache_len * 256)
42
+ # rtf and decoding related
43
+ self.stream_scale_factor = 1
44
+ assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
45
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
46
+ self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
47
+ self.lock = threading.Lock()
48
+ # dict used to store session related variable
49
+ self.tts_speech_token_dict = {}
50
+ self.llm_end_dict = {}
51
+ self.mel_overlap_dict = {}
52
+ self.hift_cache_dict = {}
26
53
 
27
54
  def load(self, llm_model, flow_model, hift_model):
28
55
  self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
29
56
  self.llm.to(self.device).eval()
57
+ self.llm.half()
30
58
  self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
31
59
  self.flow.to(self.device).eval()
32
60
  self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
33
61
  self.hift.to(self.device).eval()
34
62
 
35
- def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
36
- prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
37
- llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
38
- flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
39
- prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
40
- tts_speech_token = self.llm.inference(text=text.to(self.device),
41
- text_len=text_len.to(self.device),
42
- prompt_text=prompt_text.to(self.device),
43
- prompt_text_len=prompt_text_len.to(self.device),
44
- prompt_speech_token=llm_prompt_speech_token.to(self.device),
45
- prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
46
- embedding=llm_embedding.to(self.device),
47
- beam_size=1,
48
- sampling=25,
49
- max_token_text_ratio=30,
50
- min_token_text_ratio=3)
51
- tts_mel = self.flow.inference(token=tts_speech_token,
52
- token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
53
- prompt_token=flow_prompt_speech_token.to(self.device),
54
- prompt_token_len=flow_prompt_speech_token_len.to(self.device),
55
- prompt_feat=prompt_speech_feat.to(self.device),
56
- prompt_feat_len=prompt_speech_feat_len.to(self.device),
57
- embedding=flow_embedding.to(self.device))
58
- tts_speech = self.hift.inference(mel=tts_mel).cpu()
59
- torch.cuda.empty_cache()
60
- return {'tts_speech': tts_speech}
63
+ def load_jit(self, llm_text_encoder_model, llm_llm_model):
64
+ llm_text_encoder = torch.jit.load(llm_text_encoder_model)
65
+ self.llm.text_encoder = llm_text_encoder
66
+ llm_llm = torch.jit.load(llm_llm_model)
67
+ self.llm.llm = llm_llm
68
+
69
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
70
+ with self.llm_context:
71
+ for i in self.llm.inference(text=text.to(self.device),
72
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
73
+ prompt_text=prompt_text.to(self.device),
74
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
75
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
76
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
77
+ embedding=llm_embedding.to(self.device).half(),
78
+ sampling=25,
79
+ max_token_text_ratio=30,
80
+ min_token_text_ratio=3):
81
+ self.tts_speech_token_dict[uuid].append(i)
82
+ self.llm_end_dict[uuid] = True
83
+
84
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
85
+ with self.flow_hift_context:
86
+ tts_mel = self.flow.inference(token=token.to(self.device),
87
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
88
+ prompt_token=prompt_token.to(self.device),
89
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
90
+ prompt_feat=prompt_feat.to(self.device),
91
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
92
+ embedding=embedding.to(self.device))
93
+ # mel overlap fade in out
94
+ # if self.mel_overlap_dict[uuid] is not None:
95
+ # tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
96
+ # append hift cache
97
+ if self.hift_cache_dict[uuid] is not None:
98
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
99
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
100
+ else:
101
+ hift_cache_source = torch.zeros(1, 1, 0)
102
+ # keep overlap mel and hift cache
103
+ if finalize is False:
104
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
105
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
106
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
107
+ self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
108
+ tts_speech = tts_speech[:, :-self.source_cache_len]
109
+ else:
110
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
111
+ return tts_speech
112
+
113
+ def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
114
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
115
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
116
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
117
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
118
+ # this_uuid is used to track variables related to this inference thread
119
+ this_uuid = str(uuid.uuid1())
120
+ with self.lock:
121
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
122
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
123
+ p.start()
124
+ if stream is True:
125
+ token_hop_len = self.token_min_hop_len
126
+ while True:
127
+ time.sleep(0.1)
128
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
129
+ this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
130
+ with self.flow_hift_context:
131
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
132
+ prompt_token=flow_prompt_speech_token,
133
+ prompt_feat=prompt_speech_feat,
134
+ embedding=flow_embedding,
135
+ uuid=this_uuid,
136
+ finalize=False)
137
+ yield {'tts_speech': this_tts_speech.cpu()}
138
+ with self.lock:
139
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
140
+ # increase token_hop_len for better speech quality
141
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
142
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
143
+ break
144
+ p.join()
145
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
146
+ this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
147
+ with self.flow_hift_context:
148
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
149
+ prompt_token=flow_prompt_speech_token,
150
+ prompt_feat=prompt_speech_feat,
151
+ embedding=flow_embedding,
152
+ uuid=this_uuid,
153
+ finalize=True)
154
+ yield {'tts_speech': this_tts_speech.cpu()}
155
+ else:
156
+ # deal with all tokens
157
+ p.join()
158
+ this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
159
+ with self.flow_hift_context:
160
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
161
+ prompt_token=flow_prompt_speech_token,
162
+ prompt_feat=prompt_speech_feat,
163
+ embedding=flow_embedding,
164
+ uuid=this_uuid,
165
+ finalize=True)
166
+ yield {'tts_speech': this_tts_speech.cpu()}
167
+ with self.lock:
168
+ self.tts_speech_token_dict.pop(this_uuid)
169
+ self.llm_end_dict.pop(this_uuid)
170
+ self.mel_overlap_dict.pop(this_uuid)
171
+ self.hift_cache_dict.pop(this_uuid)
172
+ if torch.cuda.is_initialized():
173
+ torch.cuda.synchronize()
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import logging
15
+ import random
15
16
  from typing import Dict, Optional
16
17
  import torch
17
18
  import torch.nn as nn
@@ -77,6 +78,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
77
78
 
78
79
  # get conditions
79
80
  conds = torch.zeros(feat.shape, device=token.device)
81
+ for i, j in enumerate(feat_len):
82
+ if random.random() < 0.5:
83
+ continue
84
+ index = random.randint(0, int(0.3 * j))
85
+ conds[i, :index] = feat[i, :index]
80
86
  conds = conds.transpose(1, 2)
81
87
 
82
88
  mask = (~make_pad_mask(feat_len)).to(h)
@@ -105,6 +111,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
105
111
  embedding = self.spk_embed_affine_layer(embedding)
106
112
 
107
113
  # concat text and prompt_text
114
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
108
115
  token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
109
116
  mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
110
117
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
@@ -112,17 +119,16 @@ class MaskedDiffWithXvec(torch.nn.Module):
112
119
  # text encode
113
120
  h, h_lengths = self.encoder(token, token_len)
114
121
  h = self.encoder_proj(h)
115
- feat_len = (token_len / 50 * 22050 / 256).int()
116
- h, h_lengths = self.length_regulator(h, feat_len)
122
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / 50 * 22050 / 256)
123
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
117
124
 
118
125
  # get conditions
119
- conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
120
- if prompt_feat.shape[1] != 0:
121
- for i, j in enumerate(prompt_feat_len):
122
- conds[i, :j] = prompt_feat[i]
126
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
127
+ conds[:, :mel_len1] = prompt_feat
123
128
  conds = conds.transpose(1, 2)
124
129
 
125
- mask = (~make_pad_mask(feat_len)).to(h)
130
+ # mask = (~make_pad_mask(feat_len)).to(h)
131
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
126
132
  feat = self.decoder(
127
133
  mu=h.transpose(1, 2).contiguous(),
128
134
  mask=mask.unsqueeze(1),
@@ -130,6 +136,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
130
136
  cond=conds,
131
137
  n_timesteps=10
132
138
  )
133
- if prompt_feat.shape[1] != 0:
134
- feat = feat[:, :, prompt_feat.shape[1]:]
139
+ feat = feat[:, :, mel_len1:]
140
+ assert feat.shape[2] == mel_len2
135
141
  return feat
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  from typing import Tuple
15
15
  import torch.nn as nn
16
+ import torch
16
17
  from torch.nn import functional as F
17
18
  from cosyvoice.utils.mask import make_pad_mask
18
19
 
@@ -43,7 +44,25 @@ class InterpolateRegulator(nn.Module):
43
44
  def forward(self, x, ylens=None):
44
45
  # x in (B, T, D)
45
46
  mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
46
- x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
47
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
47
48
  out = self.model(x).transpose(1, 2).contiguous()
48
49
  olens = ylens
49
50
  return out * mask, olens
51
+
52
+ def inference(self, x1, x2, mel_len1, mel_len2):
53
+ # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
+ # x in (B, T, D)
55
+ if x2.shape[1] > 40:
56
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
57
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
58
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
59
+ x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
60
+ else:
61
+ x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
62
+ if x1.shape[1] != 0:
63
+ x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
64
+ x = torch.concat([x1, x2], dim=2)
65
+ else:
66
+ x = x2
67
+ out = self.model(x).transpose(1, 2).contiguous()
68
+ return out, mel_len1 + mel_len2
@@ -335,10 +335,14 @@ class HiFTGenerator(nn.Module):
335
335
  inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
336
336
  return inverse_transform
337
337
 
338
- def forward(self, x: torch.Tensor) -> torch.Tensor:
338
+ def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
339
339
  f0 = self.f0_predictor(x)
340
340
  s = self._f02source(f0)
341
341
 
342
+ # use cache_source to avoid glitch
343
+ if cache_source.shape[2] == 0:
344
+ s[:, :, :cache_source.shape[2]] = cache_source
345
+
342
346
  s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
343
347
  s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
344
348
 
@@ -370,7 +374,7 @@ class HiFTGenerator(nn.Module):
370
374
 
371
375
  x = self._istft(magnitude, phase)
372
376
  x = torch.clamp(x, -self.audio_limit, self.audio_limit)
373
- return x
377
+ return x, s
374
378
 
375
379
  def remove_weight_norm(self):
376
380
  print('Removing weight norm...')
@@ -387,5 +391,5 @@ class HiFTGenerator(nn.Module):
387
391
  l.remove_weight_norm()
388
392
 
389
393
  @torch.inference_mode()
390
- def inference(self, mel: torch.Tensor) -> torch.Tensor:
391
- return self.forward(x=mel)
394
+ def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
395
+ return self.forward(x=mel, cache_source=cache_source)
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Dict, Optional, Union
14
+ from typing import Dict, Optional, Callable, List, Generator
15
15
  import torch
16
16
  from torch import nn
17
17
  import torch.nn.functional as F
@@ -31,6 +31,7 @@ class TransformerLM(torch.nn.Module):
31
31
  speech_token_size: int,
32
32
  text_encoder: torch.nn.Module,
33
33
  llm: torch.nn.Module,
34
+ sampling: Callable,
34
35
  length_normalized_loss: bool = True,
35
36
  lsm_weight: float = 0.0,
36
37
  spk_embed_dim: int = 192,
@@ -63,6 +64,9 @@ class TransformerLM(torch.nn.Module):
63
64
  self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
64
65
  self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
65
66
 
67
+ # 4. sampling method
68
+ self.sampling = sampling
69
+
66
70
  def encode(
67
71
  self,
68
72
  text: torch.Tensor,
@@ -132,14 +136,12 @@ class TransformerLM(torch.nn.Module):
132
136
  def sampling_ids(
133
137
  self,
134
138
  weighted_scores: torch.Tensor,
135
- sampling: Union[bool, int, float] = True,
136
- beam_size: int = 1,
139
+ decoded_tokens: List,
140
+ sampling: int,
137
141
  ignore_eos: bool = True,
138
142
  ):
139
143
  while True:
140
- prob, indices = weighted_scores.softmax(dim=-1).topk(sampling)
141
- top_ids = prob.multinomial(beam_size, replacement=True)
142
- top_ids = indices[top_ids]
144
+ top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
143
145
  if (not ignore_eos) or (self.speech_token_size not in top_ids):
144
146
  break
145
147
  return top_ids
@@ -154,11 +156,10 @@ class TransformerLM(torch.nn.Module):
154
156
  prompt_speech_token: torch.Tensor,
155
157
  prompt_speech_token_len: torch.Tensor,
156
158
  embedding: torch.Tensor,
157
- beam_size: int = 1,
158
159
  sampling: int = 25,
159
160
  max_token_text_ratio: float = 20,
160
161
  min_token_text_ratio: float = 2,
161
- ) -> torch.Tensor:
162
+ ) -> Generator[torch.Tensor, None, None]:
162
163
  device = text.device
163
164
  text = torch.concat([prompt_text, text], dim=1)
164
165
  text_len += prompt_text_len
@@ -173,7 +174,7 @@ class TransformerLM(torch.nn.Module):
173
174
  embedding = self.spk_embed_affine_layer(embedding)
174
175
  embedding = embedding.unsqueeze(dim=1)
175
176
  else:
176
- embedding = torch.zeros(1, 0, self.llm_input_size).to(device)
177
+ embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
177
178
 
178
179
  # 3. concat llm_input
179
180
  sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
@@ -181,7 +182,7 @@ class TransformerLM(torch.nn.Module):
181
182
  if prompt_speech_token_len != 0:
182
183
  prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
183
184
  else:
184
- prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device)
185
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
185
186
  lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
186
187
 
187
188
  # 4. cal min/max_length
@@ -196,11 +197,11 @@ class TransformerLM(torch.nn.Module):
196
197
  y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
197
198
  att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
198
199
  logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
199
- top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
200
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
200
201
  if top_ids == self.speech_token_size:
201
202
  break
203
+ # in stream mode, yield token one by one
204
+ yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
202
205
  out_tokens.append(top_ids)
203
206
  offset += lm_input.size(1)
204
207
  lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
205
-
206
- return torch.tensor([out_tokens], dtype=torch.int64, device=device)
@@ -222,7 +222,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
222
222
  torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
223
  torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
224
 
225
- def rel_shift(self, x):
225
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
226
226
  """Compute relative positional encoding.
227
227
 
228
228
  Args:
@@ -233,10 +233,14 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
233
233
  torch.Tensor: Output tensor.
234
234
 
235
235
  """
236
- zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
236
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
237
+ device=x.device,
238
+ dtype=x.dtype)
237
239
  x_padded = torch.cat([zero_pad, x], dim=-1)
238
240
 
239
- x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
241
+ x_padded = x_padded.view(x.size()[0],
242
+ x.size()[1],
243
+ x.size(3) + 1, x.size(2))
240
244
  x = x_padded[:, :, 1:].view_as(x)[
241
245
  :, :, :, : x.size(-1) // 2 + 1
242
246
  ] # only keep the positions from 0 to time2
@@ -174,7 +174,7 @@ class TransformerDecoder(torch.nn.Module):
174
174
  memory_mask)
175
175
  return x
176
176
 
177
- @torch.jit.ignore(drop=True)
177
+ @torch.jit.unused
178
178
  def forward_layers_checkpointed(self, x: torch.Tensor,
179
179
  tgt_mask: torch.Tensor,
180
180
  memory: torch.Tensor,