xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (108) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +2 -1
  3. xinference/core/model.py +8 -4
  4. xinference/core/supervisor.py +2 -3
  5. xinference/core/worker.py +7 -5
  6. xinference/deploy/cmdline.py +2 -0
  7. xinference/deploy/local.py +5 -0
  8. xinference/deploy/test/test_cmdline.py +1 -1
  9. xinference/deploy/worker.py +6 -0
  10. xinference/model/audio/cosyvoice.py +0 -1
  11. xinference/model/audio/model_spec.json +44 -20
  12. xinference/model/core.py +3 -0
  13. xinference/model/embedding/flag/core.py +5 -0
  14. xinference/model/embedding/llama_cpp/core.py +22 -19
  15. xinference/model/embedding/sentence_transformers/core.py +18 -4
  16. xinference/model/embedding/vllm/core.py +36 -9
  17. xinference/model/image/cache_manager.py +56 -0
  18. xinference/model/image/core.py +9 -0
  19. xinference/model/image/model_spec.json +178 -1
  20. xinference/model/image/stable_diffusion/core.py +155 -23
  21. xinference/model/llm/cache_manager.py +17 -3
  22. xinference/model/llm/harmony.py +245 -0
  23. xinference/model/llm/llama_cpp/core.py +41 -40
  24. xinference/model/llm/llm_family.json +688 -11
  25. xinference/model/llm/llm_family.py +1 -1
  26. xinference/model/llm/sglang/core.py +108 -5
  27. xinference/model/llm/transformers/core.py +20 -18
  28. xinference/model/llm/transformers/gemma3.py +1 -1
  29. xinference/model/llm/transformers/gpt_oss.py +91 -0
  30. xinference/model/llm/transformers/multimodal/core.py +1 -1
  31. xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
  32. xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
  33. xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
  34. xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
  35. xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
  36. xinference/model/llm/transformers/utils.py +1 -33
  37. xinference/model/llm/utils.py +61 -7
  38. xinference/model/llm/vllm/core.py +44 -8
  39. xinference/model/rerank/__init__.py +66 -23
  40. xinference/model/rerank/cache_manager.py +35 -0
  41. xinference/model/rerank/core.py +87 -339
  42. xinference/model/rerank/custom.py +33 -8
  43. xinference/model/rerank/model_spec.json +251 -212
  44. xinference/model/rerank/rerank_family.py +137 -0
  45. xinference/model/rerank/sentence_transformers/__init__.py +13 -0
  46. xinference/model/rerank/sentence_transformers/core.py +337 -0
  47. xinference/model/rerank/vllm/__init__.py +13 -0
  48. xinference/model/rerank/vllm/core.py +156 -0
  49. xinference/model/utils.py +108 -0
  50. xinference/model/video/model_spec.json +95 -1
  51. xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
  52. xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
  53. xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
  54. xinference/thirdparty/cosyvoice/bin/train.py +23 -3
  55. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
  56. xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
  57. xinference/thirdparty/cosyvoice/cli/model.py +53 -75
  58. xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
  59. xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
  60. xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
  61. xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
  62. xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
  63. xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
  64. xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
  65. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
  66. xinference/thirdparty/cosyvoice/utils/common.py +20 -0
  67. xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
  68. xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
  69. xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
  70. xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
  71. xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
  72. xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
  73. xinference/types.py +2 -0
  74. xinference/ui/gradio/chat_interface.py +2 -0
  75. xinference/ui/gradio/media_interface.py +353 -7
  76. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  77. xinference/ui/web/ui/build/index.html +1 -1
  78. xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
  79. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
  80. xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
  81. xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
  82. xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
  83. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
  84. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
  85. xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
  86. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
  87. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
  88. xinference/ui/web/ui/src/locales/en.json +2 -0
  89. xinference/ui/web/ui/src/locales/ja.json +2 -0
  90. xinference/ui/web/ui/src/locales/ko.json +2 -0
  91. xinference/ui/web/ui/src/locales/zh.json +2 -0
  92. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
  93. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
  94. xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
  95. xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
  96. xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
  97. xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
  98. xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
  99. xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
  100. xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
  101. xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
  102. xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
  103. xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
  104. /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
  105. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
  106. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
  107. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
  108. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
