xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (170) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +28 -6
  4. xinference/core/utils.py +10 -6
  5. xinference/deploy/cmdline.py +3 -1
  6. xinference/deploy/test/test_cmdline.py +56 -0
  7. xinference/isolation.py +24 -0
  8. xinference/model/audio/core.py +10 -0
  9. xinference/model/audio/cosyvoice.py +25 -3
  10. xinference/model/audio/f5tts.py +200 -0
  11. xinference/model/audio/f5tts_mlx.py +260 -0
  12. xinference/model/audio/fish_speech.py +36 -111
  13. xinference/model/audio/model_spec.json +27 -3
  14. xinference/model/audio/model_spec_modelscope.json +18 -0
  15. xinference/model/audio/utils.py +32 -0
  16. xinference/model/embedding/core.py +203 -142
  17. xinference/model/embedding/model_spec.json +7 -0
  18. xinference/model/embedding/model_spec_modelscope.json +8 -0
  19. xinference/model/image/core.py +69 -1
  20. xinference/model/image/model_spec.json +127 -4
  21. xinference/model/image/model_spec_modelscope.json +130 -4
  22. xinference/model/image/stable_diffusion/core.py +45 -13
  23. xinference/model/llm/__init__.py +2 -2
  24. xinference/model/llm/llm_family.json +219 -53
  25. xinference/model/llm/llm_family.py +15 -36
  26. xinference/model/llm/llm_family_modelscope.json +167 -20
  27. xinference/model/llm/mlx/core.py +287 -51
  28. xinference/model/llm/sglang/core.py +1 -0
  29. xinference/model/llm/transformers/chatglm.py +9 -5
  30. xinference/model/llm/transformers/core.py +1 -0
  31. xinference/model/llm/transformers/qwen2_vl.py +2 -0
  32. xinference/model/llm/transformers/utils.py +16 -8
  33. xinference/model/llm/utils.py +5 -1
  34. xinference/model/llm/vllm/core.py +16 -2
  35. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  36. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  37. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  38. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  39. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  40. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  41. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  42. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  43. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  44. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  45. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  46. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  47. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  48. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  49. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  50. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  51. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  52. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  53. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  54. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  55. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  56. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  57. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  58. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  59. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  60. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  61. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  62. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  63. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  64. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  65. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  66. xinference/thirdparty/f5_tts/api.py +166 -0
  67. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  68. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  69. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  70. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  71. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  72. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  73. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  74. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  75. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  76. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  77. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  78. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  79. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  80. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  81. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  82. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  83. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  84. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  85. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  86. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  87. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  88. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  89. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  90. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  91. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  92. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  93. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  94. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  95. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  96. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  97. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  98. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  99. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  100. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  101. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  102. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  103. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  104. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  105. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  106. xinference/thirdparty/f5_tts/train/README.md +77 -0
  107. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  108. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  109. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  110. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  111. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  112. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  113. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  114. xinference/thirdparty/f5_tts/train/train.py +75 -0
  115. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  116. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  117. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  118. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  119. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  120. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  121. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  122. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  123. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  124. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  125. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  126. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  127. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  128. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  129. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  130. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  131. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  132. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  133. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  134. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  135. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  136. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  137. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  138. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  139. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  140. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  141. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  142. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  143. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  144. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  145. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  146. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  147. xinference/thirdparty/matcha/utils/utils.py +2 -2
  148. xinference/web/ui/build/asset-manifest.json +3 -3
  149. xinference/web/ui/build/index.html +1 -1
  150. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  151. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  153. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
  154. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
  155. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  156. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  157. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  158. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  159. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  160. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  161. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  162. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  163. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  164. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  165. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  166. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  167. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
  168. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
  169. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
  170. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,44 @@
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN # dataset name
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 15
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup steps
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
+
20
+ model:
21
+ name: E2TTS_Base
22
+ tokenizer: pinyin
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
24
+ arch:
25
+ dim: 1024
26
+ depth: 24
27
+ heads: 16
28
+ ff_mult: 4
29
+ mel_spec:
30
+ target_sample_rate: 24000
31
+ n_mel_channels: 100
32
+ hop_length: 256
33
+ win_length: 1024
34
+ n_fft: 1024
35
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
36
+ vocoder:
37
+ is_local: False # use local offline ckpt or not
38
+ local_path: None # local vocoder path
39
+
40
+ ckpts:
41
+ logger: wandb # wandb | tensorboard | None
42
+ save_per_updates: 50000 # save checkpoint per steps
43
+ last_per_steps: 5000 # save last checkpoint per steps
44
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
@@ -0,0 +1,44 @@
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 15
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup steps
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0
18
+ bnb_optimizer: False
19
+
20
+ model:
21
+ name: E2TTS_Small
22
+ tokenizer: pinyin
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
24
+ arch:
25
+ dim: 768
26
+ depth: 20
27
+ heads: 12
28
+ ff_mult: 4
29
+ mel_spec:
30
+ target_sample_rate: 24000
31
+ n_mel_channels: 100
32
+ hop_length: 256
33
+ win_length: 1024
34
+ n_fft: 1024
35
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
36
+ vocoder:
37
+ is_local: False # use local offline ckpt or not
38
+ local_path: None # local vocoder path
39
+
40
+ ckpts:
41
+ logger: wandb # wandb | tensorboard | None
42
+ save_per_updates: 50000 # save checkpoint per steps
43
+ last_per_steps: 5000 # save last checkpoint per steps
44
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
@@ -0,0 +1,46 @@
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN # dataset name
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 15
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup steps
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
+
20
+ model:
21
+ name: F5TTS_Base # model name
22
+ tokenizer: pinyin # tokenizer type
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
24
+ arch:
25
+ dim: 1024
26
+ depth: 22
27
+ heads: 16
28
+ ff_mult: 2
29
+ text_dim: 512
30
+ conv_layers: 4
31
+ mel_spec:
32
+ target_sample_rate: 24000
33
+ n_mel_channels: 100
34
+ hop_length: 256
35
+ win_length: 1024
36
+ n_fft: 1024
37
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
38
+ vocoder:
39
+ is_local: False # use local offline ckpt or not
40
+ local_path: None # local vocoder path
41
+
42
+ ckpts:
43
+ logger: wandb # wandb | tensorboard | None
44
+ save_per_updates: 50000 # save checkpoint per steps
45
+ last_per_steps: 5000 # save last checkpoint per steps
46
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
@@ -0,0 +1,46 @@
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 15
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup steps
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
+
20
+ model:
21
+ name: F5TTS_Small
22
+ tokenizer: pinyin
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
24
+ arch:
25
+ dim: 768
26
+ depth: 18
27
+ heads: 12
28
+ ff_mult: 2
29
+ text_dim: 512
30
+ conv_layers: 4
31
+ mel_spec:
32
+ target_sample_rate: 24000
33
+ n_mel_channels: 100
34
+ hop_length: 256
35
+ win_length: 1024
36
+ n_fft: 1024
37
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
38
+ vocoder:
39
+ is_local: False # use local offline ckpt or not
40
+ local_path: None # local vocoder path
41
+
42
+ ckpts:
43
+ logger: wandb # wandb | tensorboard | None
44
+ save_per_updates: 50000 # save checkpoint per steps
45
+ last_per_steps: 5000 # save last checkpoint per steps
46
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
@@ -0,0 +1,49 @@
1
+
2
+ # Evaluation
3
+
4
+ Install packages for evaluation:
5
+
6
+ ```bash
7
+ pip install -e .[eval]
8
+ ```
9
+
10
+ ## Generating Samples for Evaluation
11
+
12
+ ### Prepare Test Datasets
13
+
14
+ 1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
15
+ 2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
16
+ 3. Unzip the downloaded datasets and place them in the `data/` directory.
17
+ 4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py`
18
+ 5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
19
+
20
+ ### Batch Inference for Test Set
21
+
22
+ To run batch inference for evaluations, execute the following commands:
23
+
24
+ ```bash
25
+ # batch inference for evaluations
26
+ accelerate config # if not set before
27
+ bash src/f5_tts/eval/eval_infer_batch.sh
28
+ ```
29
+
30
+ ## Objective Evaluation on Generated Results
31
+
32
+ ### Download Evaluation Model Checkpoints
33
+
34
+ 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
35
+ 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
36
+ 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
37
+
38
+ Then update in the following scripts with the paths you put evaluation model ckpts to.
39
+
40
+ ### Objective Evaluation
41
+
42
+ Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
43
+ ```bash
44
+ # Evaluation for Seed-TTS test set
45
+ python src/f5_tts/eval/eval_seedtts_testset.py --gen_wav_dir <GEN_WAVE_DIR>
46
+
47
+ # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
48
+ python src/f5_tts/eval/eval_librispeech_test_clean.py --gen_wav_dir <GEN_WAVE_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
49
+ ```
@@ -0,0 +1,330 @@
1
+ # just for speaker similarity evaluation, third-party code
2
+
3
+ # From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
4
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ """ Res2Conv1d + BatchNorm1d + ReLU
13
+ """
14
+
15
+
16
+ class Res2Conv1dReluBn(nn.Module):
17
+ """
18
+ in_channels == out_channels == channels
19
+ """
20
+
21
+ def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
22
+ super().__init__()
23
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
24
+ self.scale = scale
25
+ self.width = channels // scale
26
+ self.nums = scale if scale == 1 else scale - 1
27
+
28
+ self.convs = []
29
+ self.bns = []
30
+ for i in range(self.nums):
31
+ self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
32
+ self.bns.append(nn.BatchNorm1d(self.width))
33
+ self.convs = nn.ModuleList(self.convs)
34
+ self.bns = nn.ModuleList(self.bns)
35
+
36
+ def forward(self, x):
37
+ out = []
38
+ spx = torch.split(x, self.width, 1)
39
+ for i in range(self.nums):
40
+ if i == 0:
41
+ sp = spx[i]
42
+ else:
43
+ sp = sp + spx[i]
44
+ # Order: conv -> relu -> bn
45
+ sp = self.convs[i](sp)
46
+ sp = self.bns[i](F.relu(sp))
47
+ out.append(sp)
48
+ if self.scale != 1:
49
+ out.append(spx[self.nums])
50
+ out = torch.cat(out, dim=1)
51
+
52
+ return out
53
+
54
+
55
+ """ Conv1d + BatchNorm1d + ReLU
56
+ """
57
+
58
+
59
+ class Conv1dReluBn(nn.Module):
60
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
61
+ super().__init__()
62
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
63
+ self.bn = nn.BatchNorm1d(out_channels)
64
+
65
+ def forward(self, x):
66
+ return self.bn(F.relu(self.conv(x)))
67
+
68
+
69
+ """ The SE connection of 1D case.
70
+ """
71
+
72
+
73
+ class SE_Connect(nn.Module):
74
+ def __init__(self, channels, se_bottleneck_dim=128):
75
+ super().__init__()
76
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
77
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
78
+
79
+ def forward(self, x):
80
+ out = x.mean(dim=2)
81
+ out = F.relu(self.linear1(out))
82
+ out = torch.sigmoid(self.linear2(out))
83
+ out = x * out.unsqueeze(2)
84
+
85
+ return out
86
+
87
+
88
+ """ SE-Res2Block of the ECAPA-TDNN architecture.
89
+ """
90
+
91
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
92
+ # return nn.Sequential(
93
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
94
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
95
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
96
+ # SE_Connect(channels)
97
+ # )
98
+
99
+
100
+ class SE_Res2Block(nn.Module):
101
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
102
+ super().__init__()
103
+ self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
104
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
105
+ self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
106
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
107
+
108
+ self.shortcut = None
109
+ if in_channels != out_channels:
110
+ self.shortcut = nn.Conv1d(
111
+ in_channels=in_channels,
112
+ out_channels=out_channels,
113
+ kernel_size=1,
114
+ )
115
+
116
+ def forward(self, x):
117
+ residual = x
118
+ if self.shortcut:
119
+ residual = self.shortcut(x)
120
+
121
+ x = self.Conv1dReluBn1(x)
122
+ x = self.Res2Conv1dReluBn(x)
123
+ x = self.Conv1dReluBn2(x)
124
+ x = self.SE_Connect(x)
125
+
126
+ return x + residual
127
+
128
+
129
+ """ Attentive weighted mean and standard deviation pooling.
130
+ """
131
+
132
+
133
+ class AttentiveStatsPool(nn.Module):
134
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
135
+ super().__init__()
136
+ self.global_context_att = global_context_att
137
+
138
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
139
+ if global_context_att:
140
+ self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
141
+ else:
142
+ self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
143
+ self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
144
+
145
+ def forward(self, x):
146
+ if self.global_context_att:
147
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
148
+ context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
149
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
150
+ else:
151
+ x_in = x
152
+
153
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
154
+ alpha = torch.tanh(self.linear1(x_in))
155
+ # alpha = F.relu(self.linear1(x_in))
156
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
157
+ mean = torch.sum(alpha * x, dim=2)
158
+ residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
159
+ std = torch.sqrt(residuals.clamp(min=1e-9))
160
+ return torch.cat([mean, std], dim=1)
161
+
162
+
163
+ class ECAPA_TDNN(nn.Module):
164
+ def __init__(
165
+ self,
166
+ feat_dim=80,
167
+ channels=512,
168
+ emb_dim=192,
169
+ global_context_att=False,
170
+ feat_type="wavlm_large",
171
+ sr=16000,
172
+ feature_selection="hidden_states",
173
+ update_extract=False,
174
+ config_path=None,
175
+ ):
176
+ super().__init__()
177
+
178
+ self.feat_type = feat_type
179
+ self.feature_selection = feature_selection
180
+ self.update_extract = update_extract
181
+ self.sr = sr
182
+
183
+ torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
184
+ try:
185
+ local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
186
+ self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
187
+ except: # noqa: E722
188
+ self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
189
+
190
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
191
+ self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
192
+ ):
193
+ self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
194
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
195
+ self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
196
+ ):
197
+ self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
198
+
199
+ self.feat_num = self.get_feat_num()
200
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
201
+
202
+ if feat_type != "fbank" and feat_type != "mfcc":
203
+ freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
204
+ for name, param in self.feature_extract.named_parameters():
205
+ for freeze_val in freeze_list:
206
+ if freeze_val in name:
207
+ param.requires_grad = False
208
+ break
209
+
210
+ if not self.update_extract:
211
+ for param in self.feature_extract.parameters():
212
+ param.requires_grad = False
213
+
214
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
215
+ # self.channels = [channels] * 4 + [channels * 3]
216
+ self.channels = [channels] * 4 + [1536]
217
+
218
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
219
+ self.layer2 = SE_Res2Block(
220
+ self.channels[0],
221
+ self.channels[1],
222
+ kernel_size=3,
223
+ stride=1,
224
+ padding=2,
225
+ dilation=2,
226
+ scale=8,
227
+ se_bottleneck_dim=128,
228
+ )
229
+ self.layer3 = SE_Res2Block(
230
+ self.channels[1],
231
+ self.channels[2],
232
+ kernel_size=3,
233
+ stride=1,
234
+ padding=3,
235
+ dilation=3,
236
+ scale=8,
237
+ se_bottleneck_dim=128,
238
+ )
239
+ self.layer4 = SE_Res2Block(
240
+ self.channels[2],
241
+ self.channels[3],
242
+ kernel_size=3,
243
+ stride=1,
244
+ padding=4,
245
+ dilation=4,
246
+ scale=8,
247
+ se_bottleneck_dim=128,
248
+ )
249
+
250
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
251
+ cat_channels = channels * 3
252
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
253
+ self.pooling = AttentiveStatsPool(
254
+ self.channels[-1], attention_channels=128, global_context_att=global_context_att
255
+ )
256
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
257
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
258
+
259
+ def get_feat_num(self):
260
+ self.feature_extract.eval()
261
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
262
+ with torch.no_grad():
263
+ features = self.feature_extract(wav)
264
+ select_feature = features[self.feature_selection]
265
+ if isinstance(select_feature, (list, tuple)):
266
+ return len(select_feature)
267
+ else:
268
+ return 1
269
+
270
+ def get_feat(self, x):
271
+ if self.update_extract:
272
+ x = self.feature_extract([sample for sample in x])
273
+ else:
274
+ with torch.no_grad():
275
+ if self.feat_type == "fbank" or self.feat_type == "mfcc":
276
+ x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
277
+ else:
278
+ x = self.feature_extract([sample for sample in x])
279
+
280
+ if self.feat_type == "fbank":
281
+ x = x.log()
282
+
283
+ if self.feat_type != "fbank" and self.feat_type != "mfcc":
284
+ x = x[self.feature_selection]
285
+ if isinstance(x, (list, tuple)):
286
+ x = torch.stack(x, dim=0)
287
+ else:
288
+ x = x.unsqueeze(0)
289
+ norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
290
+ x = (norm_weights * x).sum(dim=0)
291
+ x = torch.transpose(x, 1, 2) + 1e-6
292
+
293
+ x = self.instance_norm(x)
294
+ return x
295
+
296
+ def forward(self, x):
297
+ x = self.get_feat(x)
298
+
299
+ out1 = self.layer1(x)
300
+ out2 = self.layer2(out1)
301
+ out3 = self.layer3(out2)
302
+ out4 = self.layer4(out3)
303
+
304
+ out = torch.cat([out2, out3, out4], dim=1)
305
+ out = F.relu(self.conv(out))
306
+ out = self.bn(self.pooling(out))
307
+ out = self.linear(out)
308
+
309
+ return out
310
+
311
+
312
+ def ECAPA_TDNN_SMALL(
313
+ feat_dim,
314
+ emb_dim=256,
315
+ feat_type="wavlm_large",
316
+ sr=16000,
317
+ feature_selection="hidden_states",
318
+ update_extract=False,
319
+ config_path=None,
320
+ ):
321
+ return ECAPA_TDNN(
322
+ feat_dim=feat_dim,
323
+ channels=512,
324
+ emb_dim=emb_dim,
325
+ feat_type=feat_type,
326
+ sr=sr,
327
+ feature_selection=feature_selection,
328
+ update_extract=update_extract,
329
+ config_path=config_path,
330
+ )