xinference 1.9.0__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (92) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +415 -1
  3. xinference/constants.py +2 -0
  4. xinference/core/model.py +3 -4
  5. xinference/core/supervisor.py +29 -1
  6. xinference/core/worker.py +4 -1
  7. xinference/deploy/cmdline.py +2 -0
  8. xinference/deploy/test/test_cmdline.py +1 -1
  9. xinference/model/audio/core.py +5 -0
  10. xinference/model/audio/cosyvoice.py +0 -1
  11. xinference/model/audio/kokoro.py +1 -1
  12. xinference/model/audio/kokoro_zh.py +124 -0
  13. xinference/model/audio/model_spec.json +64 -20
  14. xinference/model/embedding/flag/core.py +5 -0
  15. xinference/model/embedding/llama_cpp/core.py +22 -19
  16. xinference/model/embedding/sentence_transformers/core.py +19 -4
  17. xinference/model/embedding/vllm/core.py +40 -8
  18. xinference/model/image/cache_manager.py +56 -0
  19. xinference/model/image/core.py +9 -0
  20. xinference/model/image/model_spec.json +116 -9
  21. xinference/model/image/stable_diffusion/core.py +141 -31
  22. xinference/model/llm/core.py +10 -0
  23. xinference/model/llm/llama_cpp/core.py +42 -40
  24. xinference/model/llm/llm_family.json +435 -23
  25. xinference/model/llm/llm_family.py +1 -0
  26. xinference/model/llm/mlx/core.py +52 -33
  27. xinference/model/llm/sglang/core.py +2 -44
  28. xinference/model/llm/tool_parsers/__init__.py +58 -0
  29. xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
  30. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +128 -0
  31. xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
  32. xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
  33. xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
  34. xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
  35. xinference/model/llm/transformers/core.py +6 -12
  36. xinference/model/llm/utils.py +128 -46
  37. xinference/model/llm/vllm/core.py +8 -61
  38. xinference/model/rerank/core.py +3 -0
  39. xinference/model/rerank/sentence_transformers/core.py +1 -1
  40. xinference/model/rerank/vllm/core.py +56 -6
  41. xinference/model/utils.py +1 -2
  42. xinference/model/video/model_spec.json +95 -1
  43. xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
  44. xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
  45. xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
  46. xinference/thirdparty/cosyvoice/bin/train.py +23 -3
  47. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
  48. xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
  49. xinference/thirdparty/cosyvoice/cli/model.py +53 -75
  50. xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
  51. xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
  52. xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
  53. xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
  54. xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
  55. xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
  56. xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
  57. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
  58. xinference/thirdparty/cosyvoice/utils/common.py +20 -0
  59. xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
  60. xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
  61. xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
  62. xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
  63. xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
  64. xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
  65. xinference/types.py +105 -2
  66. xinference/ui/gradio/chat_interface.py +2 -0
  67. xinference/ui/gradio/media_interface.py +353 -7
  68. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  69. xinference/ui/web/ui/build/index.html +1 -1
  70. xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
  71. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
  72. xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
  73. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
  74. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
  75. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
  76. xinference/ui/web/ui/src/locales/en.json +2 -0
  77. xinference/ui/web/ui/src/locales/ja.json +2 -0
  78. xinference/ui/web/ui/src/locales/ko.json +2 -0
  79. xinference/ui/web/ui/src/locales/zh.json +2 -0
  80. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/METADATA +16 -12
  81. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/RECORD +86 -77
  82. xinference/ui/web/ui/build/static/js/main.4918643a.js +0 -3
  83. xinference/ui/web/ui/build/static/js/main.4918643a.js.map +0 -1
  84. xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
  85. xinference/ui/web/ui/node_modules/.cache/babel-loader/89179f8f51887b9167721860a12412549ff04f78162e921a7b6aa6532646deb2.json +0 -1
  86. xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
  87. xinference/ui/web/ui/node_modules/.cache/babel-loader/9dc5cfc67dd0617b0272aeef8651f1589b2155a4ff1fd72ad3166b217089b619.json +0 -1
  88. /xinference/ui/web/ui/build/static/js/{main.4918643a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
  89. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/WHEEL +0 -0
  90. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/entry_points.txt +0 -0
  91. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/licenses/LICENSE +0 -0
  92. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
2
3
  #
3
4
  # Licensed under the Apache License, Version 2.0 (the "License");
4
5
  # you may not use this file except in compliance with the License.
@@ -11,10 +12,10 @@
11
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
13
  # See the License for the specific language governing permissions and
13
14
  # limitations under the License.
14
- import threading
15
15
  import torch
16
16
  import torch.nn.functional as F
17
17
  from matcha.models.components.flow_matching import BASECFM
18
+ from cosyvoice.utils.common import set_all_random_seed
18
19
 
19
20
 
20
21
  class ConditionalCFM(BASECFM):
@@ -31,7 +32,6 @@ class ConditionalCFM(BASECFM):
31
32
  in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
32
33
  # Just change the architecture of the estimator here
33
34
  self.estimator = estimator
34
- self.lock = threading.Lock()
35
35
 
36
36
  @torch.inference_mode()
37
37
  def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
@@ -68,7 +68,7 @@ class ConditionalCFM(BASECFM):
68
68
  t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
69
69
  return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
70
70
 
71
- def solve_euler(self, x, t_span, mu, mask, spks, cond):
71
+ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
72
72
  """
73
73
  Fixed euler solver for ODEs.
74
74
  Args:
@@ -109,7 +109,8 @@ class ConditionalCFM(BASECFM):
109
109
  x_in, mask_in,
110
110
  mu_in, t_in,
111
111
  spks_in,
112
- cond_in
112
+ cond_in,
113
+ streaming
113
114
  )
114
115
  dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
115
116
  dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
@@ -121,25 +122,33 @@ class ConditionalCFM(BASECFM):
121
122
 
122
123
  return sol[-1].float()
123
124
 
124
- def forward_estimator(self, x, mask, mu, t, spks, cond):
125
+ def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
125
126
  if isinstance(self.estimator, torch.nn.Module):
126
- return self.estimator(x, mask, mu, t, spks, cond)
127
+ return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
127
128
  else:
128
- with self.lock:
129
- self.estimator.set_input_shape('x', (2, 80, x.size(2)))
130
- self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
131
- self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
132
- self.estimator.set_input_shape('t', (2,))
133
- self.estimator.set_input_shape('spks', (2, 80))
134
- self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
129
+ [estimator, stream], trt_engine = self.estimator.acquire_estimator()
130
+ # NOTE need to synchronize when switching stream
131
+ torch.cuda.current_stream().synchronize()
132
+ with stream:
133
+ estimator.set_input_shape('x', (2, 80, x.size(2)))
134
+ estimator.set_input_shape('mask', (2, 1, x.size(2)))
135
+ estimator.set_input_shape('mu', (2, 80, x.size(2)))
136
+ estimator.set_input_shape('t', (2,))
137
+ estimator.set_input_shape('spks', (2, 80))
138
+ estimator.set_input_shape('cond', (2, 80, x.size(2)))
139
+ data_ptrs = [x.contiguous().data_ptr(),
140
+ mask.contiguous().data_ptr(),
141
+ mu.contiguous().data_ptr(),
142
+ t.contiguous().data_ptr(),
143
+ spks.contiguous().data_ptr(),
144
+ cond.contiguous().data_ptr(),
145
+ x.data_ptr()]
146
+ for i, j in enumerate(data_ptrs):
147
+ estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
135
148
  # run trt engine
136
- assert self.estimator.execute_v2([x.contiguous().data_ptr(),
137
- mask.contiguous().data_ptr(),
138
- mu.contiguous().data_ptr(),
139
- t.contiguous().data_ptr(),
140
- spks.contiguous().data_ptr(),
141
- cond.contiguous().data_ptr(),
142
- x.data_ptr()]) is True
149
+ assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
150
+ torch.cuda.current_stream().synchronize()
151
+ self.estimator.release_estimator(estimator, stream)
143
152
  return x
144
153
 
145
154
  def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
@@ -187,10 +196,11 @@ class ConditionalCFM(BASECFM):
187
196
  class CausalConditionalCFM(ConditionalCFM):
188
197
  def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
189
198
  super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
199
+ set_all_random_seed(0)
190
200
  self.rand_noise = torch.randn([1, 80, 50 * 300])
191
201
 
192
202
  @torch.inference_mode()
193
- def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}):
203
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
194
204
  """Forward diffusion
195
205
 
196
206
  Args:
@@ -209,131 +219,9 @@ class CausalConditionalCFM(ConditionalCFM):
209
219
  shape: (batch_size, n_feats, mel_timesteps)
210
220
  """
211
221
 
212
- offset = cache.pop('offset')
213
- z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature
214
- z = z[:, :, offset:]
215
- offset += mu.size(2)
222
+ z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
216
223
  # fix prompt and overlap part mu and z
217
224
  t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
218
225
  if self.t_scheduler == 'cosine':
219
226
  t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
220
- mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache)
221
- cache['offset'] = offset
222
- return mel, cache
223
-
224
- def solve_euler(self, x, t_span, mu, mask, spks, cond, cache):
225
- """
226
- Fixed euler solver for ODEs.
227
- Args:
228
- x (torch.Tensor): random noise
229
- t_span (torch.Tensor): n_timesteps interpolated
230
- shape: (n_timesteps + 1,)
231
- mu (torch.Tensor): output of encoder
232
- shape: (batch_size, n_feats, mel_timesteps)
233
- mask (torch.Tensor): output_mask
234
- shape: (batch_size, 1, mel_timesteps)
235
- spks (torch.Tensor, optional): speaker ids. Defaults to None.
236
- shape: (batch_size, spk_emb_dim)
237
- cond: Not used but kept for future purposes
238
- """
239
- t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
240
- t = t.unsqueeze(dim=0)
241
-
242
- # I am storing this because I can later plot it by putting a debugger here and saving it to a file
243
- # Or in future might add like a return_all_steps flag
244
- sol = []
245
-
246
- # Do not use concat, it may cause memory format changed and trt infer with wrong results!
247
- x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
248
- mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
249
- mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
250
- t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
251
- spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
252
- cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
253
- flow_cache_size = cache['down_blocks_kv_cache'].shape[4]
254
- for step in range(1, len(t_span)):
255
- # Classifier-Free Guidance inference introduced in VoiceBox
256
- x_in[:] = x
257
- mask_in[:] = mask
258
- mu_in[0] = mu
259
- t_in[:] = t.unsqueeze(0)
260
- spks_in[0] = spks
261
- cond_in[0] = cond
262
- cache_step = {k: v[step - 1] for k, v in cache.items()}
263
- dphi_dt, cache_step = self.forward_estimator(
264
- x_in, mask_in,
265
- mu_in, t_in,
266
- spks_in,
267
- cond_in,
268
- cache_step
269
- )
270
- # NOTE if smaller than flow_cache_size, means last chunk, no need to cache
271
- if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size:
272
- cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
273
- cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:, :, :, -flow_cache_size:]
274
- cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
275
- cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:, :, :, -flow_cache_size:]
276
- cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
277
- cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:, :, :, -flow_cache_size:]
278
- cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
279
- dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
280
- dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
281
- x = x + dt * dphi_dt
282
- t = t + dt
283
- sol.append(x)
284
- if step < len(t_span) - 1:
285
- dt = t_span[step + 1] - t
286
- return sol[-1].float(), cache
287
-
288
- def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
289
- if isinstance(self.estimator, torch.nn.Module):
290
- x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
291
- cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
292
- else:
293
- with self.lock:
294
- self.estimator.set_input_shape('x', (2, 80, x.size(2)))
295
- self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
296
- self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
297
- self.estimator.set_input_shape('t', (2,))
298
- self.estimator.set_input_shape('spks', (2, 80))
299
- self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
300
- self.estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
301
- self.estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
302
- self.estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
303
- self.estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
304
- self.estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
305
- self.estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
306
- self.estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
307
- # run trt engine
308
- down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
309
- mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
310
- up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
311
- assert self.estimator.execute_v2([x.contiguous().data_ptr(),
312
- mask.contiguous().data_ptr(),
313
- mu.contiguous().data_ptr(),
314
- t.contiguous().data_ptr(),
315
- spks.contiguous().data_ptr(),
316
- cond.contiguous().data_ptr(),
317
- cache['down_blocks_conv_cache'].contiguous().data_ptr(),
318
- cache['down_blocks_kv_cache'].contiguous().data_ptr(),
319
- cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
320
- cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
321
- cache['up_blocks_conv_cache'].contiguous().data_ptr(),
322
- cache['up_blocks_kv_cache'].contiguous().data_ptr(),
323
- cache['final_blocks_conv_cache'].contiguous().data_ptr(),
324
- x.data_ptr(),
325
- cache['down_blocks_conv_cache'].data_ptr(),
326
- down_blocks_kv_cache_out.data_ptr(),
327
- cache['mid_blocks_conv_cache'].data_ptr(),
328
- mid_blocks_kv_cache_out.data_ptr(),
329
- cache['up_blocks_conv_cache'].data_ptr(),
330
- up_blocks_kv_cache_out.data_ptr(),
331
- cache['final_blocks_conv_cache'].data_ptr()]) is True
332
- cache = (cache['down_blocks_conv_cache'],
333
- down_blocks_kv_cache_out,
334
- cache['mid_blocks_conv_cache'],
335
- mid_blocks_kv_cache_out,
336
- cache['up_blocks_conv_cache'],
337
- up_blocks_kv_cache_out,
338
- cache['final_blocks_conv_cache'])
339
- return x, cache
227
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
@@ -223,6 +223,172 @@ class SourceModuleHnNSF(torch.nn.Module):
223
223
  return sine_merge, noise, uv
224
224
 
225
225
 
226
+ class SineGen2(torch.nn.Module):
227
+ """ Definition of sine generator
228
+ SineGen(samp_rate, harmonic_num = 0,
229
+ sine_amp = 0.1, noise_std = 0.003,
230
+ voiced_threshold = 0,
231
+ flag_for_pulse=False)
232
+ samp_rate: sampling rate in Hz
233
+ harmonic_num: number of harmonic overtones (default 0)
234
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
235
+ noise_std: std of Gaussian noise (default 0.003)
236
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
237
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
238
+ Note: when flag_for_pulse is True, the first time step of a voiced
239
+ segment is always sin(np.pi) or cos(0)
240
+ """
241
+
242
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
243
+ sine_amp=0.1, noise_std=0.003,
244
+ voiced_threshold=0,
245
+ flag_for_pulse=False):
246
+ super(SineGen2, self).__init__()
247
+ self.sine_amp = sine_amp
248
+ self.noise_std = noise_std
249
+ self.harmonic_num = harmonic_num
250
+ self.dim = self.harmonic_num + 1
251
+ self.sampling_rate = samp_rate
252
+ self.voiced_threshold = voiced_threshold
253
+ self.flag_for_pulse = flag_for_pulse
254
+ self.upsample_scale = upsample_scale
255
+
256
+ def _f02uv(self, f0):
257
+ # generate uv signal
258
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
259
+ return uv
260
+
261
+ def _f02sine(self, f0_values):
262
+ """ f0_values: (batchsize, length, dim)
263
+ where dim indicates fundamental tone and overtones
264
+ """
265
+ # convert to F0 in rad. The interger part n can be ignored
266
+ # because 2 * np.pi * n doesn't affect phase
267
+ rad_values = (f0_values / self.sampling_rate) % 1
268
+
269
+ # initial phase noise (no noise for fundamental component)
270
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
271
+ rand_ini[:, 0] = 0
272
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
273
+
274
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
275
+ if not self.flag_for_pulse:
276
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
277
+ scale_factor=1 / self.upsample_scale,
278
+ mode="linear").transpose(1, 2)
279
+
280
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
281
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
282
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
283
+ sines = torch.sin(phase)
284
+ else:
285
+ # If necessary, make sure that the first time step of every
286
+ # voiced segments is sin(pi) or cos(0)
287
+ # This is used for pulse-train generation
288
+
289
+ # identify the last time step in unvoiced segments
290
+ uv = self._f02uv(f0_values)
291
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
292
+ uv_1[:, -1, :] = 1
293
+ u_loc = (uv < 1) * (uv_1 > 0)
294
+
295
+ # get the instantanouse phase
296
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
297
+ # different batch needs to be processed differently
298
+ for idx in range(f0_values.shape[0]):
299
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
300
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
301
+ # stores the accumulation of i.phase within
302
+ # each voiced segments
303
+ tmp_cumsum[idx, :, :] = 0
304
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
305
+
306
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
307
+ # within the previous voiced segment.
308
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
309
+
310
+ # get the sines
311
+ sines = torch.cos(i_phase * 2 * np.pi)
312
+ return sines
313
+
314
+ def forward(self, f0):
315
+ """ sine_tensor, uv = forward(f0)
316
+ input F0: tensor(batchsize=1, length, dim=1)
317
+ f0 for unvoiced steps should be 0
318
+ output sine_tensor: tensor(batchsize=1, length, dim)
319
+ output uv: tensor(batchsize=1, length, 1)
320
+ """
321
+ # fundamental component
322
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
323
+
324
+ # generate sine waveforms
325
+ sine_waves = self._f02sine(fn) * self.sine_amp
326
+
327
+ # generate uv signal
328
+ uv = self._f02uv(f0)
329
+
330
+ # noise: for unvoiced should be similar to sine_amp
331
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
332
+ # . for voiced regions is self.noise_std
333
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
334
+ noise = noise_amp * torch.randn_like(sine_waves)
335
+
336
+ # first: set the unvoiced part to 0 by uv
337
+ # then: additive noise
338
+ sine_waves = sine_waves * uv + noise
339
+ return sine_waves, uv, noise
340
+
341
+
342
+ class SourceModuleHnNSF2(torch.nn.Module):
343
+ """ SourceModule for hn-nsf
344
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
345
+ add_noise_std=0.003, voiced_threshod=0)
346
+ sampling_rate: sampling_rate in Hz
347
+ harmonic_num: number of harmonic above F0 (default: 0)
348
+ sine_amp: amplitude of sine source signal (default: 0.1)
349
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
350
+ note that amplitude of noise in unvoiced is decided
351
+ by sine_amp
352
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
353
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
354
+ F0_sampled (batchsize, length, 1)
355
+ Sine_source (batchsize, length, 1)
356
+ noise_source (batchsize, length 1)
357
+ uv (batchsize, length, 1)
358
+ """
359
+
360
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
361
+ add_noise_std=0.003, voiced_threshod=0):
362
+ super(SourceModuleHnNSF2, self).__init__()
363
+
364
+ self.sine_amp = sine_amp
365
+ self.noise_std = add_noise_std
366
+
367
+ # to produce sine waveforms
368
+ self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
369
+ sine_amp, add_noise_std, voiced_threshod)
370
+
371
+ # to merge source harmonics into a single excitation
372
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
373
+ self.l_tanh = torch.nn.Tanh()
374
+
375
+ def forward(self, x):
376
+ """
377
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
378
+ F0_sampled (batchsize, length, 1)
379
+ Sine_source (batchsize, length, 1)
380
+ noise_source (batchsize, length 1)
381
+ """
382
+ # source for harmonic branch
383
+ with torch.no_grad():
384
+ sine_wavs, uv, _ = self.l_sin_gen(x)
385
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
386
+
387
+ # source for noise branch, in the same shape as uv
388
+ noise = torch.randn_like(uv) * self.sine_amp / 3
389
+ return sine_merge, noise, uv
390
+
391
+
226
392
  class HiFTGenerator(nn.Module):
227
393
  """
