xinference 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +23 -1
- xinference/core/model.py +1 -6
- xinference/core/utils.py +10 -6
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +15 -10
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +35 -111
- xinference/model/audio/model_spec.json +19 -3
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +127 -4
- xinference/model/image/model_spec_modelscope.json +130 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/llm_family.json +47 -0
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +49 -0
- xinference/model/llm/mlx/core.py +68 -13
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/qwen2_vl.py +2 -0
- xinference/model/llm/utils.py +1 -0
- xinference/model/llm/vllm/core.py +11 -2
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/METADATA +11 -6
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/RECORD +95 -74
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
|
@@ -15,6 +15,7 @@ import torch
|
|
|
15
15
|
import numpy as np
|
|
16
16
|
import threading
|
|
17
17
|
import time
|
|
18
|
+
from torch.nn import functional as F
|
|
18
19
|
from contextlib import nullcontext
|
|
19
20
|
import uuid
|
|
20
21
|
from cosyvoice.utils.common import fade_in_out
|
|
@@ -25,100 +26,134 @@ class CosyVoiceModel:
|
|
|
25
26
|
def __init__(self,
|
|
26
27
|
llm: torch.nn.Module,
|
|
27
28
|
flow: torch.nn.Module,
|
|
28
|
-
hift: torch.nn.Module
|
|
29
|
+
hift: torch.nn.Module,
|
|
30
|
+
fp16: bool):
|
|
29
31
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
30
32
|
self.llm = llm
|
|
31
33
|
self.flow = flow
|
|
32
34
|
self.hift = hift
|
|
33
|
-
self.
|
|
34
|
-
self.
|
|
35
|
+
self.fp16 = fp16
|
|
36
|
+
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
|
37
|
+
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
|
35
38
|
self.token_overlap_len = 20
|
|
39
|
+
# here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
|
|
40
|
+
self.flow.decoder.estimator.static_chunk_size = 0
|
|
36
41
|
# mel fade in out
|
|
37
|
-
self.mel_overlap_len =
|
|
42
|
+
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
|
|
38
43
|
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
|
39
44
|
# hift cache
|
|
40
45
|
self.mel_cache_len = 20
|
|
41
46
|
self.source_cache_len = int(self.mel_cache_len * 256)
|
|
47
|
+
# speech fade in out
|
|
48
|
+
self.speech_window = np.hamming(2 * self.source_cache_len)
|
|
42
49
|
# rtf and decoding related
|
|
43
50
|
self.stream_scale_factor = 1
|
|
44
51
|
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
|
45
52
|
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
|
46
|
-
self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
|
47
53
|
self.lock = threading.Lock()
|
|
48
54
|
# dict used to store session related variable
|
|
49
55
|
self.tts_speech_token_dict = {}
|
|
50
56
|
self.llm_end_dict = {}
|
|
51
57
|
self.mel_overlap_dict = {}
|
|
58
|
+
self.flow_cache_dict = {}
|
|
52
59
|
self.hift_cache_dict = {}
|
|
53
60
|
|
|
54
61
|
def load(self, llm_model, flow_model, hift_model):
|
|
55
|
-
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
|
62
|
+
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
|
56
63
|
self.llm.to(self.device).eval()
|
|
57
|
-
self.
|
|
58
|
-
|
|
64
|
+
if self.fp16 is True:
|
|
65
|
+
self.llm.half()
|
|
66
|
+
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
|
59
67
|
self.flow.to(self.device).eval()
|
|
60
|
-
|
|
68
|
+
# in case hift_model is a hifigan model
|
|
69
|
+
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
|
70
|
+
self.hift.load_state_dict(hift_state_dict, strict=True)
|
|
61
71
|
self.hift.to(self.device).eval()
|
|
62
72
|
|
|
63
|
-
def load_jit(self, llm_text_encoder_model, llm_llm_model):
|
|
64
|
-
|
|
73
|
+
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
|
74
|
+
assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
|
|
75
|
+
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
|
|
65
76
|
self.llm.text_encoder = llm_text_encoder
|
|
66
|
-
llm_llm = torch.jit.load(llm_llm_model)
|
|
77
|
+
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
|
67
78
|
self.llm.llm = llm_llm
|
|
79
|
+
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
|
80
|
+
self.flow.encoder = flow_encoder
|
|
81
|
+
|
|
82
|
+
def load_onnx(self, flow_decoder_estimator_model):
|
|
83
|
+
import onnxruntime
|
|
84
|
+
option = onnxruntime.SessionOptions()
|
|
85
|
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
86
|
+
option.intra_op_num_threads = 1
|
|
87
|
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
|
88
|
+
del self.flow.decoder.estimator
|
|
89
|
+
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
|
68
90
|
|
|
69
91
|
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
|
92
|
+
if self.fp16 is True:
|
|
93
|
+
llm_embedding = llm_embedding.half()
|
|
70
94
|
with self.llm_context:
|
|
71
95
|
for i in self.llm.inference(text=text.to(self.device),
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
sampling=25,
|
|
79
|
-
max_token_text_ratio=30,
|
|
80
|
-
min_token_text_ratio=3):
|
|
96
|
+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
|
97
|
+
prompt_text=prompt_text.to(self.device),
|
|
98
|
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
|
99
|
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
|
100
|
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
|
101
|
+
embedding=llm_embedding.to(self.device)):
|
|
81
102
|
self.tts_speech_token_dict[uuid].append(i)
|
|
82
103
|
self.llm_end_dict[uuid] = True
|
|
83
104
|
|
|
84
|
-
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
105
|
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
|
106
|
+
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
|
|
107
|
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
|
108
|
+
prompt_token=prompt_token.to(self.device),
|
|
109
|
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
|
110
|
+
prompt_feat=prompt_feat.to(self.device),
|
|
111
|
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
|
112
|
+
embedding=embedding.to(self.device),
|
|
113
|
+
flow_cache=self.flow_cache_dict[uuid])
|
|
114
|
+
self.flow_cache_dict[uuid] = flow_cache
|
|
115
|
+
|
|
116
|
+
# mel overlap fade in out
|
|
117
|
+
if self.mel_overlap_dict[uuid].shape[2] != 0:
|
|
118
|
+
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
|
119
|
+
# append hift cache
|
|
120
|
+
if self.hift_cache_dict[uuid] is not None:
|
|
121
|
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
|
122
|
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
|
123
|
+
else:
|
|
124
|
+
hift_cache_source = torch.zeros(1, 1, 0)
|
|
125
|
+
# keep overlap mel and hift cache
|
|
126
|
+
if finalize is False:
|
|
127
|
+
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
|
|
128
|
+
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
|
|
129
|
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
|
130
|
+
if self.hift_cache_dict[uuid] is not None:
|
|
131
|
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
|
132
|
+
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
|
133
|
+
'source': tts_source[:, :, -self.source_cache_len:],
|
|
134
|
+
'speech': tts_speech[:, -self.source_cache_len:]}
|
|
135
|
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
|
136
|
+
else:
|
|
137
|
+
if speed != 1.0:
|
|
138
|
+
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
|
139
|
+
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
|
140
|
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
|
97
141
|
if self.hift_cache_dict[uuid] is not None:
|
|
98
|
-
|
|
99
|
-
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
|
100
|
-
else:
|
|
101
|
-
hift_cache_source = torch.zeros(1, 1, 0)
|
|
102
|
-
# keep overlap mel and hift cache
|
|
103
|
-
if finalize is False:
|
|
104
|
-
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
|
|
105
|
-
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
|
|
106
|
-
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
|
|
107
|
-
self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
|
|
108
|
-
tts_speech = tts_speech[:, :-self.source_cache_len]
|
|
109
|
-
else:
|
|
110
|
-
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
|
|
142
|
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
|
111
143
|
return tts_speech
|
|
112
144
|
|
|
113
|
-
def
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
145
|
+
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
|
146
|
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
|
147
|
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
|
148
|
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
|
149
|
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
|
118
150
|
# this_uuid is used to track variables related to this inference thread
|
|
119
151
|
this_uuid = str(uuid.uuid1())
|
|
120
152
|
with self.lock:
|
|
121
|
-
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid]
|
|
153
|
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
|
154
|
+
self.hift_cache_dict[this_uuid] = None
|
|
155
|
+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
|
156
|
+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
|
122
157
|
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
|
123
158
|
p.start()
|
|
124
159
|
if stream is True:
|
|
@@ -126,15 +161,15 @@ class CosyVoiceModel:
|
|
|
126
161
|
while True:
|
|
127
162
|
time.sleep(0.1)
|
|
128
163
|
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
|
129
|
-
this_tts_speech_token = torch.
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
yield
|
|
164
|
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
|
165
|
+
.unsqueeze(dim=0)
|
|
166
|
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
167
|
+
prompt_token=flow_prompt_speech_token,
|
|
168
|
+
prompt_feat=prompt_speech_feat,
|
|
169
|
+
embedding=flow_embedding,
|
|
170
|
+
uuid=this_uuid,
|
|
171
|
+
finalize=False)
|
|
172
|
+
yield {'tts_speech': this_tts_speech.cpu()}
|
|
138
173
|
with self.lock:
|
|
139
174
|
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
|
140
175
|
# increase token_hop_len for better speech quality
|
|
@@ -143,31 +178,246 @@ class CosyVoiceModel:
|
|
|
143
178
|
break
|
|
144
179
|
p.join()
|
|
145
180
|
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
|
146
|
-
this_tts_speech_token = torch.
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
finalize=True)
|
|
181
|
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
|
182
|
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
183
|
+
prompt_token=flow_prompt_speech_token,
|
|
184
|
+
prompt_feat=prompt_speech_feat,
|
|
185
|
+
embedding=flow_embedding,
|
|
186
|
+
uuid=this_uuid,
|
|
187
|
+
finalize=True)
|
|
154
188
|
yield {'tts_speech': this_tts_speech.cpu()}
|
|
155
189
|
else:
|
|
156
190
|
# deal with all tokens
|
|
157
191
|
p.join()
|
|
158
|
-
this_tts_speech_token = torch.
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
192
|
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
|
193
|
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
194
|
+
prompt_token=flow_prompt_speech_token,
|
|
195
|
+
prompt_feat=prompt_speech_feat,
|
|
196
|
+
embedding=flow_embedding,
|
|
197
|
+
uuid=this_uuid,
|
|
198
|
+
finalize=True,
|
|
199
|
+
speed=speed)
|
|
166
200
|
yield {'tts_speech': this_tts_speech.cpu()}
|
|
167
201
|
with self.lock:
|
|
168
202
|
self.tts_speech_token_dict.pop(this_uuid)
|
|
169
203
|
self.llm_end_dict.pop(this_uuid)
|
|
170
204
|
self.mel_overlap_dict.pop(this_uuid)
|
|
171
205
|
self.hift_cache_dict.pop(this_uuid)
|
|
172
|
-
|
|
173
|
-
|
|
206
|
+
self.flow_cache_dict.pop(this_uuid)
|
|
207
|
+
|
|
208
|
+
def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
|
|
209
|
+
# this_uuid is used to track variables related to this inference thread
|
|
210
|
+
this_uuid = str(uuid.uuid1())
|
|
211
|
+
with self.lock:
|
|
212
|
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
|
|
213
|
+
self.hift_cache_dict[this_uuid] = None
|
|
214
|
+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
|
215
|
+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
|
216
|
+
if stream is True:
|
|
217
|
+
token_hop_len = self.token_min_hop_len
|
|
218
|
+
while True:
|
|
219
|
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
|
220
|
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
|
221
|
+
.unsqueeze(dim=0)
|
|
222
|
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
223
|
+
prompt_token=flow_prompt_speech_token,
|
|
224
|
+
prompt_feat=prompt_speech_feat,
|
|
225
|
+
embedding=flow_embedding,
|
|
226
|
+
uuid=this_uuid,
|
|
227
|
+
finalize=False)
|
|
228
|
+
yield {'tts_speech': this_tts_speech.cpu()}
|
|
229
|
+
with self.lock:
|
|
230
|
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
|
231
|
+
# increase token_hop_len for better speech quality
|
|
232
|
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
|
233
|
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
|
234
|
+
break
|
|
235
|
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
|
236
|
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
|
237
|
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
238
|
+
prompt_token=flow_prompt_speech_token,
|
|
239
|
+
prompt_feat=prompt_speech_feat,
|
|
240
|
+
embedding=flow_embedding,
|
|
241
|
+
uuid=this_uuid,
|
|
242
|
+
finalize=True)
|
|
243
|
+
yield {'tts_speech': this_tts_speech.cpu()}
|
|
244
|
+
else:
|
|
245
|
+
# deal with all tokens
|
|
246
|
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
|
247
|
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
248
|
+
prompt_token=flow_prompt_speech_token,
|
|
249
|
+
prompt_feat=prompt_speech_feat,
|
|
250
|
+
embedding=flow_embedding,
|
|
251
|
+
uuid=this_uuid,
|
|
252
|
+
finalize=True,
|
|
253
|
+
speed=speed)
|
|
254
|
+
yield {'tts_speech': this_tts_speech.cpu()}
|
|
255
|
+
with self.lock:
|
|
256
|
+
self.tts_speech_token_dict.pop(this_uuid)
|
|
257
|
+
self.llm_end_dict.pop(this_uuid)
|
|
258
|
+
self.mel_overlap_dict.pop(this_uuid)
|
|
259
|
+
self.hift_cache_dict.pop(this_uuid)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class CosyVoice2Model:
|
|
263
|
+
|
|
264
|
+
def __init__(self,
|
|
265
|
+
llm: torch.nn.Module,
|
|
266
|
+
flow: torch.nn.Module,
|
|
267
|
+
hift: torch.nn.Module):
|
|
268
|
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
269
|
+
self.llm = llm
|
|
270
|
+
self.flow = flow
|
|
271
|
+
self.hift = hift
|
|
272
|
+
self.token_hop_len = 2 * self.flow.input_frame_rate
|
|
273
|
+
# here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
|
|
274
|
+
self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
|
|
275
|
+
self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
|
|
276
|
+
# hift cache
|
|
277
|
+
self.mel_cache_len = 8
|
|
278
|
+
self.source_cache_len = int(self.mel_cache_len * 480)
|
|
279
|
+
# speech fade in out
|
|
280
|
+
self.speech_window = np.hamming(2 * self.source_cache_len)
|
|
281
|
+
# rtf and decoding related
|
|
282
|
+
self.stream_scale_factor = 1
|
|
283
|
+
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
|
284
|
+
self.lock = threading.Lock()
|
|
285
|
+
# dict used to store session related variable
|
|
286
|
+
self.tts_speech_token_dict = {}
|
|
287
|
+
self.llm_end_dict = {}
|
|
288
|
+
self.hift_cache_dict = {}
|
|
289
|
+
|
|
290
|
+
def load(self, llm_model, flow_model, hift_model):
|
|
291
|
+
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
|
292
|
+
self.llm.to(self.device).eval()
|
|
293
|
+
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
|
294
|
+
self.flow.to(self.device).eval()
|
|
295
|
+
self.flow.decoder.fp16 = False
|
|
296
|
+
# in case hift_model is a hifigan model
|
|
297
|
+
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
|
298
|
+
self.hift.load_state_dict(hift_state_dict, strict=True)
|
|
299
|
+
self.hift.to(self.device).eval()
|
|
300
|
+
|
|
301
|
+
def load_jit(self, flow_encoder_model):
|
|
302
|
+
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
|
303
|
+
self.flow.encoder = flow_encoder
|
|
304
|
+
|
|
305
|
+
def load_onnx(self, flow_decoder_estimator_model):
|
|
306
|
+
import onnxruntime
|
|
307
|
+
option = onnxruntime.SessionOptions()
|
|
308
|
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
309
|
+
option.intra_op_num_threads = 1
|
|
310
|
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
|
311
|
+
del self.flow.decoder.estimator
|
|
312
|
+
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
|
313
|
+
|
|
314
|
+
def load_trt(self, flow_decoder_estimator_model):
|
|
315
|
+
del self.flow.decoder.estimator
|
|
316
|
+
import tensorrt as trt
|
|
317
|
+
with open(flow_decoder_estimator_model, 'rb') as f:
|
|
318
|
+
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
|
319
|
+
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
|
320
|
+
self.flow.decoder.fp16 = True
|
|
321
|
+
|
|
322
|
+
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
|
323
|
+
with self.llm_context:
|
|
324
|
+
for i in self.llm.inference(text=text.to(self.device),
|
|
325
|
+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
|
326
|
+
prompt_text=prompt_text.to(self.device),
|
|
327
|
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
|
328
|
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
|
329
|
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
|
330
|
+
embedding=llm_embedding.to(self.device)):
|
|
331
|
+
self.tts_speech_token_dict[uuid].append(i)
|
|
332
|
+
self.llm_end_dict[uuid] = True
|
|
333
|
+
|
|
334
|
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
|
335
|
+
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
|
336
|
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
|
337
|
+
prompt_token=prompt_token.to(self.device),
|
|
338
|
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
|
339
|
+
prompt_feat=prompt_feat.to(self.device),
|
|
340
|
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
|
341
|
+
embedding=embedding.to(self.device),
|
|
342
|
+
finalize=finalize)
|
|
343
|
+
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
|
344
|
+
# append hift cache
|
|
345
|
+
if self.hift_cache_dict[uuid] is not None:
|
|
346
|
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
|
347
|
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
|
348
|
+
else:
|
|
349
|
+
hift_cache_source = torch.zeros(1, 1, 0)
|
|
350
|
+
# keep overlap mel and hift cache
|
|
351
|
+
if finalize is False:
|
|
352
|
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
|
353
|
+
if self.hift_cache_dict[uuid] is not None:
|
|
354
|
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
|
355
|
+
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
|
356
|
+
'source': tts_source[:, :, -self.source_cache_len:],
|
|
357
|
+
'speech': tts_speech[:, -self.source_cache_len:]}
|
|
358
|
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
|
359
|
+
else:
|
|
360
|
+
if speed != 1.0:
|
|
361
|
+
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
|
362
|
+
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
|
363
|
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
|
364
|
+
if self.hift_cache_dict[uuid] is not None:
|
|
365
|
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
|
366
|
+
return tts_speech
|
|
367
|
+
|
|
368
|
+
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
|
369
|
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
|
370
|
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
|
371
|
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
|
372
|
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
|
373
|
+
# this_uuid is used to track variables related to this inference thread
|
|
374
|
+
this_uuid = str(uuid.uuid1())
|
|
375
|
+
with self.lock:
|
|
376
|
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
|
377
|
+
self.hift_cache_dict[this_uuid] = None
|
|
378
|
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
|
379
|
+
p.start()
|
|
380
|
+
if stream is True:
|
|
381
|
+
token_offset = 0
|
|
382
|
+
while True:
|
|
383
|
+
time.sleep(0.1)
|
|
384
|
+
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
|
|
385
|
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
|
|
386
|
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
387
|
+
prompt_token=flow_prompt_speech_token,
|
|
388
|
+
prompt_feat=prompt_speech_feat,
|
|
389
|
+
embedding=flow_embedding,
|
|
390
|
+
uuid=this_uuid,
|
|
391
|
+
token_offset=token_offset,
|
|
392
|
+
finalize=False)
|
|
393
|
+
token_offset += self.token_hop_len
|
|
394
|
+
yield {'tts_speech': this_tts_speech.cpu()}
|
|
395
|
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
|
|
396
|
+
break
|
|
397
|
+
p.join()
|
|
398
|
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
|
399
|
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
|
400
|
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
401
|
+
prompt_token=flow_prompt_speech_token,
|
|
402
|
+
prompt_feat=prompt_speech_feat,
|
|
403
|
+
embedding=flow_embedding,
|
|
404
|
+
uuid=this_uuid,
|
|
405
|
+
token_offset=token_offset,
|
|
406
|
+
finalize=True)
|
|
407
|
+
yield {'tts_speech': this_tts_speech.cpu()}
|
|
408
|
+
else:
|
|
409
|
+
# deal with all tokens
|
|
410
|
+
p.join()
|
|
411
|
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
|
412
|
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
413
|
+
prompt_token=flow_prompt_speech_token,
|
|
414
|
+
prompt_feat=prompt_speech_feat,
|
|
415
|
+
embedding=flow_embedding,
|
|
416
|
+
uuid=this_uuid,
|
|
417
|
+
token_offset=0,
|
|
418
|
+
finalize=True,
|
|
419
|
+
speed=speed)
|
|
420
|
+
yield {'tts_speech': this_tts_speech.cpu()}
|
|
421
|
+
with self.lock:
|
|
422
|
+
self.tts_speech_token_dict.pop(this_uuid)
|
|
423
|
+
self.llm_end_dict.pop(this_uuid)
|
|
@@ -126,6 +126,7 @@ class DataList(IterableDataset):
|
|
|
126
126
|
def Dataset(data_list_file,
|
|
127
127
|
data_pipeline,
|
|
128
128
|
mode='train',
|
|
129
|
+
gan=False,
|
|
129
130
|
shuffle=True,
|
|
130
131
|
partition=True,
|
|
131
132
|
tts_file='',
|
|
@@ -148,13 +149,16 @@ def Dataset(data_list_file,
|
|
|
148
149
|
tts_data = json.load(f)
|
|
149
150
|
utt2lists = read_json_lists(prompt_utt2data)
|
|
150
151
|
# filter unnecessary file in inference mode
|
|
151
|
-
lists = list(
|
|
152
|
+
lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
|
|
152
153
|
dataset = DataList(lists,
|
|
153
154
|
shuffle=shuffle,
|
|
154
155
|
partition=partition)
|
|
155
156
|
if mode == 'inference':
|
|
156
|
-
# map partial arg
|
|
157
|
+
# map partial arg to parquet_opener func in inference mode
|
|
157
158
|
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
|
|
159
|
+
if gan is True:
|
|
160
|
+
# map partial arg to padding func in gan mode
|
|
161
|
+
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
|
|
158
162
|
for func in data_pipeline:
|
|
159
163
|
dataset = Processor(dataset, func, mode=mode)
|
|
160
164
|
return dataset
|