xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (170) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +28 -6
  4. xinference/core/utils.py +10 -6
  5. xinference/deploy/cmdline.py +3 -1
  6. xinference/deploy/test/test_cmdline.py +56 -0
  7. xinference/isolation.py +24 -0
  8. xinference/model/audio/core.py +10 -0
  9. xinference/model/audio/cosyvoice.py +25 -3
  10. xinference/model/audio/f5tts.py +200 -0
  11. xinference/model/audio/f5tts_mlx.py +260 -0
  12. xinference/model/audio/fish_speech.py +36 -111
  13. xinference/model/audio/model_spec.json +27 -3
  14. xinference/model/audio/model_spec_modelscope.json +18 -0
  15. xinference/model/audio/utils.py +32 -0
  16. xinference/model/embedding/core.py +203 -142
  17. xinference/model/embedding/model_spec.json +7 -0
  18. xinference/model/embedding/model_spec_modelscope.json +8 -0
  19. xinference/model/image/core.py +69 -1
  20. xinference/model/image/model_spec.json +127 -4
  21. xinference/model/image/model_spec_modelscope.json +130 -4
  22. xinference/model/image/stable_diffusion/core.py +45 -13
  23. xinference/model/llm/__init__.py +2 -2
  24. xinference/model/llm/llm_family.json +219 -53
  25. xinference/model/llm/llm_family.py +15 -36
  26. xinference/model/llm/llm_family_modelscope.json +167 -20
  27. xinference/model/llm/mlx/core.py +287 -51
  28. xinference/model/llm/sglang/core.py +1 -0
  29. xinference/model/llm/transformers/chatglm.py +9 -5
  30. xinference/model/llm/transformers/core.py +1 -0
  31. xinference/model/llm/transformers/qwen2_vl.py +2 -0
  32. xinference/model/llm/transformers/utils.py +16 -8
  33. xinference/model/llm/utils.py +5 -1
  34. xinference/model/llm/vllm/core.py +16 -2
  35. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  36. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  37. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  38. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  39. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  40. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  41. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  42. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  43. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  44. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  45. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  46. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  47. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  48. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  49. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  50. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  51. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  52. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  53. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  54. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  55. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  56. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  57. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  58. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  59. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  60. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  61. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  62. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  63. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  64. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  65. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  66. xinference/thirdparty/f5_tts/api.py +166 -0
  67. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  68. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  69. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  70. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  71. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  72. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  73. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  74. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  75. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  76. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  77. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  78. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  79. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  80. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  81. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  82. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  83. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  84. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  85. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  86. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  87. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  88. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  89. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  90. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  91. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  92. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  93. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  94. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  95. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  96. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  97. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  98. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  99. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  100. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  101. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  102. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  103. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  104. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  105. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  106. xinference/thirdparty/f5_tts/train/README.md +77 -0
  107. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  108. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  109. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  110. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  111. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  112. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  113. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  114. xinference/thirdparty/f5_tts/train/train.py +75 -0
  115. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  116. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  117. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  118. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  119. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  120. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  121. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  122. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  123. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  124. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  125. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  126. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  127. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  128. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  129. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  130. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  131. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  132. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  133. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  134. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  135. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  136. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  137. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  138. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  139. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  140. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  141. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  142. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  143. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  144. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  145. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  146. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  147. xinference/thirdparty/matcha/utils/utils.py +2 -2
  148. xinference/web/ui/build/asset-manifest.json +3 -3
  149. xinference/web/ui/build/index.html +1 -1
  150. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  151. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  153. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
  154. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
  155. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  156. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  157. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  158. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  159. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  160. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  161. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  162. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  163. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  164. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  165. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  166. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  167. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
  168. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
  169. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
  170. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@ import torch
15
15
  import numpy as np
16
16
  import threading
17
17
  import time
18
+ from torch.nn import functional as F
18
19
  from contextlib import nullcontext
19
20
  import uuid
20
21
  from cosyvoice.utils.common import fade_in_out
@@ -25,100 +26,134 @@ class CosyVoiceModel:
25
26
  def __init__(self,
26
27
  llm: torch.nn.Module,
27
28
  flow: torch.nn.Module,
28
- hift: torch.nn.Module):
29
+ hift: torch.nn.Module,
30
+ fp16: bool):
29
31
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
32
  self.llm = llm
31
33
  self.flow = flow
32
34
  self.hift = hift
