xinference 1.5.1__py3-none-any.whl → 1.6.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (96) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +97 -8
  3. xinference/client/restful/restful_client.py +51 -11
  4. xinference/core/media_interface.py +758 -0
  5. xinference/core/model.py +49 -9
  6. xinference/core/worker.py +31 -37
  7. xinference/deploy/utils.py +0 -3
  8. xinference/model/audio/__init__.py +16 -27
  9. xinference/model/audio/core.py +1 -0
  10. xinference/model/audio/cosyvoice.py +4 -2
  11. xinference/model/audio/model_spec.json +20 -3
  12. xinference/model/audio/model_spec_modelscope.json +18 -1
  13. xinference/model/embedding/__init__.py +16 -24
  14. xinference/model/image/__init__.py +15 -25
  15. xinference/model/llm/__init__.py +37 -110
  16. xinference/model/llm/core.py +15 -6
  17. xinference/model/llm/llama_cpp/core.py +25 -353
  18. xinference/model/llm/llm_family.json +613 -89
  19. xinference/model/llm/llm_family.py +9 -1
  20. xinference/model/llm/llm_family_modelscope.json +540 -90
  21. xinference/model/llm/mlx/core.py +6 -3
  22. xinference/model/llm/reasoning_parser.py +281 -5
  23. xinference/model/llm/sglang/core.py +16 -3
  24. xinference/model/llm/transformers/chatglm.py +2 -2
  25. xinference/model/llm/transformers/cogagent.py +1 -1
  26. xinference/model/llm/transformers/cogvlm2.py +1 -1
  27. xinference/model/llm/transformers/core.py +9 -3
  28. xinference/model/llm/transformers/glm4v.py +1 -1
  29. xinference/model/llm/transformers/minicpmv26.py +1 -1
  30. xinference/model/llm/transformers/qwen-omni.py +6 -0
  31. xinference/model/llm/transformers/qwen_vl.py +1 -1
  32. xinference/model/llm/utils.py +68 -45
  33. xinference/model/llm/vllm/core.py +38 -18
  34. xinference/model/llm/vllm/xavier/test/test_xavier.py +1 -10
  35. xinference/model/rerank/__init__.py +13 -24
  36. xinference/model/video/__init__.py +15 -25
  37. xinference/model/video/core.py +3 -3
  38. xinference/model/video/diffusers.py +133 -16
  39. xinference/model/video/model_spec.json +54 -0
  40. xinference/model/video/model_spec_modelscope.json +56 -0
  41. xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
  42. xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
  43. xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
  44. xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
  45. xinference/thirdparty/cosyvoice/bin/train.py +7 -2
  46. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
  47. xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
  48. xinference/thirdparty/cosyvoice/cli/model.py +140 -155
  49. xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
  50. xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
  51. xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
  52. xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
  53. xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
  54. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
  55. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
  56. xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
  57. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
  58. xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
  59. xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
  60. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
  61. xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
  62. xinference/thirdparty/cosyvoice/utils/common.py +1 -1
  63. xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
  64. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
  65. xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
  66. xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
  67. xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
  68. xinference/types.py +0 -71
  69. xinference/web/ui/build/asset-manifest.json +3 -3
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
  72. xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
  74. xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
  75. xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
  76. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
  77. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
  78. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
  79. xinference/web/ui/src/locales/en.json +6 -4
  80. xinference/web/ui/src/locales/zh.json +6 -4
  81. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/METADATA +59 -39
  82. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/RECORD +87 -87
  83. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/WHEEL +1 -1
  84. xinference/core/image_interface.py +0 -377
  85. xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
  86. xinference/web/ui/build/static/js/main.91e77b5c.js +0 -3
  87. xinference/web/ui/build/static/js/main.91e77b5c.js.map +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/5e6edb0fb87e3798f142e9abf8dd2dc46bab33a60d31dff525797c0c99887097.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/6087820be1bd5c02c42dff797e7df365448ef35ab26dd5d6bd33e967e05cbfd4.json +0 -1
  91. xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
  92. xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
  93. /xinference/web/ui/build/static/js/{main.91e77b5c.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
  94. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/entry_points.txt +0 -0
  95. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/licenses/LICENSE +0 -0
  96. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import random
14
15
  from typing import Dict, Optional, Callable, List, Generator
15
16
  import torch
16
17
  from torch import nn
@@ -20,6 +21,8 @@ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
20
21
  from cosyvoice.utils.common import IGNORE_ID
21
22
  from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
22
23
  from cosyvoice.utils.common import th_accuracy
24
+ from cosyvoice.utils.file_utils import logging
25
+ from cosyvoice.utils.mask import make_pad_mask
23
26
 
24
27
 
25
28
  class TransformerLM(torch.nn.Module):
@@ -144,10 +147,14 @@ class TransformerLM(torch.nn.Module):
144
147
  sampling: int,
145
148
  ignore_eos: bool = True,
146
149
  ):
150
+ num_trials, max_trials = 0, 100
147
151
  while True:
148
152
  top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
149
153
  if (not ignore_eos) or (self.speech_token_size not in top_ids):
150
154
  break
155
+ num_trials += 1
156
+ if num_trials > max_trials:
157
+ raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
151
158
  return top_ids
152
159
 
153
160
  @torch.inference_mode()
@@ -178,7 +185,7 @@ class TransformerLM(torch.nn.Module):
178
185
  embedding = self.spk_embed_affine_layer(embedding)
179
186
  embedding = embedding.unsqueeze(dim=1)
180
187
  else:
181
- embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
188
+ embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
182
189
 
183
190
  # 3. concat llm_input
184
191
  sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
@@ -221,6 +228,17 @@ class Qwen2Encoder(torch.nn.Module):
221
228
  super().__init__()
222
229
  self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
223
230
 
231
+ def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
232
+ T = xs.size(1)
233
+ masks = ~make_pad_mask(xs_lens, T)
234
+ outs = self.model(
235
+ inputs_embeds=xs,
236
+ attention_mask=masks,
237
+ output_hidden_states=True,
238
+ return_dict=True,
239
+ )
240
+ return outs.hidden_states[-1], masks.unsqueeze(1)
241
+
224
242
  def forward_one_step(self, xs, masks, cache=None):
225
243
  input_masks = masks[:, -1, :]