2
3
  #
3
4
  # Licensed under the Apache License, Version 2.0 (the "License");
4
5
  # you may not use this file except in compliance with the License.
@@ -21,7 +22,8 @@ from torch.nn import functional as F
21
22
  from contextlib import nullcontext
22
23
  import uuid
23
24
  from cosyvoice.utils.common import fade_in_out
24
- from cosyvoice.utils.file_utils import convert_onnx_to_trt
25
+ from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
26
+ from cosyvoice.utils.common import TrtContextWrapper
25
27
 
26
28
 
27
29
  class CosyVoiceModel:
@@ -80,30 +82,28 @@ class CosyVoiceModel:
80
82
  flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
81
83
  self.flow.encoder = flow_encoder
82
84
 
83
- def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
85
+ def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
84
86
  assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
85
- if not os.path.exists(flow_decoder_estimator_model):
87
+ if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
86
88
  convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
87
- if os.path.getsize(flow_decoder_estimator_model) == 0:
88
- raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
89
89
  del self.flow.decoder.estimator
90
90
  import tensorrt as trt
91
91
  with open(flow_decoder_estimator_model, 'rb') as f:
92
- self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
93
- assert self.flow.decoder.estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
94
- self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
92
+ estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
93
+ assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
94
+ self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
95
95
 
96
96
  def get_trt_kwargs(self):
97
97
  min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
98
- opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200)]
98
+ opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
99
99
  max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
100
100
  input_names = ["x", "mask", "mu", "cond"]
101
101
  return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
102
102
 
103
103
  def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
104
- with self.llm_context, torch.cuda.amp.autocast(self.fp16):
104
+ with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
105
105
  if isinstance(text, Generator):
106
- assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
106
+ assert isinstance(self, CosyVoice2Model) and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2 and do not support vllm!'
107
107
  for i in self.llm.inference_bistream(text=text,
108
108
  prompt_text=prompt_text.to(self.device),
109
109
  prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
@@ -118,7 +118,8 @@ class CosyVoiceModel:
118
118
  prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
119
119
  prompt_speech_token=llm_prompt_speech_token.to(self.device),
120
120
  prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
121
- embedding=llm_embedding.to(self.device)):
121
+ embedding=llm_embedding.to(self.device),
122
+ uuid=uuid):
122
123
  self.tts_speech_token_dict[uuid].append(i)
123
124
  self.llm_end_dict[uuid] = True
124
125
 
@@ -231,7 +232,9 @@ class CosyVoiceModel:
231
232
  self.mel_overlap_dict.pop(this_uuid)
232
233
  self.hift_cache_dict.pop(this_uuid)
233
234
  self.flow_cache_dict.pop(this_uuid)
234
- torch.cuda.empty_cache()
235
+ if torch.cuda.is_available():
236
+ torch.cuda.empty_cache()
237
+ torch.cuda.current_stream().synchronize()
235
238
 
236
239
 
237
240
  class CosyVoice2Model(CosyVoiceModel):
@@ -240,20 +243,17 @@ class CosyVoice2Model(CosyVoiceModel):
240
243
  llm: torch.nn.Module,
241
244
  flow: torch.nn.Module,
242
245
  hift: torch.nn.Module,
243
- fp16: bool = False,
244
- use_flow_cache: bool = False):
246
+ fp16: bool = False):
245
247
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
246
248
  self.llm = llm
247
249
  self.flow = flow
248
250
  self.hift = hift
249
251
  self.fp16 = fp16
250
- self.use_flow_cache = use_flow_cache
251
252
  if self.fp16 is True:
252
253
  self.llm.half()
253
254
  self.flow.half()
254
- # stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
255
+ # NOTE must matching training static_chunk_size
255
256
  self.token_hop_len = 25
256
- self.flow_decoder_required_cache_size = 0 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio
257
257
  # hift cache
258
258
  self.mel_cache_len = 8
259
259
  self.source_cache_len = int(self.mel_cache_len * 480)
@@ -265,55 +265,35 @@ class CosyVoice2Model(CosyVoiceModel):
265
265
  # dict used to store session related variable