33
- self.token_min_hop_len = 100
34
- self.token_max_hop_len = 200
35
+ self.fp16 = fp16
36
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
37
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
35
38
  self.token_overlap_len = 20
39
+ # here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
40
+ self.flow.decoder.estimator.static_chunk_size = 0
36
41
  # mel fade in out
37
- self.mel_overlap_len = 34
42
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
38
43
  self.mel_window = np.hamming(2 * self.mel_overlap_len)
39
44
  # hift cache
40
45
  self.mel_cache_len = 20
41
46
  self.source_cache_len = int(self.mel_cache_len * 256)
47
+ # speech fade in out
48
+ self.speech_window = np.hamming(2 * self.source_cache_len)
42
49
  # rtf and decoding related
43
50
  self.stream_scale_factor = 1
44
51
  assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
45
52
  self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
46
- self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
47
53
  self.lock = threading.Lock()
48
54
  # dict used to store session related variable
49
55
  self.tts_speech_token_dict = {}
50
56
  self.llm_end_dict = {}
51
57
  self.mel_overlap_dict = {}
58
+ self.flow_cache_dict = {}
52
59
  self.hift_cache_dict = {}
53
60
 
54
61
  def load(self, llm_model, flow_model, hift_model):
55
- self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
62
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
56
63
  self.llm.to(self.device).eval()
57
- self.llm.half()
58
- self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
64
+ if self.fp16 is True:
65
+ self.llm.half()
66
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
59
67
  self.flow.to(self.device).eval()
60
- self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
68
+ # in case hift_model is a hifigan model
69
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
70
+ self.hift.load_state_dict(hift_state_dict, strict=True)
61
71
  self.hift.to(self.device).eval()
62
72
 
63
- def load_jit(self, llm_text_encoder_model, llm_llm_model):
64
- llm_text_encoder = torch.jit.load(llm_text_encoder_model)
73
+ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
74
+ assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
75
+ llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
65
76
  self.llm.text_encoder = llm_text_encoder
66
- llm_llm = torch.jit.load(llm_llm_model)
77
+ llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
67
78
  self.llm.llm = llm_llm
79
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
80
+ self.flow.encoder = flow_encoder
81
+
82
+ def load_onnx(self, flow_decoder_estimator_model):
83
+ import onnxruntime
84
+ option = onnxruntime.SessionOptions()
85
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
86
+ option.intra_op_num_threads = 1
87
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
88
+ del self.flow.decoder.estimator
89
+ self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
68
90
 
69
91
  def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
92
+ if self.fp16 is True:
93
+ llm_embedding = llm_embedding.half()
70
94
  with self.llm_context:
71
95
  for i in self.llm.inference(text=text.to(self.device),
72
- text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
73
- prompt_text=prompt_text.to(self.device),
74
- prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
75
- prompt_speech_token=llm_prompt_speech_token.to(self.device),
76
- prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
77
- embedding=llm_embedding.to(self.device).half(),
78
- sampling=25,
79
- max_token_text_ratio=30,
80
- min_token_text_ratio=3):
96
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
97
+ prompt_text=prompt_text.to(self.device),
98
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
99
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
100
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
101
+ embedding=llm_embedding.to(self.device)):
81
102
  self.tts_speech_token_dict[uuid].append(i)
82
103
  self.llm_end_dict[uuid] = True
83
104
 
84
- def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
85
- with self.flow_hift_context:
86
- tts_mel = self.flow.inference(token=token.to(self.device),
87
- token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
88
- prompt_token=prompt_token.to(self.device),
89
- prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
90
- prompt_feat=prompt_feat.to(self.device),
91
- prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
92
- embedding=embedding.to(self.device))
93
- # mel overlap fade in out
94
- # if self.mel_overlap_dict[uuid] is not None:
95
- # tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
96
- # append hift cache
105
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
106
+ tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
107
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
108
+ prompt_token=prompt_token.to(self.device),
109
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
110
+ prompt_feat=prompt_feat.to(self.device),
111
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
112
+ embedding=embedding.to(self.device),
113
+ flow_cache=self.flow_cache_dict[uuid])
114
+ self.flow_cache_dict[uuid] = flow_cache
115
+
116
+ # mel overlap fade in out
117
+ if self.mel_overlap_dict[uuid].shape[2] != 0:
118
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
119
+ # append hift cache
120
+ if self.hift_cache_dict[uuid] is not None:
121
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
122
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
123
+ else:
124
+ hift_cache_source = torch.zeros(1, 1, 0)
125
+ # keep overlap mel and hift cache
126
+ if finalize is False:
127
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
128
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
129
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
130
+ if self.hift_cache_dict[uuid] is not None:
131
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
132
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
133
+ 'source': tts_source[:, :, -self.source_cache_len:],
134
+ 'speech': tts_speech[:, -self.source_cache_len:]}
135
+ tts_speech = tts_speech[:, :-self.source_cache_len]
136
+ else:
137
+ if speed != 1.0:
138
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
139
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
140
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
97
141
  if self.hift_cache_dict[uuid] is not None:
98
- hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
99
- tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
100
- else:
101
- hift_cache_source = torch.zeros(1, 1, 0)
102
- # keep overlap mel and hift cache
103
- if finalize is False:
104
- self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
105
- tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
106
- tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
107
- self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
108
- tts_speech = tts_speech[:, :-self.source_cache_len]
109
- else:
110
- tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
142
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
111
143
  return tts_speech
112
144
 
113
- def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
114
- prompt_text=torch.zeros(1, 0, dtype=torch.int32),
115
- llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
116
- flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
117
- prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
145
+ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
146
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
147
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
148
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
149
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
118
150
  # this_uuid is used to track variables related to this inference thread
119
151
  this_uuid = str(uuid.uuid1())
120
152
  with self.lock:
121
- self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
153
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
154
+ self.hift_cache_dict[this_uuid] = None
155
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
156
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
122
157
  p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
123
158
  p.start()
124
159
  if stream is True:
@@ -126,15 +161,15 @@ class CosyVoiceModel:
126
161
  while True:
127
162
  time.sleep(0.1)
128
163
  if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
129
- this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
130
- with self.flow_hift_context:
131
- this_tts_speech = self.token2wav(token=this_tts_speech_token,
132
- prompt_token=flow_prompt_speech_token,
133
- prompt_feat=prompt_speech_feat,
134
- embedding=flow_embedding,
135
- uuid=this_uuid,
136
- finalize=False)
137
- yield {'tts_speech': this_tts_speech.cpu()}
164
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
165
+ .unsqueeze(dim=0)
166
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
167
+ prompt_token=flow_prompt_speech_token,
168
+ prompt_feat=prompt_speech_feat,
169
+ embedding=flow_embedding,
170
+ uuid=this_uuid,
171
+ finalize=False)
172
+ yield {'tts_speech': this_tts_speech.cpu()}
138
173
  with self.lock:
139
174
  self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
140
175
  # increase token_hop_len for better speech quality
@@ -143,31 +178,246 @@ class CosyVoiceModel:
143
178
  break
144
179
  p.join()
145
180
  # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
146
- this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
147
- with self.flow_hift_context:
148
- this_tts_speech = self.token2wav(token=this_tts_speech_token,
149
- prompt_token=flow_prompt_speech_token,
150
- prompt_feat=prompt_speech_feat,
151
- embedding=flow_embedding,
152
- uuid=this_uuid,
153
- finalize=True)
181
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
182
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
183
+ prompt_token=flow_prompt_speech_token,
184
+ prompt_feat=prompt_speech_feat,
185
+ embedding=flow_embedding,
186
+ uuid=this_uuid,
187
+ finalize=True)
154
188
  yield {'tts_speech': this_tts_speech.cpu()}
155
189
  else:
156
190
  # deal with all tokens
157
191
  p.join()
158
- this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
159
- with self.flow_hift_context:
160
- this_tts_speech = self.token2wav(token=this_tts_speech_token,
161
- prompt_token=flow_prompt_speech_token,
162
- prompt_feat=prompt_speech_feat,
163
- embedding=flow_embedding,
164
- uuid=this_uuid,
165
- finalize=True)
192
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
193
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
194
+ prompt_token=flow_prompt_speech_token,
195
+ prompt_feat=prompt_speech_feat,
196
+ embedding=flow_embedding,
197
+ uuid=this_uuid,
198
+ finalize=True,
199
+ speed=speed)
166
200
  yield {'tts_speech': this_tts_speech.cpu()}
167
201
  with self.lock:
168
202
  self.tts_speech_token_dict.pop(this_uuid)
169
203
  self.llm_end_dict.pop(this_uuid)