226
244
  outs = self.model(
@@ -236,7 +254,7 @@ class Qwen2Encoder(torch.nn.Module):
236
254
  return xs, new_cache
237
255
 
238
256
 
239
- class Qwen2LM(torch.nn.Module):
257
+ class Qwen2LM(TransformerLM):
240
258
  def __init__(
241
259
  self,
242
260
  llm_input_size: int,
@@ -246,8 +264,9 @@ class Qwen2LM(torch.nn.Module):
246
264
  sampling: Callable,
247
265
  length_normalized_loss: bool = True,
248
266
  lsm_weight: float = 0.0,
267
+ mix_ratio: List[int] = [5, 15],
249
268
  ):
250
- super().__init__()
269
+ torch.nn.Module.__init__(self)
251
270
  self.llm_input_size = llm_input_size
252
271
  self.llm_output_size = llm_output_size
253
272
  self.speech_token_size = speech_token_size
@@ -272,19 +291,83 @@ class Qwen2LM(torch.nn.Module):
272
291
 
273
292
  # 4. sampling method
274
293
  self.sampling = sampling
294
+ self.mix_ratio = mix_ratio
275
295
 
276
- def sampling_ids(
296
+ def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
297
+ lm_target, lm_input = [], []
298
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
299
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
300
+ text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
301
+ speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
302
+ for i in range(len(text_token)):
303
+ # bistream sequence
304
+ if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
305
+ this_lm_target, this_lm_input = [], []
306
+ this_lm_target.append(IGNORE_ID)
307
+ this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
308
+ for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
309
+ this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
310
+ this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
311
+ if len(this_text_token) == self.mix_ratio[0]:
312
+ assert len(this_speech_token) == self.mix_ratio[1]
313
+ this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
314
+ this_lm_target += this_speech_token
315
+ this_lm_target.append(self.speech_token_size + 2)
316
+ this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
317
+ this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
318
+ else:
319
+ this_lm_target += [-1] * len(this_text_token)
320
+ this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
321
+ this_lm_target.append(self.speech_token_size)
322
+ this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
323
+ this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
324
+ this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
325
+ this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
326
+ # unistream sequence
327
+ else:
328
+ this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
329
+ this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
330
+ self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
331
+ lm_target.append(this_lm_target)
332
+ lm_input.append(this_lm_input)
333
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
334
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
335
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
336
+ return lm_target, lm_input, lm_input_len
337
+
338
+ def forward(
277
339
  self,
278
- weighted_scores: torch.Tensor,
279
- decoded_tokens: List,
280
- sampling: int,
281
- ignore_eos: bool = True,
282
- ):
283
- while True:
284
- top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
285
- if (not ignore_eos) or (self.speech_token_size not in top_ids):
286
- break
287
- return top_ids
340
+ batch: dict,
341
+ device: torch.device,
342
+ ) -> Dict[str, Optional[torch.Tensor]]:
343
+ """
344
+ Args:
345
+ text: (B, L, D)
346
+ text_lengths: (B,)
347
+ audio: (B, T, N) or (B, T)
348
+ audio_lengths: (B,)
349
+ """
350
+ text_token = batch['text_token'].to(device)
351
+ text_token_len = batch['text_token_len'].to(device)
352
+ speech_token = batch['speech_token'].to(device)
353
+ speech_token_len = batch['speech_token_len'].to(device)
354
+
355
+ # 1. encode text_token
356
+ text_token_emb = self.llm.model.model.embed_tokens(text_token)
357
+
358
+ # 2. encode speech_token
359
+ speech_token_emb = self.speech_embedding(speech_token)
360
+
361
+ # 3. prepare llm_input/target
362
+ lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
363
+ lm_target = lm_target.to(device)
364
+
365
+ # 4. run lm forward
366
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
367
+ logits = self.llm_decoder(lm_output)
368
+ loss = self.criterion_ce(logits, lm_target.to(device))
369
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
370
+ return {'loss': loss, 'acc': acc}
288
371
 
289
372
  @torch.inference_mode()
290
373
  def inference(
@@ -305,9 +388,6 @@ class Qwen2LM(torch.nn.Module):
305
388
  text_len += prompt_text_len
306
389
  text = self.llm.model.model.embed_tokens(text)
307
390
 
308
- # 2. encode embedding
309
- embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
310
-
311
391
  # 3. concat llm_input
312
392
  sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
313
393
  task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
@@ -315,7 +395,7 @@ class Qwen2LM(torch.nn.Module):
315
395
  prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
316
396
  else:
317
397
  prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
318
- lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
398
+ lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
319
399
 
320
400
  # 4. cal min/max_length
321
401
  min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -338,3 +418,103 @@ class Qwen2LM(torch.nn.Module):
338
418
  yield top_ids
339
419
  out_tokens.append(top_ids)
340
420
  lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
421
+
422
+ @torch.inference_mode()
423
+ def inference_bistream(
424
+ self,
425
+ text: Generator,
426
+ prompt_text: torch.Tensor,
427
+ prompt_text_len: torch.Tensor,
428
+ prompt_speech_token: torch.Tensor,
429
+ prompt_speech_token_len: torch.Tensor,
430
+ embedding: torch.Tensor,
431
+ sampling: int = 25,
432
+ max_token_text_ratio: float = 20,
433
+ min_token_text_ratio: float = 2,
434
+ ) -> Generator[torch.Tensor, None, None]:
435
+
436
+ device = prompt_text.device
437
+ # 1. prepare input
438
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
439
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
440
+ if prompt_speech_token_len != 0:
441
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
442
+ else:
443
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
444
+ lm_input = torch.concat([sos_eos_emb], dim=1)
445
+
446
+ # 2. iterate text
447
+ out_tokens = []
448
+ cache = None
449
+ # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
450
+ text_cache = self.llm.model.model.embed_tokens(prompt_text)
451
+ next_fill_index = -1
452
+ for this_text in text:
453
+ text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
454
+ # prompt_speech_token_emb not empty, try append to lm_input
455
+ while prompt_speech_token_emb.size(1) != 0:
456
+ if text_cache.size(1) >= self.mix_ratio[0]:
457
+ lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
458
+ logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
459
+ lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
460
+ text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
461
+ else:
462
+ logging.info('not enough text token to decode, wait for more')
463
+ break
464
+ # no prompt_speech_token_emb remain, can decode some speech token
465
+ if prompt_speech_token_emb.size(1) == 0:
466
+ if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
467
+ logging.info('get fill token, need to append more text token')
468
+ if text_cache.size(1) >= self.mix_ratio[0]:
469
+ lm_input_text = text_cache[:, :self.mix_ratio[0]]
470
+ logging.info('append {} text token'.format(lm_input_text.size(1)))
471
+ if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
472
+ lm_input = lm_input_text
473
+ else:
474
+ lm_input = torch.concat([lm_input, lm_input_text], dim=1)
475
+ text_cache = text_cache[:, self.mix_ratio[0]:]
476
+ else:
477
+ logging.info('not enough text token to decode, wait for more')
478
+ continue
479
+ while True:
480
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
481
+ y_pred, cache = self.llm.forward_one_step(lm_input,
482
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
483
+ cache=cache)
484
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
485
+ if next_fill_index != -1 and len(out_tokens) == next_fill_index:
486
+ top_ids = self.speech_token_size + 2
487
+ next_fill_index += (self.mix_ratio[1] + 1)
488
+ else:
489
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
490
+ if top_ids == self.speech_token_size + 2:
491
+ next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
492
+ logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
493
+ out_tokens.append(top_ids)
494
+ if top_ids >= self.speech_token_size:
495
+ if top_ids == self.speech_token_size + 2:
496
+ break
497
+ else:
498
+ raise ValueError('should not get token {}'.format(top_ids))
499
+ yield top_ids
500
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
501
+
502
+ # 3. final decode
503
+ lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
504
+ logging.info('no more text token, decode until met eos')
505
+ while True:
506
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
507
+ y_pred, cache = self.llm.forward_one_step(lm_input,
508
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
509
+ cache=cache)
510
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
511
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
512
+ out_tokens.append(top_ids)
513
+ if top_ids >= self.speech_token_size:
514
+ if top_ids == self.speech_token_size:
515
+ break
516
+ else:
517
+ raise ValueError('should not get token {}'.format(top_ids))
518
+ # in stream mode, yield token one by one
519
+ yield top_ids
520
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
@@ -287,8 +287,16 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
287
287
  Returns:
288
288
  torch.Tensor: Corresponding encoding
289
289
  """
290
- pos_emb = self.pe[
291
- :,
292
- self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
293
- ]
290
+ # How to subscript a Union type:
291
+ # https://github.com/pytorch/pytorch/issues/69434
292
+ if isinstance(offset, int):
293
+ pos_emb = self.pe[
294
+ :,
295
+ self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
296
+ ]
297
+ elif isinstance(offset, torch.Tensor):
298
+ pos_emb = self.pe[
299
+ :,
300
+ self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
301
+ ]
294
302
  return pos_emb
@@ -56,11 +56,16 @@ class Upsample1D(nn.Module):
56
56
  # In this mode, first repeat interpolate, than conv with stride=1
57
57
  self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
58
58
 
59
- def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
59
+ def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
60
60
  outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
61
- outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
61
+ if conv_cache.size(2) == 0:
62
+ outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
63
+ else:
64
+ assert conv_cache.size(2) == self.stride * 2
65
+ outputs = torch.concat([conv_cache, outputs], dim=2)
66
+ conv_cache_new = outputs[:, :, -self.stride * 2:]
62
67
  outputs = self.conv(outputs)
63
- return outputs, input_lengths * self.stride
68
+ return outputs, input_lengths * self.stride, conv_cache_new
64
69
 
65
70
 
66
71
  class PreLookaheadLayer(nn.Module):
@@ -78,22 +83,32 @@ class PreLookaheadLayer(nn.Module):
78
83
  kernel_size=3, stride=1, padding=0,
79
84
  )
80
85
 
81
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
86
+ def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0), conv2_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
82
87
  """
83
88
  inputs: (batch_size, seq_len, channels)
84
89
  """
85
90
  outputs = inputs.transpose(1, 2).contiguous()
91
+ context = context.transpose(1, 2).contiguous()
86
92
  # look ahead
87
- outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
93
+ if context.size(2) == 0:
94
+ outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
95
+ else:
96
+ assert context.size(2) == self.pre_lookahead_len
97
+ outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
88
98
  outputs = F.leaky_relu(self.conv1(outputs))
89
99
  # outputs
90
- outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
100
+ if conv2_cache.size(2) == 0:
101
+ outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
102
+ else:
103
+ assert conv2_cache.size(2) == self.conv2.kernel_size[0] - 1
104
+ outputs = torch.concat([conv2_cache, outputs], dim=2)
105
+ conv2_cache_new = outputs[:, :, -(self.conv2.kernel_size[0] - 1):]
91
106
  outputs = self.conv2(outputs)
92
107
  outputs = outputs.transpose(1, 2).contiguous()
93
108
 
94
109
  # residual connection
95
110
  outputs = outputs + inputs
96
- return outputs
111
+ return outputs, conv2_cache_new
97
112
 
98
113
 
99
114
  class UpsampleConformerEncoder(torch.nn.Module):
@@ -240,6 +255,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
240
255
  xs_lens: torch.Tensor,
241
256
  decoding_chunk_size: int = 0,
242
257
  num_decoding_left_chunks: int = -1,
258
+ streaming: bool = False,
243
259
  ) -> Tuple[torch.Tensor, torch.Tensor]:
244
260
  """Embed positions in tensor.
245
261
 
@@ -270,30 +286,20 @@ class UpsampleConformerEncoder(torch.nn.Module):
270
286
  xs = self.global_cmvn(xs)
271
287
  xs, pos_emb, masks = self.embed(xs, masks)
272
288
  mask_pad = masks # (B, 1, T/subsample_rate)
273
- chunk_masks = add_optional_chunk_mask(xs, masks,
274
- self.use_dynamic_chunk,
275
- self.use_dynamic_left_chunk,
276
- decoding_chunk_size,
277
- self.static_chunk_size,
278
- num_decoding_left_chunks)
289
+ chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
279
290
  # lookahead + conformer encoder
280
- xs = self.pre_lookahead_layer(xs)
291
+ xs, _ = self.pre_lookahead_layer(xs)
281
292
  xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
282
293
 
283
294
  # upsample + conformer encoder
284
295
  xs = xs.transpose(1, 2).contiguous()
285
- xs, xs_lens = self.up_layer(xs, xs_lens)
296
+ xs, xs_lens, _ = self.up_layer(xs, xs_lens)
286
297
  xs = xs.transpose(1, 2).contiguous()
287
298
  T = xs.size(1)
288
299
  masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
289
300
  xs, pos_emb, masks = self.up_embed(xs, masks)
290
301
  mask_pad = masks # (B, 1, T/subsample_rate)
291
- chunk_masks = add_optional_chunk_mask(xs, masks,
292
- self.use_dynamic_chunk,
293
- self.use_dynamic_left_chunk,
294
- decoding_chunk_size,
295
- self.static_chunk_size * self.up_layer.stride,
296
- num_decoding_left_chunks)
302
+ chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
297
303
  xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
298
304
 
299
305
  if self.normalize_before:
@@ -316,3 +322,100 @@ class UpsampleConformerEncoder(torch.nn.Module):
316
322
  for layer in self.up_encoders:
317
323
  xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
318
324
  return xs
325
+
326
+ @torch.jit.export
327
+ def forward_chunk(
328
+ self,
329
+ xs: torch.Tensor,
330
+ xs_lens: torch.Tensor,
331
+ offset: int = 0,
332
+ context: torch.Tensor = torch.zeros(0, 0, 0),
333
+ pre_lookahead_layer_conv2_cache: torch.Tensor = torch.zeros(0, 0, 0),
334
+ encoders_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0),
335
+ upsample_offset: int = 0,
336
+ upsample_conv_cache: torch.Tensor = torch.zeros(0, 0, 0),
337
+ upsample_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0)
338
+ ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[int, torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]]:
339
+ """Embed positions in tensor.
340
+
341
+ Args:
342
+ xs: padded input tensor (B, T, D)
343
+ xs_lens: input length (B)
344
+ decoding_chunk_size: decoding chunk size for dynamic chunk
345
+ 0: default for training, use random dynamic chunk.
346
+ <0: for decoding, use full chunk.
347
+ >0: for decoding, use fixed chunk size as set.
348
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
349
+ the chunk size is decoding_chunk_size.
350
+ >=0: use num_decoding_left_chunks
351
+ <0: use all left chunks
352
+ Returns:
353
+ encoder output tensor xs, and subsampled masks
354
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
355
+ masks: torch.Tensor batch padding mask after subsample
356
+ (B, 1, T' ~= T/subsample_rate)
357
+ NOTE(xcsong):
358
+ We pass the `__call__` method of the modules instead of `forward` to the
359
+ checkpointing API because `__call__` attaches all the hooks of the module.
360
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
361
+ """
362
+ assert xs.size(0) == 1
363
+ # tmp_masks is just for interface compatibility
364
+ tmp_masks = torch.ones(1,
365
+ xs.size(1),
366
+ device=xs.device,
367
+ dtype=torch.bool)
368
+ tmp_masks = tmp_masks.unsqueeze(1)
369
+ if self.global_cmvn is not None:
370
+ xs = self.global_cmvn(xs)
371
+ # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
372
+ xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
373
+ offset += xs.size(1)
374
+ tmp_masks = torch.ones(1,
375
+ context.size(1),
376
+ device=context.device,
377
+ dtype=torch.bool)
378
+ tmp_masks = tmp_masks.unsqueeze(1)
379
+ if context.size(1) != 0:
380
+ context, _, _ = self.embed(context, tmp_masks, offset)
381
+
382
+ # lookahead + conformer encoder
383
+ xs, pre_lookahead_layer_conv2_cache = self.pre_lookahead_layer(xs, context, pre_lookahead_layer_conv2_cache)
384
+ # NOTE in cache mode we do not need to call add_optional_chunk_mask
385
+ chunk_masks = torch.ones((1, xs.size(1), offset), dtype=torch.bool, device=xs.device)
386
+ mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
387
+ encoders_kv_cache_list = []
388
+ for index, layer in enumerate(self.encoders):
389
+ xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index])
390
+ encoders_kv_cache_list.append(encoders_kv_cache_new)
391
+ encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0)
392
+
393
+ # upsample
394
+ xs = xs.transpose(1, 2).contiguous()
395
+ xs, xs_lens, upsample_conv_cache = self.up_layer(xs, xs_lens, upsample_conv_cache)
396
+ xs = xs.transpose(1, 2).contiguous()
397
+
398
+ # tmp_masks is just for interface compatibility
399
+ tmp_masks = torch.ones(1,
400
+ xs.size(1),
401
+ device=xs.device,
402
+ dtype=torch.bool)
403
+ tmp_masks = tmp_masks.unsqueeze(1)
404
+ xs, pos_emb, masks = self.up_embed(xs, tmp_masks, upsample_offset)
405
+ upsample_offset += xs.size(1)
406
+
407
+ # conformer encoder
408
+ chunk_masks = torch.ones((1, xs.size(1), upsample_offset), dtype=torch.bool, device=xs.device)
409
+ mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
410
+ upsample_kv_cache_list = []
411
+ for index, layer in enumerate(self.up_encoders):
412
+ xs, chunk_masks, upsample_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, upsample_kv_cache[index])
413
+ upsample_kv_cache_list.append(upsample_kv_cache_new)
414
+ upsample_kv_cache = torch.stack(upsample_kv_cache_list, dim=0)
415
+
416
+ if self.normalize_before:
417
+ xs = self.after_norm(xs)
418
+ # Here we assume the mask is not changed in encoder layers, so just
419
+ # return the masks before encoder layers, and the masks will be used
420
+ # for cross attention with decoder later
421
+ return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache)
@@ -32,6 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
32
32
  RelPositionMultiHeadedAttention)
