xinference 0.16.3__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (373) hide show
  1. xinference/_compat.py +24 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +219 -77
  4. xinference/client/restful/restful_client.py +47 -2
  5. xinference/constants.py +1 -0
  6. xinference/core/chat_interface.py +6 -1
  7. xinference/core/model.py +124 -34
  8. xinference/core/supervisor.py +180 -12
  9. xinference/core/utils.py +73 -4
  10. xinference/core/worker.py +102 -4
  11. xinference/deploy/cmdline.py +3 -1
  12. xinference/deploy/test/test_cmdline.py +56 -0
  13. xinference/isolation.py +24 -0
  14. xinference/model/audio/__init__.py +12 -0
  15. xinference/model/audio/core.py +37 -4
  16. xinference/model/audio/cosyvoice.py +39 -6
  17. xinference/model/audio/f5tts.py +200 -0
  18. xinference/model/audio/f5tts_mlx.py +260 -0
  19. xinference/model/audio/fish_speech.py +70 -110
  20. xinference/model/audio/melotts.py +110 -0
  21. xinference/model/audio/model_spec.json +179 -3
  22. xinference/model/audio/model_spec_modelscope.json +27 -0
  23. xinference/model/audio/utils.py +32 -0
  24. xinference/model/audio/whisper.py +35 -10
  25. xinference/model/audio/whisper_mlx.py +208 -0
  26. xinference/model/embedding/core.py +322 -6
  27. xinference/model/embedding/model_spec.json +8 -1
  28. xinference/model/embedding/model_spec_modelscope.json +9 -1
  29. xinference/model/image/core.py +69 -1
  30. xinference/model/image/model_spec.json +145 -4
  31. xinference/model/image/model_spec_modelscope.json +150 -4
  32. xinference/model/image/stable_diffusion/core.py +50 -15
  33. xinference/model/llm/__init__.py +6 -2
  34. xinference/model/llm/llm_family.json +1055 -93
  35. xinference/model/llm/llm_family.py +15 -36
  36. xinference/model/llm/llm_family_modelscope.json +1031 -78
  37. xinference/model/llm/memory.py +1 -1
  38. xinference/model/llm/mlx/core.py +285 -47
  39. xinference/model/llm/sglang/core.py +2 -0
  40. xinference/model/llm/transformers/chatglm.py +9 -5
  41. xinference/model/llm/transformers/cogagent.py +272 -0
  42. xinference/model/llm/transformers/core.py +3 -0
  43. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  44. xinference/model/llm/transformers/qwen2_vl.py +12 -1
  45. xinference/model/llm/transformers/utils.py +16 -8
  46. xinference/model/llm/utils.py +55 -4
  47. xinference/model/llm/vllm/core.py +137 -12
  48. xinference/model/llm/vllm/xavier/__init__.py +13 -0
  49. xinference/model/llm/vllm/xavier/allocator.py +74 -0
  50. xinference/model/llm/vllm/xavier/block.py +111 -0
  51. xinference/model/llm/vllm/xavier/block_manager.py +71 -0
  52. xinference/model/llm/vllm/xavier/block_tracker.py +129 -0
  53. xinference/model/llm/vllm/xavier/collective.py +74 -0
  54. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  55. xinference/model/llm/vllm/xavier/engine.py +247 -0
  56. xinference/model/llm/vllm/xavier/executor.py +134 -0
  57. xinference/model/llm/vllm/xavier/scheduler.py +438 -0
  58. xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
  59. xinference/model/llm/vllm/xavier/test/test_xavier.py +147 -0
  60. xinference/model/llm/vllm/xavier/transfer.py +319 -0
  61. xinference/model/rerank/core.py +11 -4
  62. xinference/model/video/diffusers.py +14 -0
  63. xinference/model/video/model_spec.json +15 -0
  64. xinference/model/video/model_spec_modelscope.json +16 -0
  65. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  66. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  67. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  68. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  69. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  70. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  71. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  72. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  73. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  74. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  75. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  76. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  77. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  78. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  79. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  80. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  81. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  82. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  83. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  84. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  85. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  86. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  87. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  88. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  89. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  90. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  91. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  92. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  93. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  94. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  95. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  96. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  97. xinference/thirdparty/f5_tts/api.py +166 -0
  98. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  99. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  100. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  101. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  102. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  103. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  104. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  105. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  106. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  107. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  108. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  109. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  110. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  111. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  112. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  113. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  114. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  115. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  116. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  117. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  118. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  119. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  120. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  121. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  122. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  123. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  124. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  125. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  126. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  127. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  128. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  129. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  130. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  131. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  132. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  133. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  134. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  135. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  136. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  137. xinference/thirdparty/f5_tts/train/README.md +77 -0
  138. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  139. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  140. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  141. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  142. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  143. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  144. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  145. xinference/thirdparty/f5_tts/train/train.py +75 -0
  146. xinference/thirdparty/fish_speech/fish_speech/conversation.py +266 -1
  147. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  148. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  149. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  150. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  151. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  152. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +137 -29
  153. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  154. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  155. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +17 -11
  156. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  157. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  158. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  159. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  160. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  161. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  162. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +2 -2
  163. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +34 -18
  164. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  165. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  166. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  167. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  168. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  169. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  170. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  171. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  172. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  173. xinference/thirdparty/fish_speech/tools/llama/generate.py +484 -72
  174. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  175. xinference/thirdparty/fish_speech/tools/schema.py +170 -0
  176. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  177. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  178. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  179. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  180. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  181. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  182. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  183. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  184. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  185. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  186. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  187. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  188. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  189. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  190. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  191. xinference/thirdparty/matcha/utils/utils.py +2 -2
  192. xinference/thirdparty/melo/api.py +135 -0
  193. xinference/thirdparty/melo/app.py +61 -0
  194. xinference/thirdparty/melo/attentions.py +459 -0
  195. xinference/thirdparty/melo/commons.py +160 -0
  196. xinference/thirdparty/melo/configs/config.json +94 -0
  197. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  198. xinference/thirdparty/melo/data_utils.py +413 -0
  199. xinference/thirdparty/melo/download_utils.py +67 -0
  200. xinference/thirdparty/melo/infer.py +25 -0
  201. xinference/thirdparty/melo/init_downloads.py +14 -0
  202. xinference/thirdparty/melo/losses.py +58 -0
  203. xinference/thirdparty/melo/main.py +36 -0
  204. xinference/thirdparty/melo/mel_processing.py +174 -0
  205. xinference/thirdparty/melo/models.py +1030 -0
  206. xinference/thirdparty/melo/modules.py +598 -0
  207. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  208. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  209. xinference/thirdparty/melo/preprocess_text.py +135 -0
  210. xinference/thirdparty/melo/split_utils.py +174 -0
  211. xinference/thirdparty/melo/text/__init__.py +35 -0
  212. xinference/thirdparty/melo/text/chinese.py +199 -0
  213. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  214. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  215. xinference/thirdparty/melo/text/cleaner.py +36 -0
  216. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  217. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  218. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  219. xinference/thirdparty/melo/text/english.py +284 -0
  220. xinference/thirdparty/melo/text/english_bert.py +39 -0
  221. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  222. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  223. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  224. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  225. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  226. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  227. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  228. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  229. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  230. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  231. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  232. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  233. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  234. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  235. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  236. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  237. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  238. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  239. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  240. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  241. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  242. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  243. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  244. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  245. xinference/thirdparty/melo/text/french.py +94 -0
  246. xinference/thirdparty/melo/text/french_bert.py +39 -0
  247. xinference/thirdparty/melo/text/japanese.py +647 -0
  248. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  249. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  250. xinference/thirdparty/melo/text/korean.py +192 -0
  251. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  252. xinference/thirdparty/melo/text/spanish.py +122 -0
  253. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  254. xinference/thirdparty/melo/text/symbols.py +290 -0
  255. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  256. xinference/thirdparty/melo/train.py +635 -0
  257. xinference/thirdparty/melo/train.sh +19 -0
  258. xinference/thirdparty/melo/transforms.py +209 -0
  259. xinference/thirdparty/melo/utils.py +424 -0
  260. xinference/types.py +17 -1
  261. xinference/web/ui/build/asset-manifest.json +6 -6
  262. xinference/web/ui/build/index.html +1 -1
  263. xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
  264. xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
  265. xinference/web/ui/build/static/js/main.b0936c54.js +3 -0
  266. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  267. xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
  268. xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
  269. xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
  270. xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
  271. xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
  272. xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
  273. xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
  274. xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
  275. xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
  276. xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
  277. xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
  278. xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
  279. xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
  280. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
  281. xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
  282. xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
  283. xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
  284. xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
  285. xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
  286. xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
  287. xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
  288. xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
  289. xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
  290. xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
  291. xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
  292. xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
  293. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  294. xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
  295. xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
  296. xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
  297. xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
  298. xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
  299. xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
  300. xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
  301. xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
  302. xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
  303. xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
  304. xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
  305. xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
  306. xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
  307. xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
  308. xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
  309. xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
  310. xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
  311. xinference/web/ui/node_modules/.package-lock.json +67 -3
  312. xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
  313. xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
  314. xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
  315. xinference/web/ui/node_modules/i18next/package.json +129 -0
  316. xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
  317. xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
  318. xinference/web/ui/node_modules/react-i18next/package.json +162 -0
  319. xinference/web/ui/node_modules/void-elements/package.json +34 -0
  320. xinference/web/ui/package-lock.json +69 -3
  321. xinference/web/ui/package.json +2 -0
  322. xinference/web/ui/src/locales/en.json +186 -0
  323. xinference/web/ui/src/locales/zh.json +186 -0
  324. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/METADATA +96 -36
  325. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/RECORD +335 -146
  326. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/WHEEL +1 -1
  327. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  328. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  329. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  330. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  331. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  332. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  333. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  334. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  335. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  336. xinference/thirdparty/fish_speech/tools/api.py +0 -440
  337. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  338. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  339. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -34
  340. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  341. xinference/thirdparty/fish_speech/tools/webui.py +0 -485
  342. xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
  343. xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
  344. xinference/web/ui/build/static/js/main.2f269bb3.js +0 -3
  345. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  346. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
  347. xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
  348. xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
  349. xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
  350. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
  351. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
  352. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
  353. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
  354. xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
  355. xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
  356. xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
  357. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  358. xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
  359. xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
  360. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
  361. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
  362. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
  363. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
  364. xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
  365. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  366. /xinference/thirdparty/{cosyvoice/flow → melo}/__init__.py +0 -0
  367. /xinference/thirdparty/{cosyvoice/hifigan → melo/text/english_utils}/__init__.py +0 -0
  368. /xinference/thirdparty/{cosyvoice/llm → melo/text/es_phonemizer}/__init__.py +0 -0
  369. /xinference/thirdparty/{fish_speech/fish_speech/configs → melo/text/fr_phonemizer}/__init__.py +0 -0
  370. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  371. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/LICENSE +0 -0
  372. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/entry_points.txt +0 -0
  373. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,75 @@
1
+ from argparse import ArgumentParser
2
+ from http import HTTPStatus
3
+ from typing import Annotated, Any
4
+
5
+ import ormsgpack
6
+ from baize.datastructures import ContentType
7
+ from kui.asgi import HTTPException, HttpRequest
8
+
9
+ from tools.inference_engine import TTSInferenceEngine
10
+ from tools.schema import ServeTTSRequest
11
+ from tools.server.inference import inference_wrapper as inference
12
+
13
+
14
+ def parse_args():
15
+ parser = ArgumentParser()
16
+ parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
17
+ parser.add_argument("--load-asr-model", action="store_true")
18
+ parser.add_argument(
19
+ "--llama-checkpoint-path",
20
+ type=str,
21
+ default="checkpoints/fish-speech-1.5",
22
+ )
23
+ parser.add_argument(
24
+ "--decoder-checkpoint-path",
25
+ type=str,
26
+ default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
27
+ )
28
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
29
+ parser.add_argument("--device", type=str, default="cuda")
30
+ parser.add_argument("--half", action="store_true")
31
+ parser.add_argument("--compile", action="store_true")
32
+ parser.add_argument("--max-text-length", type=int, default=0)
33
+ parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
34
+ parser.add_argument("--workers", type=int, default=1)
35
+
36
+ return parser.parse_args()
37
+
38
+
39
+ class MsgPackRequest(HttpRequest):
40
+ async def data(
41
+ self,
42
+ ) -> Annotated[
43
+ Any, ContentType("application/msgpack"), ContentType("application/json")
44
+ ]:
45
+ if self.content_type == "application/msgpack":
46
+ return ormsgpack.unpackb(await self.body)
47
+
48
+ elif self.content_type == "application/json":
49
+ return await self.json
50
+
51
+ raise HTTPException(
52
+ HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
53
+ headers={"Accept": "application/msgpack, application/json"},
54
+ )
55
+
56
+
57
+ async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
58
+ for chunk in inference(req, engine):
59
+ if isinstance(chunk, bytes):
60
+ yield chunk
61
+
62
+
63
+ async def buffer_to_async_generator(buffer):
64
+ yield buffer
65
+
66
+
67
+ def get_content_type(audio_format):
68
+ if audio_format == "wav":
69
+ return "audio/wav"
70
+ elif audio_format == "flac":
71
+ return "audio/flac"
72
+ elif audio_format == "mp3":
73
+ return "audio/mpeg"
74
+ else:
75
+ return "application/octet-stream"
@@ -0,0 +1,27 @@
1
+ import traceback
2
+ from http import HTTPStatus
3
+
4
+ from kui.asgi import HTTPException, JSONResponse
5
+
6
+
7
+ class ExceptionHandler:
8
+
9
+ async def http_exception_handler(self, exc: HTTPException):
10
+ return JSONResponse(
11
+ dict(
12
+ statusCode=exc.status_code,
13
+ message=exc.content,
14
+ error=HTTPStatus(exc.status_code).phrase,
15
+ ),
16
+ exc.status_code,
17
+ exc.headers,
18
+ )
19
+
20
+ async def other_exception_handler(self, exc: Exception):
21
+ traceback.print_exc()
22
+
23
+ status = HTTPStatus.INTERNAL_SERVER_ERROR
24
+ return JSONResponse(
25
+ dict(statusCode=status, message=str(exc), error=status.phrase),
26
+ status,
27
+ )
@@ -0,0 +1,45 @@
1
+ from http import HTTPStatus
2
+
3
+ import numpy as np
4
+ from kui.asgi import HTTPException
5
+
6
+ from tools.inference_engine import TTSInferenceEngine
7
+ from tools.schema import ServeTTSRequest
8
+
9
+ AMPLITUDE = 32768 # Needs an explaination
10
+
11
+
12
+ def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
13
+ """
14
+ Wrapper for the inference function.
15
+ Used in the API server.
16
+ """
17
+ count = 0
18
+ for result in engine.inference(req):
19
+ match result.code:
20
+ case "header":
21
+ if isinstance(result.audio, tuple):
22
+ yield result.audio[1]
23
+
24
+ case "error":
25
+ raise HTTPException(
26
+ HTTPStatus.INTERNAL_SERVER_ERROR,
27
+ content=str(result.error),
28
+ )
29
+
30
+ case "segment":
31
+ count += 1
32
+ if isinstance(result.audio, tuple):
33
+ yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes()
34
+
35
+ case "final":
36
+ count += 1
37
+ if isinstance(result.audio, tuple):
38
+ yield result.audio[1]
39
+ return None # Stop the generator
40
+
41
+ if count == 0:
42
+ raise HTTPException(
43
+ HTTPStatus.INTERNAL_SERVER_ERROR,
44
+ content="No audio generated, please check the input text.",
45
+ )
@@ -0,0 +1,122 @@
1
+ import torch
2
+ from funasr import AutoModel
3
+ from loguru import logger
4
+
5
+ from tools.inference_engine import TTSInferenceEngine
6
+ from tools.llama.generate import (
7
+ launch_thread_safe_queue,
8
+ launch_thread_safe_queue_agent,
9
+ )
10
+ from tools.schema import ServeTTSRequest
11
+ from tools.server.inference import inference_wrapper as inference
12
+ from tools.vqgan.inference import load_model as load_decoder_model
13
+
14
+ ASR_MODEL_NAME = "iic/SenseVoiceSmall"
15
+
16
+
17
+ class ModelManager:
18
+ def __init__(
19
+ self,
20
+ mode: str,
21
+ device: str,
22
+ half: bool,
23
+ compile: bool,
24
+ asr_enabled: bool,
25
+ llama_checkpoint_path: str,
26
+ decoder_checkpoint_path: str,
27
+ decoder_config_name: str,
28
+ ) -> None:
29
+
30
+ self.mode = mode
31
+ self.device = device
32
+ self.half = half
33
+ self.compile = compile
34
+
35
+ self.precision = torch.half if half else torch.bfloat16
36
+
37
+ # Check if MPS or CUDA is available
38
+ if torch.backends.mps.is_available():
39
+ self.device = "mps"
40
+ logger.info("mps is available, running on mps.")
41
+ elif not torch.cuda.is_available():
42
+ self.device = "cpu"
43
+ logger.info("CUDA is not available, running on CPU.")
44
+
45
+ # Load the ASR model if enabled
46
+ if asr_enabled:
47
+ self.load_asr_model(self.device)
48
+
49
+ # Load the TTS models
50
+ self.load_llama_model(
51
+ llama_checkpoint_path, self.device, self.precision, self.compile, self.mode
52
+ )
53
+ self.load_decoder_model(
54
+ decoder_config_name, decoder_checkpoint_path, self.device
55
+ )
56
+ self.tts_inference_engine = TTSInferenceEngine(
57
+ llama_queue=self.llama_queue,
58
+ decoder_model=self.decoder_model,
59
+ precision=self.precision,
60
+ compile=self.compile,
61
+ )
62
+
63
+ # Warm up the models
64
+ if self.mode == "tts":
65
+ self.warm_up(self.tts_inference_engine)
66
+
67
+ def load_asr_model(self, device, hub="ms") -> None:
68
+ self.asr_model = AutoModel(
69
+ model=ASR_MODEL_NAME,
70
+ device=device,
71
+ disable_pbar=True,
72
+ hub=hub,
73
+ )
74
+ logger.info("ASR model loaded.")
75
+
76
+ def load_llama_model(
77
+ self, checkpoint_path, device, precision, compile, mode
78
+ ) -> None:
79
+
80
+ if mode == "tts":
81
+ self.llama_queue = launch_thread_safe_queue(
82
+ checkpoint_path=checkpoint_path,
83
+ device=device,
84
+ precision=precision,
85
+ compile=compile,
86
+ )
87
+ elif mode == "agent":
88
+ self.llama_queue, self.tokenizer, self.config = (
89
+ launch_thread_safe_queue_agent(
90
+ checkpoint_path=checkpoint_path,
91
+ device=device,
92
+ precision=precision,
93
+ compile=compile,
94
+ )
95
+ )
96
+ else:
97
+ raise ValueError(f"Invalid mode: {mode}")
98
+
99
+ logger.info("LLAMA model loaded.")
100
+
101
+ def load_decoder_model(self, config_name, checkpoint_path, device) -> None:
102
+ self.decoder_model = load_decoder_model(
103
+ config_name=config_name,
104
+ checkpoint_path=checkpoint_path,
105
+ device=device,
106
+ )
107
+ logger.info("Decoder model loaded.")
108
+
109
+ def warm_up(self, tts_inference_engine) -> None:
110
+ request = ServeTTSRequest(
111
+ text="Hello world.",
112
+ references=[],
113
+ reference_id=None,
114
+ max_new_tokens=1024,
115
+ chunk_length=200,
116
+ top_p=0.7,
117
+ repetition_penalty=1.2,
118
+ temperature=0.7,
119
+ format="wav",
120
+ )
121
+ list(inference(request, tts_inference_engine))
122
+ logger.info("Models warmed up.")
@@ -0,0 +1,129 @@
1
+ import io
2
+ import re
3
+
4
+ import librosa
5
+ import torch
6
+ import torchaudio
7
+ from cachetools import LRUCache, cached
8
+
9
+ CACHE_MAXSIZE = 10000
10
+ MICRO_BATCH_SIZE = 8
11
+ ASR_SAMPLE_RATE = 16000
12
+ HUGE_GAP_THRESHOLD = 4000
13
+
14
+
15
+ @torch.no_grad()
16
+ @torch.autocast(device_type="cuda", dtype=torch.half)
17
+ def batch_encode(model, audios_list: list[bytes]):
18
+ audios: list[torch.Tensor] = [
19
+ (
20
+ torch.from_numpy(
21
+ librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
22
+ )[None]
23
+ if isinstance(audio, bytes)
24
+ else audio
25
+ )
26
+ for audio in audios_list
27
+ ]
28
+
29
+ lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
30
+ max_length = lengths.max().item()
31
+
32
+ print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
33
+
34
+ padded = torch.stack(
35
+ [
36
+ torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1])))
37
+ for audio in audios
38
+ ]
39
+ ).to(model.device)
40
+
41
+ features, feature_lengths = model.encode(padded, audio_lengths=lengths)
42
+ features, feature_lengths = features.cpu(), feature_lengths.cpu()
43
+
44
+ return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
45
+
46
+
47
+ @cached(
48
+ cache=LRUCache(maxsize=CACHE_MAXSIZE),
49
+ key=lambda model, audios: (model.device, tuple(audios)),
50
+ )
51
+ def cached_vqgan_batch_encode(model, audios: list[bytes]):
52
+ return batch_encode(model, audios)
53
+
54
+
55
+ @torch.no_grad()
56
+ @torch.autocast(device_type="cuda", dtype=torch.half)
57
+ def vqgan_decode(model, features):
58
+ lengths = torch.tensor(
59
+ [feature.shape[-1] for feature in features], device=model.device
60
+ )
61
+ max_length = lengths.max().item()
62
+ padded = torch.stack(
63
+ [
64
+ torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
65
+ for feature in features
66
+ ]
67
+ ).to(model.device)
68
+
69
+ # If bs too large, we do micro batch decode
70
+ audios, audio_lengths = [], []
71
+ for i in range(0, padded.shape[0], MICRO_BATCH_SIZE):
72
+ audio, audio_length = model.decode(
73
+ padded[i : i + MICRO_BATCH_SIZE],
74
+ feature_lengths=lengths[i : i + MICRO_BATCH_SIZE],
75
+ )
76
+ audios.append(audio)
77
+ audio_lengths.append(audio_length)
78
+ audios = torch.cat(audios, dim=0)
79
+ audio_lengths = torch.cat(audio_lengths, dim=0)
80
+ audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
81
+
82
+ return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
83
+
84
+
85
+ @torch.no_grad()
86
+ def batch_asr(model, lock, audios, sr, language="auto"):
87
+ resampled_audios = []
88
+ for audio in audios:
89
+ audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE)
90
+ assert audio.ndim == 1
91
+ resampled_audios.append(audio)
92
+
93
+ with lock:
94
+ res = model.generate(
95
+ input=resampled_audios,
96
+ batch_size=len(resampled_audios),
97
+ language=language,
98
+ use_itn=True,
99
+ )
100
+
101
+ results = []
102
+ for r, audio in zip(res, audios):
103
+ text = r["text"]
104
+ text = re.sub(r"<\|.*?\|>", "", text)
105
+ duration = len(audio) / sr * 1000
106
+ huge_gap = False
107
+
108
+ if "timestamp" in r and len(r["timestamp"]) > 2:
109
+ for timestamp_a, timestamp_b in zip(
110
+ r["timestamp"][:-1], r["timestamp"][1:]
111
+ ):
112
+ # If there is a gap of more than 4 seconds, we consider it as a huge gap
113
+ if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD:
114
+ huge_gap = True
115
+ break
116
+
117
+ # Doesn't make sense to have a huge gap at the end
118
+ if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD:
119
+ huge_gap = True
120
+
121
+ results.append(
122
+ {
123
+ "text": text,
124
+ "duration": duration,
125
+ "huge_gap": huge_gap,
126
+ }
127
+ )
128
+
129
+ return results
@@ -0,0 +1,246 @@
1
+ import io
2
+ import os
3
+ import time
4
+ from http import HTTPStatus
5
+
6
+ import numpy as np
7
+ import ormsgpack
8
+ import soundfile as sf
9
+ import torch
10
+ from kui.asgi import HTTPException, HttpView, JSONResponse, StreamResponse, request
11
+ from loguru import logger
12
+
13
+ from tools.schema import (
14
+ ServeASRRequest,
15
+ ServeASRResponse,
16
+ ServeChatRequest,
17
+ ServeTTSRequest,
18
+ ServeVQGANDecodeRequest,
19
+ ServeVQGANDecodeResponse,
20
+ ServeVQGANEncodeRequest,
21
+ ServeVQGANEncodeResponse,
22
+ )
23
+ from tools.server.agent import get_response_generator
24
+ from tools.server.api_utils import (
25
+ buffer_to_async_generator,
26
+ get_content_type,
27
+ inference_async,
28
+ )
29
+ from tools.server.inference import inference_wrapper as inference
30
+ from tools.server.model_manager import ModelManager
31
+ from tools.server.model_utils import batch_asr, cached_vqgan_batch_encode, vqgan_decode
32
+
33
+ MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
34
+
35
+
36
+ class HealthView(HttpView):
37
+ """
38
+ Return the health status of the server.
39
+ """
40
+
41
+ @classmethod
42
+ async def post(cls):
43
+ return JSONResponse({"status": "ok"})
44
+
45
+
46
+ class VQGANEncodeView(HttpView):
47
+ """
48
+ Encode the audio into symbolic tokens.
49
+ """
50
+
51
+ @classmethod
52
+ async def post(cls):
53
+ # Decode the request
54
+ payload = await request.data()
55
+ req = ServeVQGANEncodeRequest(**payload)
56
+
57
+ # Get the model from the app
58
+ model_manager: ModelManager = request.app.state.model_manager
59
+ decoder_model = model_manager.decoder_model
60
+
61
+ # Encode the audio
62
+ start_time = time.time()
63
+ tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
64
+ logger.info(
65
+ f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms"
66
+ )
67
+
68
+ # Return the response
69
+ return ormsgpack.packb(
70
+ ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
71
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
72
+ )
73
+
74
+
75
+ class VQGANDecodeView(HttpView):
76
+ """
77
+ Decode the symbolic tokens into audio.
78
+ """
79
+
80
+ @classmethod
81
+ async def post(cls):
82
+ # Decode the request
83
+ payload = await request.data()
84
+ req = ServeVQGANDecodeRequest(**payload)
85
+
86
+ # Get the model from the app
87
+ model_manager: ModelManager = request.app.state.model_manager
88
+ decoder_model = model_manager.decoder_model
89
+
90
+ # Decode the audio
91
+ tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
92
+ start_time = time.time()
93
+ audios = vqgan_decode(decoder_model, tokens)
94
+ logger.info(
95
+ f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms"
96
+ )
97
+ audios = [audio.astype(np.float16).tobytes() for audio in audios]
98
+
99
+ # Return the response
100
+ return ormsgpack.packb(
101
+ ServeVQGANDecodeResponse(audios=audios),
102
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
103
+ )
104
+
105
+
106
+ class ASRView(HttpView):
107
+ """
108
+ Perform automatic speech recognition on the audio.
109
+ """
110
+
111
+ @classmethod
112
+ async def post(cls):
113
+ # Decode the request
114
+ payload = await request.data()
115
+ req = ServeASRRequest(**payload)
116
+
117
+ # Get the model from the app
118
+ model_manager: ModelManager = request.app.state.model_manager
119
+ asr_model = model_manager.asr_model
120
+ lock = request.app.state.lock
121
+
122
+ # Perform ASR
123
+ start_time = time.time()
124
+ audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
125
+ audios = [torch.from_numpy(audio).float() for audio in audios]
126
+
127
+ if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
128
+ raise HTTPException(status_code=400, content="Audio length is too long")
129
+
130
+ transcriptions = batch_asr(
131
+ asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
132
+ )
133
+ logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
134
+
135
+ # Return the response
136
+ return ormsgpack.packb(
137
+ ServeASRResponse(transcriptions=transcriptions),
138
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
139
+ )
140
+
141
+
142
+ class TTSView(HttpView):
143
+ """
144
+ Perform text-to-speech on the input text.
145
+ """
146
+
147
+ @classmethod
148
+ async def post(cls):
149
+ # Decode the request
150
+ payload = await request.data()
151
+ req = ServeTTSRequest(**payload)
152
+
153
+ # Get the model from the app
154
+ app_state = request.app.state
155
+ model_manager: ModelManager = app_state.model_manager
156
+ engine = model_manager.tts_inference_engine
157
+ sample_rate = engine.decoder_model.spec_transform.sample_rate
158
+
159
+ # Check if the text is too long
160
+ if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
161
+ raise HTTPException(
162
+ HTTPStatus.BAD_REQUEST,
163
+ content=f"Text is too long, max length is {app_state.max_text_length}",
164
+ )
165
+
166
+ # Check if streaming is enabled
167
+ if req.streaming and req.format != "wav":
168
+ raise HTTPException(
169
+ HTTPStatus.BAD_REQUEST,
170
+ content="Streaming only supports WAV format",
171
+ )
172
+
173
+ # Perform TTS
174
+ if req.streaming:
175
+ return StreamResponse(
176
+ iterable=inference_async(req, engine),
177
+ headers={
178
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
179
+ },
180
+ content_type=get_content_type(req.format),
181
+ )
182
+ else:
183
+ fake_audios = next(inference(req, engine))
184
+ buffer = io.BytesIO()
185
+ sf.write(
186
+ buffer,
187
+ fake_audios,
188
+ sample_rate,
189
+ format=req.format,
190
+ )
191
+
192
+ return StreamResponse(
193
+ iterable=buffer_to_async_generator(buffer.getvalue()),
194
+ headers={
195
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
196
+ },
197
+ content_type=get_content_type(req.format),
198
+ )
199
+
200
+
201
+ class ChatView(HttpView):
202
+ """
203
+ Perform chatbot inference on the input text.
204
+ """
205
+
206
+ @classmethod
207
+ async def post(cls):
208
+ # Decode the request
209
+ payload = await request.data()
210
+ req = ServeChatRequest(**payload)
211
+
212
+ # Check that the number of samples requested is correct
213
+ if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
214
+ raise HTTPException(
215
+ HTTPStatus.BAD_REQUEST,
216
+ content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
217
+ )
218
+
219
+ # Get the type of content provided
220
+ content_type = request.headers.get("Content-Type", "application/json")
221
+ json_mode = "application/json" in content_type
222
+
223
+ # Get the models from the app
224
+ model_manager: ModelManager = request.app.state.model_manager
225
+ llama_queue = model_manager.llama_queue
226
+ tokenizer = model_manager.tokenizer
227
+ config = model_manager.config
228
+
229
+ device = request.app.state.device
230
+
231
+ # Get the response generators
232
+ response_generator = get_response_generator(
233
+ llama_queue, tokenizer, config, req, device, json_mode
234
+ )
235
+
236
+ # Return the response in the correct format
237
+ if req.streaming is False:
238
+ result = response_generator()
239
+ if json_mode:
240
+ return JSONResponse(result.model_dump())
241
+ else:
242
+ return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
243
+
244
+ return StreamResponse(
245
+ iterable=response_generator(), content_type="text/event-stream"
246
+ )
@@ -24,6 +24,12 @@ OmegaConf.register_new_resolver("eval", eval)
24
24
  # This file is used to convert the audio files to text files using the Whisper model.
