xinference 0.16.3__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (373) hide show
  1. xinference/_compat.py +24 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +219 -77
  4. xinference/client/restful/restful_client.py +47 -2
  5. xinference/constants.py +1 -0
  6. xinference/core/chat_interface.py +6 -1
  7. xinference/core/model.py +124 -34
  8. xinference/core/supervisor.py +180 -12
  9. xinference/core/utils.py +73 -4
  10. xinference/core/worker.py +102 -4
  11. xinference/deploy/cmdline.py +3 -1
  12. xinference/deploy/test/test_cmdline.py +56 -0
  13. xinference/isolation.py +24 -0
  14. xinference/model/audio/__init__.py +12 -0
  15. xinference/model/audio/core.py +37 -4
  16. xinference/model/audio/cosyvoice.py +39 -6
  17. xinference/model/audio/f5tts.py +200 -0
  18. xinference/model/audio/f5tts_mlx.py +260 -0
  19. xinference/model/audio/fish_speech.py +70 -110
  20. xinference/model/audio/melotts.py +110 -0
  21. xinference/model/audio/model_spec.json +179 -3
  22. xinference/model/audio/model_spec_modelscope.json +27 -0
  23. xinference/model/audio/utils.py +32 -0
  24. xinference/model/audio/whisper.py +35 -10
  25. xinference/model/audio/whisper_mlx.py +208 -0
  26. xinference/model/embedding/core.py +322 -6
  27. xinference/model/embedding/model_spec.json +8 -1
  28. xinference/model/embedding/model_spec_modelscope.json +9 -1
  29. xinference/model/image/core.py +69 -1
  30. xinference/model/image/model_spec.json +145 -4
  31. xinference/model/image/model_spec_modelscope.json +150 -4
  32. xinference/model/image/stable_diffusion/core.py +50 -15
  33. xinference/model/llm/__init__.py +6 -2
  34. xinference/model/llm/llm_family.json +1055 -93
  35. xinference/model/llm/llm_family.py +15 -36
  36. xinference/model/llm/llm_family_modelscope.json +1031 -78
  37. xinference/model/llm/memory.py +1 -1
  38. xinference/model/llm/mlx/core.py +285 -47
  39. xinference/model/llm/sglang/core.py +2 -0
  40. xinference/model/llm/transformers/chatglm.py +9 -5
  41. xinference/model/llm/transformers/cogagent.py +272 -0
  42. xinference/model/llm/transformers/core.py +3 -0
  43. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  44. xinference/model/llm/transformers/qwen2_vl.py +12 -1
  45. xinference/model/llm/transformers/utils.py +16 -8
  46. xinference/model/llm/utils.py +55 -4
  47. xinference/model/llm/vllm/core.py +137 -12
  48. xinference/model/llm/vllm/xavier/__init__.py +13 -0
  49. xinference/model/llm/vllm/xavier/allocator.py +74 -0
  50. xinference/model/llm/vllm/xavier/block.py +111 -0
  51. xinference/model/llm/vllm/xavier/block_manager.py +71 -0
  52. xinference/model/llm/vllm/xavier/block_tracker.py +129 -0
  53. xinference/model/llm/vllm/xavier/collective.py +74 -0
  54. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  55. xinference/model/llm/vllm/xavier/engine.py +247 -0
  56. xinference/model/llm/vllm/xavier/executor.py +134 -0
  57. xinference/model/llm/vllm/xavier/scheduler.py +438 -0
  58. xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
  59. xinference/model/llm/vllm/xavier/test/test_xavier.py +147 -0
  60. xinference/model/llm/vllm/xavier/transfer.py +319 -0
  61. xinference/model/rerank/core.py +11 -4
  62. xinference/model/video/diffusers.py +14 -0
  63. xinference/model/video/model_spec.json +15 -0
  64. xinference/model/video/model_spec_modelscope.json +16 -0
  65. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  66. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  67. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  68. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  69. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  70. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  71. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  72. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  73. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  74. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  75. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  76. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  77. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  78. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  79. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  80. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  81. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  82. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  83. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  84. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  85. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  86. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  87. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  88. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  89. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  90. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  91. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  92. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  93. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  94. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  95. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  96. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  97. xinference/thirdparty/f5_tts/api.py +166 -0
  98. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  99. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  100. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  101. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  102. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  103. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  104. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  105. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  106. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  107. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  108. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  109. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  110. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  111. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  112. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  113. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  114. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  115. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  116. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  117. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  118. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  119. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  120. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  121. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  122. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  123. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  124. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  125. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  126. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  127. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  128. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  129. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  130. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  131. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  132. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  133. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  134. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  135. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  136. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  137. xinference/thirdparty/f5_tts/train/README.md +77 -0
  138. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  139. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  140. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  141. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  142. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  143. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  144. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  145. xinference/thirdparty/f5_tts/train/train.py +75 -0
  146. xinference/thirdparty/fish_speech/fish_speech/conversation.py +266 -1
  147. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  148. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  149. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  150. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  151. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  152. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +137 -29
  153. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  154. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  155. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +17 -11
  156. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  157. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  158. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  159. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  160. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  161. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  162. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +2 -2
  163. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +34 -18
  164. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  165. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  166. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  167. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  168. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  169. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  170. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  171. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  172. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  173. xinference/thirdparty/fish_speech/tools/llama/generate.py +484 -72
  174. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  175. xinference/thirdparty/fish_speech/tools/schema.py +170 -0
  176. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  177. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  178. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  179. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  180. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  181. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  182. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  183. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  184. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  185. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  186. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  187. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  188. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  189. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  190. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  191. xinference/thirdparty/matcha/utils/utils.py +2 -2
  192. xinference/thirdparty/melo/api.py +135 -0
  193. xinference/thirdparty/melo/app.py +61 -0
  194. xinference/thirdparty/melo/attentions.py +459 -0
  195. xinference/thirdparty/melo/commons.py +160 -0
  196. xinference/thirdparty/melo/configs/config.json +94 -0
  197. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  198. xinference/thirdparty/melo/data_utils.py +413 -0
  199. xinference/thirdparty/melo/download_utils.py +67 -0
  200. xinference/thirdparty/melo/infer.py +25 -0
  201. xinference/thirdparty/melo/init_downloads.py +14 -0
  202. xinference/thirdparty/melo/losses.py +58 -0
  203. xinference/thirdparty/melo/main.py +36 -0
  204. xinference/thirdparty/melo/mel_processing.py +174 -0
  205. xinference/thirdparty/melo/models.py +1030 -0
  206. xinference/thirdparty/melo/modules.py +598 -0
  207. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  208. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  209. xinference/thirdparty/melo/preprocess_text.py +135 -0
  210. xinference/thirdparty/melo/split_utils.py +174 -0
  211. xinference/thirdparty/melo/text/__init__.py +35 -0
  212. xinference/thirdparty/melo/text/chinese.py +199 -0
  213. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  214. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  215. xinference/thirdparty/melo/text/cleaner.py +36 -0
  216. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  217. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  218. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  219. xinference/thirdparty/melo/text/english.py +284 -0
  220. xinference/thirdparty/melo/text/english_bert.py +39 -0
  221. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  222. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  223. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  224. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  225. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  226. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  227. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  228. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  229. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  230. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  231. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  232. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  233. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  234. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  235. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  236. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  237. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  238. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  239. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  240. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  241. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  242. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  243. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  244. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  245. xinference/thirdparty/melo/text/french.py +94 -0
  246. xinference/thirdparty/melo/text/french_bert.py +39 -0
  247. xinference/thirdparty/melo/text/japanese.py +647 -0
  248. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  249. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  250. xinference/thirdparty/melo/text/korean.py +192 -0
  251. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  252. xinference/thirdparty/melo/text/spanish.py +122 -0
  253. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  254. xinference/thirdparty/melo/text/symbols.py +290 -0
  255. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  256. xinference/thirdparty/melo/train.py +635 -0
  257. xinference/thirdparty/melo/train.sh +19 -0
  258. xinference/thirdparty/melo/transforms.py +209 -0
  259. xinference/thirdparty/melo/utils.py +424 -0
  260. xinference/types.py +17 -1
  261. xinference/web/ui/build/asset-manifest.json +6 -6
  262. xinference/web/ui/build/index.html +1 -1
  263. xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
  264. xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
  265. xinference/web/ui/build/static/js/main.b0936c54.js +3 -0
  266. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  267. xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
  268. xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
  269. xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
  270. xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
  271. xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
  272. xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
  273. xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
  274. xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
  275. xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
  276. xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
  277. xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
  278. xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
  279. xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
  280. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
  281. xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
  282. xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
  283. xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
  284. xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
  285. xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
  286. xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
  287. xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
  288. xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
  289. xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
  290. xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
  291. xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
  292. xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
  293. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  294. xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
  295. xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
  296. xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
  297. xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
  298. xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
  299. xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
  300. xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
  301. xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
  302. xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
  303. xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
  304. xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
  305. xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
  306. xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
  307. xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
  308. xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
  309. xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
  310. xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
  311. xinference/web/ui/node_modules/.package-lock.json +67 -3
  312. xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
  313. xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
  314. xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
  315. xinference/web/ui/node_modules/i18next/package.json +129 -0
  316. xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
  317. xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
  318. xinference/web/ui/node_modules/react-i18next/package.json +162 -0
  319. xinference/web/ui/node_modules/void-elements/package.json +34 -0
  320. xinference/web/ui/package-lock.json +69 -3
  321. xinference/web/ui/package.json +2 -0
  322. xinference/web/ui/src/locales/en.json +186 -0
  323. xinference/web/ui/src/locales/zh.json +186 -0
  324. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/METADATA +96 -36
  325. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/RECORD +335 -146
  326. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/WHEEL +1 -1
  327. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  328. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  329. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  330. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  331. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  332. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  333. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  334. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  335. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  336. xinference/thirdparty/fish_speech/tools/api.py +0 -440
  337. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  338. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  339. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -34
  340. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  341. xinference/thirdparty/fish_speech/tools/webui.py +0 -485
  342. xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
  343. xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
  344. xinference/web/ui/build/static/js/main.2f269bb3.js +0 -3
  345. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  346. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
  347. xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
  348. xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
  349. xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
  350. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
  351. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
  352. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
  353. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
  354. xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
  355. xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
  356. xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
  357. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  358. xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
  359. xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
  360. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
  361. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
  362. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
  363. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
  364. xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
  365. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  366. /xinference/thirdparty/{cosyvoice/flow → melo}/__init__.py +0 -0
  367. /xinference/thirdparty/{cosyvoice/hifigan → melo/text/english_utils}/__init__.py +0 -0
  368. /xinference/thirdparty/{cosyvoice/llm → melo/text/es_phonemizer}/__init__.py +0 -0
  369. /xinference/thirdparty/{fish_speech/fish_speech/configs → melo/text/fr_phonemizer}/__init__.py +0 -0
  370. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  371. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/LICENSE +0 -0
  372. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/entry_points.txt +0 -0
  373. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ import torch.nn.functional as F
23
23
 
24
24
  torchaudio.set_audio_backend('soundfile')
25
25
 
26
- AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
26
+ AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
27
27
 
28
28
 
29
29
  def parquet_opener(data, mode='train', tts_data={}):
@@ -40,20 +40,22 @@ def parquet_opener(data, mode='train', tts_data={}):
40
40
  assert 'src' in sample
41
41
  url = sample['src']
42
42
  try:
43
- df = pq.read_table(url).to_pandas()
44
- for i in range(len(df)):
45
- if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
46
- continue
47
- sample.update(dict(df.loc[i]))
48
- if mode == 'train':
49
- # NOTE do not return sample directly, must initialize a new dict
50
- yield {**sample}
51
- else:
52
- for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
53
- yield {**sample, 'tts_index': index, 'tts_text': text}
43
+ for df in pq.ParquetFile(url).iter_batches(batch_size=64):
44
+ df = df.to_pandas()
45
+ for i in range(len(df)):
46
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
47
+ continue
48
+ sample.update(dict(df.loc[i]))
49
+ if mode == 'train':
50
+ # NOTE do not return sample directly, must initialize a new dict
51
+ yield {**sample}
52
+ else:
53
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
54
+ yield {**sample, 'tts_index': index, 'tts_text': text}
54
55
  except Exception as ex:
55
56
  logging.warning('Failed to open {}, ex info {}'.format(url, ex))
56
57
 
58
+
57
59
  def filter(data,
58
60
  max_length=10240,
59
61
  min_length=10,
@@ -84,6 +86,7 @@ def filter(data,
84
86
  """
85
87
  for sample in data:
86
88
  sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
89
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
87
90
  del sample['audio_data']
88
91
  # sample['wav'] is torch.Tensor, we have 100 frames every second
89
92
  num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
@@ -133,6 +136,27 @@ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
133
136
  yield sample
134
137
 
135
138
 
139
+ def truncate(data, truncate_length=24576, mode='train'):
140
+ """ Truncate data.
141
+
142
+ Args:
143
+ data: Iterable[{key, wav, label, sample_rate}]
144
+ truncate_length: truncate length
145
+
146
+ Returns:
147
+ Iterable[{key, wav, label, sample_rate}]
148
+ """
149
+ for sample in data:
150
+ waveform = sample['speech']
151
+ if waveform.shape[1] > truncate_length:
152
+ start = random.randint(0, waveform.shape[1] - truncate_length)
153
+ waveform = waveform[:, start: start + truncate_length]
154
+ else:
155
+ waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
156
+ sample['speech'] = waveform
157
+ yield sample
158
+
159
+
136
160
  def compute_fbank(data,
137
161
  feat_extractor,
138
162
  mode='train'):
@@ -152,7 +176,27 @@ def compute_fbank(data,
152
176
  waveform = sample['speech']
153
177
  mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
154
178
  sample['speech_feat'] = mat
155
- del sample['speech']
179
+ yield sample
180
+
181
+
182
+ def compute_f0(data, pitch_extractor, mode='train'):
183
+ """ Extract f0
184
+
185
+ Args:
186
+ data: Iterable[{key, wav, label, sample_rate}]
187
+
188
+ Returns:
189
+ Iterable[{key, feat, label}]
190
+ """
191
+ for sample in data:
192
+ assert 'sample_rate' in sample
193
+ assert 'speech' in sample
194
+ assert 'utt' in sample
195
+ assert 'text_token' in sample
196
+ waveform = sample['speech']
197
+ mat = pitch_extractor(waveform).transpose(1, 2)
198
+ mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
199
+ sample['pitch_feat'] = mat[0, 0]
156
200
  yield sample
157
201
 
158
202
 
@@ -308,7 +352,7 @@ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, m
308
352
  logging.fatal('Unsupported batch type {}'.format(batch_type))
309
353
 
310
354
 
311
- def padding(data, use_spk_embedding, mode='train'):
355
+ def padding(data, use_spk_embedding, mode='train', gan=False):
312
356
  """ Padding the data into training data
313
357
 
314
358
  Args:
@@ -324,6 +368,9 @@ def padding(data, use_spk_embedding, mode='train'):
324
368
  order = torch.argsort(speech_feat_len, descending=True)
325
369
 
326
370
  utts = [sample[i]['utt'] for i in order]
371
+ speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
372
+ speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
373
+ speech = pad_sequence(speech, batch_first=True, padding_value=0)
327
374
  speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
328
375
  speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
329
376
  speech_token = pad_sequence(speech_token,
@@ -342,6 +389,8 @@ def padding(data, use_spk_embedding, mode='train'):
342
389
  spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
343
390
  batch = {
344
391
  "utts": utts,
392
+ "speech": speech,
393
+ "speech_len": speech_len,
345
394
  "speech_token": speech_token,
346
395
  "speech_token_len": speech_token_len,
347
396
  "speech_feat": speech_feat,
@@ -352,6 +401,19 @@ def padding(data, use_spk_embedding, mode='train'):
352
401
  "utt_embedding": utt_embedding,
353
402
  "spk_embedding": spk_embedding,
354
403
  }
404
+ if gan is True:
405
+ # in gan train, we need pitch_feat
406
+ pitch_feat = [sample[i]['pitch_feat'] for i in order]
407
+ pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
408
+ pitch_feat = pad_sequence(pitch_feat,
409
+ batch_first=True,
410
+ padding_value=0)
411
+ batch["pitch_feat"] = pitch_feat
412
+ batch["pitch_feat_len"] = pitch_feat_len
413
+ else:
414
+ # only gan train needs speech, delete it to save memory
415
+ del batch["speech"]
416
+ del batch["speech_len"]
355
417
  if mode == 'inference':
356
418
  tts_text = [sample[i]['tts_text'] for i in order]
357
419
  tts_index = [sample[i]['tts_index'] for i in order]
@@ -13,16 +13,83 @@
13
13
  # limitations under the License.
14
14
  import torch
15
15
  import torch.nn as nn
16
+ import torch.nn.functional as F
16
17
  from einops import pack, rearrange, repeat
18
+ from cosyvoice.utils.common import mask_to_bias
19
+ from cosyvoice.utils.mask import add_optional_chunk_mask
17
20
  from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
18
21
  from matcha.models.components.transformer import BasicTransformerBlock
19
22
 
20
23
 
24
+ class Transpose(torch.nn.Module):
25
+ def __init__(self, dim0: int, dim1: int):
26
+ super().__init__()
27
+ self.dim0 = dim0
28
+ self.dim1 = dim1
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x = torch.transpose(x, self.dim0, self.dim1)
32
+ return x
33
+
34
+
35
+ class CausalBlock1D(Block1D):
36
+ def __init__(self, dim: int, dim_out: int):
37
+ super(CausalBlock1D, self).__init__(dim, dim_out)
38
+ self.block = torch.nn.Sequential(
39
+ CausalConv1d(dim, dim_out, 3),
40
+ Transpose(1, 2),
41
+ nn.LayerNorm(dim_out),
42
+ Transpose(1, 2),
43
+ nn.Mish(),
44
+ )
45
+
46
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
47
+ output = self.block(x * mask)
48
+ return output * mask
49
+
50
+
51
+ class CausalResnetBlock1D(ResnetBlock1D):
52
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
53
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
54
+ self.block1 = CausalBlock1D(dim, dim_out)
55
+ self.block2 = CausalBlock1D(dim_out, dim_out)
56
+
57
+
58
+ class CausalConv1d(torch.nn.Conv1d):
59
+ def __init__(
60
+ self,
61
+ in_channels: int,
62
+ out_channels: int,
63
+ kernel_size: int,
64
+ stride: int = 1,
65
+ dilation: int = 1,
66
+ groups: int = 1,
67
+ bias: bool = True,
68
+ padding_mode: str = 'zeros',
69
+ device=None,
70
+ dtype=None
71
+ ) -> None:
72
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
73
+ kernel_size, stride,
74
+ padding=0, dilation=dilation,
75
+ groups=groups, bias=bias,
76
+ padding_mode=padding_mode,
77
+ device=device, dtype=dtype)
78
+ assert stride == 1
79
+ self.causal_padding = (kernel_size - 1, 0)
80
+
81
+ def forward(self, x: torch.Tensor):
82
+ x = F.pad(x, self.causal_padding)
83
+ x = super(CausalConv1d, self).forward(x)
84
+ return x
85
+
86
+
21
87
  class ConditionalDecoder(nn.Module):
22
88
  def __init__(
23
89
  self,
24
90
  in_channels,
25
91
  out_channels,
92
+ causal=False,
26
93
  channels=(256, 256),
27
94
  dropout=0.05,
28
95
  attention_head_dim=64,
@@ -39,7 +106,7 @@ class ConditionalDecoder(nn.Module):
39
106
  channels = tuple(channels)
40
107
  self.in_channels = in_channels
41
108
  self.out_channels = out_channels
42
-
109
+ self.causal = causal
43
110
  self.time_embeddings = SinusoidalPosEmb(in_channels)
44
111
  time_embed_dim = channels[0] * 4
45
112
  self.time_mlp = TimestepEmbedding(
@@ -56,7 +123,8 @@ class ConditionalDecoder(nn.Module):
56
123
  input_channel = output_channel
57
124
  output_channel = channels[i]
58
125
  is_last = i == len(channels) - 1
59
- resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
126
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
127
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
60
128
  transformer_blocks = nn.ModuleList(
61
129
  [
62
130
  BasicTransformerBlock(
@@ -70,14 +138,16 @@ class ConditionalDecoder(nn.Module):
70
138
  ]
71
139
  )
72
140
  downsample = (
73
- Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
141
+ Downsample1D(output_channel) if not is_last else
142
+ CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
74
143
  )
75
144
  self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
76
145
 
77
- for i in range(num_mid_blocks):
146
+ for _ in range(num_mid_blocks):
78
147
  input_channel = channels[-1]
79
148
  out_channels = channels[-1]
80
- resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
149
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
150
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
81
151
 
82
152
  transformer_blocks = nn.ModuleList(
83
153
  [
@@ -99,7 +169,11 @@ class ConditionalDecoder(nn.Module):
99
169
  input_channel = channels[i] * 2
100
170
  output_channel = channels[i + 1]
101
171
  is_last = i == len(channels) - 2
102
- resnet = ResnetBlock1D(
172
+ resnet = CausalResnetBlock1D(
173
+ dim=input_channel,
174
+ dim_out=output_channel,
175
+ time_emb_dim=time_embed_dim,
176
+ ) if self.causal else ResnetBlock1D(
103
177
  dim=input_channel,
104
178
  dim_out=output_channel,
105
179
  time_emb_dim=time_embed_dim,
@@ -119,14 +193,13 @@ class ConditionalDecoder(nn.Module):
119
193
  upsample = (
120
194
  Upsample1D(output_channel, use_conv_transpose=True)
121
195
  if not is_last
122
- else nn.Conv1d(output_channel, output_channel, 3, padding=1)
196
+ else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
123
197
  )
124
198
  self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
125
- self.final_block = Block1D(channels[-1], channels[-1])
199
+ self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
126
200
  self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
127
201
  self.initialize_weights()
128
202
 
129
-
130
203
  def initialize_weights(self):
131
204
  for m in self.modules():
132
205
  if isinstance(m, nn.Conv1d):
@@ -159,7 +232,7 @@ class ConditionalDecoder(nn.Module):
159
232
  _type_: _description_
160
233
  """
161
234
 
162
- t = self.time_embeddings(t)
235
+ t = self.time_embeddings(t).to(t.dtype)
163
236
  t = self.time_mlp(t)
164
237
 
165
238
  x = pack([x, mu], "b * t")[0]
@@ -176,7 +249,9 @@ class ConditionalDecoder(nn.Module):
176
249
  mask_down = masks[-1]
177
250
  x = resnet(x, mask_down, t)
178
251
  x = rearrange(x, "b c t -> b t c").contiguous()
179
- attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
252
+ # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
253
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
254
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
180
255
  for transformer_block in transformer_blocks:
181
256
  x = transformer_block(
182
257
  hidden_states=x,
@@ -193,7 +268,9 @@ class ConditionalDecoder(nn.Module):
193
268
  for resnet, transformer_blocks in self.mid_blocks:
194
269
  x = resnet(x, mask_mid, t)
195
270
  x = rearrange(x, "b c t -> b t c").contiguous()
196
- attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
271
+ # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
272
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
273
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
197
274
  for transformer_block in transformer_blocks:
198
275
  x = transformer_block(
199
276
  hidden_states=x,
@@ -208,7 +285,9 @@ class ConditionalDecoder(nn.Module):
208
285
  x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
209
286
  x = resnet(x, mask_up, t)
210
287
  x = rearrange(x, "b c t -> b t c").contiguous()
211
- attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
288
+ # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
289
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
290
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
212
291
  for transformer_block in transformer_blocks:
213
292
  x = transformer_block(
214
293
  hidden_states=x,
@@ -33,8 +33,13 @@ class MaskedDiffWithXvec(torch.nn.Module):
33
33
  encoder: torch.nn.Module = None,
34
34
  length_regulator: torch.nn.Module = None,
35
35
  decoder: torch.nn.Module = None,
36
- decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
37
- mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
36
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
37
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
38
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
39
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
40
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
41
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
42
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
38
43
  super().__init__()
39
44
  self.input_size = input_size
40
45
  self.output_size = output_size
@@ -104,7 +109,8 @@ class MaskedDiffWithXvec(torch.nn.Module):
104
109
  prompt_token_len,
105
110
  prompt_feat,
106
111
  prompt_feat_len,
107
- embedding):
112
+ embedding,
113
+ flow_cache):
108
114
  assert token.shape[0] == 1
109
115
  # xvec projection
110
116
  embedding = F.normalize(embedding, dim=1)
@@ -113,23 +119,107 @@ class MaskedDiffWithXvec(torch.nn.Module):
113
119
  # concat text and prompt_text
114
120
  token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
115
121
  token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
116
- mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
122
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
117
123
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
118
124
 
119
125
  # text encode
120
126
  h, h_lengths = self.encoder(token, token_len)
121
127
  h = self.encoder_proj(h)
122
- mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / 50 * 22050 / 256)
123
- h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
128
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
129
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
124
130
 
125
131
  # get conditions
126
132
  conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
127
133
  conds[:, :mel_len1] = prompt_feat
128
134
  conds = conds.transpose(1, 2)
129
135
 
130
- # mask = (~make_pad_mask(feat_len)).to(h)
131
136
  mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
132
- feat = self.decoder(
137
+ feat, flow_cache = self.decoder(
138
+ mu=h.transpose(1, 2).contiguous(),
139
+ mask=mask.unsqueeze(1),
140
+ spks=embedding,
141
+ cond=conds,
142
+ n_timesteps=10,
143
+ prompt_len=mel_len1,
144
+ flow_cache=flow_cache
145
+ )
146
+ feat = feat[:, :, mel_len1:]
147
+ assert feat.shape[2] == mel_len2
148
+ return feat, flow_cache
149
+
150
+
151
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
152
+ def __init__(self,
153
+ input_size: int = 512,
154
+ output_size: int = 80,
155
+ spk_embed_dim: int = 192,
156
+ output_type: str = "mel",
157
+ vocab_size: int = 4096,
158
+ input_frame_rate: int = 50,
159
+ only_mask_loss: bool = True,
160
+ token_mel_ratio: int = 2,
161
+ pre_lookahead_len: int = 3,
162
+ encoder: torch.nn.Module = None,
163
+ decoder: torch.nn.Module = None,
164
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
165
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
166
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
167
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
168
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
169
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
170
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
171
+ super().__init__()
172
+ self.input_size = input_size
173
+ self.output_size = output_size
174
+ self.decoder_conf = decoder_conf
175
+ self.mel_feat_conf = mel_feat_conf
176
+ self.vocab_size = vocab_size
177
+ self.output_type = output_type
178
+ self.input_frame_rate = input_frame_rate
179
+ logging.info(f"input frame rate={self.input_frame_rate}")
180
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
181
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
182
+ self.encoder = encoder
183
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
184
+ self.decoder = decoder
185
+ self.only_mask_loss = only_mask_loss
186
+ self.token_mel_ratio = token_mel_ratio
187
+ self.pre_lookahead_len = pre_lookahead_len
188
+
189
+ @torch.inference_mode()
190
+ def inference(self,
191
+ token,
192
+ token_len,
193
+ prompt_token,
194
+ prompt_token_len,
195
+ prompt_feat,
196
+ prompt_feat_len,
197
+ embedding,
198
+ finalize):
199
+ assert token.shape[0] == 1
200
+ # xvec projection
201
+ embedding = F.normalize(embedding, dim=1)
202
+ embedding = self.spk_embed_affine_layer(embedding)
203
+
204
+ # concat text and prompt_text
205
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
206
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
207
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
208
+
209
+ # text encode
210
+ h, h_lengths = self.encoder(token, token_len)
211
+ if finalize is False:
212
+ h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
213
+ mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
214
+ h = self.encoder_proj(h)
215
+
216
+ # get conditions
217
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
218
+ conds[:, :mel_len1] = prompt_feat
219
+ conds = conds.transpose(1, 2)
220
+
221
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
222
+ feat, _ = self.decoder(
133
223
  mu=h.transpose(1, 2).contiguous(),
134
224
  mask=mask.unsqueeze(1),
135
225
  spks=embedding,
@@ -138,4 +228,4 @@ class MaskedDiffWithXvec(torch.nn.Module):
138
228
  )
139
229
  feat = feat[:, :, mel_len1:]
140
230
  assert feat.shape[2] == mel_len2
141
- return feat
231
+ return feat, None
@@ -11,10 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import onnxruntime
14
15
  import torch
15
16
  import torch.nn.functional as F
16
17
  from matcha.models.components.flow_matching import BASECFM
17
18
 
19
+
18
20
  class ConditionalCFM(BASECFM):
19
21
  def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
20
22
  super().__init__(
@@ -31,7 +33,7 @@ class ConditionalCFM(BASECFM):
31
33
  self.estimator = estimator
32
34
 
33
35
  @torch.inference_mode()
34
- def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
36
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
35
37
  """Forward diffusion
36
38
 
37
39
  Args:
@@ -49,11 +51,21 @@ class ConditionalCFM(BASECFM):
49
51
  sample: generated mel-spectrogram
50
52
  shape: (batch_size, n_feats, mel_timesteps)
51
53
  """
54
+
52
55
  z = torch.randn_like(mu) * temperature
53
- t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
56
+ cache_size = flow_cache.shape[2]
57
+ # fix prompt and overlap part mu and z
58
+ if cache_size != 0:
59
+ z[:, :, :cache_size] = flow_cache[:, :, :, 0]
60
+ mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
61
+ z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
62
+ mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
63
+ flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
64
+
65
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
54
66
  if self.t_scheduler == 'cosine':
55
67
  t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
56
- return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
68
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
57
69
 
58
70
  def solve_euler(self, x, t_span, mu, mask, spks, cond):
59
71
  """
@@ -71,30 +83,80 @@ class ConditionalCFM(BASECFM):
71
83
  cond: Not used but kept for future purposes
72
84
  """
73
85
  t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
86
+ t = t.unsqueeze(dim=0)
74
87
 
75
88
  # I am storing this because I can later plot it by putting a debugger here and saving it to a file
76
89
  # Or in future might add like a return_all_steps flag
77
90
  sol = []
78
91
 
92
+ if self.inference_cfg_rate > 0:
93
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
94
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
95
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
96
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
97
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
98
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
99
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
100
+ else:
101
+ x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
79
102
  for step in range(1, len(t_span)):
80
- dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
81
103
  # Classifier-Free Guidance inference introduced in VoiceBox
82
104
  if self.inference_cfg_rate > 0:
83
- cfg_dphi_dt = self.estimator(
84
- x, mask,
85
- torch.zeros_like(mu), t,
86
- torch.zeros_like(spks) if spks is not None else None,
87
- torch.zeros_like(cond)
88
- )
89
- dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
90
- self.inference_cfg_rate * cfg_dphi_dt)
105
+ x_in[:] = x
106
+ mask_in[:] = mask
107
+ mu_in[0] = mu
108
+ t_in[:] = t.unsqueeze(0)
109
+ spks_in[0] = spks
110
+ cond_in[0] = cond
111
+ else:
112
+ x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
113
+ dphi_dt = self.forward_estimator(
114
+ x_in, mask_in,
115
+ mu_in, t_in,
116
+ spks_in,
117
+ cond_in
118
+ )
119
+ if self.inference_cfg_rate > 0:
120
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
121
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
91
122
  x = x + dt * dphi_dt
92
123
  t = t + dt
93
124
  sol.append(x)
94
125
  if step < len(t_span) - 1:
95
126
  dt = t_span[step + 1] - t
96
127
 
97
- return sol[-1]
128
+ return sol[-1].float()
129
+
130
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
131
+ if isinstance(self.estimator, torch.nn.Module):
132
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
133
+ elif isinstance(self.estimator, onnxruntime.InferenceSession):
134
+ ort_inputs = {
135
+ 'x': x.cpu().numpy(),
136
+ 'mask': mask.cpu().numpy(),
137
+ 'mu': mu.cpu().numpy(),
138
+ 't': t.cpu().numpy(),
139
+ 'spks': spks.cpu().numpy(),
140
+ 'cond': cond.cpu().numpy()
141
+ }
142
+ output = self.estimator.run(None, ort_inputs)[0]
143
+ return torch.tensor(output, dtype=x.dtype, device=x.device)
144
+ else:
145
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
146
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
147
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
148
+ self.estimator.set_input_shape('t', (2,))
149
+ self.estimator.set_input_shape('spks', (2, 80))
150
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
151
+ # run trt engine
152
+ self.estimator.execute_v2([x.contiguous().data_ptr(),
153
+ mask.contiguous().data_ptr(),
154
+ mu.contiguous().data_ptr(),
155
+ t.contiguous().data_ptr(),
156
+ spks.contiguous().data_ptr(),
157
+ cond.contiguous().data_ptr(),
158
+ x.data_ptr()])
159
+ return x
98
160
 
99
161
  def compute_loss(self, x1, mask, mu, spks=None, cond=None):
100
162
  """Computes diffusion loss
@@ -136,3 +198,38 @@ class ConditionalCFM(BASECFM):
136
198
  pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
137
199
  loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
138
200
  return loss, y
201
+
202
+
203
+ class CausalConditionalCFM(ConditionalCFM):
204
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
205
+ super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
206
+ self.rand_noise = torch.randn([1, 80, 50 * 300])
207
+
208
+ @torch.inference_mode()
209
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
210
+ """Forward diffusion
211
+
212
+ Args:
213
+ mu (torch.Tensor): output of encoder
214
+ shape: (batch_size, n_feats, mel_timesteps)
215
+ mask (torch.Tensor): output_mask
216
+ shape: (batch_size, 1, mel_timesteps)
217
+ n_timesteps (int): number of diffusion steps
218
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
219
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
220
+ shape: (batch_size, spk_emb_dim)
221
+ cond: Not used but kept for future purposes
222
+
223
+ Returns:
224
+ sample: generated mel-spectrogram
225
+ shape: (batch_size, n_feats, mel_timesteps)
226
+ """
227
+
228
+ z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
229
+ if self.fp16 is True:
230
+ z = z.half()
231
+ # fix prompt and overlap part mu and z
232
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
233
+ if self.t_scheduler == 'cosine':
234
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
235
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None