228
394
  HiFTNet Generator: Neural Source Filter + ISTFTNet
@@ -259,7 +425,9 @@ class HiFTGenerator(nn.Module):
259
425
 
260
426
  self.num_kernels = len(resblock_kernel_sizes)
261
427
  self.num_upsamples = len(upsample_rates)
262
- self.m_source = SourceModuleHnNSF(
428
+ # NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
429
+ this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
430
+ self.m_source = this_SourceModuleHnNSF(
263
431
  sampling_rate=sampling_rate,
264
432
  upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
265
433
  harmonic_num=nb_harmonics,
@@ -1,4 +1,5 @@
1
1
  # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
2
3
  #
3
4
  # Licensed under the Apache License, Version 2.0 (the "License");
4
5
  # you may not use this file except in compliance with the License.
@@ -11,7 +12,10 @@
11
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
13
  # See the License for the specific language governing permissions and
13
14
  # limitations under the License.
15
+ import queue
14
16
  import random
17
+ import time
18
+ import threading
15
19
  from typing import Dict, Optional, Callable, List, Generator
16
20
  import torch
17
21
  from torch import nn
@@ -170,6 +174,7 @@ class TransformerLM(torch.nn.Module):
170
174
  sampling: int = 25,
171
175
  max_token_text_ratio: float = 20,
172
176
  min_token_text_ratio: float = 2,
177
+ uuid: str = '',
173
178
  ) -> Generator[torch.Tensor, None, None]:
174
179
  device = text.device
175
180
  text = torch.concat([prompt_text, text], dim=1)
@@ -270,7 +275,6 @@ class Qwen2LM(TransformerLM):
270
275
  self.llm_input_size = llm_input_size
271
276
  self.llm_output_size = llm_output_size
272
277
  self.speech_token_size = speech_token_size
273
-
274
278
  # 2. build speech token language model related modules
275
279
  self.sos_eos = 0
276
280
  self.task_id = 1
@@ -293,6 +297,10 @@ class Qwen2LM(TransformerLM):
293
297
  self.sampling = sampling
294
298
  self.mix_ratio = mix_ratio
295
299
 
300
+ # 5. vllm related
301
+ self.stop_token_ids = [speech_token_size + i for i in range(3)]
302
+ self.vllm_output_queue = {}
303
+
296
304
  def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
297
305
  lm_target, lm_input = [], []
298
306
  text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
@@ -369,6 +377,53 @@ class Qwen2LM(TransformerLM):
369
377
  acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
370
378
  return {'loss': loss, 'acc': acc}
371
379
 
380
+ def forward_dpo(
381
+ self,
382
+ batch: dict,
383
+ device: torch.device,
384
+ ) -> Dict[str, Optional[torch.Tensor]]:
385
+ text_token = batch['text_token'].to(device)
386
+ text_token_len = batch['text_token_len'].to(device)
387
+ speech_token = batch['speech_token'].to(device)
388
+ speech_token_len = batch['speech_token_len'].to(device)
389
+ reject_speech_token = batch['reject_speech_token'].to(device)
390
+ reject_speech_token_len = batch['reject_speech_token_len'].to(device)
391
+
392
+ # 1. encode text_token
393
+ text_token_emb = self.llm.model.model.embed_tokens(text_token)
394
+
395
+ # 2. encode speech_token
396
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
397
+ reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
398
+ speech_token_combined = speech_token + reject_speech_token
399
+ speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
400
+ speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
401
+ speech_token_combined_emb = self.speech_embedding(speech_token_combined)
402
+
403
+ # 3. prepare llm_input/target
404
+ lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
405
+ speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
406
+ lm_target = lm_target.to(device)
407
+
408
+ # 4. run lm forward
409
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
410
+ logits = self.llm_decoder(lm_output)
411
+ chosen_logits = logits[:text_token.shape[0]]
412
+ rejected_logits = logits[text_token.shape[0]:]
413
+ chosen_lm_target = lm_target[:text_token.shape[0]]
414
+ rejected_lm_target = lm_target[text_token.shape[0]:]
415
+ loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
416
+ acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
417
+
418
+ # 5. calculate dpo logits
419
+ chosen_lm_mask = chosen_lm_target == IGNORE_ID
420
+ rejected_lm_mask = rejected_lm_target == IGNORE_ID
421
+ chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
422
+ rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
423
+ chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
424
+ rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
425
+ return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
426
+
372
427
  @torch.inference_mode()
373
428
  def inference(
374
429
  self,
@@ -382,6 +437,7 @@ class Qwen2LM(TransformerLM):
382
437
  sampling: int = 25,
383
438
  max_token_text_ratio: float = 20,
384
439
  min_token_text_ratio: float = 2,
440
+ uuid: str = '',
385
441
  ) -> Generator[torch.Tensor, None, None]:
386
442
  device = text.device
387
443
  text = torch.concat([prompt_text, text], dim=1)
@@ -402,22 +458,57 @@ class Qwen2LM(TransformerLM):
402
458
  max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
403
459
 
404
460
  # 5. step by step decode
405
- out_tokens = []
406
- cache = None
407
- for i in range(max_len):
408
- y_pred, cache = self.llm.forward_one_step(lm_input,
409
- masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
410
- cache=cache)
411
- logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
412
- top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
413
- if top_ids == self.speech_token_size:
414
- break
415
- if top_ids > self.speech_token_size:
416
- continue
417
- # in stream mode, yield token one by one
418
- yield top_ids
419
- out_tokens.append(top_ids)
420
- lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
461
+ for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
462
+ yield token
463
+
464
+ @torch.inference_mode()
465
+ def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
466
+ if hasattr(self, 'vllm'):
467
+ from vllm import SamplingParams, RequestOutput
468
+ sampling_params = SamplingParams(top_k=sampling,
469
+ stop_token_ids=self.stop_token_ids,
470
+ min_tokens=min_len,
471
+ max_tokens=max_len)
472
+ with self.lock:
473
+ self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
474
+ self.vllm_output_queue[uuid] = queue.Queue()
475
+ out_tokens = []
476
+ while True:
477
+ with self.lock:
478
+ if self.vllm_output_queue[uuid].empty() is True:
479
+ request_outputs: List[RequestOutput] = self.vllm.step()
480
+ for request_output in request_outputs:
481
+ top_ids = list(request_output.outputs[0].token_ids)[-1]
482
+ self.vllm_output_queue[request_output.request_id].put(top_ids)
483
+ if self.vllm_output_queue[uuid].empty() is False:
484
+ top_ids = self.vllm_output_queue[uuid].get()
485
+ if top_ids in self.stop_token_ids:
486
+ break
487
+ # in stream mode, yield token one by one
488
+ yield top_ids
489
+ out_tokens.append(top_ids)
490
+ if len(out_tokens) == max_len:
491
+ break
492
+ time.sleep(0.001)
493
+ with self.lock:
494
+ self.vllm_output_queue.pop(uuid)
495
+ else:
496
+ out_tokens = []
497
+ cache = None
498
+ for i in range(max_len):
499
+ y_pred, cache = self.llm.forward_one_step(lm_input,
500
+ masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
501
+ cache=cache)
502
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
503
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
504
+ if top_ids == self.speech_token_size:
505
+ break
506
+ if top_ids > self.speech_token_size:
507
+ continue
508
+ # in stream mode, yield token one by one
509
+ yield top_ids
510
+ out_tokens.append(top_ids)
511
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
421
512
 
422
513
  @torch.inference_mode()
423
514
  def inference_bistream(