xinference 0.16.3__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (373) hide show
  1. xinference/_compat.py +24 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +219 -77
  4. xinference/client/restful/restful_client.py +47 -2
  5. xinference/constants.py +1 -0
  6. xinference/core/chat_interface.py +6 -1
  7. xinference/core/model.py +124 -34
  8. xinference/core/supervisor.py +180 -12
  9. xinference/core/utils.py +73 -4
  10. xinference/core/worker.py +102 -4
  11. xinference/deploy/cmdline.py +3 -1
  12. xinference/deploy/test/test_cmdline.py +56 -0
  13. xinference/isolation.py +24 -0
  14. xinference/model/audio/__init__.py +12 -0
  15. xinference/model/audio/core.py +37 -4
  16. xinference/model/audio/cosyvoice.py +39 -6
  17. xinference/model/audio/f5tts.py +200 -0
  18. xinference/model/audio/f5tts_mlx.py +260 -0
  19. xinference/model/audio/fish_speech.py +70 -110
  20. xinference/model/audio/melotts.py +110 -0
  21. xinference/model/audio/model_spec.json +179 -3
  22. xinference/model/audio/model_spec_modelscope.json +27 -0
  23. xinference/model/audio/utils.py +32 -0
  24. xinference/model/audio/whisper.py +35 -10
  25. xinference/model/audio/whisper_mlx.py +208 -0
  26. xinference/model/embedding/core.py +322 -6
  27. xinference/model/embedding/model_spec.json +8 -1
  28. xinference/model/embedding/model_spec_modelscope.json +9 -1
  29. xinference/model/image/core.py +69 -1
  30. xinference/model/image/model_spec.json +145 -4
  31. xinference/model/image/model_spec_modelscope.json +150 -4
  32. xinference/model/image/stable_diffusion/core.py +50 -15
  33. xinference/model/llm/__init__.py +6 -2
  34. xinference/model/llm/llm_family.json +1055 -93
  35. xinference/model/llm/llm_family.py +15 -36
  36. xinference/model/llm/llm_family_modelscope.json +1031 -78
  37. xinference/model/llm/memory.py +1 -1
  38. xinference/model/llm/mlx/core.py +285 -47
  39. xinference/model/llm/sglang/core.py +2 -0
  40. xinference/model/llm/transformers/chatglm.py +9 -5
  41. xinference/model/llm/transformers/cogagent.py +272 -0
  42. xinference/model/llm/transformers/core.py +3 -0
  43. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  44. xinference/model/llm/transformers/qwen2_vl.py +12 -1
  45. xinference/model/llm/transformers/utils.py +16 -8
  46. xinference/model/llm/utils.py +55 -4
  47. xinference/model/llm/vllm/core.py +137 -12
  48. xinference/model/llm/vllm/xavier/__init__.py +13 -0
  49. xinference/model/llm/vllm/xavier/allocator.py +74 -0
  50. xinference/model/llm/vllm/xavier/block.py +111 -0
  51. xinference/model/llm/vllm/xavier/block_manager.py +71 -0
  52. xinference/model/llm/vllm/xavier/block_tracker.py +129 -0
  53. xinference/model/llm/vllm/xavier/collective.py +74 -0
  54. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  55. xinference/model/llm/vllm/xavier/engine.py +247 -0
  56. xinference/model/llm/vllm/xavier/executor.py +134 -0
  57. xinference/model/llm/vllm/xavier/scheduler.py +438 -0
  58. xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
  59. xinference/model/llm/vllm/xavier/test/test_xavier.py +147 -0
  60. xinference/model/llm/vllm/xavier/transfer.py +319 -0
  61. xinference/model/rerank/core.py +11 -4
  62. xinference/model/video/diffusers.py +14 -0
  63. xinference/model/video/model_spec.json +15 -0
  64. xinference/model/video/model_spec_modelscope.json +16 -0
  65. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  66. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  67. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  68. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  69. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  70. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  71. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  72. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  73. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  74. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  75. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  76. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  77. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  78. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  79. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  80. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  81. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  82. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  83. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  84. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  85. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  86. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  87. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  88. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  89. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  90. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  91. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  92. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  93. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  94. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  95. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  96. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  97. xinference/thirdparty/f5_tts/api.py +166 -0
  98. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  99. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  100. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  101. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  102. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  103. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  104. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  105. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  106. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  107. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  108. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  109. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  110. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  111. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  112. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  113. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  114. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  115. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  116. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  117. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  118. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  119. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  120. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  121. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  122. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  123. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  124. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  125. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  126. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  127. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  128. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  129. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  130. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  131. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  132. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  133. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  134. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  135. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  136. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  137. xinference/thirdparty/f5_tts/train/README.md +77 -0
  138. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  139. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  140. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  141. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  142. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  143. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  144. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  145. xinference/thirdparty/f5_tts/train/train.py +75 -0
  146. xinference/thirdparty/fish_speech/fish_speech/conversation.py +266 -1
  147. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  148. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  149. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  150. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  151. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  152. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +137 -29
  153. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  154. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  155. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +17 -11
  156. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  157. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  158. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  159. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  160. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  161. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  162. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +2 -2
  163. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +34 -18
  164. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  165. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  166. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  167. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  168. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  169. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  170. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  171. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  172. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  173. xinference/thirdparty/fish_speech/tools/llama/generate.py +484 -72
  174. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  175. xinference/thirdparty/fish_speech/tools/schema.py +170 -0
  176. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  177. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  178. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  179. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  180. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  181. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  182. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  183. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  184. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  185. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  186. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  187. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  188. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  189. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  190. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  191. xinference/thirdparty/matcha/utils/utils.py +2 -2
  192. xinference/thirdparty/melo/api.py +135 -0
  193. xinference/thirdparty/melo/app.py +61 -0
  194. xinference/thirdparty/melo/attentions.py +459 -0
  195. xinference/thirdparty/melo/commons.py +160 -0
  196. xinference/thirdparty/melo/configs/config.json +94 -0
  197. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  198. xinference/thirdparty/melo/data_utils.py +413 -0
  199. xinference/thirdparty/melo/download_utils.py +67 -0
  200. xinference/thirdparty/melo/infer.py +25 -0
  201. xinference/thirdparty/melo/init_downloads.py +14 -0
  202. xinference/thirdparty/melo/losses.py +58 -0
  203. xinference/thirdparty/melo/main.py +36 -0
  204. xinference/thirdparty/melo/mel_processing.py +174 -0
  205. xinference/thirdparty/melo/models.py +1030 -0
  206. xinference/thirdparty/melo/modules.py +598 -0
  207. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  208. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  209. xinference/thirdparty/melo/preprocess_text.py +135 -0
  210. xinference/thirdparty/melo/split_utils.py +174 -0
  211. xinference/thirdparty/melo/text/__init__.py +35 -0
  212. xinference/thirdparty/melo/text/chinese.py +199 -0
  213. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  214. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  215. xinference/thirdparty/melo/text/cleaner.py +36 -0
  216. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  217. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  218. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  219. xinference/thirdparty/melo/text/english.py +284 -0
  220. xinference/thirdparty/melo/text/english_bert.py +39 -0
  221. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  222. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  223. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  224. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  225. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  226. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  227. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  228. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  229. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  230. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  231. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  232. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  233. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  234. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  235. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  236. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  237. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  238. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  239. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  240. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  241. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  242. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  243. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  244. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  245. xinference/thirdparty/melo/text/french.py +94 -0
  246. xinference/thirdparty/melo/text/french_bert.py +39 -0
  247. xinference/thirdparty/melo/text/japanese.py +647 -0
  248. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  249. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  250. xinference/thirdparty/melo/text/korean.py +192 -0
  251. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  252. xinference/thirdparty/melo/text/spanish.py +122 -0
  253. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  254. xinference/thirdparty/melo/text/symbols.py +290 -0
  255. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  256. xinference/thirdparty/melo/train.py +635 -0
  257. xinference/thirdparty/melo/train.sh +19 -0
  258. xinference/thirdparty/melo/transforms.py +209 -0
  259. xinference/thirdparty/melo/utils.py +424 -0
  260. xinference/types.py +17 -1
  261. xinference/web/ui/build/asset-manifest.json +6 -6
  262. xinference/web/ui/build/index.html +1 -1
  263. xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
  264. xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
  265. xinference/web/ui/build/static/js/main.b0936c54.js +3 -0
  266. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  267. xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
  268. xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
  269. xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
  270. xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
  271. xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
  272. xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
  273. xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
  274. xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
  275. xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
  276. xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
  277. xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
  278. xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
  279. xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
  280. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
  281. xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
  282. xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
  283. xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
  284. xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
  285. xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
  286. xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
  287. xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
  288. xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
  289. xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
  290. xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
  291. xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
  292. xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
  293. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  294. xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
  295. xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
  296. xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
  297. xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
  298. xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
  299. xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
  300. xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
  301. xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
  302. xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
  303. xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
  304. xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
  305. xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
  306. xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
  307. xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
  308. xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
  309. xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
  310. xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
  311. xinference/web/ui/node_modules/.package-lock.json +67 -3
  312. xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
  313. xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
  314. xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
  315. xinference/web/ui/node_modules/i18next/package.json +129 -0
  316. xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
  317. xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
  318. xinference/web/ui/node_modules/react-i18next/package.json +162 -0
  319. xinference/web/ui/node_modules/void-elements/package.json +34 -0
  320. xinference/web/ui/package-lock.json +69 -3
  321. xinference/web/ui/package.json +2 -0
  322. xinference/web/ui/src/locales/en.json +186 -0
  323. xinference/web/ui/src/locales/zh.json +186 -0
  324. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/METADATA +96 -36
  325. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/RECORD +335 -146
  326. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/WHEEL +1 -1
  327. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  328. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  329. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  330. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  331. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  332. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  333. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  334. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  335. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  336. xinference/thirdparty/fish_speech/tools/api.py +0 -440
  337. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  338. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  339. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -34
  340. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  341. xinference/thirdparty/fish_speech/tools/webui.py +0 -485
  342. xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
  343. xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
  344. xinference/web/ui/build/static/js/main.2f269bb3.js +0 -3
  345. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  346. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
  347. xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
  348. xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
  349. xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
  350. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
  351. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
  352. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
  353. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
  354. xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
  355. xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
  356. xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
  357. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  358. xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
  359. xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
  360. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
  361. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
  362. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
  363. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
  364. xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
  365. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  366. /xinference/thirdparty/{cosyvoice/flow → melo}/__init__.py +0 -0
  367. /xinference/thirdparty/{cosyvoice/hifigan → melo/text/english_utils}/__init__.py +0 -0
  368. /xinference/thirdparty/{cosyvoice/llm → melo/text/es_phonemizer}/__init__.py +0 -0
  369. /xinference/thirdparty/{fish_speech/fish_speech/configs → melo/text/fr_phonemizer}/__init__.py +0 -0
  370. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  371. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/LICENSE +0 -0
  372. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/entry_points.txt +0 -0
  373. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/top_level.txt +0 -0