170
204
  self.mel_overlap_dict.pop(this_uuid)
171
205
  self.hift_cache_dict.pop(this_uuid)
172
- if torch.cuda.is_initialized():
173
- torch.cuda.synchronize()
206
+ self.flow_cache_dict.pop(this_uuid)
207
+
208
+ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
209
+ # this_uuid is used to track variables related to this inference thread
210
+ this_uuid = str(uuid.uuid1())
211
+ with self.lock:
212
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
213
+ self.hift_cache_dict[this_uuid] = None
214
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
215
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
216
+ if stream is True:
217
+ token_hop_len = self.token_min_hop_len
218
+ while True:
219
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
220
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
221
+ .unsqueeze(dim=0)
222
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
223
+ prompt_token=flow_prompt_speech_token,
224
+ prompt_feat=prompt_speech_feat,
225
+ embedding=flow_embedding,
226
+ uuid=this_uuid,
227
+ finalize=False)
228
+ yield {'tts_speech': this_tts_speech.cpu()}
229
+ with self.lock:
230
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
231
+ # increase token_hop_len for better speech quality
232
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
233
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
234
+ break
235
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
236
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
237
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
238
+ prompt_token=flow_prompt_speech_token,
239
+ prompt_feat=prompt_speech_feat,
240
+ embedding=flow_embedding,
241
+ uuid=this_uuid,
242
+ finalize=True)
243
+ yield {'tts_speech': this_tts_speech.cpu()}
244
+ else:
245
+ # deal with all tokens
246
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
247
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
248
+ prompt_token=flow_prompt_speech_token,
249
+ prompt_feat=prompt_speech_feat,
250
+ embedding=flow_embedding,
251
+ uuid=this_uuid,
252
+ finalize=True,
253
+ speed=speed)
254
+ yield {'tts_speech': this_tts_speech.cpu()}
255
+ with self.lock:
256
+ self.tts_speech_token_dict.pop(this_uuid)
257
+ self.llm_end_dict.pop(this_uuid)
258
+ self.mel_overlap_dict.pop(this_uuid)
259
+ self.hift_cache_dict.pop(this_uuid)
260
+
261
+
262
+ class CosyVoice2Model:
263
+
264
+ def __init__(self,
265
+ llm: torch.nn.Module,
266
+ flow: torch.nn.Module,
267
+ hift: torch.nn.Module):
268
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
269
+ self.llm = llm
270
+ self.flow = flow
271
+ self.hift = hift
272
+ self.token_hop_len = 2 * self.flow.input_frame_rate
273
+ # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
274
+ self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
275
+ self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
276
+ # hift cache
277
+ self.mel_cache_len = 8
278
+ self.source_cache_len = int(self.mel_cache_len * 480)
279
+ # speech fade in out
280
+ self.speech_window = np.hamming(2 * self.source_cache_len)
281
+ # rtf and decoding related
282
+ self.stream_scale_factor = 1
283
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
284
+ self.lock = threading.Lock()
285
+ # dict used to store session related variable
286
+ self.tts_speech_token_dict = {}
287
+ self.llm_end_dict = {}
288
+ self.hift_cache_dict = {}
289
+
290
+ def load(self, llm_model, flow_model, hift_model):
291
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
292
+ self.llm.to(self.device).eval()
293
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
294
+ self.flow.to(self.device).eval()
295
+ self.flow.decoder.fp16 = False
296
+ # in case hift_model is a hifigan model
297
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
298
+ self.hift.load_state_dict(hift_state_dict, strict=True)
299
+ self.hift.to(self.device).eval()
300
+
301
+ def load_jit(self, flow_encoder_model):
302
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
303
+ self.flow.encoder = flow_encoder
304
+
305
+ def load_onnx(self, flow_decoder_estimator_model):
306
+ import onnxruntime
307
+ option = onnxruntime.SessionOptions()
308
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
309
+ option.intra_op_num_threads = 1
310
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
311
+ del self.flow.decoder.estimator
312
+ self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
313
+
314
+ def load_trt(self, flow_decoder_estimator_model):
315
+ del self.flow.decoder.estimator
316
+ import tensorrt as trt
317
+ with open(flow_decoder_estimator_model, 'rb') as f:
318
+ self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
319
+ self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
320
+ self.flow.decoder.fp16 = True
321
+
322
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
323
+ with self.llm_context:
324
+ for i in self.llm.inference(text=text.to(self.device),
325
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
326
+ prompt_text=prompt_text.to(self.device),
327
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
328
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
329
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
330
+ embedding=llm_embedding.to(self.device)):
331
+ self.tts_speech_token_dict[uuid].append(i)
332
+ self.llm_end_dict[uuid] = True
333
+
334
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
335
+ tts_mel, _ = self.flow.inference(token=token.to(self.device),
336
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
337
+ prompt_token=prompt_token.to(self.device),
338
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
339
+ prompt_feat=prompt_feat.to(self.device),
340
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
341
+ embedding=embedding.to(self.device),
342
+ finalize=finalize)
343
+ tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
344
+ # append hift cache
345
+ if self.hift_cache_dict[uuid] is not None:
346
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
347
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
348
+ else:
349
+ hift_cache_source = torch.zeros(1, 1, 0)
350
+ # keep overlap mel and hift cache
351
+ if finalize is False:
352
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
353
+ if self.hift_cache_dict[uuid] is not None:
354
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
355
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
356
+ 'source': tts_source[:, :, -self.source_cache_len:],
357
+ 'speech': tts_speech[:, -self.source_cache_len:]}
358
+ tts_speech = tts_speech[:, :-self.source_cache_len]
359
+ else:
360
+ if speed != 1.0:
361
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
362
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
363
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
364
+ if self.hift_cache_dict[uuid] is not None:
365
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
366
+ return tts_speech
367
+
368
+ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
369
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
370
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
371
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
372
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
373
+ # this_uuid is used to track variables related to this inference thread
374
+ this_uuid = str(uuid.uuid1())
375
+ with self.lock:
376
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
377
+ self.hift_cache_dict[this_uuid] = None
378
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
379
+ p.start()
380
+ if stream is True:
381
+ token_offset = 0
382
+ while True:
383
+ time.sleep(0.1)
384
+ if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
385
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
386
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
387
+ prompt_token=flow_prompt_speech_token,
388
+ prompt_feat=prompt_speech_feat,
389
+ embedding=flow_embedding,
390
+ uuid=this_uuid,
391
+ token_offset=token_offset,
392
+ finalize=False)
393
+ token_offset += self.token_hop_len
394
+ yield {'tts_speech': this_tts_speech.cpu()}
395
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
396
+ break
397
+ p.join()
398
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
399
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
400
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
401
+ prompt_token=flow_prompt_speech_token,
402
+ prompt_feat=prompt_speech_feat,
403
+ embedding=flow_embedding,
404
+ uuid=this_uuid,
405
+ token_offset=token_offset,
406
+ finalize=True)
407
+ yield {'tts_speech': this_tts_speech.cpu()}
408
+ else:
409
+ # deal with all tokens
410
+ p.join()
411
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
412
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
413
+ prompt_token=flow_prompt_speech_token,
414
+ prompt_feat=prompt_speech_feat,
415
+ embedding=flow_embedding,
416
+ uuid=this_uuid,
417
+ token_offset=0,
418
+ finalize=True,
419
+ speed=speed)
420
+ yield {'tts_speech': this_tts_speech.cpu()}
421
+ with self.lock:
422
+ self.tts_speech_token_dict.pop(this_uuid)
423
+ self.llm_end_dict.pop(this_uuid)
@@ -126,6 +126,7 @@ class DataList(IterableDataset):
126
126
  def Dataset(data_list_file,
127
127
  data_pipeline,
128
128
  mode='train',
129
+ gan=False,
129
130
  shuffle=True,
130
131
  partition=True,
131
132
  tts_file='',
@@ -148,13 +149,16 @@ def Dataset(data_list_file,
148
149
  tts_data = json.load(f)
149
150
  utt2lists = read_json_lists(prompt_utt2data)
150
151
  # filter unnecessary file in inference mode
151
- lists = list(set([utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists]))
152
+ lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
152
153
  dataset = DataList(lists,
153
154
  shuffle=shuffle,
154
155
  partition=partition)
155
156
  if mode == 'inference':
156
- # map partial arg tts_data in inference mode
157
+ # map partial arg to parquet_opener func in inference mode
157
158
  data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
159
+ if gan is True:
160
+ # map partial arg to padding func in gan mode
161
+ data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
158
162
  for func in data_pipeline:
159
163
  dataset = Processor(dataset, func, mode=mode)
160
164
  return dataset