xinference 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (94) hide show
  1. xinference/_compat.py +22 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +91 -6
  4. xinference/client/restful/restful_client.py +39 -0
  5. xinference/core/model.py +41 -13
  6. xinference/deploy/cmdline.py +3 -1
  7. xinference/deploy/test/test_cmdline.py +56 -0
  8. xinference/isolation.py +24 -0
  9. xinference/model/audio/__init__.py +12 -0
  10. xinference/model/audio/core.py +26 -4
  11. xinference/model/audio/f5tts.py +195 -0
  12. xinference/model/audio/fish_speech.py +71 -35
  13. xinference/model/audio/model_spec.json +88 -0
  14. xinference/model/audio/model_spec_modelscope.json +9 -0
  15. xinference/model/audio/whisper_mlx.py +208 -0
  16. xinference/model/embedding/core.py +322 -6
  17. xinference/model/embedding/model_spec.json +8 -1
  18. xinference/model/embedding/model_spec_modelscope.json +9 -1
  19. xinference/model/llm/__init__.py +4 -2
  20. xinference/model/llm/llm_family.json +479 -53
  21. xinference/model/llm/llm_family_modelscope.json +423 -17
  22. xinference/model/llm/mlx/core.py +230 -50
  23. xinference/model/llm/sglang/core.py +2 -0
  24. xinference/model/llm/transformers/chatglm.py +9 -5
  25. xinference/model/llm/transformers/core.py +1 -0
  26. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  27. xinference/model/llm/transformers/utils.py +16 -8
  28. xinference/model/llm/utils.py +23 -1
  29. xinference/model/llm/vllm/core.py +89 -2
  30. xinference/thirdparty/f5_tts/__init__.py +0 -0
  31. xinference/thirdparty/f5_tts/api.py +166 -0
  32. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  33. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  34. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  35. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  36. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  37. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  38. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  39. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  40. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  41. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  42. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  43. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  44. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  45. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  46. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  47. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  48. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  49. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  50. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  51. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  52. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  53. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  54. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  55. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  56. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  57. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  58. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  59. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  60. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  61. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  62. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  63. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  64. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  65. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  66. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  67. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  68. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  69. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  70. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  71. xinference/thirdparty/f5_tts/train/README.md +77 -0
  72. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  73. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  74. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  75. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  76. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  77. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  78. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  79. xinference/thirdparty/f5_tts/train/train.py +75 -0
  80. xinference/types.py +2 -1
  81. xinference/web/ui/build/asset-manifest.json +3 -3
  82. xinference/web/ui/build/index.html +1 -1
  83. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  84. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  86. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/METADATA +39 -18
  87. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/RECORD +92 -39
  88. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/WHEEL +1 -1
  89. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  91. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  92. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/LICENSE +0 -0
  93. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/entry_points.txt +0 -0
  94. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/top_level.txt +0 -0
@@ -156,6 +156,7 @@ def _get_completion(
156
156
  finish_reason: Optional[str],
157
157
  model_uid: str,
158
158
  r: InferenceRequest,
159
+ completion_tokens: int,
159
160
  ):
160
161
  completion_choice = CompletionChoice(
161
162
  text=output, index=0, logprobs=None, finish_reason=finish_reason
@@ -170,8 +171,8 @@ def _get_completion(
170
171
  )
171
172
  completion_usage = CompletionUsage(
172
173
  prompt_tokens=len(r.prompt_tokens),
173
- completion_tokens=len(r.new_tokens),
174
- total_tokens=len(r.prompt_tokens) + len(r.new_tokens),
174
+ completion_tokens=completion_tokens,
175
+ total_tokens=len(r.prompt_tokens) + completion_tokens,
175
176
  )
176
177
  completion = Completion(
177
178
  id=completion_chunk["id"],
@@ -371,7 +372,7 @@ def _batch_inference_one_step_internal(
371
372
  r.stopped = stopped
372
373
  r.finish_reason = finish_reason
373
374
 
374
- if r.stopped and r not in stop_token_mapping and r not in output_mapping:
375
+ if r.stopped and r not in stop_token_mapping:
375
376
  stop_token_mapping[r] = _i + 1
376
377
 
377
378
  if r.stream:
@@ -446,12 +447,14 @@ def _batch_inference_one_step_internal(
446
447
  else:
447
448
  # last round, handle non-stream result
448
449
  if r.stopped and _i == decode_round - 1:
449
- invalid_token_num = decode_round - stop_token_mapping[r]
450
+ invalid_token_num = (
451
+ (decode_round - stop_token_mapping[r] + 1)
452
+ if r.finish_reason == "stop"
453
+ else (decode_round - stop_token_mapping[r])
454
+ )
450
455
  outputs = (
451
456
  tokenizer.decode(
452
- r.new_tokens[: -(invalid_token_num + 1)]
453
- if r.finish_reason == "stop"
454
- else r.new_tokens[:-invalid_token_num],
457
+ r.new_tokens[:-invalid_token_num],
455
458
  skip_special_tokens=True,
456
459
  spaces_between_special_tokens=False,
457
460
  clean_up_tokenization_spaces=True,
@@ -460,7 +463,12 @@ def _batch_inference_one_step_internal(
460
463
  else output_mapping[r]
461
464
  )
462
465
  completion = _get_completion(
463
- outputs, r.chunk_id, r.finish_reason, model_uid, r
466
+ outputs,
467
+ r.chunk_id,
468
+ r.finish_reason,
469
+ model_uid,
470
+ r,
471
+ len(r.new_tokens) - invalid_token_num,
464
472
  )
465
473
  r.completion = [completion]
466
474
 
@@ -324,7 +324,10 @@ class ChatModelMixin:
324
324
  """
325
325
  try:
326
326
  if isinstance(c, dict):
327
- return [(None, c["name"], c["arguments"])]
327
+ try:
328
+ return [(None, c["name"], json.loads(c["arguments"]))]
329
+ except Exception:
330
+ return [(None, c["name"], c["arguments"])]
328
331
  except KeyError:
329
332
  logger.error("Can't parse glm output: %s", c)
330
333
  return [(str(c), None, None)]
@@ -569,6 +572,25 @@ def _decode_image(_url):
569
572
  return Image.open(BytesIO(response.content)).convert("RGB")
570
573
 
571
574
 
575
+ def _decode_image_without_rgb(_url):
576
+ if _url.startswith("data:"):
577
+ logging.info("Parse url by base64 decoder.")
578
+ # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
579
+ # e.g. f"data:image/jpeg;base64,{base64_image}"
580
+ _type, data = _url.split(";")
581
+ _, ext = _type.split("/")
582
+ data = data[len("base64,") :]
583
+ data = base64.b64decode(data.encode("utf-8"))
584
+ return Image.open(BytesIO(data))
585
+ else:
586
+ try:
587
+ response = requests.get(_url)
588
+ except requests.exceptions.MissingSchema:
589
+ return Image.open(_url)
590
+ else:
591
+ return Image.open(BytesIO(response.content))
592
+
593
+
572
594
  @typing.no_type_check
573
595
  def generate_completion_chunk(
574
596
  chunk_text: Optional[str],
@@ -69,6 +69,7 @@ class VLLMModelConfig(TypedDict, total=False):
69
69
  quantization: Optional[str]
70
70
  max_model_len: Optional[int]
71
71
  limit_mm_per_prompt: Optional[Dict[str, int]]
72
+ guided_decoding_backend: Optional[str]
72
73
 
73
74
 
74
75
  class VLLMGenerateConfig(TypedDict, total=False):
@@ -85,6 +86,15 @@ class VLLMGenerateConfig(TypedDict, total=False):
85
86
  stop: Optional[Union[str, List[str]]]
86
87
  stream: bool # non-sampling param, should not be passed to the engine.
87
88
  stream_options: Optional[Union[dict, None]]
89
+ skip_special_tokens: Optional[bool]
90
+ response_format: Optional[dict]
91
+ guided_json: Optional[Union[str, dict]]
92
+ guided_regex: Optional[str]
93
+ guided_choice: Optional[List[str]]
94
+ guided_grammar: Optional[str]
95
+ guided_json_object: Optional[bool]
96
+ guided_decoding_backend: Optional[str]
97
+ guided_whitespace_pattern: Optional[str]
88
98
 
89
99
 
90
100
  try:
@@ -144,6 +154,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
144
154
  VLLM_SUPPORTED_CHAT_MODELS.append("qwen2.5-instruct")
145
155
  VLLM_SUPPORTED_MODELS.append("qwen2.5-coder")
146
156
  VLLM_SUPPORTED_CHAT_MODELS.append("qwen2.5-coder-instruct")
157
+ VLLM_SUPPORTED_CHAT_MODELS.append("QwQ-32B-Preview")
147
158
 
148
159
 
149
160
  if VLLM_INSTALLED and vllm.__version__ >= "0.3.2":
@@ -171,6 +182,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.5.3":
171
182
  if VLLM_INSTALLED and vllm.__version__ > "0.5.3":
172
183
  VLLM_SUPPORTED_MODELS.append("llama-3.1")
173
184
  VLLM_SUPPORTED_CHAT_MODELS.append("llama-3.1-instruct")
185
+ VLLM_SUPPORTED_CHAT_MODELS.append("llama-3.3-instruct")
174
186
 
175
187
  if VLLM_INSTALLED and vllm.__version__ >= "0.6.1":
176
188
  VLLM_SUPPORTED_VISION_MODEL_LIST.append("internvl2")
@@ -314,6 +326,7 @@ class VLLMModel(LLM):
314
326
  model_config.setdefault("max_num_seqs", 256)
315
327
  model_config.setdefault("quantization", None)
316
328
  model_config.setdefault("max_model_len", None)
329
+ model_config.setdefault("guided_decoding_backend", "outlines")
317
330
 
318
331
  return model_config
319
332
 
@@ -325,6 +338,22 @@ class VLLMModel(LLM):
325
338
  generate_config = {}
326
339
 
327
340
  sanitized = VLLMGenerateConfig()
341
+
342
+ response_format = generate_config.pop("response_format", None)
343
+ guided_decoding_backend = generate_config.get("guided_decoding_backend", None)
344
+ guided_json_object = None
345
+ guided_json = None
346
+
347
+ if response_format is not None:
348
+ if response_format.get("type") == "json_object":
349
+ guided_json_object = True
350
+ elif response_format.get("type") == "json_schema":
351
+ json_schema = response_format.get("json_schema")
352
+ assert json_schema is not None
353
+ guided_json = json_schema.get("json_schema")
354
+ if guided_decoding_backend is None:
355
+ guided_decoding_backend = "outlines"
356
+
328
357
  sanitized.setdefault("lora_name", generate_config.get("lora_name", None))
329
358
  sanitized.setdefault("n", generate_config.get("n", 1))
330
359
  sanitized.setdefault("best_of", generate_config.get("best_of", None))
@@ -346,6 +375,31 @@ class VLLMModel(LLM):
346
375
  sanitized.setdefault(
347
376
  "stream_options", generate_config.get("stream_options", None)
348
377
  )
378
+ sanitized.setdefault(
379
+ "skip_special_tokens", generate_config.get("skip_special_tokens", True)
380
+ )
381
+ sanitized.setdefault(
382
+ "guided_json", generate_config.get("guided_json", guided_json)
383
+ )
384
+ sanitized.setdefault("guided_regex", generate_config.get("guided_regex", None))
385
+ sanitized.setdefault(
386
+ "guided_choice", generate_config.get("guided_choice", None)
387
+ )
388
+ sanitized.setdefault(
389
+ "guided_grammar", generate_config.get("guided_grammar", None)
390
+ )
391
+ sanitized.setdefault(
392
+ "guided_whitespace_pattern",
393
+ generate_config.get("guided_whitespace_pattern", None),
394
+ )
395
+ sanitized.setdefault(
396
+ "guided_json_object",
397
+ generate_config.get("guided_json_object", guided_json_object),
398
+ )
399
+ sanitized.setdefault(
400
+ "guided_decoding_backend",
401
+ generate_config.get("guided_decoding_backend", guided_decoding_backend),
402
+ )
349
403
 
350
404
  return sanitized
351
405
 
@@ -483,13 +537,46 @@ class VLLMModel(LLM):
483
537
  if isinstance(stream_options, dict)
484
538
  else False
485
539
  )
486
- sampling_params = SamplingParams(**sanitized_generate_config)
540
+
541
+ if VLLM_INSTALLED and vllm.__version__ >= "0.6.3":
542
+ # guided decoding only available for vllm >= 0.6.3
543
+ from vllm.sampling_params import GuidedDecodingParams
544
+
545
+ guided_options = GuidedDecodingParams.from_optional(
546
+ json=sanitized_generate_config.pop("guided_json", None),
547
+ regex=sanitized_generate_config.pop("guided_regex", None),
548
+ choice=sanitized_generate_config.pop("guided_choice", None),
549
+ grammar=sanitized_generate_config.pop("guided_grammar", None),
550
+ json_object=sanitized_generate_config.pop("guided_json_object", None),
551
+ backend=sanitized_generate_config.pop("guided_decoding_backend", None),
552
+ whitespace_pattern=sanitized_generate_config.pop(
553
+ "guided_whitespace_pattern", None
554
+ ),
555
+ )
556
+
557
+ sampling_params = SamplingParams(
558
+ guided_decoding=guided_options, **sanitized_generate_config
559
+ )
560
+ else:
561
+ # ignore generate configs
562
+ sanitized_generate_config.pop("guided_json", None)
563
+ sanitized_generate_config.pop("guided_regex", None)
564
+ sanitized_generate_config.pop("guided_choice", None)
565
+ sanitized_generate_config.pop("guided_grammar", None)
566
+ sanitized_generate_config.pop("guided_json_object", None)
567
+ sanitized_generate_config.pop("guided_decoding_backend", None)
568
+ sanitized_generate_config.pop("guided_whitespace_pattern", None)
569
+ sampling_params = SamplingParams(**sanitized_generate_config)
570
+
487
571
  if not request_id:
488
572
  request_id = str(uuid.uuid1())
489
573
 
490
574
  assert self._engine is not None
491
575
  results_generator = self._engine.generate(
492
- prompt, sampling_params, request_id, lora_request=lora_request
576
+ prompt,
577
+ sampling_params,
578
+ request_id,
579
+ lora_request,
493
580
  )
494
581
 
495
582
  async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
File without changes
@@ -0,0 +1,166 @@
1
+ import random
2
+ import sys
3
+ from importlib.resources import files
4
+
5
+ import soundfile as sf
6
+ import tqdm
7
+ from cached_path import cached_path
8
+
9
+ from f5_tts.infer.utils_infer import (
10
+ hop_length,
11
+ infer_process,
12
+ load_model,
13
+ load_vocoder,
14
+ preprocess_ref_audio_text,
15
+ remove_silence_for_generated_wav,
16
+ save_spectrogram,
17
+ transcribe,
18
+ target_sample_rate,
19
+ )
20
+ from f5_tts.model import DiT, UNetT
21
+ from f5_tts.model.utils import seed_everything
22
+
23
+
24
+ class F5TTS:
25
+ def __init__(
26
+ self,
27
+ model_type="F5-TTS",
28
+ ckpt_file="",
29
+ vocab_file="",
30
+ ode_method="euler",
31
+ use_ema=True,
32
+ vocoder_name="vocos",
33
+ local_path=None,
34
+ device=None,
35
+ hf_cache_dir=None,
36
+ ):
37
+ # Initialize parameters
38
+ self.final_wave = None
39
+ self.target_sample_rate = target_sample_rate
40
+ self.hop_length = hop_length
41
+ self.seed = -1
42
+ self.mel_spec_type = vocoder_name
43
+
44
+ # Set device
45
+ if device is not None:
46
+ self.device = device
47
+ else:
48
+ import torch
49
+
50
+ self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
51
+
52
+ # Load models
53
+ self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
54
+ self.load_ema_model(
55
+ model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
56
+ )
57
+
58
+ def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
59
+ self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
60
+
61
+ def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
62
+ if model_type == "F5-TTS":
63
+ if not ckpt_file:
64
+ if mel_spec_type == "vocos":
65
+ ckpt_file = str(
66
+ cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
67
+ )
68
+ elif mel_spec_type == "bigvgan":
69
+ ckpt_file = str(
70
+ cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
71
+ )
72
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
73
+ model_cls = DiT
74
+ elif model_type == "E2-TTS":
75
+ if not ckpt_file:
76
+ ckpt_file = str(
77
+ cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
78
+ )
79
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
80
+ model_cls = UNetT
81
+ else:
82
+ raise ValueError(f"Unknown model type: {model_type}")
83
+
84
+ self.ema_model = load_model(
85
+ model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
86
+ )
87
+
88
+ def transcribe(self, ref_audio, language=None):
89
+ return transcribe(ref_audio, language)
90
+
91
+ def export_wav(self, wav, file_wave, remove_silence=False):
92
+ sf.write(file_wave, wav, self.target_sample_rate)
93
+
94
+ if remove_silence:
95
+ remove_silence_for_generated_wav(file_wave)
96
+
97
+ def export_spectrogram(self, spect, file_spect):
98
+ save_spectrogram(spect, file_spect)
99
+
100
+ def infer(
101
+ self,
102
+ ref_file,
103
+ ref_text,
104
+ gen_text,
105
+ show_info=print,
106
+ progress=tqdm,
107
+ target_rms=0.1,
108
+ cross_fade_duration=0.15,
109
+ sway_sampling_coef=-1,
110
+ cfg_strength=2,
111
+ nfe_step=32,
112
+ speed=1.0,
113
+ fix_duration=None,
114
+ remove_silence=False,
115
+ file_wave=None,
116
+ file_spect=None,
117
+ seed=-1,
118
+ ):
119
+ if seed == -1:
120
+ seed = random.randint(0, sys.maxsize)
121
+ seed_everything(seed)
122
+ self.seed = seed
123
+
124
+ ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
125
+
126
+ wav, sr, spect = infer_process(
127
+ ref_file,
128
+ ref_text,
129
+ gen_text,
130
+ self.ema_model,
131
+ self.vocoder,
132
+ self.mel_spec_type,
133
+ show_info=show_info,
134
+ progress=progress,
135
+ target_rms=target_rms,
136
+ cross_fade_duration=cross_fade_duration,
137
+ nfe_step=nfe_step,
138
+ cfg_strength=cfg_strength,
139
+ sway_sampling_coef=sway_sampling_coef,
140
+ speed=speed,
141
+ fix_duration=fix_duration,
142
+ device=self.device,
143
+ )
144
+
145
+ if file_wave is not None:
146
+ self.export_wav(wav, file_wave, remove_silence)
147
+
148
+ if file_spect is not None:
149
+ self.export_spectrogram(spect, file_spect)
150
+
151
+ return wav, sr, spect
152
+
153
+
154
+ if __name__ == "__main__":
155
+ f5tts = F5TTS()
156
+
157
+ wav, sr, spect = f5tts.infer(
158
+ ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
159
+ ref_text="some call me nature, others call me mother nature.",
160
+ gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
161
+ file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
162
+ file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
163
+ seed=-1, # random seed = -1
164
+ )
165
+
166
+ print("seed :", f5tts.seed)
@@ -0,0 +1,44 @@
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN # dataset name
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 15
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup steps
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
+
20
+ model:
21
+ name: E2TTS_Base
22
+ tokenizer: pinyin
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
24
+ arch:
25
+ dim: 1024
26
+ depth: 24
27
+ heads: 16
28
+ ff_mult: 4
29
+ mel_spec:
30
+ target_sample_rate: 24000
31
+ n_mel_channels: 100
32
+ hop_length: 256
33
+ win_length: 1024
34
+ n_fft: 1024
35
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
36
+ vocoder:
37
+ is_local: False # use local offline ckpt or not
38
+ local_path: None # local vocoder path
39
+
40
+ ckpts:
41
+ logger: wandb # wandb | tensorboard | None
42
+ save_per_updates: 50000 # save checkpoint per steps
43
+ last_per_steps: 5000 # save last checkpoint per steps
44
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
@@ -0,0 +1,44 @@
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 15
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup steps
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0
18
+ bnb_optimizer: False
19
+
20
+ model:
21
+ name: E2TTS_Small
22
+ tokenizer: pinyin
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
24
+ arch:
25
+ dim: 768
26
+ depth: 20
27
+ heads: 12
28
+ ff_mult: 4
29
+ mel_spec:
30
+ target_sample_rate: 24000
31
+ n_mel_channels: 100
32
+ hop_length: 256
33
+ win_length: 1024
34
+ n_fft: 1024
35
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
36
+ vocoder:
37
+ is_local: False # use local offline ckpt or not
38
+ local_path: None # local vocoder path
39
+
40
+ ckpts:
41
+ logger: wandb # wandb | tensorboard | None
42
+ save_per_updates: 50000 # save checkpoint per steps
43
+ last_per_steps: 5000 # save last checkpoint per steps
44
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
@@ -0,0 +1,46 @@
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN # dataset name
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 15
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup steps
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
+
20
+ model:
21
+ name: F5TTS_Base # model name
22
+ tokenizer: pinyin # tokenizer type
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
24
+ arch:
25
+ dim: 1024
26
+ depth: 22
27
+ heads: 16
28
+ ff_mult: 2
29
+ text_dim: 512
30
+ conv_layers: 4
31
+ mel_spec:
32
+ target_sample_rate: 24000
33
+ n_mel_channels: 100
34
+ hop_length: 256
35
+ win_length: 1024
36
+ n_fft: 1024
37
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
38
+ vocoder:
39
+ is_local: False # use local offline ckpt or not
40
+ local_path: None # local vocoder path
41
+
42
+ ckpts:
43
+ logger: wandb # wandb | tensorboard | None
44
+ save_per_updates: 50000 # save checkpoint per steps
45
+ last_per_steps: 5000 # save last checkpoint per steps
46
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
@@ -0,0 +1,46 @@
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 15
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup steps
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
+
20
+ model:
21
+ name: F5TTS_Small
22
+ tokenizer: pinyin
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
24
+ arch:
25
+ dim: 768
26
+ depth: 18
27
+ heads: 12
28
+ ff_mult: 2
29
+ text_dim: 512
30
+ conv_layers: 4
31
+ mel_spec:
32
+ target_sample_rate: 24000
33
+ n_mel_channels: 100
34
+ hop_length: 256
35
+ win_length: 1024
36
+ n_fft: 1024
37
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
38
+ vocoder:
39
+ is_local: False # use local offline ckpt or not
40
+ local_path: None # local vocoder path
41
+
42
+ ckpts:
43
+ logger: wandb # wandb | tensorboard | None
44
+ save_per_updates: 50000 # save checkpoint per steps
45
+ last_per_steps: 5000 # save last checkpoint per steps
46
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
@@ -0,0 +1,49 @@
1
+
2
+ # Evaluation
3
+
4
+ Install packages for evaluation:
5
+
6
+ ```bash
7
+ pip install -e .[eval]
8
+ ```
9
+
10
+ ## Generating Samples for Evaluation
11
+
12
+ ### Prepare Test Datasets
13
+
14
+ 1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
15
+ 2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
16
+ 3. Unzip the downloaded datasets and place them in the `data/` directory.
17
+ 4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py`
18
+ 5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
19
+
20
+ ### Batch Inference for Test Set
21
+
22
+ To run batch inference for evaluations, execute the following commands:
23
+
24
+ ```bash
25
+ # batch inference for evaluations
26
+ accelerate config # if not set before
27
+ bash src/f5_tts/eval/eval_infer_batch.sh
28
+ ```
29
+
30
+ ## Objective Evaluation on Generated Results
31
+
32
+ ### Download Evaluation Model Checkpoints
33
+
34
+ 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
35
+ 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
36
+ 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
37
+
38
+ Then update in the following scripts with the paths you put evaluation model ckpts to.
39
+
40
+ ### Objective Evaluation
41
+
42
+ Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
43
+ ```bash
44
+ # Evaluation for Seed-TTS test set
45
+ python src/f5_tts/eval/eval_seedtts_testset.py --gen_wav_dir <GEN_WAVE_DIR>
46
+
47
+ # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
48
+ python src/f5_tts/eval/eval_librispeech_test_clean.py --gen_wav_dir <GEN_WAVE_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
49
+ ```