33
33
  from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
34
34
  from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
35
+ from cosyvoice.llm.llm import TransformerLM, Qwen2LM
36
+ from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
37
+ from cosyvoice.hifigan.generator import HiFTGenerator
38
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
35
39
 
36
40
 
37
41
  COSYVOICE_ACTIVATION_CLASSES = {
@@ -68,3 +72,12 @@ COSYVOICE_ATTENTION_CLASSES = {
68
72
  "selfattn": MultiHeadedAttention,
69
73
  "rel_selfattn": RelPositionMultiHeadedAttention,
70
74
  }
75
+
76
+
77
+ def get_model_type(configs):
78
+ # NOTE CosyVoice2Model inherits CosyVoiceModel
79
+ if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
80
+ return CosyVoiceModel
81
+ if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
82
+ return CosyVoice2Model
83
+ raise TypeError('No valid model type found!')
@@ -162,5 +162,5 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
162
162
  # attention mask bias
163
163
  # NOTE(Mddct): torch.finfo jit issues
164
164
  # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
165
- mask = (1.0 - mask) * torch.finfo(dtype).min
165
+ mask = (1.0 - mask) * -1.0e+10
166
166
  return mask
@@ -1,5 +1,5 @@
1
1
  # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
- # 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -39,9 +39,47 @@ def read_json_lists(list_file):
39
39
 
40
40
 
41
41
  def load_wav(wav, target_sr):
42
- speech, sample_rate = torchaudio.load(wav)
42
+ speech, sample_rate = torchaudio.load(wav, backend='soundfile')
43
43
  speech = speech.mean(dim=0, keepdim=True)
44
44
  if sample_rate != target_sr:
45
45
  assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
46
46
  speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
47
47
  return speech
48
+
49
+
50
+ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
51
+ import tensorrt as trt
52
+ logging.info("Converting onnx to trt...")
53
+ network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
54
+ logger = trt.Logger(trt.Logger.INFO)
55
+ builder = trt.Builder(logger)
56
+ network = builder.create_network(network_flags)
57
+ parser = trt.OnnxParser(network, logger)
58
+ config = builder.create_builder_config()
59
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB
60
+ if fp16:
61
+ config.set_flag(trt.BuilderFlag.FP16)
62
+ profile = builder.create_optimization_profile()
63
+ # load onnx model
64
+ with open(onnx_model, "rb") as f:
65
+ if not parser.parse(f.read()):
66
+ for error in range(parser.num_errors):
67
+ print(parser.get_error(error))
68
+ raise ValueError('failed to parse {}'.format(onnx_model))
69
+ # set input shapes
70
+ for i in range(len(trt_kwargs['input_names'])):
71
+ profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
72
+ tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
73
+ # set input and output data type
74
+ for i in range(network.num_inputs):
75
+ input_tensor = network.get_input(i)
76
+ input_tensor.dtype = tensor_dtype
77
+ for i in range(network.num_outputs):
78
+ output_tensor = network.get_output(i)
79
+ output_tensor.dtype = tensor_dtype
80
+ config.add_optimization_profile(profile)
81
+ engine_bytes = builder.build_serialized_network(network, config)
82
+ # save trt engine
83
+ with open(trt_model, "wb") as f:
84
+ f.write(engine_bytes)
85
+ logging.info("Succesfully convert onnx to trt...")
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import re
16
+ import regex
16
17
  chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
17
18
 
18
19
 
@@ -127,3 +128,9 @@ def replace_blank(text: str):
127
128
  else:
128
129
  out_str.append(c)
129
130
  return "".join(out_str)
131
+
132
+
133
+ def is_only_punctuation(text):
134
+ # Regular expression: Match strings that consist only of punctuation marks or are empty.
135
+ punctuation_pattern = r'^[\p{P}\p{S}]*$'
136
+ return bool(regex.fullmatch(punctuation_pattern, text))
@@ -195,6 +195,10 @@ def add_optional_chunk_mask(xs: torch.Tensor,
195
195
  chunk_masks = masks & chunk_masks # (B, L, L)
196
196
  else:
197
197
  chunk_masks = masks
198
+ assert chunk_masks.dtype == torch.bool
199
+ if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
200
+ print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
201
+ chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
198
202
  return chunk_masks
199
203
 
200
204
 
@@ -286,11 +286,15 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
286
286
  # optimizer.step().
287
287
  if torch.isfinite(grad_norm):
288
288
  scaler.step(optimizer)
289
+ else:
290
+ logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
289
291
  scaler.update()
290
292
  else:
291
293
  grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
292
294
  if torch.isfinite(grad_norm):
293
295
  optimizer.step()
296
+ else:
297
+ logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
294
298
  optimizer.zero_grad()
295
299
  scheduler.step()
296
300
  info_dict["lr"] = optimizer.param_groups[0]['lr']
@@ -336,7 +340,7 @@ def log_per_save(writer, info_dict):
336
340
  rank = int(os.environ.get('RANK', 0))
337
341
  logging.info(
338
342
  'Epoch {} Step {} CV info lr {} {} rank {}'.format(
339
- epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))
343
+ epoch, step + 1, lr, rank, ' '.join(['{} {}'.format(k, v) for k, v in loss_dict.items()])))
340
344
 
341
345
  if writer is not None:
342
346
  for k in ['epoch', 'lr']: