xinference 1.0.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (343) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +77 -71
  4. xinference/core/chat_interface.py +6 -1
  5. xinference/core/model.py +79 -19
  6. xinference/core/supervisor.py +172 -10
  7. xinference/core/utils.py +12 -8
  8. xinference/core/worker.py +102 -4
  9. xinference/deploy/cmdline.py +3 -1
  10. xinference/deploy/test/test_cmdline.py +56 -0
  11. xinference/isolation.py +24 -0
  12. xinference/model/audio/core.py +16 -0
  13. xinference/model/audio/cosyvoice.py +39 -6
  14. xinference/model/audio/f5tts.py +200 -0
  15. xinference/model/audio/f5tts_mlx.py +260 -0
  16. xinference/model/audio/fish_speech.py +36 -111
  17. xinference/model/audio/melotts.py +110 -0
  18. xinference/model/audio/model_spec.json +99 -3
  19. xinference/model/audio/model_spec_modelscope.json +27 -0
  20. xinference/model/audio/utils.py +32 -0
  21. xinference/model/audio/whisper.py +35 -10
  22. xinference/model/embedding/core.py +203 -142
  23. xinference/model/embedding/model_spec.json +7 -0
  24. xinference/model/embedding/model_spec_modelscope.json +8 -0
  25. xinference/model/image/core.py +69 -1
  26. xinference/model/image/model_spec.json +145 -4
  27. xinference/model/image/model_spec_modelscope.json +150 -4
  28. xinference/model/image/stable_diffusion/core.py +45 -13
  29. xinference/model/llm/__init__.py +4 -2
  30. xinference/model/llm/llm_family.json +536 -53
  31. xinference/model/llm/llm_family.py +15 -36
  32. xinference/model/llm/llm_family_modelscope.json +454 -20
  33. xinference/model/llm/memory.py +1 -1
  34. xinference/model/llm/mlx/core.py +248 -52
  35. xinference/model/llm/sglang/core.py +1 -0
  36. xinference/model/llm/transformers/chatglm.py +9 -5
  37. xinference/model/llm/transformers/cogagent.py +272 -0
  38. xinference/model/llm/transformers/core.py +2 -0
  39. xinference/model/llm/transformers/qwen2_vl.py +12 -1
  40. xinference/model/llm/transformers/utils.py +16 -8
  41. xinference/model/llm/utils.py +36 -4
  42. xinference/model/llm/vllm/core.py +53 -10
  43. xinference/model/llm/vllm/xavier/__init__.py +13 -0
  44. xinference/model/llm/vllm/xavier/allocator.py +74 -0
  45. xinference/model/llm/vllm/xavier/block.py +111 -0
  46. xinference/model/llm/vllm/xavier/block_manager.py +71 -0
  47. xinference/model/llm/vllm/xavier/block_tracker.py +129 -0
  48. xinference/model/llm/vllm/xavier/collective.py +74 -0
  49. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  50. xinference/model/llm/vllm/xavier/engine.py +247 -0
  51. xinference/model/llm/vllm/xavier/executor.py +134 -0
  52. xinference/model/llm/vllm/xavier/scheduler.py +438 -0
  53. xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
  54. xinference/model/llm/vllm/xavier/test/test_xavier.py +147 -0
  55. xinference/model/llm/vllm/xavier/transfer.py +319 -0
  56. xinference/model/video/diffusers.py +14 -0
  57. xinference/model/video/model_spec.json +15 -0
  58. xinference/model/video/model_spec_modelscope.json +16 -0
  59. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  60. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  61. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  62. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  63. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  64. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  65. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  66. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  67. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  68. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  69. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  70. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  71. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  72. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  73. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  74. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  75. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  76. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  77. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  78. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  79. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  80. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  81. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  82. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  83. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  84. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  85. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  86. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  87. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  88. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  89. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  90. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  91. xinference/thirdparty/f5_tts/api.py +166 -0
  92. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  93. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  94. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  95. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  96. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  97. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  98. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  99. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  100. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  101. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  102. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  103. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  104. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  105. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  106. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  107. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  108. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  109. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  110. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  111. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  112. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  113. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  114. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  115. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  116. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  117. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  118. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  119. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  120. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  121. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  122. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  123. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  124. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  125. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  126. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  127. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  128. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  129. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  130. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  131. xinference/thirdparty/f5_tts/train/README.md +77 -0
  132. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  133. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  134. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  135. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  136. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  137. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  138. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  139. xinference/thirdparty/f5_tts/train/train.py +75 -0
  140. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  141. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  142. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  143. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  144. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  145. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  146. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  147. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  148. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  149. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  150. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  151. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  152. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  153. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  154. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  155. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  156. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  157. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  158. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  159. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  160. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  161. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  162. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  163. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  164. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  165. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  166. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  167. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  168. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  169. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  170. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  171. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  172. xinference/thirdparty/matcha/utils/utils.py +2 -2
  173. xinference/thirdparty/melo/api.py +135 -0
  174. xinference/thirdparty/melo/app.py +61 -0
  175. xinference/thirdparty/melo/attentions.py +459 -0
  176. xinference/thirdparty/melo/commons.py +160 -0
  177. xinference/thirdparty/melo/configs/config.json +94 -0
  178. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  179. xinference/thirdparty/melo/data_utils.py +413 -0
  180. xinference/thirdparty/melo/download_utils.py +67 -0
  181. xinference/thirdparty/melo/infer.py +25 -0
  182. xinference/thirdparty/melo/init_downloads.py +14 -0
  183. xinference/thirdparty/melo/losses.py +58 -0
  184. xinference/thirdparty/melo/main.py +36 -0
  185. xinference/thirdparty/melo/mel_processing.py +174 -0
  186. xinference/thirdparty/melo/models.py +1030 -0
  187. xinference/thirdparty/melo/modules.py +598 -0
  188. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  189. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  190. xinference/thirdparty/melo/preprocess_text.py +135 -0
  191. xinference/thirdparty/melo/split_utils.py +174 -0
  192. xinference/thirdparty/melo/text/__init__.py +35 -0
  193. xinference/thirdparty/melo/text/chinese.py +199 -0
  194. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  195. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  196. xinference/thirdparty/melo/text/cleaner.py +36 -0
  197. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  198. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  199. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  200. xinference/thirdparty/melo/text/english.py +284 -0
  201. xinference/thirdparty/melo/text/english_bert.py +39 -0
  202. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  203. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  204. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  205. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  206. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  207. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  208. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  209. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  210. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  211. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  212. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  213. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  214. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  215. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  216. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  217. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  218. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  219. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  220. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  221. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  222. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  223. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  224. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  225. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  226. xinference/thirdparty/melo/text/french.py +94 -0
  227. xinference/thirdparty/melo/text/french_bert.py +39 -0
  228. xinference/thirdparty/melo/text/japanese.py +647 -0
  229. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  230. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  231. xinference/thirdparty/melo/text/korean.py +192 -0
  232. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  233. xinference/thirdparty/melo/text/spanish.py +122 -0
  234. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  235. xinference/thirdparty/melo/text/symbols.py +290 -0
  236. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  237. xinference/thirdparty/melo/train.py +635 -0
  238. xinference/thirdparty/melo/train.sh +19 -0
  239. xinference/thirdparty/melo/transforms.py +209 -0
  240. xinference/thirdparty/melo/utils.py +424 -0
  241. xinference/types.py +15 -0
  242. xinference/web/ui/build/asset-manifest.json +6 -6
  243. xinference/web/ui/build/index.html +1 -1
  244. xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
  245. xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
  246. xinference/web/ui/build/static/js/main.b0936c54.js +3 -0
  247. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  248. xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
  249. xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
  250. xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
  251. xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
  252. xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
  253. xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
  254. xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
  255. xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
  256. xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
  257. xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
  258. xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
  259. xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
  260. xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
  261. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
  262. xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
  263. xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
  264. xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
  265. xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
  266. xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
  267. xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
  268. xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
  269. xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
  270. xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
  271. xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
  272. xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
  273. xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
  274. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  275. xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
  276. xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
  277. xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
  278. xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
  279. xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
  280. xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
  281. xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
  282. xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
  283. xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
  284. xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
  285. xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
  286. xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
  287. xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
  288. xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
  289. xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
  290. xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
  291. xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
  292. xinference/web/ui/node_modules/.package-lock.json +67 -3
  293. xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
  294. xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
  295. xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
  296. xinference/web/ui/node_modules/i18next/package.json +129 -0
  297. xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
  298. xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
  299. xinference/web/ui/node_modules/react-i18next/package.json +162 -0
  300. xinference/web/ui/node_modules/void-elements/package.json +34 -0
  301. xinference/web/ui/package-lock.json +69 -3
  302. xinference/web/ui/package.json +2 -0
  303. xinference/web/ui/src/locales/en.json +186 -0
  304. xinference/web/ui/src/locales/zh.json +186 -0
  305. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/METADATA +68 -32
  306. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/RECORD +316 -122
  307. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  308. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  309. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  310. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  311. xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
  312. xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
  313. xinference/web/ui/build/static/js/main.2f269bb3.js +0 -3
  314. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  315. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
  316. xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
  317. xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
  318. xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
  319. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
  320. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
  321. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
  322. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
  323. xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
  324. xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
  325. xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
  326. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  327. xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
  328. xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
  329. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
  330. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
  331. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
  332. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
  333. xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
  334. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  335. /xinference/thirdparty/{cosyvoice/flow → melo}/__init__.py +0 -0
  336. /xinference/thirdparty/{cosyvoice/hifigan → melo/text/english_utils}/__init__.py +0 -0
  337. /xinference/thirdparty/{cosyvoice/llm → melo/text/es_phonemizer}/__init__.py +0 -0
  338. /xinference/thirdparty/{fish_speech/tools → melo/text/fr_phonemizer}/__init__.py +0 -0
  339. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  340. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/LICENSE +0 -0
  341. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/WHEEL +0 -0
  342. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/entry_points.txt +0 -0
  343. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,635 @@