25
25
  # It's mainly used to generate the training data for the VQ model.
26
26
 
27
+ backends = torchaudio.list_audio_backends()
28
+
29
+ if "ffmpeg" in backends:
30
+ backend = "ffmpeg"
31
+ else:
32
+ backend = "soundfile"
27
33
 
28
34
  RANK = int(os.environ.get("SLURM_PROCID", 0))
29
35
  WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
@@ -81,7 +87,7 @@ def process_batch(files: list[Path], model) -> float:
81
87
  for file in files:
82
88
  try:
83
89
  wav, sr = torchaudio.load(
84
- str(file), backend="sox" if sys.platform == "linux" else "soundfile"
90
+ str(file), backend=backend
85
91
  ) # Need to install libsox-dev
86
92
  except Exception as e:
87
93
  logger.error(f"Error reading {file}: {e}")
@@ -24,8 +24,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):
24
24
 
25
25
  model = instantiate(cfg)
26
26
  state_dict = torch.load(
27
- checkpoint_path,
28
- map_location=device,
27
+ checkpoint_path, map_location=device, mmap=True, weights_only=True
29
28
  )
30
29
  if "state_dict" in state_dict:
31
30
  state_dict = state_dict["state_dict"]
@@ -37,7 +36,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):
37
36
  if "generator." in k
38
37
  }
39
38
 
40
- result = model.load_state_dict(state_dict, strict=False)
39
+ result = model.load_state_dict(state_dict, strict=False, assign=True)
41
40
  model.eval()
42
41
  model.to(device)
43
42