266
266
  self.tts_speech_token_dict = {}
267
267
  self.llm_end_dict = {}
268
- self.flow_cache_dict = {}
269
268
  self.hift_cache_dict = {}
270
269
 
271
- def init_flow_cache(self):
272
- encoder_cache = {'offset': 0,
273
- 'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device),
274
- 'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device),
275
- 'upsample_offset': 0,
276
- 'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device),
277
- 'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)}
278
- decoder_cache = {'offset': 0,
279
- 'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device),
280
- 'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
281
- 'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device),
282
- 'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
283
- 'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
284
- 'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
285
- 'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
286
- if self.fp16 is True:
287
- for cache in [encoder_cache, decoder_cache]:
288
- for k, v in cache.items():
289
- if isinstance(v, torch.Tensor):
290
- cache[k] = v.half()
291
- cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
292
- return cache
293
-
294
270
  def load_jit(self, flow_encoder_model):
295
271
  flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
296
272
  self.flow.encoder = flow_encoder
297
273
 
298
- def get_trt_kwargs(self):
299
- min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (1, 4, 2, 0, 512, 2), (12, 4, 2, 0, 512, 2), (1, 4, 2, 0, 512, 2)]
300
- opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200), (1, 4, 2, 100, 512, 2), (12, 4, 2, 100, 512, 2), (1, 4, 2, 100, 512, 2)]
301
- max_shape = [(2, 80, 1500), (2, 1, 1500), (2, 80, 1500), (2, 80, 1500), (1, 4, 2, 200, 512, 2), (12, 4, 2, 200, 512, 2), (1, 4, 2, 200, 512, 2)]
302
- input_names = ["x", "mask", "mu", "cond", 'down_blocks_kv_cache', 'mid_blocks_kv_cache', 'up_blocks_kv_cache']
303
- assert self.use_flow_cache is True, "get_trt_kwargs is set for flow cache mode. If you want to use trt with use_flow_cache=False, please set higher max_shape"
304
- return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
274
+ def load_vllm(self, model_dir):
275
+ export_cosyvoice2_vllm(self.llm, model_dir, self.device)
276
+ from vllm import EngineArgs, LLMEngine
277
+ engine_args = EngineArgs(model=model_dir,
278
+ skip_tokenizer_init=True,
279
+ enable_prompt_embeds=True,
280
+ gpu_memory_utilization=0.2)
281
+ self.llm.vllm = LLMEngine.from_engine_args(engine_args)
282
+ self.llm.lock = threading.Lock()
283
+ del self.llm.llm.model.model.layers
305
284
 
306
- def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
285
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
307
286
  with torch.cuda.amp.autocast(self.fp16):
308
- tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
309
- token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
310
- prompt_token=prompt_token.to(self.device),
311
- prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
312
- prompt_feat=prompt_feat.to(self.device),
313
- prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
314
- embedding=embedding.to(self.device),
315
- cache=self.flow_cache_dict[uuid],
316
- finalize=finalize)
287
+ tts_mel, _ = self.flow.inference(token=token.to(self.device),
288
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
289
+ prompt_token=prompt_token.to(self.device),
290
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
291
+ prompt_feat=prompt_feat.to(self.device),
292
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
293
+ embedding=embedding.to(self.device),
294
+ streaming=stream,
295
+ finalize=finalize)
296
+ tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
317
297
  # append hift cache
318
298
  if self.hift_cache_dict[uuid] is not None:
319
299
  hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
@@ -348,34 +328,30 @@ class CosyVoice2Model(CosyVoiceModel):
348
328
  with self.lock:
349
329
  self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
350
330
  self.hift_cache_dict[this_uuid] = None
351
- self.flow_cache_dict[this_uuid] = self.init_flow_cache()
352
331
  if source_speech_token.shape[1] == 0:
353
332
  p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
354
333
  else:
355
334
  p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
356
335
  p.start()
357
336
  if stream is True:
358
- assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM"
359
- # NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
360
- flow_prompt_speech_token = flow_prompt_speech_token[:, -int(self.flow_decoder_required_cache_size / self.flow.token_mel_ratio):]
361
- prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size:]
337
+ token_offset = 0
338
+ prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
362
339
  while True:
363
340
  time.sleep(0.1)
364
- if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len:
365
- this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
341
+ this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
342
+ if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
343
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
366
344
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
367
345
  prompt_token=flow_prompt_speech_token,
368
346
  prompt_feat=prompt_speech_feat,
369
347
  embedding=flow_embedding,
348
+ token_offset=token_offset,
370
349
  uuid=this_uuid,
350
+ stream=stream,
371
351
  finalize=False)
372
- # NOTE in cache inference mode, we only use flow_prompt_speech_token/prompt_speech_feat in first chunk
373
- flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
374
- prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
352
+ token_offset += this_token_hop_len
375
353
  yield {'tts_speech': this_tts_speech.cpu()}
376
- with self.lock:
377
- self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:]
378
- if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len:
354
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
379
355
  break
380
356
  p.join()
381
357
  # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
@@ -384,18 +360,19 @@ class CosyVoice2Model(CosyVoiceModel):
384
360
  prompt_token=flow_prompt_speech_token,
385
361
  prompt_feat=prompt_speech_feat,
386
362
  embedding=flow_embedding,
363
+ token_offset=token_offset,
387
364
  uuid=this_uuid,
388
365
  finalize=True)
389
366
  yield {'tts_speech': this_tts_speech.cpu()}
390
367
  else:
391
368
  # deal with all tokens
392
- assert self.use_flow_cache is False, "set use_flow_cache=False for nonstream inference"
393
369
  p.join()
394
370
  this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
395
371
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
396
372
  prompt_token=flow_prompt_speech_token,
397
373
  prompt_feat=prompt_speech_feat,
398
374
  embedding=flow_embedding,
375
+ token_offset=0,
399
376
  uuid=this_uuid,
400
377
  finalize=True,
401
378
  speed=speed)
@@ -404,5 +381,6 @@ class CosyVoice2Model(CosyVoiceModel):
404
381
  self.tts_speech_token_dict.pop(this_uuid)
405
382
  self.llm_end_dict.pop(this_uuid)
406
383
  self.hift_cache_dict.pop(this_uuid)
407
- self.flow_cache_dict.pop(this_uuid)
408
- torch.cuda.empty_cache()
384
+ if torch.cuda.is_available():
385
+ torch.cuda.empty_cache()
386
+ torch.cuda.current_stream().synchronize()
@@ -14,14 +14,13 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import random
17
- import json
18
17
  import math
19
18
  from functools import partial
20
19
 
21
20
  import torch
22
21
  import torch.distributed as dist
23
22
  from torch.utils.data import IterableDataset
24
- from cosyvoice.utils.file_utils import read_lists, read_json_lists
23
+ from cosyvoice.utils.file_utils import read_lists
25
24
 
26
25
 
27
26
  class Processor(IterableDataset):
@@ -127,10 +126,9 @@ def Dataset(data_list_file,
127
126
  data_pipeline,
128
127
  mode='train',
129
128
  gan=False,
129
+ dpo=False,
130
130
  shuffle=True,
131
- partition=True,
132
- tts_file='',
133
- prompt_utt2data=''):
131
+ partition=True):
134
132
  """ Construct dataset from arguments
135
133
 
136
134
  We have two shuffle stage in the Dataset. The first is global
@@ -142,23 +140,12 @@ def Dataset(data_list_file,
142
140
  tokenizer (BaseTokenizer): tokenizer to tokenize
143
141
  partition(bool): whether to do data partition in terms of rank
144
142
  """
145
- assert mode in ['train', 'inference']
146
143
  lists = read_lists(data_list_file)
147
- if mode == 'inference':
148
- with open(tts_file) as f:
149
- tts_data = json.load(f)
150
- utt2lists = read_json_lists(prompt_utt2data)
151
- # filter unnecessary file in inference mode
152
- lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
153
144
  dataset = DataList(lists,
154
145
  shuffle=shuffle,
155
146
  partition=partition)
156
- if mode == 'inference':
157
- # map partial arg to parquet_opener func in inference mode
158
- data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
159
- if gan is True:
160
- # map partial arg to padding func in gan mode
161
- data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
147
+ # map partial arg to padding func
148
+ data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo)
162
149
  for func in data_pipeline:
163
150
  dataset = Processor(dataset, func, mode=mode)
164
151
  return dataset
@@ -43,8 +43,6 @@ def parquet_opener(data, mode='train', tts_data={}):
43
43
  for df in pq.ParquetFile(url).iter_batches(batch_size=64):
44
44
  df = df.to_pandas()
45
45
  for i in range(len(df)):
46
- if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
47
- continue
48
46
  sample.update(dict(df.loc[i]))
49
47
  if mode == 'train':
50
48
  # NOTE do not return sample directly, must initialize a new dict
@@ -100,6 +98,8 @@ def filter(data,
100
98
  continue
101
99
  if len(sample['speech_token']) == 0:
102
100
  continue
101
+ if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
102
+ continue
103
103
  if num_frames != 0:
104
104
  if len(sample['text_token']) / num_frames < min_output_input_ratio:
105
105
  continue
@@ -159,6 +159,7 @@ def truncate(data, truncate_length=24576, mode='train'):
159
159
 
160
160
  def compute_fbank(data,
161
161
  feat_extractor,
162
+ token_mel_ratio=0,
162
163
  mode='train'):
163
164
  """ Extract fbank
164
165
 
@@ -174,8 +175,13 @@ def compute_fbank(data,
174
175
  assert 'utt' in sample
175
176
  assert 'text_token' in sample
176
177
  waveform = sample['speech']
177
- mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
178
- sample['speech_feat'] = mat
178
+ feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
179
+ if token_mel_ratio != 0:
180
+ # trim to align speech_token and speech_feat
181
+ token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
182
+ feat = feat[:token_mel_ratio * token_len]
183
+ sample["speech_token"] = sample["speech_token"][:token_len]
184
+ sample['speech_feat'] = feat
179
185
  yield sample
180
186
 
181
187
 
@@ -236,8 +242,6 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
236
242
  for sample in data:
237
243
  assert 'text' in sample
238
244
  sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
239
- if mode == 'inference':
240
- sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
241
245
  yield sample
242
246
 
243
247
 
@@ -345,18 +349,15 @@ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
345
349
  def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
346
350
  """ Wrapper for static/dynamic batch
347
351
  """
348
- if mode == 'inference':
349
- return static_batch(data, 1)
352
+ if batch_type == 'static':
353
+ return static_batch(data, batch_size)
354
+ elif batch_type == 'dynamic':
355
+ return dynamic_batch(data, max_frames_in_batch)
350
356
  else:
351
- if batch_type == 'static':
352
- return static_batch(data, batch_size)
353
- elif batch_type == 'dynamic':
354
- return dynamic_batch(data, max_frames_in_batch)
355
- else:
356
- logging.fatal('Unsupported batch type {}'.format(batch_type))
357
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
357
358
 
358
359
 
359
- def padding(data, use_spk_embedding, mode='train', gan=False):
360
+ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
360
361
  """ Padding the data into training data
361
362
 
362
363
  Args:
@@ -418,16 +419,14 @@ def padding(data, use_spk_embedding, mode='train', gan=False):
418
419
  # only gan train needs speech, delete it to save memory
419
420
  del batch["speech"]
420
421
  del batch["speech_len"]
421
- if mode == 'inference':
422
- tts_text = [sample[i]['tts_text'] for i in order]
423
- tts_index = [sample[i]['tts_index'] for i in order]
424
- tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
425
- tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
426
- tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
427
- batch.update({'tts_text': tts_text,
428
- 'tts_index': tts_index,
429
- 'tts_text_token': tts_text_token,
430
- 'tts_text_token_len': tts_text_token_len})
422
+ if dpo is True:
423
+ reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
424
+ reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
425
+ reject_speech_token = pad_sequence(reject_speech_token,
426
+ batch_first=True,
427
+ padding_value=0)
428
+ batch['reject_speech_token'] = reject_speech_token
429
+ batch['reject_speech_token_len'] = reject_speech_token_len
431
430
  if use_spk_embedding is True:
432
431
  batch["embedding"] = batch["spk_embedding"]
433
432
  else: