xinference 0.16.3__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (373) hide show
  1. xinference/_compat.py +24 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +219 -77
  4. xinference/client/restful/restful_client.py +47 -2
  5. xinference/constants.py +1 -0
  6. xinference/core/chat_interface.py +6 -1
  7. xinference/core/model.py +124 -34
  8. xinference/core/supervisor.py +180 -12
  9. xinference/core/utils.py +73 -4
  10. xinference/core/worker.py +102 -4
  11. xinference/deploy/cmdline.py +3 -1
  12. xinference/deploy/test/test_cmdline.py +56 -0
  13. xinference/isolation.py +24 -0
  14. xinference/model/audio/__init__.py +12 -0
  15. xinference/model/audio/core.py +37 -4
  16. xinference/model/audio/cosyvoice.py +39 -6
  17. xinference/model/audio/f5tts.py +200 -0
  18. xinference/model/audio/f5tts_mlx.py +260 -0
  19. xinference/model/audio/fish_speech.py +70 -110
  20. xinference/model/audio/melotts.py +110 -0
  21. xinference/model/audio/model_spec.json +179 -3
  22. xinference/model/audio/model_spec_modelscope.json +27 -0
  23. xinference/model/audio/utils.py +32 -0
  24. xinference/model/audio/whisper.py +35 -10
  25. xinference/model/audio/whisper_mlx.py +208 -0
  26. xinference/model/embedding/core.py +322 -6
  27. xinference/model/embedding/model_spec.json +8 -1
  28. xinference/model/embedding/model_spec_modelscope.json +9 -1
  29. xinference/model/image/core.py +69 -1
  30. xinference/model/image/model_spec.json +145 -4
  31. xinference/model/image/model_spec_modelscope.json +150 -4
  32. xinference/model/image/stable_diffusion/core.py +50 -15
  33. xinference/model/llm/__init__.py +6 -2
  34. xinference/model/llm/llm_family.json +1055 -93
  35. xinference/model/llm/llm_family.py +15 -36
  36. xinference/model/llm/llm_family_modelscope.json +1031 -78
  37. xinference/model/llm/memory.py +1 -1
  38. xinference/model/llm/mlx/core.py +285 -47
  39. xinference/model/llm/sglang/core.py +2 -0
  40. xinference/model/llm/transformers/chatglm.py +9 -5
  41. xinference/model/llm/transformers/cogagent.py +272 -0
  42. xinference/model/llm/transformers/core.py +3 -0
  43. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  44. xinference/model/llm/transformers/qwen2_vl.py +12 -1
  45. xinference/model/llm/transformers/utils.py +16 -8
  46. xinference/model/llm/utils.py +55 -4
  47. xinference/model/llm/vllm/core.py +137 -12
  48. xinference/model/llm/vllm/xavier/__init__.py +13 -0
  49. xinference/model/llm/vllm/xavier/allocator.py +74 -0
  50. xinference/model/llm/vllm/xavier/block.py +111 -0
  51. xinference/model/llm/vllm/xavier/block_manager.py +71 -0
  52. xinference/model/llm/vllm/xavier/block_tracker.py +129 -0
  53. xinference/model/llm/vllm/xavier/collective.py +74 -0
  54. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  55. xinference/model/llm/vllm/xavier/engine.py +247 -0
  56. xinference/model/llm/vllm/xavier/executor.py +134 -0
  57. xinference/model/llm/vllm/xavier/scheduler.py +438 -0
  58. xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
  59. xinference/model/llm/vllm/xavier/test/test_xavier.py +147 -0
  60. xinference/model/llm/vllm/xavier/transfer.py +319 -0
  61. xinference/model/rerank/core.py +11 -4
  62. xinference/model/video/diffusers.py +14 -0
  63. xinference/model/video/model_spec.json +15 -0
  64. xinference/model/video/model_spec_modelscope.json +16 -0
  65. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  66. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  67. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  68. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  69. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  70. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  71. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  72. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  73. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  74. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  75. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  76. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  77. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  78. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  79. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  80. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  81. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  82. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  83. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  84. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  85. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  86. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  87. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  88. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  89. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  90. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  91. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  92. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  93. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  94. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  95. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  96. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  97. xinference/thirdparty/f5_tts/api.py +166 -0
  98. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  99. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  100. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  101. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  102. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  103. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  104. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  105. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  106. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  107. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  108. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  109. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  110. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  111. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  112. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  113. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  114. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  115. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  116. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  117. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  118. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  119. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  120. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  121. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  122. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  123. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  124. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  125. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  126. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  127. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  128. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  129. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  130. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  131. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  132. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  133. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  134. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  135. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  136. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  137. xinference/thirdparty/f5_tts/train/README.md +77 -0
  138. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  139. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  140. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  141. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  142. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  143. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  144. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  145. xinference/thirdparty/f5_tts/train/train.py +75 -0
  146. xinference/thirdparty/fish_speech/fish_speech/conversation.py +266 -1
  147. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  148. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  149. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  150. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  151. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  152. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +137 -29
  153. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  154. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  155. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +17 -11
  156. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  157. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  158. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  159. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  160. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  161. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  162. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +2 -2
  163. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +34 -18
  164. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  165. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  166. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  167. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  168. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  169. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  170. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  171. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  172. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  173. xinference/thirdparty/fish_speech/tools/llama/generate.py +484 -72
  174. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  175. xinference/thirdparty/fish_speech/tools/schema.py +170 -0
  176. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  177. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  178. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  179. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  180. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  181. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  182. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  183. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  184. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  185. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  186. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  187. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  188. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  189. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  190. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  191. xinference/thirdparty/matcha/utils/utils.py +2 -2
  192. xinference/thirdparty/melo/api.py +135 -0
  193. xinference/thirdparty/melo/app.py +61 -0
  194. xinference/thirdparty/melo/attentions.py +459 -0
  195. xinference/thirdparty/melo/commons.py +160 -0
  196. xinference/thirdparty/melo/configs/config.json +94 -0
  197. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  198. xinference/thirdparty/melo/data_utils.py +413 -0
  199. xinference/thirdparty/melo/download_utils.py +67 -0
  200. xinference/thirdparty/melo/infer.py +25 -0
  201. xinference/thirdparty/melo/init_downloads.py +14 -0
  202. xinference/thirdparty/melo/losses.py +58 -0
  203. xinference/thirdparty/melo/main.py +36 -0
  204. xinference/thirdparty/melo/mel_processing.py +174 -0
  205. xinference/thirdparty/melo/models.py +1030 -0
  206. xinference/thirdparty/melo/modules.py +598 -0
  207. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  208. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  209. xinference/thirdparty/melo/preprocess_text.py +135 -0
  210. xinference/thirdparty/melo/split_utils.py +174 -0
  211. xinference/thirdparty/melo/text/__init__.py +35 -0
  212. xinference/thirdparty/melo/text/chinese.py +199 -0
  213. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  214. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  215. xinference/thirdparty/melo/text/cleaner.py +36 -0
  216. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  217. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  218. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  219. xinference/thirdparty/melo/text/english.py +284 -0
  220. xinference/thirdparty/melo/text/english_bert.py +39 -0
  221. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  222. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  223. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  224. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  225. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  226. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  227. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  228. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  229. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  230. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  231. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  232. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  233. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  234. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  235. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  236. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  237. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  238. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  239. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  240. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  241. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  242. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  243. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  244. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  245. xinference/thirdparty/melo/text/french.py +94 -0
  246. xinference/thirdparty/melo/text/french_bert.py +39 -0
  247. xinference/thirdparty/melo/text/japanese.py +647 -0
  248. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  249. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  250. xinference/thirdparty/melo/text/korean.py +192 -0
  251. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  252. xinference/thirdparty/melo/text/spanish.py +122 -0
  253. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  254. xinference/thirdparty/melo/text/symbols.py +290 -0
  255. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  256. xinference/thirdparty/melo/train.py +635 -0
  257. xinference/thirdparty/melo/train.sh +19 -0
  258. xinference/thirdparty/melo/transforms.py +209 -0
  259. xinference/thirdparty/melo/utils.py +424 -0
  260. xinference/types.py +17 -1
  261. xinference/web/ui/build/asset-manifest.json +6 -6
  262. xinference/web/ui/build/index.html +1 -1
  263. xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
  264. xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
  265. xinference/web/ui/build/static/js/main.b0936c54.js +3 -0
  266. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  267. xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
  268. xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
  269. xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
  270. xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
  271. xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
  272. xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
  273. xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
  274. xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
  275. xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
  276. xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
  277. xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
  278. xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
  279. xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
  280. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
  281. xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
  282. xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
  283. xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
  284. xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
  285. xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
  286. xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
  287. xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
  288. xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
  289. xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
  290. xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
  291. xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
  292. xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
  293. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  294. xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
  295. xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
  296. xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
  297. xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
  298. xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
  299. xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
  300. xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
  301. xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
  302. xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
  303. xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
  304. xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
  305. xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
  306. xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
  307. xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
  308. xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
  309. xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
  310. xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
  311. xinference/web/ui/node_modules/.package-lock.json +67 -3
  312. xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
  313. xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
  314. xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
  315. xinference/web/ui/node_modules/i18next/package.json +129 -0
  316. xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
  317. xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
  318. xinference/web/ui/node_modules/react-i18next/package.json +162 -0
  319. xinference/web/ui/node_modules/void-elements/package.json +34 -0
  320. xinference/web/ui/package-lock.json +69 -3
  321. xinference/web/ui/package.json +2 -0
  322. xinference/web/ui/src/locales/en.json +186 -0
  323. xinference/web/ui/src/locales/zh.json +186 -0
  324. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/METADATA +96 -36
  325. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/RECORD +335 -146
  326. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/WHEEL +1 -1
  327. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  328. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  329. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  330. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  331. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  332. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  333. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  334. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  335. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  336. xinference/thirdparty/fish_speech/tools/api.py +0 -440
  337. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  338. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  339. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -34
  340. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  341. xinference/thirdparty/fish_speech/tools/webui.py +0 -485
  342. xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
  343. xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
  344. xinference/web/ui/build/static/js/main.2f269bb3.js +0 -3
  345. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  346. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
  347. xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
  348. xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
  349. xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
  350. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
  351. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
  352. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
  353. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
  354. xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
  355. xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
  356. xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
  357. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  358. xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
  359. xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
  360. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
  361. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
  362. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
  363. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
  364. xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
  365. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  366. /xinference/thirdparty/{cosyvoice/flow → melo}/__init__.py +0 -0
  367. /xinference/thirdparty/{cosyvoice/hifigan → melo/text/english_utils}/__init__.py +0 -0
  368. /xinference/thirdparty/{cosyvoice/llm → melo/text/es_phonemizer}/__init__.py +0 -0
  369. /xinference/thirdparty/{fish_speech/fish_speech/configs → melo/text/fr_phonemizer}/__init__.py +0 -0
  370. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  371. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/LICENSE +0 -0
  372. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/entry_points.txt +0 -0
  373. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,366 @@
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import os
5
+
6
+ import torch
7
+ import torchaudio
8
+ import wandb
9
+ from accelerate import Accelerator
10
+ from accelerate.utils import DistributedDataParallelKwargs
11
+ from ema_pytorch import EMA
12
+ from torch.optim import AdamW
13
+ from torch.optim.lr_scheduler import LinearLR, SequentialLR
14
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler
15
+ from tqdm import tqdm
16
+
17
+ from f5_tts.model import CFM
18
+ from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
19
+ from f5_tts.model.utils import default, exists
20
+
21
+ # trainer
22
+
23
+
24
+ class Trainer:
25
+ def __init__(
26
+ self,
27
+ model: CFM,
28
+ epochs,
29
+ learning_rate,
30
+ num_warmup_updates=20000,
31
+ save_per_updates=1000,
32
+ checkpoint_path=None,
33
+ batch_size=32,
34
+ batch_size_type: str = "sample",
35
+ max_samples=32,
36
+ grad_accumulation_steps=1,
37
+ max_grad_norm=1.0,
38
+ noise_scheduler: str | None = None,
39
+ duration_predictor: torch.nn.Module | None = None,
40
+ logger: str | None = "wandb", # "wandb" | "tensorboard" | None
41
+ wandb_project="test_e2-tts",
42
+ wandb_run_name="test_run",
43
+ wandb_resume_id: str = None,
44
+ log_samples: bool = False,
45
+ last_per_steps=None,
46
+ accelerate_kwargs: dict = dict(),
47
+ ema_kwargs: dict = dict(),
48
+ bnb_optimizer: bool = False,
49
+ mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
50
+ is_local_vocoder: bool = False, # use local path vocoder
51
+ local_vocoder_path: str = "", # local vocoder path
52
+ ):
53
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
54
+
55
+ if logger == "wandb" and not wandb.api.api_key:
56
+ logger = None
57
+ print(f"Using logger: {logger}")
58
+ self.log_samples = log_samples
59
+
60
+ self.accelerator = Accelerator(
61
+ log_with=logger if logger == "wandb" else None,
62
+ kwargs_handlers=[ddp_kwargs],
63
+ gradient_accumulation_steps=grad_accumulation_steps,
64
+ **accelerate_kwargs,
65
+ )
66
+
67
+ self.logger = logger
68
+ if self.logger == "wandb":
69
+ if exists(wandb_resume_id):
70
+ init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
71
+ else:
72
+ init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
73
+
74
+ self.accelerator.init_trackers(
75
+ project_name=wandb_project,
76
+ init_kwargs=init_kwargs,
77
+ config={
78
+ "epochs": epochs,
79
+ "learning_rate": learning_rate,
80
+ "num_warmup_updates": num_warmup_updates,
81
+ "batch_size": batch_size,
82
+ "batch_size_type": batch_size_type,
83
+ "max_samples": max_samples,
84
+ "grad_accumulation_steps": grad_accumulation_steps,
85
+ "max_grad_norm": max_grad_norm,
86
+ "gpus": self.accelerator.num_processes,
87
+ "noise_scheduler": noise_scheduler,
88
+ },
89
+ )
90
+
91
+ elif self.logger == "tensorboard":
92
+ from torch.utils.tensorboard import SummaryWriter
93
+
94
+ self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")
95
+
96
+ self.model = model
97
+
98
+ if self.is_main:
99
+ self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
100
+ self.ema_model.to(self.accelerator.device)
101
+
102
+ self.epochs = epochs
103
+ self.num_warmup_updates = num_warmup_updates
104
+ self.save_per_updates = save_per_updates
105
+ self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
106
+ self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
107
+
108
+ self.batch_size = batch_size
109
+ self.batch_size_type = batch_size_type
110
+ self.max_samples = max_samples
111
+ self.grad_accumulation_steps = grad_accumulation_steps
112
+ self.max_grad_norm = max_grad_norm
113
+
114
+ # mel vocoder config
115
+ self.vocoder_name = mel_spec_type
116
+ self.is_local_vocoder = is_local_vocoder
117
+ self.local_vocoder_path = local_vocoder_path
118
+
119
+ self.noise_scheduler = noise_scheduler
120
+
121
+ self.duration_predictor = duration_predictor
122
+
123
+ if bnb_optimizer:
124
+ import bitsandbytes as bnb
125
+
126
+ self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
127
+ else:
128
+ self.optimizer = AdamW(model.parameters(), lr=learning_rate)
129
+ self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
130
+
131
+ @property
132
+ def is_main(self):
133
+ return self.accelerator.is_main_process
134
+
135
+ def save_checkpoint(self, step, last=False):
136
+ self.accelerator.wait_for_everyone()
137
+ if self.is_main:
138
+ checkpoint = dict(
139
+ model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
140
+ optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
141
+ ema_model_state_dict=self.ema_model.state_dict(),
142
+ scheduler_state_dict=self.scheduler.state_dict(),
143
+ step=step,
144
+ )
145
+ if not os.path.exists(self.checkpoint_path):
146
+ os.makedirs(self.checkpoint_path)
147
+ if last:
148
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
149
+ print(f"Saved last checkpoint at step {step}")
150
+ else:
151
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
152
+
153
+ def load_checkpoint(self):
154
+ if (
155
+ not exists(self.checkpoint_path)
156
+ or not os.path.exists(self.checkpoint_path)
157
+ or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path))
158
+ ):
159
+ return 0
160
+
161
+ self.accelerator.wait_for_everyone()
162
+ if "model_last.pt" in os.listdir(self.checkpoint_path):
163
+ latest_checkpoint = "model_last.pt"
164
+ else:
165
+ latest_checkpoint = sorted(
166
+ [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
167
+ key=lambda x: int("".join(filter(str.isdigit, x))),
168
+ )[-1]
169
+ # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
170
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
171
+
172
+ # patch for backward compatibility, 305e3ea
173
+ for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
174
+ if key in checkpoint["ema_model_state_dict"]:
175
+ del checkpoint["ema_model_state_dict"][key]
176
+
177
+ if self.is_main:
178
+ self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
179
+
180
+ if "step" in checkpoint:
181
+ # patch for backward compatibility, 305e3ea
182
+ for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
183
+ if key in checkpoint["model_state_dict"]:
184
+ del checkpoint["model_state_dict"][key]
185
+
186
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
187
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
188
+ if self.scheduler:
189
+ self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
190
+ step = checkpoint["step"]
191
+ else:
192
+ checkpoint["model_state_dict"] = {
193
+ k.replace("ema_model.", ""): v
194
+ for k, v in checkpoint["ema_model_state_dict"].items()
195
+ if k not in ["initted", "step"]
196
+ }
197
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
198
+ step = 0
199
+
200
+ del checkpoint
201
+ gc.collect()
202
+ return step
203
+
204
+ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
205
+ if self.log_samples:
206
+ from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
207
+
208
+ vocoder = load_vocoder(
209
+ vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path
210
+ )
211
+ target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
212
+ log_samples_path = f"{self.checkpoint_path}/samples"
213
+ os.makedirs(log_samples_path, exist_ok=True)
214
+
215
+ if exists(resumable_with_seed):
216
+ generator = torch.Generator()
217
+ generator.manual_seed(resumable_with_seed)
218
+ else:
219
+ generator = None
220
+
221
+ if self.batch_size_type == "sample":
222
+ train_dataloader = DataLoader(
223
+ train_dataset,
224
+ collate_fn=collate_fn,
225
+ num_workers=num_workers,
226
+ pin_memory=True,
227
+ persistent_workers=True,
228
+ batch_size=self.batch_size,
229
+ shuffle=True,
230
+ generator=generator,
231
+ )
232
+ elif self.batch_size_type == "frame":
233
+ self.accelerator.even_batches = False
234
+ sampler = SequentialSampler(train_dataset)
235
+ batch_sampler = DynamicBatchSampler(
236
+ sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
237
+ )
238
+ train_dataloader = DataLoader(
239
+ train_dataset,
240
+ collate_fn=collate_fn,
241
+ num_workers=num_workers,
242
+ pin_memory=True,
243
+ persistent_workers=True,
244
+ batch_sampler=batch_sampler,
245
+ )
246
+ else:
247
+ raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
248
+
249
+ # accelerator.prepare() dispatches batches to devices;
250
+ # which means the length of dataloader calculated before, should consider the number of devices
251
+ warmup_steps = (
252
+ self.num_warmup_updates * self.accelerator.num_processes
253
+ ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
254
+ # otherwise by default with split_batches=False, warmup steps change with num_processes
255
+ total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
256
+ decay_steps = total_steps - warmup_steps
257
+ warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
258
+ decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
259
+ self.scheduler = SequentialLR(
260
+ self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
261
+ )
262
+ train_dataloader, self.scheduler = self.accelerator.prepare(
263
+ train_dataloader, self.scheduler
264
+ ) # actual steps = 1 gpu steps / gpus
265
+ start_step = self.load_checkpoint()
266
+ global_step = start_step
267
+
268
+ if exists(resumable_with_seed):
269
+ orig_epoch_step = len(train_dataloader)
270
+ skipped_epoch = int(start_step // orig_epoch_step)
271
+ skipped_batch = start_step % orig_epoch_step
272
+ skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
273
+ else:
274
+ skipped_epoch = 0
275
+
276
+ for epoch in range(skipped_epoch, self.epochs):
277
+ self.model.train()
278
+ if exists(resumable_with_seed) and epoch == skipped_epoch:
279
+ progress_bar = tqdm(
280
+ skipped_dataloader,
281
+ desc=f"Epoch {epoch+1}/{self.epochs}",
282
+ unit="step",
283
+ disable=not self.accelerator.is_local_main_process,
284
+ initial=skipped_batch,
285
+ total=orig_epoch_step,
286
+ )
287
+ else:
288
+ progress_bar = tqdm(
289
+ train_dataloader,
290
+ desc=f"Epoch {epoch+1}/{self.epochs}",
291
+ unit="step",
292
+ disable=not self.accelerator.is_local_main_process,
293
+ )
294
+
295
+ for batch in progress_bar:
296
+ with self.accelerator.accumulate(self.model):
297
+ text_inputs = batch["text"]
298
+ mel_spec = batch["mel"].permute(0, 2, 1)
299
+ mel_lengths = batch["mel_lengths"]
300
+
301
+ # TODO. add duration predictor training
302
+ if self.duration_predictor is not None and self.accelerator.is_local_main_process:
303
+ dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
304
+ self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
305
+
306
+ loss, cond, pred = self.model(
307
+ mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
308
+ )
309
+ self.accelerator.backward(loss)
310
+
311
+ if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
312
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
313
+
314
+ self.optimizer.step()
315
+ self.scheduler.step()
316
+ self.optimizer.zero_grad()
317
+
318
+ if self.is_main:
319
+ self.ema_model.update()
320
+
321
+ global_step += 1
322
+
323
+ if self.accelerator.is_local_main_process:
324
+ self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
325
+ if self.logger == "tensorboard":
326
+ self.writer.add_scalar("loss", loss.item(), global_step)
327
+ self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
328
+
329
+ progress_bar.set_postfix(step=str(global_step), loss=loss.item())
330
+
331
+ if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
332
+ self.save_checkpoint(global_step)
333
+
334
+ if self.log_samples and self.accelerator.is_local_main_process:
335
+ ref_audio_len = mel_lengths[0]
336
+ infer_text = [
337
+ text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
338
+ ]
339
+ with torch.inference_mode():
340
+ generated, _ = self.accelerator.unwrap_model(self.model).sample(
341
+ cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
342
+ text=infer_text,
343
+ duration=ref_audio_len * 2,
344
+ steps=nfe_step,
345
+ cfg_strength=cfg_strength,
346
+ sway_sampling_coef=sway_sampling_coef,
347
+ )
348
+ generated = generated.to(torch.float32)
349
+ gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
350
+ ref_mel_spec = batch["mel"][0].unsqueeze(0)
351
+ if self.vocoder_name == "vocos":
352
+ gen_audio = vocoder.decode(gen_mel_spec).cpu()
353
+ ref_audio = vocoder.decode(ref_mel_spec).cpu()
354
+ elif self.vocoder_name == "bigvgan":
355
+ gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
356
+ ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
357
+
358
+ torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
359
+ torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
360
+
361
+ if global_step % self.last_per_steps == 0:
362
+ self.save_checkpoint(global_step, last=True)
363
+
364
+ self.save_checkpoint(global_step, last=True)
365
+
366
+ self.accelerator.end_training()
@@ -0,0 +1,185 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import random
5
+ from collections import defaultdict
6
+ from importlib.resources import files
7
+
8
+ import torch
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+ import jieba
12
+ from pypinyin import lazy_pinyin, Style
13
+
14
+
15
+ # seed everything
16
+
17
+
18
+ def seed_everything(seed=0):
19
+ random.seed(seed)
20
+ os.environ["PYTHONHASHSEED"] = str(seed)
21
+ torch.manual_seed(seed)
22
+ torch.cuda.manual_seed(seed)
23
+ torch.cuda.manual_seed_all(seed)
24
+ torch.backends.cudnn.deterministic = True
25
+ torch.backends.cudnn.benchmark = False
26
+
27
+
28
+ # helpers
29
+
30
+
31
+ def exists(v):
32
+ return v is not None
33
+
34
+
35
+ def default(v, d):
36
+ return v if exists(v) else d
37
+
38
+
39
+ # tensor helpers
40
+
41
+
42
+ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
43
+ if not exists(length):
44
+ length = t.amax()
45
+
46
+ seq = torch.arange(length, device=t.device)
47
+ return seq[None, :] < t[:, None]
48
+
49
+
50
+ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
51
+ max_seq_len = seq_len.max().item()
52
+ seq = torch.arange(max_seq_len, device=start.device).long()
53
+ start_mask = seq[None, :] >= start[:, None]
54
+ end_mask = seq[None, :] < end[:, None]
55
+ return start_mask & end_mask
56
+
57
+
58
+ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
59
+ lengths = (frac_lengths * seq_len).long()
60
+ max_start = seq_len - lengths
61
+
62
+ rand = torch.rand_like(frac_lengths)
63
+ start = (max_start * rand).long().clamp(min=0)
64
+ end = start + lengths
65
+
66
+ return mask_from_start_end_indices(seq_len, start, end)
67
+
68
+
69
+ def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
70
+ if not exists(mask):
71
+ return t.mean(dim=1)
72
+
73
+ t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
74
+ num = t.sum(dim=1)
75
+ den = mask.float().sum(dim=1)
76
+
77
+ return num / den.clamp(min=1.0)
78
+
79
+
80
+ # simple utf-8 tokenizer, since paper went character based
81
+ def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
82
+ list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
83
+ text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
84
+ return text
85
+
86
+
87
+ # char tokenizer, based on custom dataset's extracted .txt file
88
+ def list_str_to_idx(
89
+ text: list[str] | list[list[str]],
90
+ vocab_char_map: dict[str, int], # {char: idx}
91
+ padding_value=-1,
92
+ ) -> int["b nt"]: # noqa: F722
93
+ list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
94
+ text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
95
+ return text
96
+
97
+
98
+ # Get tokenizer
99
+
100
+
101
+ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
102
+ """
103
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
104
+ - "char" for char-wise tokenizer, need .txt vocab_file
105
+ - "byte" for utf-8 tokenizer
106
+ - "custom" if you're directly passing in a path to the vocab.txt you want to use
107
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
108
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
109
+ - if use "byte", set to 256 (unicode byte range)
110
+ """
111
+ if tokenizer in ["pinyin", "char"]:
112
+ tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
113
+ with open(tokenizer_path, "r", encoding="utf-8") as f:
114
+ vocab_char_map = {}
115
+ for i, char in enumerate(f):
116
+ vocab_char_map[char[:-1]] = i
117
+ vocab_size = len(vocab_char_map)
118
+ assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
119
+
120
+ elif tokenizer == "byte":
121
+ vocab_char_map = None
122
+ vocab_size = 256
123
+
124
+ elif tokenizer == "custom":
125
+ with open(dataset_name, "r", encoding="utf-8") as f:
126
+ vocab_char_map = {}
127
+ for i, char in enumerate(f):
128
+ vocab_char_map[char[:-1]] = i
129
+ vocab_size = len(vocab_char_map)
130
+
131
+ return vocab_char_map, vocab_size
132
+
133
+
134
+ # convert char to pinyin
135
+
136
+
137
+ def convert_char_to_pinyin(text_list, polyphone=True):
138
+ final_text_list = []
139
+ god_knows_why_en_testset_contains_zh_quote = str.maketrans(
140
+ {"“": '"', "”": '"', "‘": "'", "’": "'"}
141
+ ) # in case librispeech (orig no-pc) test-clean
142
+ custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
143
+ for text in text_list:
144
+ char_list = []
145
+ text = text.translate(god_knows_why_en_testset_contains_zh_quote)
146
+ text = text.translate(custom_trans)
147
+ for seg in jieba.cut(text):
148
+ seg_byte_len = len(bytes(seg, "UTF-8"))
149
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
150
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
151
+ char_list.append(" ")
152
+ char_list.extend(seg)
153
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
154
+ seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
155
+ for c in seg:
156
+ if c not in "。,、;:?!《》【】—…":
157
+ char_list.append(" ")
158
+ char_list.append(c)
159
+ else: # if mixed chinese characters, alphabets and symbols
160
+ for c in seg:
161
+ if ord(c) < 256:
162
+ char_list.extend(c)
163
+ else:
164
+ if c not in "。,、;:?!《》【】—…":
165
+ char_list.append(" ")
166
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
167
+ else: # if is zh punc
168
+ char_list.append(c)
169
+ final_text_list.append(char_list)
170
+
171
+ return final_text_list
172
+
173
+
174
+ # filter func for dirty data with many repetitions
175
+
176
+
177
+ def repetition_found(text, length=2, tolerance=10):
178
+ pattern_count = defaultdict(int)
179
+ for i in range(len(text) - length + 1):
180
+ pattern = text[i : i + length]
181
+ pattern_count[pattern] += 1
182
+ for pattern, count in pattern_count.items():
183
+ if count > tolerance:
184
+ return True
185
+ return False
@@ -0,0 +1,33 @@
1
+ """ADAPTIVE BATCH SIZE"""
2
+
3
+ print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
4
+ print(" -> least padding, gather wavs with accumulated frames in a batch\n")
5
+
6
+ # data
7
+ total_hours = 95282
8
+ mel_hop_length = 256
9
+ mel_sampling_rate = 24000
10
+
11
+ # target
12
+ wanted_max_updates = 1000000
13
+
14
+ # train params
15
+ gpus = 8
16
+ frames_per_gpu = 38400 # 8 * 38400 = 307200
17
+ grad_accum = 1
18
+
19
+ # intermediate
20
+ mini_batch_frames = frames_per_gpu * grad_accum * gpus
21
+ mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
22
+ updates_per_epoch = total_hours / mini_batch_hours
23
+ steps_per_epoch = updates_per_epoch * grad_accum
24
+
25
+ # result
26
+ epochs = wanted_max_updates / updates_per_epoch
27
+ print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
28
+ print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
29
+ print(f" or approx. 0/{steps_per_epoch:.0f} steps")
30
+
31
+ # others
32
+ print(f"total {total_hours:.0f} hours")
33
+ print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")
@@ -0,0 +1,39 @@
1
+ import sys
2
+ import os
3
+
4
+ sys.path.append(os.getcwd())
5
+
6
+ from f5_tts.model import CFM, DiT
7
+
8
+ import torch
9
+ import thop
10
+
11
+
12
+ """ ~155M """
13
+ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
14
+ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
15
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
16
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
17
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
18
+ # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
19
+
20
+ """ ~335M """
21
+ # FLOPs: 622.1 G, Params: 333.2 M
22
+ # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
23
+ # FLOPs: 363.4 G, Params: 335.8 M
24
+ transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
25
+
26
+
27
+ model = CFM(transformer=transformer)
28
+ target_sample_rate = 24000
29
+ n_mel_channels = 100
30
+ hop_length = 256
31
+ duration = 20
32
+ frame_length = int(duration * target_sample_rate / hop_length)
33
+ text_length = 150
34
+
35
+ flops, params = thop.profile(
36
+ model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
37
+ )
38
+ print(f"FLOPs: {flops / 1e9} G")
39
+ print(f"Params: {params / 1e6} M")