@@ -49,13 +49,14 @@ class InterpolateRegulator(nn.Module):
49
49
  olens = ylens
50
50
  return out * mask, olens
51
51
 
52
- def inference(self, x1, x2, mel_len1, mel_len2):
52
+ def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
53
53
  # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
54
  # x in (B, T, D)
55
55
  if x2.shape[1] > 40:
56
- x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
57
- x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
58
- x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
56
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
57
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
58
+ mode='linear')
59
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
59
60
  x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
60
61
  else:
61
62
  x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
@@ -0,0 +1,140 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils import weight_norm
4
+ from typing import List, Optional, Tuple
5
+ from einops import rearrange
6
+ from torchaudio.transforms import Spectrogram
7
+
8
+
9
+ class MultipleDiscriminator(nn.Module):
10
+ def __init__(
11
+ self, mpd: nn.Module, mrd: nn.Module
12
+ ):
13
+ super().__init__()
14
+ self.mpd = mpd
15
+ self.mrd = mrd
16
+
17
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
18
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
19
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
20
+ y_d_rs += this_y_d_rs
21
+ y_d_gs += this_y_d_gs
22
+ fmap_rs += this_fmap_rs
23
+ fmap_gs += this_fmap_gs
24
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
25
+ y_d_rs += this_y_d_rs
26
+ y_d_gs += this_y_d_gs
27
+ fmap_rs += this_fmap_rs
28
+ fmap_gs += this_fmap_gs
29
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
30
+
31
+
32
+ class MultiResolutionDiscriminator(nn.Module):
33
+ def __init__(
34
+ self,
35
+ fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
36
+ num_embeddings: Optional[int] = None,
37
+ ):
38
+ """
39
+ Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
40
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
41
+
42
+ Args:
43
+ fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
44
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
45
+ Defaults to None.
46
+ """
47
+
48
+ super().__init__()
49
+ self.discriminators = nn.ModuleList(
50
+ [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
51
+ )
52
+
53
+ def forward(
54
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
55
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
56
+ y_d_rs = []
57
+ y_d_gs = []
58
+ fmap_rs = []
59
+ fmap_gs = []
60
+
61
+ for d in self.discriminators:
62
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
63
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
64
+ y_d_rs.append(y_d_r)
65
+ fmap_rs.append(fmap_r)
66
+ y_d_gs.append(y_d_g)
67
+ fmap_gs.append(fmap_g)
68
+
69
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
70
+
71
+
72
+ class DiscriminatorR(nn.Module):
73
+ def __init__(
74
+ self,
75
+ window_length: int,
76
+ num_embeddings: Optional[int] = None,
77
+ channels: int = 32,
78
+ hop_factor: float = 0.25,
79
+ bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
80
+ ):
81
+ super().__init__()
82
+ self.window_length = window_length
83
+ self.hop_factor = hop_factor
84
+ self.spec_fn = Spectrogram(
85
+ n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
86
+ )
87
+ n_fft = window_length // 2 + 1
88
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
89
+ self.bands = bands
90
+ convs = lambda: nn.ModuleList(
91
+ [
92
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
93
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
94
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
95
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
96
+ weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
97
+ ]
98
+ )
99
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
100
+
101
+ if num_embeddings is not None:
102
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
103
+ torch.nn.init.zeros_(self.emb.weight)
104
+
105
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
106
+
107
+ def spectrogram(self, x):
108
+ # Remove DC offset
109
+ x = x - x.mean(dim=-1, keepdims=True)
110
+ # Peak normalize the volume of input audio
111
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
112
+ x = self.spec_fn(x)
113
+ x = torch.view_as_real(x)
114
+ x = rearrange(x, "b f t c -> b c t f")
115
+ # Split into bands
116
+ x_bands = [x[..., b[0]: b[1]] for b in self.bands]
117
+ return x_bands
118
+
119
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
120
+ x_bands = self.spectrogram(x)
121
+ fmap = []
122
+ x = []
123
+ for band, stack in zip(x_bands, self.band_convs):
124
+ for i, layer in enumerate(stack):
125
+ band = layer(band)
126
+ band = torch.nn.functional.leaky_relu(band, 0.1)
127
+ if i > 0:
128
+ fmap.append(band)
129
+ x.append(band)
130
+ x = torch.cat(x, dim=-1)
131
+ if cond_embedding_id is not None:
132
+ emb = self.emb(cond_embedding_id)
133
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
134
+ else:
135
+ h = 0
136
+ x = self.conv_post(x)
137
+ fmap.append(x)
138
+ x += h
139
+
140
+ return x, fmap
@@ -14,7 +14,7 @@
14
14
 
15
15
  """HIFI-GAN"""
16
16
 
17
- import typing as tp
17
+ from typing import Dict, Optional, List
18
18
  import numpy as np
19
19
  from scipy.signal import get_window
20
20
  import torch
@@ -38,13 +38,15 @@ This code is modified from https://github.com/jik876/hifi-gan
38
38
  https://github.com/NVIDIA/BigVGAN
39
39
 
40
40
  """
41
+
42
+
41
43
  class ResBlock(torch.nn.Module):
42
44
  """Residual block module in HiFiGAN/BigVGAN."""
43
45
  def __init__(
44
46
  self,
45
47
  channels: int = 512,
46
48
  kernel_size: int = 3,
47
- dilations: tp.List[int] = [1, 3, 5],
49
+ dilations: List[int] = [1, 3, 5],
48
50
  ):
49
51
  super(ResBlock, self).__init__()
50
52
  self.convs1 = nn.ModuleList()
@@ -100,6 +102,7 @@ class ResBlock(torch.nn.Module):
100
102
  remove_weight_norm(self.convs1[idx])
101
103
  remove_weight_norm(self.convs2[idx])
102
104
 
105
+
103
106
  class SineGen(torch.nn.Module):
104
107
  """ Definition of sine generator
105
108
  SineGen(samp_rate, harmonic_num = 0,
@@ -231,13 +234,13 @@ class HiFTGenerator(nn.Module):
231
234
  nsf_alpha: float = 0.1,
232
235
  nsf_sigma: float = 0.003,
233
236
  nsf_voiced_threshold: float = 10,
234
- upsample_rates: tp.List[int] = [8, 8],
235
- upsample_kernel_sizes: tp.List[int] = [16, 16],
236
- istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
237
- resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
238
- resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
239
- source_resblock_kernel_sizes: tp.List[int] = [7, 11],
240
- source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
237
+ upsample_rates: List[int] = [8, 8],
238
+ upsample_kernel_sizes: List[int] = [16, 16],
239
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
240
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
241
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
242
+ source_resblock_kernel_sizes: List[int] = [7, 11],
243
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
241
244
  lrelu_slope: float = 0.1,
242
245
  audio_limit: float = 0.99,
243
246
  f0_predictor: torch.nn.Module = None,
@@ -286,8 +289,7 @@ class HiFTGenerator(nn.Module):
286
289
  self.source_resblocks = nn.ModuleList()
287
290
  downsample_rates = [1] + upsample_rates[::-1][:-1]
288
291
  downsample_cum_rates = np.cumprod(downsample_rates)
289
- for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
290
- source_resblock_dilation_sizes)):
292
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
291
293
  if u == 1:
292
294
  self.source_downs.append(
293
295
  Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
@@ -304,7 +306,7 @@ class HiFTGenerator(nn.Module):
304
306
  self.resblocks = nn.ModuleList()
305
307
  for i in range(len(self.ups)):
306
308
  ch = base_channels // (2**(i + 1))
307
- for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
309
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
308
310
  self.resblocks.append(ResBlock(ch, k, d))
309
311
 
310
312
  self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
@@ -314,11 +316,19 @@ class HiFTGenerator(nn.Module):
314
316
  self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
315
317
  self.f0_predictor = f0_predictor
316
318
 
317
- def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
318
- f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
319
-
320
- har_source, _, _ = self.m_source(f0)
321
- return har_source.transpose(1, 2)
319
+ def remove_weight_norm(self):
320
+ print('Removing weight norm...')
321
+ for l in self.ups:
322
+ remove_weight_norm(l)
323
+ for l in self.resblocks:
324
+ l.remove_weight_norm()
325
+ remove_weight_norm(self.conv_pre)
326
+ remove_weight_norm(self.conv_post)
327
+ self.m_source.remove_weight_norm()
328
+ for l in self.source_downs:
329
+ remove_weight_norm(l)
330
+ for l in self.source_resblocks:
331
+ l.remove_weight_norm()
322
332
 
323
333
  def _stft(self, x):
324
334
  spec = torch.stft(
@@ -332,17 +342,11 @@ class HiFTGenerator(nn.Module):
332
342
  magnitude = torch.clip(magnitude, max=1e2)
333
343
  real = magnitude * torch.cos(phase)
334
344
  img = magnitude * torch.sin(phase)
335
- inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
345
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
346
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
336
347
  return inverse_transform
337
348
 
338
- def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
339
- f0 = self.f0_predictor(x)
340
- s = self._f02source(f0)
341
-
342
- # use cache_source to avoid glitch
343
- if cache_source.shape[2] == 0:
344
- s[:, :, :cache_source.shape[2]] = cache_source
345
-
349
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
346
350
  s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
347
351
  s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
348
352
 
@@ -374,22 +378,34 @@ class HiFTGenerator(nn.Module):
374
378
 
375
379
  x = self._istft(magnitude, phase)
376
380
  x = torch.clamp(x, -self.audio_limit, self.audio_limit)
377
- return x, s
381
+ return x
378
382
 
379
- def remove_weight_norm(self):
380
- print('Removing weight norm...')
381
- for l in self.ups:
382
- remove_weight_norm(l)
383
- for l in self.resblocks:
384
- l.remove_weight_norm()
385
- remove_weight_norm(self.conv_pre)
386
- remove_weight_norm(self.conv_post)
387
- self.source_module.remove_weight_norm()
388
- for l in self.source_downs:
389
- remove_weight_norm(l)
390
- for l in self.source_resblocks:
391
- l.remove_weight_norm()
383
+ def forward(
384
+ self,
385
+ batch: dict,
386
+ device: torch.device,
387
+ ) -> Dict[str, Optional[torch.Tensor]]:
388
+ speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
389
+ # mel->f0
390
+ f0 = self.f0_predictor(speech_feat)
391
+ # f0->source
392
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
393
+ s, _, _ = self.m_source(s)
394
+ s = s.transpose(1, 2)
395
+ # mel+source->speech
396
+ generated_speech = self.decode(x=speech_feat, s=s)
397
+ return generated_speech, f0
392
398
 
393
399
  @torch.inference_mode()
394
- def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
395
- return self.forward(x=mel, cache_source=cache_source)
400
+ def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
401
+ # mel->f0
402
+ f0 = self.f0_predictor(speech_feat)
403
+ # f0->source
404
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
405
+ s, _, _ = self.m_source(s)
406
+ s = s.transpose(1, 2)
407
+ # use cache_source to avoid glitch
408
+ if cache_source.shape[2] != 0:
409
+ s[:, :, :cache_source.shape[2]] = cache_source
410
+ generated_speech = self.decode(x=speech_feat, s=s)
411
+ return generated_speech, s
@@ -0,0 +1,67 @@
1
+ from typing import Dict, Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
6
+ from cosyvoice.utils.losses import tpr_loss, mel_loss
7
+
8
+
9
+ class HiFiGan(nn.Module):
10
+ def __init__(self, generator, discriminator, mel_spec_transform,
11
+ multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
12
+ tpr_loss_weight=1.0, tpr_loss_tau=0.04):
13
+ super(HiFiGan, self).__init__()
14
+ self.generator = generator
15
+ self.discriminator = discriminator
16
+ self.mel_spec_transform = mel_spec_transform
17
+ self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
18
+ self.feat_match_loss_weight = feat_match_loss_weight
19
+ self.tpr_loss_weight = tpr_loss_weight
20
+ self.tpr_loss_tau = tpr_loss_tau
21
+
22
+ def forward(
23
+ self,
24
+ batch: dict,
25
+ device: torch.device,
26
+ ) -> Dict[str, Optional[torch.Tensor]]:
27
+ if batch['turn'] == 'generator':
28
+ return self.forward_generator(batch, device)
29
+ else:
30
+ return self.forward_discriminator(batch, device)
31
+
32
+ def forward_generator(self, batch, device):
33
+ real_speech = batch['speech'].to(device)
34
+ pitch_feat = batch['pitch_feat'].to(device)
35
+ # 1. calculate generator outputs
36
+ generated_speech, generated_f0 = self.generator(batch, device)
37
+ # 2. calculate discriminator outputs
38
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
39
+ # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
40
+ loss_gen, _ = generator_loss(y_d_gs)
41
+ loss_fm = feature_loss(fmap_rs, fmap_gs)
42
+ loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
43
+ if self.tpr_loss_weight != 0:
44
+ loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
45
+ else:
46
+ loss_tpr = torch.zeros(1).to(device)
47
+ loss_f0 = F.l1_loss(generated_f0, pitch_feat)
48
+ loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
49
+ self.multi_mel_spectral_recon_loss_weight * loss_mel + \
50
+ self.tpr_loss_weight * loss_tpr + loss_f0
51
+ return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
52
+
53
+ def forward_discriminator(self, batch, device):
54
+ real_speech = batch['speech'].to(device)
55
+ # 1. calculate generator outputs
56
+ with torch.no_grad():
57
+ generated_speech, generated_f0 = self.generator(batch, device)
58
+ # 2. calculate discriminator outputs
59
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
60
+ # 3. calculate discriminator losses, tpr losses [Optional]
61
+ loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
62
+ if self.tpr_loss_weight != 0:
63
+ loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
64
+ else:
65
+ loss_tpr = torch.zeros(1).to(device)
66
+ loss = loss_disc + self.tpr_loss_weight * loss_tpr
67
+ return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}
@@ -15,6 +15,7 @@ from typing import Dict, Optional, Callable, List, Generator
15
15
  import torch
16
16
  from torch import nn
17
17
  import torch.nn.functional as F
18
+ from transformers import Qwen2ForCausalLM
18
19
  from torch.nn.utils.rnn import pad_sequence, unpad_sequence
19
20
  from cosyvoice.utils.common import IGNORE_ID
20
21
  from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
@@ -80,7 +81,8 @@ class TransformerLM(torch.nn.Module):
80
81
  def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
81
82
  text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
82
83
  speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
83
- lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) for i in range(len(text_token))]
84
+ lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
85
+ for i in range(len(text_token))]
84
86
  lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
85
87
  lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
86
88
  return lm_input, lm_input_len
@@ -104,7 +106,8 @@ class TransformerLM(torch.nn.Module):
104
106
  embedding = batch['embedding'].to(device)
105
107
 
106
108
  # 1. prepare llm_target
107
- lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]
109
+ lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
110
+ [self.speech_token_size]) for i in range(text_token.size(0))]
108
111
  lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
109
112
 
110
113
  # 1. encode text_token
@@ -124,7 +127,8 @@ class TransformerLM(torch.nn.Module):
124
127
  speech_token = self.speech_embedding(speech_token)
125
128
 
126
129
  # 5. unpad and pad
127
- lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len)
130
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
131
+ task_id_emb, speech_token, speech_token_len)
128
132
 
129
133
  # 6. run lm forward
130
134
  lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
@@ -194,14 +198,143 @@ class TransformerLM(torch.nn.Module):
194
198
  offset = 0
195
199
  att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
196
200
  for i in range(max_len):
197
- y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
198
- att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
201
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
202
+ att_cache=att_cache, cnn_cache=cnn_cache,
203
+ att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
204
+ device=lm_input.device)).to(torch.bool))
199
205
  logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
206
+ # force continue decode first token
207
+ if i == 0:
208
+ logp[:, self.speech_token_size] = -float('inf')
200
209
  top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
