xinference 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (104) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +23 -1
  4. xinference/core/model.py +1 -6
  5. xinference/core/utils.py +10 -6
  6. xinference/model/audio/core.py +5 -0
  7. xinference/model/audio/cosyvoice.py +25 -3
  8. xinference/model/audio/f5tts.py +15 -10
  9. xinference/model/audio/f5tts_mlx.py +260 -0
  10. xinference/model/audio/fish_speech.py +35 -111
  11. xinference/model/audio/model_spec.json +19 -3
  12. xinference/model/audio/model_spec_modelscope.json +9 -0
  13. xinference/model/audio/utils.py +32 -0
  14. xinference/model/image/core.py +69 -1
  15. xinference/model/image/model_spec.json +127 -4
  16. xinference/model/image/model_spec_modelscope.json +130 -4
  17. xinference/model/image/stable_diffusion/core.py +45 -13
  18. xinference/model/llm/llm_family.json +47 -0
  19. xinference/model/llm/llm_family.py +15 -36
  20. xinference/model/llm/llm_family_modelscope.json +49 -0
  21. xinference/model/llm/mlx/core.py +68 -13
  22. xinference/model/llm/transformers/core.py +1 -0
  23. xinference/model/llm/transformers/qwen2_vl.py +2 -0
  24. xinference/model/llm/utils.py +1 -0
  25. xinference/model/llm/vllm/core.py +11 -2
  26. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  27. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  28. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  29. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  30. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  31. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  32. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  33. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  34. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  35. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  36. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  37. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  38. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  39. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  40. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  41. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  42. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  43. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  44. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  45. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  46. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  47. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  48. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  49. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  50. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  51. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  52. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  53. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  54. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  55. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  56. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  57. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  58. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  59. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  60. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  61. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  62. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  63. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  64. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  65. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  66. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  67. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  68. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  69. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  70. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  71. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  72. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  73. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  74. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  75. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  76. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  77. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  78. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  79. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  80. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  81. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  82. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  83. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  84. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  85. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  86. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  87. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  88. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  89. xinference/thirdparty/matcha/utils/utils.py +2 -2
  90. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/METADATA +11 -6
  91. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/RECORD +95 -74
  92. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  93. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  94. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  95. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  96. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  97. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  98. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  99. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  100. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  101. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
  102. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
  103. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
  104. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
@@ -17,9 +17,16 @@ from loguru import logger
17
17
  from tqdm import tqdm
18
18
  from transformers import AutoTokenizer
19
19
 
20
- from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
20
+ from fish_speech.conversation import (
21
+ CODEBOOK_PAD_TOKEN_ID,
22
+ Conversation,
23
+ Message,
24
+ TextPart,
25
+ VQPart,
26
+ )
21
27
  from fish_speech.models.text2semantic.llama import BaseModelArgs
22
28
  from fish_speech.text import clean_text, split_text
29
+ from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
23
30
 
24
31
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
25
32
  torch._inductor.config.coordinate_descent_tuning = True
@@ -145,8 +152,8 @@ def decode_one_token_ar_agent(
145
152
  model: DualARTransformer,
146
153
  x: torch.Tensor,
147
154
  input_pos: torch.Tensor,
155
+ semantic_ids: list,
148
156
  previous_tokens: torch.Tensor = None,
149
- semantic_id: int = 32003,
150
157
  **sampling_kwargs,
151
158
  ) -> torch.Tensor:
152
159
  # print(x, input_pos)
@@ -190,19 +197,13 @@ def decode_one_token_ar_agent(
190
197
  codebooks.append(a)
191
198
 
192
199
  codebooks = torch.stack(codebooks, dim=1)
200
+ semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
193
201
  codebooks[:, 1:, :] = torch.masked_fill(
194
- codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
202
+ codebooks[:, 1:, :],
203
+ ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
204
+ CODEBOOK_PAD_TOKEN_ID,
195
205
  )
196
206
 
197
- # for i in range(codebooks.size(1) - 1):
198
- # codebooks[:, i + 1, :] = torch.masked_fill(
199
- # codebooks[:, i + 1, :],
200
- # codebooks[:, :1, :] != semantic_id,
201
- # CODEBOOK_PAD_TOKEN_ID + i * 1024,
202
- # )
203
-
204
- # print(codebooks)
205
-
206
207
  return codebooks
207
208
 
208
209
 
@@ -210,8 +211,8 @@ def decode_one_token_naive_agent(
210
211
  model: NaiveTransformer,
211
212
  x: torch.Tensor,
212
213
  input_pos: torch.Tensor,
214
+ semantic_ids: list,
213
215
  previous_tokens: torch.Tensor = None,
214
- semantic_id: int = 32003,
215
216
  **sampling_kwargs,
216
217
  ) -> torch.Tensor:
217
218
  x = model.forward_generate(x, input_pos)
@@ -236,8 +237,11 @@ def decode_one_token_naive_agent(
236
237
  )
237
238
 
238
239
  codebooks = torch.stack(codebooks, dim=1)
240
+ semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
239
241
  codebooks[:, 1:, :] = torch.masked_fill(
240
- codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
242
+ codebooks[:, 1:, :],
243
+ ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
244
+ CODEBOOK_PAD_TOKEN_ID,
241
245
  )
242
246
 
243
247
  return codebooks
@@ -247,8 +251,8 @@ def decode_one_token_ar(
247
251
  model: DualARTransformer,
248
252
  x: torch.Tensor,
249
253
  input_pos: torch.Tensor,
254
+ semantic_ids: list,
250
255
  previous_tokens: torch.Tensor = None,
251
- semantic_id: int = 0,
252
256
  **sampling_kwargs,
253
257
  ) -> torch.Tensor:
254
258
  x = model.forward_generate(x, input_pos)
@@ -261,21 +265,32 @@ def decode_one_token_ar(
261
265
  codebooks = [
262
266
  sample(
263
267
  x.logits,
264
- previous_tokens=None, # Disable repetition penalty for the token codebook
268
+ previous_tokens=(
269
+ previous_tokens[0] if previous_tokens is not None else None
270
+ ), # Disable repetition penalty for the token codebook
265
271
  **sampling_kwargs_main,
266
272
  )[0]
267
273
  ]
268
274
 
269
- x = x.hidden_states
275
+ hidden_states = x.hidden_states
270
276
 
271
277
  # Cleanup the cache
272
278
  for layer in model.fast_layers:
273
279
  layer.attention.kv_cache.k_cache.fill_(0)
274
280
  layer.attention.kv_cache.v_cache.fill_(0)
275
281
 
276
- for codebook_idx in range(model.config.num_codebooks):
277
- input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
278
- logits = model.forward_generate_fast(x, input_pos)
282
+ input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
283
+ model.forward_generate_fast(hidden_states, input_pos)
284
+ a = codebooks[0] - model.tokenizer.semantic_begin_id
285
+ a[a < 0] = 0
286
+ hidden_states = model.fast_embeddings(a)
287
+ codebooks.append(a)
288
+
289
+ for codebook_idx in range(1, model.config.num_codebooks):
290
+ input_pos = torch.tensor(
291
+ [codebook_idx], device=hidden_states.device, dtype=torch.long
292
+ )
293
+ logits = model.forward_generate_fast(hidden_states, input_pos)
279
294
  a = sample(
280
295
  logits,
281
296
  previous_tokens=(
@@ -285,14 +300,16 @@ def decode_one_token_ar(
285
300
  ),
286
301
  **sampling_kwargs,
287
302
  )[0]
288
- x = model.fast_embeddings(a)
303
+ hidden_states = model.fast_embeddings(a)
289
304
  codebooks.append(a)
290
305
 
291
306
  codebooks = torch.stack(codebooks, dim=0)
292
- codebooks[1:, :] = torch.masked_fill(
293
- codebooks[1:, :], codebooks[:1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
294
- )
307
+ # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
308
+ # codebooks[1:, :] = torch.masked_fill(
309
+ # codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
310
+ # )
295
311
 
312
+ # print(codebooks)
296
313
  return codebooks
297
314
 
298
315
 
@@ -337,9 +354,8 @@ def decode_n_tokens(
337
354
  cur_token: torch.Tensor,
338
355
  input_pos: torch.Tensor,
339
356
  num_new_tokens: int,
340
- im_end_id: int = 4,
357
+ semantic_ids: list,
341
358
  decode_one_token=decode_one_token_naive,
342
- semantic_id: int = 0,
343
359
  **sampling_kwargs,
344
360
  ):
345
361
  previous_tokens = torch.zeros(
@@ -368,7 +384,7 @@ def decode_n_tokens(
368
384
  x=cur_token,
369
385
  input_pos=input_pos,
370
386
  previous_tokens=window,
371
- semantic_id=semantic_id,
387
+ semantic_ids=semantic_ids,
372
388
  **sampling_kwargs,
373
389
  )
374
390
 
@@ -378,7 +394,7 @@ def decode_n_tokens(
378
394
  model.config.num_codebooks + 1, -1
379
395
  )
380
396
 
381
- if cur_token[0, 0, -1] == im_end_id:
397
+ if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
382
398
  break
383
399
 
384
400
  return previous_tokens[:, : i + 1]
@@ -391,7 +407,6 @@ def generate(
391
407
  model: NaiveTransformer,
392
408
  prompt: torch.Tensor,
393
409
  max_new_tokens: int,
394
- im_end_id: int = 4,
395
410
  decode_one_token=decode_one_token_naive,
396
411
  **sampling_kwargs,
397
412
  ) -> torch.Tensor:
@@ -401,7 +416,10 @@ def generate(
401
416
 
402
417
  # create an empty tensor of the expected final shape and fill in the current tokens
403
418
  T = prompt.size(1)
404
- semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
419
+ # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
420
+ semantic_ids = [
421
+ model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
422
+ ]
405
423
 
406
424
  if max_new_tokens:
407
425
  if T + max_new_tokens > model.config.max_seq_len:
@@ -435,7 +453,7 @@ def generate(
435
453
  model,
436
454
  prompt.view(1, codebook_dim, -1),
437
455
  input_pos,
438
- semantic_id=semantic_id,
456
+ semantic_ids=semantic_ids,
439
457
  **sampling_kwargs,
440
458
  )
441
459
  seq[:, T : T + 1] = next_token
@@ -446,9 +464,8 @@ def generate(
446
464
  next_token.view(1, codebook_dim, -1),
447
465
  input_pos,
448
466
  max_new_tokens - 1,
449
- im_end_id=im_end_id,
450
467
  decode_one_token=decode_one_token,
451
- semantic_id=semantic_id,
468
+ semantic_ids=semantic_ids,
452
469
  **sampling_kwargs,
453
470
  )
454
471
  # x = torch.cat(generated_tokens, dim=1)
@@ -463,8 +480,8 @@ def decode_n_tokens_agent(
463
480
  cur_token: torch.Tensor,
464
481
  input_pos: torch.Tensor,
465
482
  num_new_tokens: int,
483
+ semantic_ids: list,
466
484
  im_end_id: int = 4,
467
- semantic_id: int = 32003,
468
485
  decode_one_token=decode_one_token_naive_agent,
469
486
  early_stop_threshold: float = 0.6,
470
487
  **sampling_kwargs,
@@ -495,7 +512,7 @@ def decode_n_tokens_agent(
495
512
  x=cur_token,
496
513
  input_pos=input_pos,
497
514
  previous_tokens=window,
498
- semantic_id=semantic_id,
515
+ semantic_ids=semantic_ids,
499
516
  **sampling_kwargs,
500
517
  )
501
518
 
@@ -529,8 +546,8 @@ def generate_agent(
529
546
  model: BaseTransformer,
530
547
  prompt: torch.Tensor,
531
548
  max_new_tokens: int,
549
+ semantic_ids: list,
532
550
  im_end_id: int = 4,
533
- semantic_id: int = 32003,
534
551
  decode_one_token=decode_one_token_naive_agent,
535
552
  num_samples: int = 1,
536
553
  early_stop_threshold: float = 0.6,
@@ -574,7 +591,7 @@ def generate_agent(
574
591
  model,
575
592
  prompt,
576
593
  input_pos,
577
- semantic_id=semantic_id,
594
+ semantic_ids=semantic_ids,
578
595
  **sampling_kwargs,
579
596
  ).view(num_samples, codebook_dim, -1)
580
597
  yield next_token.cpu()
@@ -587,7 +604,7 @@ def generate_agent(
587
604
  input_pos,
588
605
  max_new_tokens - 1,
589
606
  im_end_id=im_end_id,
590
- semantic_id=semantic_id,
607
+ semantic_ids=semantic_ids,
591
608
  decode_one_token=decode_one_token,
592
609
  early_stop_threshold=early_stop_threshold,
593
610
  **sampling_kwargs,
@@ -602,65 +619,63 @@ def encode_tokens(
602
619
  num_codebooks=4,
603
620
  ):
604
621
  string = clean_text(string)
605
- string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
606
622
 
607
- new_tokens = tokenizer.encode(
608
- string,
609
- add_special_tokens=False,
610
- max_length=10**6,
611
- truncation=False,
623
+ messages = []
624
+ messages.append(
625
+ Message(
626
+ role="user",
627
+ parts=[TextPart(text=string)],
628
+ cal_loss=False,
629
+ )
612
630
  )
613
- tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
614
631
 
615
- # Codebooks
616
- zeros = (
617
- torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
618
- * CODEBOOK_PAD_TOKEN_ID
619
- )
620
- prompt = torch.cat((tokens, zeros), dim=0)
632
+ if prompt_tokens is not None:
633
+ if prompt_tokens.ndim == 3:
634
+ assert (
635
+ prompt_tokens.shape[0] == 1
636
+ ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
637
+ prompt_tokens = prompt_tokens[0]
621
638
 
622
- if prompt_tokens is None:
623
- return prompt
639
+ assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
624
640
 
625
- # Get prompt tokens
626
- if prompt_tokens.ndim == 3:
627
- assert (
628
- prompt_tokens.shape[0] == 1
629
- ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
630
- prompt_tokens = prompt_tokens[0]
641
+ if prompt_tokens.shape[0] > num_codebooks:
642
+ logger.warning(
643
+ f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
644
+ )
645
+ prompt_tokens = prompt_tokens[:num_codebooks]
631
646
 
632
- assert prompt_tokens.ndim == 2
633
- data = prompt_tokens + 1
647
+ vq_part = VQPart(codes=prompt_tokens.to(device))
634
648
 
635
- if prompt_tokens.shape[0] > num_codebooks:
636
- logger.warning(
637
- f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
649
+ messages.append(
650
+ Message(
651
+ role="assistant",
652
+ parts=[TextPart(text="<|voice|>"), vq_part],
653
+ cal_loss=False,
654
+ )
655
+ )
656
+ else:
657
+ messages.append(
658
+ Message(
659
+ role="assistant",
660
+ parts=[TextPart(text="<|voice|>")],
661
+ cal_loss=False,
662
+ add_im_end=False,
663
+ )
638
664
  )
639
- data = data[:num_codebooks]
640
-
641
- # Add pad token for each codebook
642
- data = torch.cat(
643
- (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
644
- dim=1,
645
- )
646
665
 
647
- # Since 1.0, we use <|semantic|>
648
- s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
649
- end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
650
- main_token_ids = (
651
- torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
666
+ conversation = Conversation(messages=messages)
667
+ # conversation.visualize(tokenizer)
668
+ encoded = conversation.encode_for_inference(
669
+ tokenizer=tokenizer,
670
+ num_codebooks=num_codebooks,
652
671
  )
653
- main_token_ids[0, -1] = end_token_id
654
-
655
- data = torch.cat((main_token_ids, data), dim=0)
656
- prompt = torch.cat((prompt, data), dim=1)
657
672
 
658
- return prompt
673
+ return encoded.to(device)
659
674
 
660
675
 
661
676
  def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
662
677
  model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
663
- checkpoint_path, load_weights=True
678
+ checkpoint_path, load_weights=True, is_agent=is_agent
664
679
  )
665
680
 
666
681
  model = model.to(device=device, dtype=precision)
@@ -729,11 +744,26 @@ def generate_long(
729
744
 
730
745
  model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
731
746
  tokenizer = model.tokenizer
732
- im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
747
+ im_end_id = tokenizer.get_token_id("<|im_end|>")
733
748
 
734
749
  encoded = []
735
750
  texts = split_text(text, chunk_length) if iterative_prompt else [text]
736
- encoded_prompts = []
751
+ encoded_prompts = [
752
+ Conversation(
753
+ messages=[
754
+ Message(
755
+ role="system",
756
+ parts=[TextPart(text="Speak out the provided text.")],
757
+ cal_loss=False,
758
+ )
759
+ ]
760
+ )
761
+ .encode_for_inference(
762
+ tokenizer=tokenizer,
763
+ num_codebooks=model.config.num_codebooks,
764
+ )
765
+ .to(device)
766
+ ]
737
767
 
738
768
  if use_prompt:
739
769
  for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
@@ -812,7 +842,6 @@ def generate_long(
812
842
  model=model,
813
843
  prompt=cat_encoded,
814
844
  max_new_tokens=max_new_tokens,
815
- im_end_id=im_end_id,
816
845
  decode_one_token=decode_one_token,
817
846
  temperature=temperature,
818
847
  top_p=top_p,
@@ -842,12 +871,11 @@ def generate_long(
842
871
  )
843
872
 
844
873
  # Put the generated tokens
845
- # since there is <im_end> and <eos> tokens, we remove last 2 tokens
846
- codes = y[1:, prompt_length:-1].clone()
847
- codes = codes - 1
874
+ # since there is <im_end>, we remove last token
875
+ codes = y[1:, prompt_length + 1 :].clone()
848
876
  assert (codes >= 0).all(), f"Negative code found"
849
877
 
850
- decoded = y[:, prompt_length:-1].clone()
878
+ decoded = y[:, prompt_length:].clone()
851
879
  # But for global encoding, we should keep the <im_end> token
852
880
 
853
881
  global_encoded.append(decoded)
@@ -0,0 +1,104 @@
1
+ import os
2
+ from argparse import ArgumentParser
3
+ from pathlib import Path
4
+
5
+ import pyrootutils
6
+ import torch
7
+ from loguru import logger
8
+
9
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
10
+
11
+ from tools.inference_engine import TTSInferenceEngine
12
+ from tools.llama.generate import launch_thread_safe_queue
13
+ from tools.schema import ServeTTSRequest
14
+ from tools.vqgan.inference import load_model as load_decoder_model
15
+ from tools.webui import build_app
16
+ from tools.webui.inference import get_inference_wrapper
17
+
18
+ # Make einx happy
19
+ os.environ["EINX_FILTER_TRACEBACK"] = "false"
20
+
21
+
22
+ def parse_args():
23
+ parser = ArgumentParser()
24
+ parser.add_argument(
25
+ "--llama-checkpoint-path",
26
+ type=Path,
27
+ default="checkpoints/fish-speech-1.5",
28
+ )
29
+ parser.add_argument(
30
+ "--decoder-checkpoint-path",
31
+ type=Path,
32
+ default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
33
+ )
34
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
35
+ parser.add_argument("--device", type=str, default="cuda")
36
+ parser.add_argument("--half", action="store_true")
37
+ parser.add_argument("--compile", action="store_true")
38
+ parser.add_argument("--max-gradio-length", type=int, default=0)
39
+ parser.add_argument("--theme", type=str, default="light")
40
+
41
+ return parser.parse_args()
42
+
43
+
44
+ if __name__ == "__main__":
45
+ args = parse_args()
46
+ args.precision = torch.half if args.half else torch.bfloat16
47
+
48
+ # Check if MPS or CUDA is available
49
+ if torch.backends.mps.is_available():
50
+ args.device = "mps"
51
+ logger.info("mps is available, running on mps.")
52
+ elif not torch.cuda.is_available():
53
+ logger.info("CUDA is not available, running on CPU.")
54
+ args.device = "cpu"
55
+
56
+ logger.info("Loading Llama model...")
57
+ llama_queue = launch_thread_safe_queue(
58
+ checkpoint_path=args.llama_checkpoint_path,
59
+ device=args.device,
60
+ precision=args.precision,
61
+ compile=args.compile,
62
+ )
63
+
64
+ logger.info("Loading VQ-GAN model...")
65
+ decoder_model = load_decoder_model(
66
+ config_name=args.decoder_config_name,
67
+ checkpoint_path=args.decoder_checkpoint_path,
68
+ device=args.device,
69
+ )
70
+
71
+ logger.info("Decoder model loaded, warming up...")
72
+
73
+ # Create the inference engine
74
+ inference_engine = TTSInferenceEngine(
75
+ llama_queue=llama_queue,
76
+ decoder_model=decoder_model,
77
+ compile=args.compile,
78
+ precision=args.precision,
79
+ )
80
+
81
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
82
+ list(
83
+ inference_engine.inference(
84
+ ServeTTSRequest(
85
+ text="Hello world.",
86
+ references=[],
87
+ reference_id=None,
88
+ max_new_tokens=1024,
89
+ chunk_length=200,
90
+ top_p=0.7,
91
+ repetition_penalty=1.5,
92
+ temperature=0.7,
93
+ format="wav",
94
+ )
95
+ )
96
+ )
97
+
98
+ logger.info("Warming up done, launching the web UI...")
99
+
100
+ # Get the inference function with the immutable arguments
101
+ inference_fct = get_inference_wrapper(inference_engine)
102
+
103
+ app = build_app(inference_fct, args.theme)
104
+ app.launch(show_api=True)
@@ -1,16 +1,14 @@
1
1
  import os
2
2
  import queue
3
3
  from dataclasses import dataclass
4
- from typing import Annotated, Literal, Optional
4
+ from typing import Annotated, Literal
5
5
 
6
6
  import torch
7
- from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
7
+ from pydantic import BaseModel, Field, conint, conlist
8
8
  from pydantic.functional_validators import SkipValidation
9
9
 
10
10
  from fish_speech.conversation import Message, TextPart, VQPart
11
11
 
12
- GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
13
-
14
12
 
15
13
  class ServeVQPart(BaseModel):
16
14
  type: Literal["vq"] = "vq"
@@ -69,6 +67,9 @@ class ServeMessage(BaseModel):
69
67
 
70
68
  def to_conversation_message(self):
71
69
  new_message = Message(role=self.role, parts=[])
70
+ if self.role == "assistant":
71
+ new_message.modality = "voice"
72
+
72
73
  for part in self.parts:
73
74
  if isinstance(part, ServeTextPart):
74
75
  new_message.parts.append(TextPart(text=part.text))
@@ -82,7 +83,7 @@ class ServeMessage(BaseModel):
82
83
  return new_message
83
84
 
84
85
 
85
- class ServeRequest(BaseModel):
86
+ class ServeChatRequest(BaseModel):
86
87
  messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
87
88
  max_new_tokens: int = 1024
88
89
  top_p: float = 0.7
@@ -111,11 +112,6 @@ class ServeVQGANDecodeResponse(BaseModel):
111
112
  audios: list[bytes]
112
113
 
113
114
 
114
- class ServeReferenceAudio(BaseModel):
115
- audio: bytes
116
- text: str
117
-
118
-
119
115
  class ServeForwardMessage(BaseModel):
120
116
  role: str
121
117
  content: str
@@ -147,24 +143,11 @@ class ServeReferenceAudio(BaseModel):
147
143
  return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
148
144
 
149
145
 
150
- class ServeChatRequestV1(BaseModel):
151
- model: str = "llama3-8b"
152
- messages: list[ServeForwardMessage] = []
153
- audio: bytes | None = None
154
- temperature: float = 1.0
155
- top_p: float = 1.0
156
- max_tokens: int = 256
157
- voice: str = "jessica"
158
- tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
159
- tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
160
-
161
-
162
146
  class ServeTTSRequest(BaseModel):
163
147
  text: str
164
148
  chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
165
149
  # Audio format
166
150
  format: Literal["wav", "pcm", "mp3"] = "wav"
167
- mp3_bitrate: Literal[64, 128, 192] = 128
168
151
  # References audios for in-context learning
169
152
  references: list[ServeReferenceAudio] = []
170
153
  # Reference id
@@ -172,16 +155,16 @@ class ServeTTSRequest(BaseModel):
172
155
  # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
173
156
  reference_id: str | None = None
174
157
  seed: int | None = None
175
- use_memory_cache: Literal["on-demand", "never"] = "never"
158
+ use_memory_cache: Literal["on", "off"] = "off"
176
159
  # Normalize text for en & zh, this increase stability for numbers
177
160
  normalize: bool = True
178
- mp3_bitrate: Optional[int] = 64
179
- opus_bitrate: Optional[int] = -1000
180
- # Balance mode will reduce latency to 300ms, but may decrease stability
181
- latency: Literal["normal", "balanced"] = "normal"
182
161
  # not usually used below
183
162
  streaming: bool = False
184
163
  max_new_tokens: int = 1024
185
164
  top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
186
165
  repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
187
166
  temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
167
+
168
+ class Config:
169
+ # Allow arbitrary types for pytorch related types
170
+ arbitrary_types_allowed = True
@@ -0,0 +1,57 @@
1
+ import struct
2
+ from functools import partial
3
+
4
+ import ormsgpack
5
+
6
+ from tools.server.agent.generate import generate_responses
7
+ from tools.server.agent.pre_generation_utils import prepare_messages
8
+
9
+
10
+ def execute_request(input_queue, tokenizer, config, request, device):
11
+ """
12
+ This function prepares the conversation, encodes the request,
13
+ sends the generation request, and handles decoding/streaming.
14
+ It returns a response generator (ServeResponse or ServeStreamResponse).
15
+ """
16
+ prompt, im_end_id = prepare_messages(request, tokenizer, config)
17
+ yield from generate_responses(
18
+ input_queue, tokenizer, config, request, prompt, im_end_id, device
19
+ )
20
+
21
+
22
+ def response_generator(req, llama_queue, tokenizer, config, device):
23
+ """
24
+ Non-streaming response wrapper for the chat endpoint.
25
+ Only returns the final result.
26
+ """
27
+ generator = execute_request(llama_queue, tokenizer, config, req, device)
28
+ return next(generator)
29
+
30
+
31
+ async def streaming_generator(req, llama_queue, tokenizer, config, device, json_mode):
32
+ """
33
+ Streaming response wrapper for the chat endpoint.
34
+ Returns the response in chunks.
35
+ """
36
+ generator = execute_request(llama_queue, tokenizer, config, req, device)
37
+ for i in generator:
38
+ if json_mode:
39
+ body = i.model_dump_json().encode("utf-8")
40
+ yield b"data: " + body + b"\n\n"
41
+ else:
42
+ body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
43
+ yield struct.pack("I", len(body)) + body
44
+
45
+
46
+ def get_response_generator(
47
+ llama_queue, tokenizer, config, req, device, json_mode
48
+ ) -> partial:
49
+ """
50
+ Get the correct response generator based on the request.
51
+ """
52
+ if not req.streaming:
53
+ return partial(response_generator, req, llama_queue, tokenizer, config, device)
54
+ else:
55
+ return partial(
56
+ streaming_generator, req, llama_queue, tokenizer, config, device, json_mode
57
+ )