1
+ # flake8: noqa: E402
2
+
3
+ import os
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from torch.utils.data import DataLoader
7
+ from torch.utils.tensorboard import SummaryWriter
8
+ import torch.distributed as dist
9
+ from torch.nn.parallel import DistributedDataParallel as DDP
10
+ from torch.cuda.amp import autocast, GradScaler
11
+ from tqdm import tqdm
12
+ import logging
13
+
14
+ logging.getLogger("numba").setLevel(logging.WARNING)
15
+ import commons
16
+ import utils
17
+ from data_utils import (
18
+ TextAudioSpeakerLoader,
19
+ TextAudioSpeakerCollate,
20
+ DistributedBucketSampler,
21
+ )
22
+ from models import (
23
+ SynthesizerTrn,
24
+ MultiPeriodDiscriminator,
25
+ DurationDiscriminator,
26
+ )
27
+ from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
28
+ from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
29
+ from text.symbols import symbols
30
+ from melo.download_utils import load_pretrain_model
31
+
32
+ torch.backends.cuda.matmul.allow_tf32 = True
33
+ torch.backends.cudnn.allow_tf32 = (
34
+ True # If encontered training problem,please try to disable TF32.
35
+ )
36
+ torch.set_float32_matmul_precision("medium")
37
+
38
+
39
+ torch.backends.cudnn.benchmark = True
40
+ torch.backends.cuda.sdp_kernel("flash")
41
+ torch.backends.cuda.enable_flash_sdp(True)
42
+ # torch.backends.cuda.enable_mem_efficient_sdp(
43
+ # True
44
+ # ) # Not available if torch version is lower than 2.0
45
+ torch.backends.cuda.enable_math_sdp(True)
46
+ global_step = 0
47
+
48
+
49
+ def run():
50
+ hps = utils.get_hparams()
51
+ local_rank = int(os.environ["LOCAL_RANK"])
52
+ dist.init_process_group(
53
+ backend="gloo",
54
+ init_method="env://", # Due to some training problem,we proposed to use gloo instead of nccl.
55
+ rank=local_rank,
56
+ ) # Use torchrun instead of mp.spawn
57
+ rank = dist.get_rank()
58
+ n_gpus = dist.get_world_size()
59
+
60
+ torch.manual_seed(hps.train.seed)
61
+ torch.cuda.set_device(rank)
62
+ global global_step
63
+ if rank == 0:
64
+ logger = utils.get_logger(hps.model_dir)
65
+ logger.info(hps)
66
+ utils.check_git_hash(hps.model_dir)
67
+ writer = SummaryWriter(log_dir=hps.model_dir)
68
+ writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
69
+ train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
70
+ train_sampler = DistributedBucketSampler(
71
+ train_dataset,
72
+ hps.train.batch_size,
73
+ [32, 300, 400, 500, 600, 700, 800, 900, 1000],
74
+ num_replicas=n_gpus,
75
+ rank=rank,
76
+ shuffle=True,
77
+ )
78
+ collate_fn = TextAudioSpeakerCollate()
79
+ train_loader = DataLoader(
80
+ train_dataset,
81
+ num_workers=16,
82
+ shuffle=False,
83
+ pin_memory=True,
84
+ collate_fn=collate_fn,
85
+ batch_sampler=train_sampler,
86
+ persistent_workers=True,
87
+ prefetch_factor=4,
88
+ ) # DataLoader config could be adjusted.
89
+ if rank == 0:
90
+ eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)
91
+ eval_loader = DataLoader(
92
+ eval_dataset,
93
+ num_workers=0,
94
+ shuffle=False,
95
+ batch_size=1,
96
+ pin_memory=True,
97
+ drop_last=False,
98
+ collate_fn=collate_fn,
99
+ )
100
+ if (
101
+ "use_noise_scaled_mas" in hps.model.keys()
102
+ and hps.model.use_noise_scaled_mas is True
103
+ ):
104
+ print("Using noise scaled MAS for VITS2")
105
+ mas_noise_scale_initial = 0.01
106
+ noise_scale_delta = 2e-6
107
+ else:
108
+ print("Using normal MAS for VITS1")
109
+ mas_noise_scale_initial = 0.0
110
+ noise_scale_delta = 0.0
111
+ if (
112
+ "use_duration_discriminator" in hps.model.keys()
113
+ and hps.model.use_duration_discriminator is True
114
+ ):
115
+ print("Using duration discriminator for VITS2")
116
+ net_dur_disc = DurationDiscriminator(
117
+ hps.model.hidden_channels,
118
+ hps.model.hidden_channels,
119
+ 3,
120
+ 0.1,
121
+ gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
122
+ ).cuda(rank)
123
+ if (
124
+ "use_spk_conditioned_encoder" in hps.model.keys()
125
+ and hps.model.use_spk_conditioned_encoder is True
126
+ ):
127
+ if hps.data.n_speakers == 0:
128
+ raise ValueError(
129
+ "n_speakers must be > 0 when using spk conditioned encoder to train multi-speaker model"
130
+ )
131
+ else:
132
+ print("Using normal encoder for VITS1")
133
+
134
+ net_g = SynthesizerTrn(
135
+ len(symbols),
136
+ hps.data.filter_length // 2 + 1,
137
+ hps.train.segment_size // hps.data.hop_length,
138
+ n_speakers=hps.data.n_speakers,
139
+ mas_noise_scale_initial=mas_noise_scale_initial,
140
+ noise_scale_delta=noise_scale_delta,
141
+ **hps.model,
142
+ ).cuda(rank)
143
+
144
+ net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
145
+ optim_g = torch.optim.AdamW(
146
+ filter(lambda p: p.requires_grad, net_g.parameters()),
147
+ hps.train.learning_rate,
148
+ betas=hps.train.betas,
149
+ eps=hps.train.eps,
150
+ )
151
+ optim_d = torch.optim.AdamW(
152
+ net_d.parameters(),
153
+ hps.train.learning_rate,
154
+ betas=hps.train.betas,
155
+ eps=hps.train.eps,
156
+ )
157
+ if net_dur_disc is not None:
158
+ optim_dur_disc = torch.optim.AdamW(
159
+ net_dur_disc.parameters(),
160
+ hps.train.learning_rate,
161
+ betas=hps.train.betas,
162
+ eps=hps.train.eps,
163
+ )
164
+ else:
165
+ optim_dur_disc = None
166
+ net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
167
+ net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
168
+
169
+ pretrain_G, pretrain_D, pretrain_dur = load_pretrain_model()
170
+ hps.pretrain_G = hps.pretrain_G or pretrain_G
171
+ hps.pretrain_D = hps.pretrain_D or pretrain_D
172
+ hps.pretrain_dur = hps.pretrain_dur or pretrain_dur
173
+
174
+ if hps.pretrain_G:
175
+ utils.load_checkpoint(
176
+ hps.pretrain_G,
177
+ net_g,
178
+ None,
179
+ skip_optimizer=True
180
+ )
181
+ if hps.pretrain_D:
182
+ utils.load_checkpoint(
183
+ hps.pretrain_D,
184
+ net_d,
185
+ None,
186
+ skip_optimizer=True
187
+ )
188
+
189
+
190
+ if net_dur_disc is not None:
191
+ net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True)
192
+ if hps.pretrain_dur:
193
+ utils.load_checkpoint(
194
+ hps.pretrain_dur,
195
+ net_dur_disc,
196
+ None,
197
+ skip_optimizer=True
198
+ )
199
+
200
+ try:
201
+ if net_dur_disc is not None:
202
+ _, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
203
+ utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
204
+ net_dur_disc,
205
+ optim_dur_disc,
206
+ skip_optimizer=hps.train.skip_optimizer
207
+ if "skip_optimizer" in hps.train
208
+ else True,
209
+ )
210
+ _, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
211
+ utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
212
+ net_g,
213
+ optim_g,
214
+ skip_optimizer=hps.train.skip_optimizer
215
+ if "skip_optimizer" in hps.train
216
+ else True,
217
+ )
218
+ _, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
219
+ utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
220
+ net_d,
221
+ optim_d,
222
+ skip_optimizer=hps.train.skip_optimizer
223
+ if "skip_optimizer" in hps.train
224
+ else True,
225
+ )
226
+ if not optim_g.param_groups[0].get("initial_lr"):
227
+ optim_g.param_groups[0]["initial_lr"] = g_resume_lr
228
+ if not optim_d.param_groups[0].get("initial_lr"):
229
+ optim_d.param_groups[0]["initial_lr"] = d_resume_lr
230
+ if not optim_dur_disc.param_groups[0].get("initial_lr"):
231
+ optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
232
+
233
+ epoch_str = max(epoch_str, 1)
234
+ global_step = (epoch_str - 1) * len(train_loader)
235
+ except Exception as e:
236
+ print(e)
237
+ epoch_str = 1
238
+ global_step = 0
239
+
240
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
241
+ optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
242
+ )
243
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
244
+ optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
245
+ )
246
+ if net_dur_disc is not None:
247
+ scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
248
+ optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
249
+ )
250
+ else:
251
+ scheduler_dur_disc = None
252
+ scaler = GradScaler(enabled=hps.train.fp16_run)
253
+
254
+ for epoch in range(epoch_str, hps.train.epochs + 1):
255
+ try:
256
+ if rank == 0:
257
+ train_and_evaluate(
258
+ rank,
259
+ epoch,
260
+ hps,
261
+ [net_g, net_d, net_dur_disc],
262
+ [optim_g, optim_d, optim_dur_disc],
263
+ [scheduler_g, scheduler_d, scheduler_dur_disc],
264
+ scaler,
265
+ [train_loader, eval_loader],
266
+ logger,
267
+ [writer, writer_eval],
268
+ )
269
+ else:
270
+ train_and_evaluate(
271
+ rank,
272
+ epoch,
273
+ hps,
274
+ [net_g, net_d, net_dur_disc],
275
+ [optim_g, optim_d, optim_dur_disc],
276
+ [scheduler_g, scheduler_d, scheduler_dur_disc],
277
+ scaler,
278
+ [train_loader, None],
279
+ None,
280
+ None,
281
+ )
282
+ except Exception as e:
283
+ print(e)
284
+ torch.cuda.empty_cache()
285
+ scheduler_g.step()
286
+ scheduler_d.step()
287
+ if net_dur_disc is not None:
288
+ scheduler_dur_disc.step()
289
+
290
+
291
+ def train_and_evaluate(
292
+ rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
293
+ ):
294
+ net_g, net_d, net_dur_disc = nets
295
+ optim_g, optim_d, optim_dur_disc = optims
296
+ scheduler_g, scheduler_d, scheduler_dur_disc = schedulers
297
+ train_loader, eval_loader = loaders
298
+ if writers is not None:
299
+ writer, writer_eval = writers
300
+
301
+ train_loader.batch_sampler.set_epoch(epoch)
302
+ global global_step
303
+
304
+ net_g.train()
305
+ net_d.train()
306
+ if net_dur_disc is not None:
307
+ net_dur_disc.train()
308
+ for batch_idx, (
309
+ x,
310
+ x_lengths,
311
+ spec,
312
+ spec_lengths,
313
+ y,
314
+ y_lengths,
315
+ speakers,
316
+ tone,
317
+ language,
318
+ bert,
319
+ ja_bert,
320
+ ) in enumerate(tqdm(train_loader)):
321
+ if net_g.module.use_noise_scaled_mas:
322
+ current_mas_noise_scale = (
323
+ net_g.module.mas_noise_scale_initial
324
+ - net_g.module.noise_scale_delta * global_step
325
+ )
326
+ net_g.module.current_mas_noise_scale = max(current_mas_noise_scale, 0.0)
327
+ x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(
328
+ rank, non_blocking=True
329
+ )
330
+ spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
331
+ rank, non_blocking=True
332
+ )
333
+ y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
334
+ rank, non_blocking=True
335
+ )
336
+ speakers = speakers.cuda(rank, non_blocking=True)
337
+ tone = tone.cuda(rank, non_blocking=True)
338
+ language = language.cuda(rank, non_blocking=True)
339
+ bert = bert.cuda(rank, non_blocking=True)
340
+ ja_bert = ja_bert.cuda(rank, non_blocking=True)
341
+
342
+ with autocast(enabled=hps.train.fp16_run):
343
+ (
344
+ y_hat,
345
+ l_length,
346
+ attn,
347
+ ids_slice,
348
+ x_mask,
349
+ z_mask,
350
+ (z, z_p, m_p, logs_p, m_q, logs_q),
351
+ (hidden_x, logw, logw_),
352
+ ) = net_g(
353
+ x,
354
+ x_lengths,
355
+ spec,
356
+ spec_lengths,
357
+ speakers,
358
+ tone,
359
+ language,
360
+ bert,
361
+ ja_bert,
362
+ )
363
+ mel = spec_to_mel_torch(
364
+ spec,
365
+ hps.data.filter_length,
366
+ hps.data.n_mel_channels,
367
+ hps.data.sampling_rate,
368
+ hps.data.mel_fmin,
369
+ hps.data.mel_fmax,
370
+ )
371
+ y_mel = commons.slice_segments(
372
+ mel, ids_slice, hps.train.segment_size // hps.data.hop_length
373
+ )
374
+ y_hat_mel = mel_spectrogram_torch(
375
+ y_hat.squeeze(1),
376
+ hps.data.filter_length,
377
+ hps.data.n_mel_channels,
378
+ hps.data.sampling_rate,
379
+ hps.data.hop_length,
380
+ hps.data.win_length,
381
+ hps.data.mel_fmin,
382
+ hps.data.mel_fmax,
383
+ )
384
+
385
+ y = commons.slice_segments(
386
+ y, ids_slice * hps.data.hop_length, hps.train.segment_size
387
+ ) # slice
388
+
389
+ # Discriminator
390
+ y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
391
+ with autocast(enabled=False):
392
+ loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
393
+ y_d_hat_r, y_d_hat_g
394
+ )
395
+ loss_disc_all = loss_disc
396
+ if net_dur_disc is not None:
397
+ y_dur_hat_r, y_dur_hat_g = net_dur_disc(
398
+ hidden_x.detach(), x_mask.detach(), logw.detach(), logw_.detach()
399
+ )
400
+ with autocast(enabled=False):
401
+ # TODO: I think need to mean using the mask, but for now, just mean all
402
+ (
403
+ loss_dur_disc,
404
+ losses_dur_disc_r,
405
+ losses_dur_disc_g,
406
+ ) = discriminator_loss(y_dur_hat_r, y_dur_hat_g)
407
+ loss_dur_disc_all = loss_dur_disc
408
+ optim_dur_disc.zero_grad()
409
+ scaler.scale(loss_dur_disc_all).backward()
410
+ scaler.unscale_(optim_dur_disc)
411
+ commons.clip_grad_value_(net_dur_disc.parameters(), None)
412
+ scaler.step(optim_dur_disc)
413
+
414
+ optim_d.zero_grad()
415
+ scaler.scale(loss_disc_all).backward()
416
+ scaler.unscale_(optim_d)
417
+ grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
418
+ scaler.step(optim_d)
419
+
420
+ with autocast(enabled=hps.train.fp16_run):
421
+ # Generator
422
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
423
+ if net_dur_disc is not None:
424
+ y_dur_hat_r, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw, logw_)
425
+ with autocast(enabled=False):
426
+ loss_dur = torch.sum(l_length.float())
427
+ loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
428
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
429
+
430
+ loss_fm = feature_loss(fmap_r, fmap_g)
431
+ loss_gen, losses_gen = generator_loss(y_d_hat_g)
432
+ loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
433
+ if net_dur_disc is not None:
434
+ loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
435
+ loss_gen_all += loss_dur_gen
436
+ optim_g.zero_grad()
437
+ scaler.scale(loss_gen_all).backward()
438
+ scaler.unscale_(optim_g)
439
+ grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
440
+ scaler.step(optim_g)
441
+ scaler.update()
442
+
443
+ if rank == 0:
444
+ if global_step % hps.train.log_interval == 0:
445
+ lr = optim_g.param_groups[0]["lr"]
446
+ losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl]
447
+ logger.info(
448
+ "Train Epoch: {} [{:.0f}%]".format(
449
+ epoch, 100.0 * batch_idx / len(train_loader)
450
+ )
451
+ )
452
+ logger.info([x.item() for x in losses] + [global_step, lr])
453
+
454
+ scalar_dict = {
455
+ "loss/g/total": loss_gen_all,
456
+ "loss/d/total": loss_disc_all,
457
+ "learning_rate": lr,
458
+ "grad_norm_d": grad_norm_d,
459
+ "grad_norm_g": grad_norm_g,
460
+ }
461
+ scalar_dict.update(
462
+ {
463
+ "loss/g/fm": loss_fm,
464
+ "loss/g/mel": loss_mel,
465
+ "loss/g/dur": loss_dur,
466
+ "loss/g/kl": loss_kl,
467
+ }
468
+ )
469
+ scalar_dict.update(
470
+ {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
471
+ )
472
+ scalar_dict.update(
473
+ {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
474
+ )
475
+ scalar_dict.update(
476
+ {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
477
+ )
478
+
479
+ image_dict = {
480
+ "slice/mel_org": utils.plot_spectrogram_to_numpy(
481
+ y_mel[0].data.cpu().numpy()
482
+ ),
483
+ "slice/mel_gen": utils.plot_spectrogram_to_numpy(
484
+ y_hat_mel[0].data.cpu().numpy()
485
+ ),
486
+ "all/mel": utils.plot_spectrogram_to_numpy(
487
+ mel[0].data.cpu().numpy()
488
+ ),
489
+ "all/attn": utils.plot_alignment_to_numpy(
490
+ attn[0, 0].data.cpu().numpy()
491
+ ),
492
+ }
493
+ utils.summarize(
494
+ writer=writer,
495
+ global_step=global_step,
496
+ images=image_dict,
497
+ scalars=scalar_dict,
498
+ )
499
+
500
+ if global_step % hps.train.eval_interval == 0:
501
+ evaluate(hps, net_g, eval_loader, writer_eval)
502
+ utils.save_checkpoint(
503
+ net_g,
504
+ optim_g,
505
+ hps.train.learning_rate,
506
+ epoch,
507
+ os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
508
+ )
509
+ utils.save_checkpoint(
510
+ net_d,
511
+ optim_d,
512
+ hps.train.learning_rate,
513
+ epoch,
514
+ os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
515
+ )
516
+ if net_dur_disc is not None:
517
+ utils.save_checkpoint(
518
+ net_dur_disc,
519
+ optim_dur_disc,
520
+ hps.train.learning_rate,
521
+ epoch,
522
+ os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step)),
523
+ )
524
+ keep_ckpts = getattr(hps.train, "keep_ckpts", 5)
525
+ if keep_ckpts > 0:
526
+ utils.clean_checkpoints(
527
+ path_to_models=hps.model_dir,
528
+ n_ckpts_to_keep=keep_ckpts,
529
+ sort_by_time=True,
530
+ )
531
+
532
+ global_step += 1
533
+
534
+ if rank == 0:
535
+ logger.info("====> Epoch: {}".format(epoch))
536
+ torch.cuda.empty_cache()
537
+
538
+
539
+ def evaluate(hps, generator, eval_loader, writer_eval):
540
+ generator.eval()
541
+ image_dict = {}
542
+ audio_dict = {}
543
+ print("Evaluating ...")
544
+ with torch.no_grad():
545
+ for batch_idx, (
546
+ x,
547
+ x_lengths,
548
+ spec,
549
+ spec_lengths,
550
+ y,
551
+ y_lengths,
552
+ speakers,
553
+ tone,
554
+ language,
555
+ bert,
556
+ ja_bert,
557
+ ) in enumerate(eval_loader):
558
+ x, x_lengths = x.cuda(), x_lengths.cuda()
559
+ spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
560
+ y, y_lengths = y.cuda(), y_lengths.cuda()
561
+ speakers = speakers.cuda()
562
+ bert = bert.cuda()
563
+ ja_bert = ja_bert.cuda()
564
+ tone = tone.cuda()
565
+ language = language.cuda()
566
+ for use_sdp in [True, False]:
567
+ y_hat, attn, mask, *_ = generator.module.infer(
568
+ x,
569
+ x_lengths,
570
+ speakers,
571
+ tone,
572
+ language,
573
+ bert,
574
+ ja_bert,
575
+ y=spec,
576
+ max_len=1000,
577
+ sdp_ratio=0.0 if not use_sdp else 1.0,
578
+ )
579
+ y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
580
+
581
+ mel = spec_to_mel_torch(
582
+ spec,
583
+ hps.data.filter_length,
584
+ hps.data.n_mel_channels,
585
+ hps.data.sampling_rate,
586
+ hps.data.mel_fmin,
587
+ hps.data.mel_fmax,
588
+ )
589
+ y_hat_mel = mel_spectrogram_torch(
590
+ y_hat.squeeze(1).float(),
591
+ hps.data.filter_length,
592
+ hps.data.n_mel_channels,
593
+ hps.data.sampling_rate,
594
+ hps.data.hop_length,
595
+ hps.data.win_length,
596
+ hps.data.mel_fmin,
597
+ hps.data.mel_fmax,
598
+ )
599
+ image_dict.update(
600
+ {
601
+ f"gen/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
602
+ y_hat_mel[0].cpu().numpy()
603
+ )
604
+ }
605
+ )
606
+ audio_dict.update(
607
+ {
608
+ f"gen/audio_{batch_idx}_{use_sdp}": y_hat[
609
+ 0, :, : y_hat_lengths[0]
610
+ ]
611
+ }
612
+ )
613
+ image_dict.update(
614
+ {
615
+ f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
616
+ mel[0].cpu().numpy()
617
+ )
618
+ }
619
+ )
620
+ audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
621
+
622
+ utils.summarize(
623
+ writer=writer_eval,
624
+ global_step=global_step,
625
+ images=image_dict,
626
+ audios=audio_dict,
627
+ audio_sampling_rate=hps.data.sampling_rate,
628
+ )
629
+ generator.train()
630
+ print('Evauate done')
631
+ torch.cuda.empty_cache()
632
+
633
+
634
+ if __name__ == "__main__":
635
+ run()
@@ -0,0 +1,19 @@
1
+ CONFIG=$1
2
+ GPUS=$2
3
+ MODEL_NAME=$(basename "$(dirname $CONFIG)")
4
+
5
+ PORT=10902
6
+
7
+ while : # auto-resume: the code sometimes crash due to bug of gloo on some gpus
8
+ do
9
+ torchrun --nproc_per_node=$GPUS \
10
+ --master_port=$PORT \
11
+ train.py --c $CONFIG --model $MODEL_NAME
12
+
13
+ for PID in $(ps -aux | grep $CONFIG | grep python | awk '{print $2}')
14
+ do
15
+ echo $PID
16
+ kill -9 $PID
17
+ done
18
+ sleep 30
19
+ done