201
210
  if top_ids == self.speech_token_size:
202
211
  break
203
212
  # in stream mode, yield token one by one
204
- yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
213
+ yield top_ids
205
214
  out_tokens.append(top_ids)
206
215
  offset += lm_input.size(1)
207
216
  lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
217
+
218
+
219
+ class Qwen2Encoder(torch.nn.Module):
220
+ def __init__(self, pretrain_path):
221
+ super().__init__()
222
+ self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
223
+
224
+ def forward_one_step(self, xs, masks, cache=None):
225
+ input_masks = masks[:, -1, :]
226
+ outs = self.model(
227
+ inputs_embeds=xs,
228
+ attention_mask=input_masks,
229
+ output_hidden_states=True,
230
+ return_dict=True,
231
+ use_cache=True,
232
+ past_key_values=cache,
233
+ )
234
+ xs = outs.hidden_states[-1]
235
+ new_cache = outs.past_key_values
236
+ return xs, new_cache
237
+
238
+
239
+ class Qwen2LM(torch.nn.Module):
240
+ def __init__(
241
+ self,
242
+ llm_input_size: int,
243
+ llm_output_size: int,
244
+ speech_token_size: int,
245
+ llm: torch.nn.Module,
246
+ sampling: Callable,
247
+ length_normalized_loss: bool = True,
248
+ lsm_weight: float = 0.0,
249
+ ):
250
+ super().__init__()
251
+ self.llm_input_size = llm_input_size
252
+ self.llm_output_size = llm_output_size
253
+ self.speech_token_size = speech_token_size
254
+
255
+ # 2. build speech token language model related modules
256
+ self.sos_eos = 0
257
+ self.task_id = 1
258
+ self.fill_token = 2
259
+
260
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
261
+ self.llm = llm
262
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
263
+ self.criterion_ce = LabelSmoothingLoss(
264
+ size=speech_token_size + 3,
265
+ padding_idx=IGNORE_ID,
266
+ smoothing=lsm_weight,
267
+ normalize_length=length_normalized_loss,
268
+ )
269
+
270
+ # 3. [Optional] build speech token related modules
271
+ self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
272
+
273
+ # 4. sampling method
274
+ self.sampling = sampling
275
+
276
+ def sampling_ids(
277
+ self,
278
+ weighted_scores: torch.Tensor,
279
+ decoded_tokens: List,
280
+ sampling: int,
281
+ ignore_eos: bool = True,
282
+ ):
283
+ while True:
284
+ top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
285
+ if (not ignore_eos) or (self.speech_token_size not in top_ids):
286
+ break
287
+ return top_ids
288
+
289
+ @torch.inference_mode()
290
+ def inference(
291
+ self,
292
+ text: torch.Tensor,
293
+ text_len: torch.Tensor,
294
+ prompt_text: torch.Tensor,
295
+ prompt_text_len: torch.Tensor,
296
+ prompt_speech_token: torch.Tensor,
297
+ prompt_speech_token_len: torch.Tensor,
298
+ embedding: torch.Tensor,
299
+ sampling: int = 25,
300
+ max_token_text_ratio: float = 20,
301
+ min_token_text_ratio: float = 2,
302
+ ) -> Generator[torch.Tensor, None, None]:
303
+ device = text.device
304
+ text = torch.concat([prompt_text, text], dim=1)
305
+ text_len += prompt_text_len
306
+ text = self.llm.model.model.embed_tokens(text)
307
+
308
+ # 2. encode embedding
309
+ embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
310
+
311
+ # 3. concat llm_input
312
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
313
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
314
+ if prompt_speech_token_len != 0:
315
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
316
+ else:
317
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
318
+ lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
319
+
320
+ # 4. cal min/max_length
321
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
322
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
323
+
324
+ # 5. step by step decode
325
+ out_tokens = []
326
+ cache = None
327
+ for i in range(max_len):
328
+ y_pred, cache = self.llm.forward_one_step(lm_input,
329
+ masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
330
+ cache=cache)
331
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
332
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
333
+ if top_ids == self.speech_token_size:
334
+ break
335
+ if top_ids > self.speech_token_size:
336
+ continue
337
+ # in stream mode, yield token one by one
338
+ yield top_ids
339
+ out_tokens.append(top_ids)
340
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)