xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +2 -1
- xinference/core/model.py +8 -4
- xinference/core/supervisor.py +2 -3
- xinference/core/worker.py +7 -5
- xinference/deploy/cmdline.py +2 -0
- xinference/deploy/local.py +5 -0
- xinference/deploy/test/test_cmdline.py +1 -1
- xinference/deploy/worker.py +6 -0
- xinference/model/audio/cosyvoice.py +0 -1
- xinference/model/audio/model_spec.json +44 -20
- xinference/model/core.py +3 -0
- xinference/model/embedding/flag/core.py +5 -0
- xinference/model/embedding/llama_cpp/core.py +22 -19
- xinference/model/embedding/sentence_transformers/core.py +18 -4
- xinference/model/embedding/vllm/core.py +36 -9
- xinference/model/image/cache_manager.py +56 -0
- xinference/model/image/core.py +9 -0
- xinference/model/image/model_spec.json +178 -1
- xinference/model/image/stable_diffusion/core.py +155 -23
- xinference/model/llm/cache_manager.py +17 -3
- xinference/model/llm/harmony.py +245 -0
- xinference/model/llm/llama_cpp/core.py +41 -40
- xinference/model/llm/llm_family.json +688 -11
- xinference/model/llm/llm_family.py +1 -1
- xinference/model/llm/sglang/core.py +108 -5
- xinference/model/llm/transformers/core.py +20 -18
- xinference/model/llm/transformers/gemma3.py +1 -1
- xinference/model/llm/transformers/gpt_oss.py +91 -0
- xinference/model/llm/transformers/multimodal/core.py +1 -1
- xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
- xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
- xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
- xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
- xinference/model/llm/transformers/utils.py +1 -33
- xinference/model/llm/utils.py +61 -7
- xinference/model/llm/vllm/core.py +44 -8
- xinference/model/rerank/__init__.py +66 -23
- xinference/model/rerank/cache_manager.py +35 -0
- xinference/model/rerank/core.py +87 -339
- xinference/model/rerank/custom.py +33 -8
- xinference/model/rerank/model_spec.json +251 -212
- xinference/model/rerank/rerank_family.py +137 -0
- xinference/model/rerank/sentence_transformers/__init__.py +13 -0
- xinference/model/rerank/sentence_transformers/core.py +337 -0
- xinference/model/rerank/vllm/__init__.py +13 -0
- xinference/model/rerank/vllm/core.py +156 -0
- xinference/model/utils.py +108 -0
- xinference/model/video/model_spec.json +95 -1
- xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
- xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
- xinference/thirdparty/cosyvoice/bin/train.py +23 -3
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
- xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
- xinference/thirdparty/cosyvoice/cli/model.py +53 -75
- xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
- xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
- xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
- xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
- xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
- xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
- xinference/thirdparty/cosyvoice/utils/common.py +20 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
- xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
- xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
- xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
- xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
- xinference/types.py +2 -0
- xinference/ui/gradio/chat_interface.py +2 -0
- xinference/ui/gradio/media_interface.py +353 -7
- xinference/ui/web/ui/build/asset-manifest.json +3 -3
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
- xinference/ui/web/ui/src/locales/en.json +2 -0
- xinference/ui/web/ui/src/locales/ja.json +2 -0
- xinference/ui/web/ui/src/locales/ko.json +2 -0
- xinference/ui/web/ui/src/locales/zh.json +2 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
- xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
- xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
- /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
2
|
+
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -21,7 +22,8 @@ from torch.nn import functional as F
|
|
|
21
22
|
from contextlib import nullcontext
|
|
22
23
|
import uuid
|
|
23
24
|
from cosyvoice.utils.common import fade_in_out
|
|
24
|
-
from cosyvoice.utils.file_utils import convert_onnx_to_trt
|
|
25
|
+
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
|
26
|
+
from cosyvoice.utils.common import TrtContextWrapper
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
class CosyVoiceModel:
|
|
@@ -80,30 +82,28 @@ class CosyVoiceModel:
|
|
|
80
82
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
|
81
83
|
self.flow.encoder = flow_encoder
|
|
82
84
|
|
|
83
|
-
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
|
|
85
|
+
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
|
|
84
86
|
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
|
85
|
-
if not os.path.exists(flow_decoder_estimator_model):
|
|
87
|
+
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
|
86
88
|
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
|
|
87
|
-
if os.path.getsize(flow_decoder_estimator_model) == 0:
|
|
88
|
-
raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
|
|
89
89
|
del self.flow.decoder.estimator
|
|
90
90
|
import tensorrt as trt
|
|
91
91
|
with open(flow_decoder_estimator_model, 'rb') as f:
|
|
92
|
-
|
|
93
|
-
assert
|
|
94
|
-
self.flow.decoder.estimator = self.
|
|
92
|
+
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
|
93
|
+
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
|
94
|
+
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
|
95
95
|
|
|
96
96
|
def get_trt_kwargs(self):
|
|
97
97
|
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
|
98
|
-
opt_shape = [(2, 80,
|
|
98
|
+
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
|
|
99
99
|
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
|
|
100
100
|
input_names = ["x", "mask", "mu", "cond"]
|
|
101
101
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
|
102
102
|
|
|
103
103
|
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
|
104
|
-
with self.llm_context, torch.cuda.amp.autocast(self.fp16):
|
|
104
|
+
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
|
|
105
105
|
if isinstance(text, Generator):
|
|
106
|
-
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
|
|
106
|
+
assert isinstance(self, CosyVoice2Model) and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2 and do not support vllm!'
|
|
107
107
|
for i in self.llm.inference_bistream(text=text,
|
|
108
108
|
prompt_text=prompt_text.to(self.device),
|
|
109
109
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
|
@@ -118,7 +118,8 @@ class CosyVoiceModel:
|
|
|
118
118
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
|
119
119
|
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
|
120
120
|
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
|
121
|
-
embedding=llm_embedding.to(self.device)
|
|
121
|
+
embedding=llm_embedding.to(self.device),
|
|
122
|
+
uuid=uuid):
|
|
122
123
|
self.tts_speech_token_dict[uuid].append(i)
|
|
123
124
|
self.llm_end_dict[uuid] = True
|
|
124
125
|
|
|
@@ -231,7 +232,9 @@ class CosyVoiceModel:
|
|
|
231
232
|
self.mel_overlap_dict.pop(this_uuid)
|
|
232
233
|
self.hift_cache_dict.pop(this_uuid)
|
|
233
234
|
self.flow_cache_dict.pop(this_uuid)
|
|
234
|
-
torch.cuda.
|
|
235
|
+
if torch.cuda.is_available():
|
|
236
|
+
torch.cuda.empty_cache()
|
|
237
|
+
torch.cuda.current_stream().synchronize()
|
|
235
238
|
|
|
236
239
|
|
|
237
240
|
class CosyVoice2Model(CosyVoiceModel):
|
|
@@ -240,20 +243,17 @@ class CosyVoice2Model(CosyVoiceModel):
|
|
|
240
243
|
llm: torch.nn.Module,
|
|
241
244
|
flow: torch.nn.Module,
|
|
242
245
|
hift: torch.nn.Module,
|
|
243
|
-
fp16: bool = False
|
|
244
|
-
use_flow_cache: bool = False):
|
|
246
|
+
fp16: bool = False):
|
|
245
247
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
246
248
|
self.llm = llm
|
|
247
249
|
self.flow = flow
|
|
248
250
|
self.hift = hift
|
|
249
251
|
self.fp16 = fp16
|
|
250
|
-
self.use_flow_cache = use_flow_cache
|
|
251
252
|
if self.fp16 is True:
|
|
252
253
|
self.llm.half()
|
|
253
254
|
self.flow.half()
|
|
254
|
-
#
|
|
255
|
+
# NOTE must matching training static_chunk_size
|
|
255
256
|
self.token_hop_len = 25
|
|
256
|
-
self.flow_decoder_required_cache_size = 0 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio
|
|
257
257
|
# hift cache
|
|
258
258
|
self.mel_cache_len = 8
|
|
259
259
|
self.source_cache_len = int(self.mel_cache_len * 480)
|
|
@@ -265,55 +265,35 @@ class CosyVoice2Model(CosyVoiceModel):
|
|
|
265
265
|
# dict used to store session related variable
|
|
266
266
|
self.tts_speech_token_dict = {}
|
|
267
267
|
self.llm_end_dict = {}
|
|
268
|
-
self.flow_cache_dict = {}
|
|
269
268
|
self.hift_cache_dict = {}
|
|
270
269
|
|
|
271
|
-
def init_flow_cache(self):
|
|
272
|
-
encoder_cache = {'offset': 0,
|
|
273
|
-
'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device),
|
|
274
|
-
'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device),
|
|
275
|
-
'upsample_offset': 0,
|
|
276
|
-
'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device),
|
|
277
|
-
'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)}
|
|
278
|
-
decoder_cache = {'offset': 0,
|
|
279
|
-
'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device),
|
|
280
|
-
'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
|
|
281
|
-
'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device),
|
|
282
|
-
'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
|
|
283
|
-
'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
|
|
284
|
-
'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
|
|
285
|
-
'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
|
|
286
|
-
if self.fp16 is True:
|
|
287
|
-
for cache in [encoder_cache, decoder_cache]:
|
|
288
|
-
for k, v in cache.items():
|
|
289
|
-
if isinstance(v, torch.Tensor):
|
|
290
|
-
cache[k] = v.half()
|
|
291
|
-
cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
|
|
292
|
-
return cache
|
|
293
|
-
|
|
294
270
|
def load_jit(self, flow_encoder_model):
|
|
295
271
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
|
296
272
|
self.flow.encoder = flow_encoder
|
|
297
273
|
|
|
298
|
-
def
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
274
|
+
def load_vllm(self, model_dir):
|
|
275
|
+
export_cosyvoice2_vllm(self.llm, model_dir, self.device)
|
|
276
|
+
from vllm import EngineArgs, LLMEngine
|
|
277
|
+
engine_args = EngineArgs(model=model_dir,
|
|
278
|
+
skip_tokenizer_init=True,
|
|
279
|
+
enable_prompt_embeds=True,
|
|
280
|
+
gpu_memory_utilization=0.2)
|
|
281
|
+
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
|
|
282
|
+
self.llm.lock = threading.Lock()
|
|
283
|
+
del self.llm.llm.model.model.layers
|
|
305
284
|
|
|
306
|
-
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
|
285
|
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
|
307
286
|
with torch.cuda.amp.autocast(self.fp16):
|
|
308
|
-
tts_mel,
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
287
|
+
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
|
288
|
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
|
289
|
+
prompt_token=prompt_token.to(self.device),
|
|
290
|
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
|
291
|
+
prompt_feat=prompt_feat.to(self.device),
|
|
292
|
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
|
293
|
+
embedding=embedding.to(self.device),
|
|
294
|
+
streaming=stream,
|
|
295
|
+
finalize=finalize)
|
|
296
|
+
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
|
317
297
|
# append hift cache
|
|
318
298
|
if self.hift_cache_dict[uuid] is not None:
|
|
319
299
|
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
|
@@ -348,34 +328,30 @@ class CosyVoice2Model(CosyVoiceModel):
|
|
|
348
328
|
with self.lock:
|
|
349
329
|
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
|
350
330
|
self.hift_cache_dict[this_uuid] = None
|
|
351
|
-
self.flow_cache_dict[this_uuid] = self.init_flow_cache()
|
|
352
331
|
if source_speech_token.shape[1] == 0:
|
|
353
332
|
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
|
354
333
|
else:
|
|
355
334
|
p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
|
|
356
335
|
p.start()
|
|
357
336
|
if stream is True:
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
flow_prompt_speech_token = flow_prompt_speech_token[:, -int(self.flow_decoder_required_cache_size / self.flow.token_mel_ratio):]
|
|
361
|
-
prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size:]
|
|
337
|
+
token_offset = 0
|
|
338
|
+
prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
|
|
362
339
|
while True:
|
|
363
340
|
time.sleep(0.1)
|
|
364
|
-
|
|
365
|
-
|
|
341
|
+
this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
|
|
342
|
+
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
|
|
343
|
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
|
|
366
344
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
367
345
|
prompt_token=flow_prompt_speech_token,
|
|
368
346
|
prompt_feat=prompt_speech_feat,
|
|
369
347
|
embedding=flow_embedding,
|
|
348
|
+
token_offset=token_offset,
|
|
370
349
|
uuid=this_uuid,
|
|
350
|
+
stream=stream,
|
|
371
351
|
finalize=False)
|
|
372
|
-
|
|
373
|
-
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
|
|
374
|
-
prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
|
|
352
|
+
token_offset += this_token_hop_len
|
|
375
353
|
yield {'tts_speech': this_tts_speech.cpu()}
|
|
376
|
-
|
|
377
|
-
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:]
|
|
378
|
-
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len:
|
|
354
|
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
|
|
379
355
|
break
|
|
380
356
|
p.join()
|
|
381
357
|
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
|
@@ -384,18 +360,19 @@ class CosyVoice2Model(CosyVoiceModel):
|
|
|
384
360
|
prompt_token=flow_prompt_speech_token,
|
|
385
361
|
prompt_feat=prompt_speech_feat,
|
|
386
362
|
embedding=flow_embedding,
|
|
363
|
+
token_offset=token_offset,
|
|
387
364
|
uuid=this_uuid,
|
|
388
365
|
finalize=True)
|
|
389
366
|
yield {'tts_speech': this_tts_speech.cpu()}
|
|
390
367
|
else:
|
|
391
368
|
# deal with all tokens
|
|
392
|
-
assert self.use_flow_cache is False, "set use_flow_cache=False for nonstream inference"
|
|
393
369
|
p.join()
|
|
394
370
|
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
|
395
371
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
|
396
372
|
prompt_token=flow_prompt_speech_token,
|
|
397
373
|
prompt_feat=prompt_speech_feat,
|
|
398
374
|
embedding=flow_embedding,
|
|
375
|
+
token_offset=0,
|
|
399
376
|
uuid=this_uuid,
|
|
400
377
|
finalize=True,
|
|
401
378
|
speed=speed)
|
|
@@ -404,5 +381,6 @@ class CosyVoice2Model(CosyVoiceModel):
|
|
|
404
381
|
self.tts_speech_token_dict.pop(this_uuid)
|
|
405
382
|
self.llm_end_dict.pop(this_uuid)
|
|
406
383
|
self.hift_cache_dict.pop(this_uuid)
|
|
407
|
-
|
|
408
|
-
|
|
384
|
+
if torch.cuda.is_available():
|
|
385
|
+
torch.cuda.empty_cache()
|
|
386
|
+
torch.cuda.current_stream().synchronize()
|
|
@@ -14,14 +14,13 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import random
|
|
17
|
-
import json
|
|
18
17
|
import math
|
|
19
18
|
from functools import partial
|
|
20
19
|
|
|
21
20
|
import torch
|
|
22
21
|
import torch.distributed as dist
|
|
23
22
|
from torch.utils.data import IterableDataset
|
|
24
|
-
from cosyvoice.utils.file_utils import read_lists
|
|
23
|
+
from cosyvoice.utils.file_utils import read_lists
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
class Processor(IterableDataset):
|
|
@@ -127,10 +126,9 @@ def Dataset(data_list_file,
|
|
|
127
126
|
data_pipeline,
|
|
128
127
|
mode='train',
|
|
129
128
|
gan=False,
|
|
129
|
+
dpo=False,
|
|
130
130
|
shuffle=True,
|
|
131
|
-
partition=True
|
|
132
|
-
tts_file='',
|
|
133
|
-
prompt_utt2data=''):
|
|
131
|
+
partition=True):
|
|
134
132
|
""" Construct dataset from arguments
|
|
135
133
|
|
|
136
134
|
We have two shuffle stage in the Dataset. The first is global
|
|
@@ -142,23 +140,12 @@ def Dataset(data_list_file,
|
|
|
142
140
|
tokenizer (BaseTokenizer): tokenizer to tokenize
|
|
143
141
|
partition(bool): whether to do data partition in terms of rank
|
|
144
142
|
"""
|
|
145
|
-
assert mode in ['train', 'inference']
|
|
146
143
|
lists = read_lists(data_list_file)
|
|
147
|
-
if mode == 'inference':
|
|
148
|
-
with open(tts_file) as f:
|
|
149
|
-
tts_data = json.load(f)
|
|
150
|
-
utt2lists = read_json_lists(prompt_utt2data)
|
|
151
|
-
# filter unnecessary file in inference mode
|
|
152
|
-
lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
|
|
153
144
|
dataset = DataList(lists,
|
|
154
145
|
shuffle=shuffle,
|
|
155
146
|
partition=partition)
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
|
|
159
|
-
if gan is True:
|
|
160
|
-
# map partial arg to padding func in gan mode
|
|
161
|
-
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
|
|
147
|
+
# map partial arg to padding func
|
|
148
|
+
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo)
|
|
162
149
|
for func in data_pipeline:
|
|
163
150
|
dataset = Processor(dataset, func, mode=mode)
|
|
164
151
|
return dataset
|
|
@@ -43,8 +43,6 @@ def parquet_opener(data, mode='train', tts_data={}):
|
|
|
43
43
|
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
|
|
44
44
|
df = df.to_pandas()
|
|
45
45
|
for i in range(len(df)):
|
|
46
|
-
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
|
|
47
|
-
continue
|
|
48
46
|
sample.update(dict(df.loc[i]))
|
|
49
47
|
if mode == 'train':
|
|
50
48
|
# NOTE do not return sample directly, must initialize a new dict
|
|
@@ -100,6 +98,8 @@ def filter(data,
|
|
|
100
98
|
continue
|
|
101
99
|
if len(sample['speech_token']) == 0:
|
|
102
100
|
continue
|
|
101
|
+
if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
|
|
102
|
+
continue
|
|
103
103
|
if num_frames != 0:
|
|
104
104
|
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
|
105
105
|
continue
|
|
@@ -159,6 +159,7 @@ def truncate(data, truncate_length=24576, mode='train'):
|
|
|
159
159
|
|
|
160
160
|
def compute_fbank(data,
|
|
161
161
|
feat_extractor,
|
|
162
|
+
token_mel_ratio=0,
|
|
162
163
|
mode='train'):
|
|
163
164
|
""" Extract fbank
|
|
164
165
|
|
|
@@ -174,8 +175,13 @@ def compute_fbank(data,
|
|
|
174
175
|
assert 'utt' in sample
|
|
175
176
|
assert 'text_token' in sample
|
|
176
177
|
waveform = sample['speech']
|
|
177
|
-
|
|
178
|
-
|
|
178
|
+
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
|
179
|
+
if token_mel_ratio != 0:
|
|
180
|
+
# trim to align speech_token and speech_feat
|
|
181
|
+
token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
|
|
182
|
+
feat = feat[:token_mel_ratio * token_len]
|
|
183
|
+
sample["speech_token"] = sample["speech_token"][:token_len]
|
|
184
|
+
sample['speech_feat'] = feat
|
|
179
185
|
yield sample
|
|
180
186
|
|
|
181
187
|
|
|
@@ -236,8 +242,6 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
|
|
236
242
|
for sample in data:
|
|
237
243
|
assert 'text' in sample
|
|
238
244
|
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
|
239
|
-
if mode == 'inference':
|
|
240
|
-
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
|
|
241
245
|
yield sample
|
|
242
246
|
|
|
243
247
|
|
|
@@ -345,18 +349,15 @@ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
|
|
|
345
349
|
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
|
|
346
350
|
""" Wrapper for static/dynamic batch
|
|
347
351
|
"""
|
|
348
|
-
if
|
|
349
|
-
return static_batch(data,
|
|
352
|
+
if batch_type == 'static':
|
|
353
|
+
return static_batch(data, batch_size)
|
|
354
|
+
elif batch_type == 'dynamic':
|
|
355
|
+
return dynamic_batch(data, max_frames_in_batch)
|
|
350
356
|
else:
|
|
351
|
-
|
|
352
|
-
return static_batch(data, batch_size)
|
|
353
|
-
elif batch_type == 'dynamic':
|
|
354
|
-
return dynamic_batch(data, max_frames_in_batch)
|
|
355
|
-
else:
|
|
356
|
-
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
|
357
|
+
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
|
357
358
|
|
|
358
359
|
|
|
359
|
-
def padding(data, use_spk_embedding, mode='train', gan=False):
|
|
360
|
+
def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
|
360
361
|
""" Padding the data into training data
|
|
361
362
|
|
|
362
363
|
Args:
|
|
@@ -418,16 +419,14 @@ def padding(data, use_spk_embedding, mode='train', gan=False):
|
|
|
418
419
|
# only gan train needs speech, delete it to save memory
|
|
419
420
|
del batch["speech"]
|
|
420
421
|
del batch["speech_len"]
|
|
421
|
-
if
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
batch
|
|
428
|
-
|
|
429
|
-
'tts_text_token': tts_text_token,
|
|
430
|
-
'tts_text_token_len': tts_text_token_len})
|
|
422
|
+
if dpo is True:
|
|
423
|
+
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
|
|
424
|
+
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
|
|
425
|
+
reject_speech_token = pad_sequence(reject_speech_token,
|
|
426
|
+
batch_first=True,
|
|
427
|
+
padding_value=0)
|
|
428
|
+
batch['reject_speech_token'] = reject_speech_token
|
|
429
|
+
batch['reject_speech_token_len'] = reject_speech_token_len
|
|
431
430
|
if use_spk_embedding is True:
|
|
432
431
|
batch["embedding"] = batch["spk_embedding"]
|
|
433
432
|
else:
|