xinference 1.0.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (343) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +77 -71
  4. xinference/core/chat_interface.py +6 -1
  5. xinference/core/model.py +79 -19
  6. xinference/core/supervisor.py +172 -10
  7. xinference/core/utils.py +12 -8
  8. xinference/core/worker.py +102 -4
  9. xinference/deploy/cmdline.py +3 -1
  10. xinference/deploy/test/test_cmdline.py +56 -0
  11. xinference/isolation.py +24 -0
  12. xinference/model/audio/core.py +16 -0
  13. xinference/model/audio/cosyvoice.py +39 -6
  14. xinference/model/audio/f5tts.py +200 -0
  15. xinference/model/audio/f5tts_mlx.py +260 -0
  16. xinference/model/audio/fish_speech.py +36 -111
  17. xinference/model/audio/melotts.py +110 -0
  18. xinference/model/audio/model_spec.json +99 -3
  19. xinference/model/audio/model_spec_modelscope.json +27 -0
  20. xinference/model/audio/utils.py +32 -0
  21. xinference/model/audio/whisper.py +35 -10
  22. xinference/model/embedding/core.py +203 -142
  23. xinference/model/embedding/model_spec.json +7 -0
  24. xinference/model/embedding/model_spec_modelscope.json +8 -0
  25. xinference/model/image/core.py +69 -1
  26. xinference/model/image/model_spec.json +145 -4
  27. xinference/model/image/model_spec_modelscope.json +150 -4
  28. xinference/model/image/stable_diffusion/core.py +45 -13
  29. xinference/model/llm/__init__.py +4 -2
  30. xinference/model/llm/llm_family.json +536 -53
  31. xinference/model/llm/llm_family.py +15 -36
  32. xinference/model/llm/llm_family_modelscope.json +454 -20
  33. xinference/model/llm/memory.py +1 -1
  34. xinference/model/llm/mlx/core.py +248 -52
  35. xinference/model/llm/sglang/core.py +1 -0
  36. xinference/model/llm/transformers/chatglm.py +9 -5
  37. xinference/model/llm/transformers/cogagent.py +272 -0
  38. xinference/model/llm/transformers/core.py +2 -0
  39. xinference/model/llm/transformers/qwen2_vl.py +12 -1
  40. xinference/model/llm/transformers/utils.py +16 -8
  41. xinference/model/llm/utils.py +36 -4
  42. xinference/model/llm/vllm/core.py +53 -10
  43. xinference/model/llm/vllm/xavier/__init__.py +13 -0
  44. xinference/model/llm/vllm/xavier/allocator.py +74 -0
  45. xinference/model/llm/vllm/xavier/block.py +111 -0
  46. xinference/model/llm/vllm/xavier/block_manager.py +71 -0
  47. xinference/model/llm/vllm/xavier/block_tracker.py +129 -0
  48. xinference/model/llm/vllm/xavier/collective.py +74 -0
  49. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  50. xinference/model/llm/vllm/xavier/engine.py +247 -0
  51. xinference/model/llm/vllm/xavier/executor.py +134 -0
  52. xinference/model/llm/vllm/xavier/scheduler.py +438 -0
  53. xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
  54. xinference/model/llm/vllm/xavier/test/test_xavier.py +147 -0
  55. xinference/model/llm/vllm/xavier/transfer.py +319 -0
  56. xinference/model/video/diffusers.py +14 -0
  57. xinference/model/video/model_spec.json +15 -0
  58. xinference/model/video/model_spec_modelscope.json +16 -0
  59. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  60. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  61. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  62. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  63. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  64. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  65. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  66. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  67. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  68. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  69. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  70. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  71. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  72. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  73. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  74. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  75. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  76. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  77. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  78. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  79. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  80. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  81. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  82. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  83. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  84. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  85. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  86. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  87. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  88. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  89. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  90. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  91. xinference/thirdparty/f5_tts/api.py +166 -0
  92. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  93. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  94. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  95. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  96. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  97. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  98. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  99. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  100. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  101. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  102. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  103. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  104. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  105. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  106. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  107. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  108. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  109. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  110. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  111. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  112. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  113. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  114. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  115. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  116. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  117. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  118. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  119. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  120. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  121. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  122. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  123. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  124. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  125. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  126. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  127. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  128. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  129. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  130. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  131. xinference/thirdparty/f5_tts/train/README.md +77 -0
  132. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  133. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  134. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  135. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  136. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  137. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  138. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  139. xinference/thirdparty/f5_tts/train/train.py +75 -0
  140. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  141. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  142. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  143. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  144. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  145. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  146. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  147. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  148. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  149. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  150. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  151. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  152. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  153. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  154. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  155. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  156. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  157. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  158. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  159. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  160. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  161. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  162. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  163. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  164. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  165. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  166. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  167. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  168. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  169. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  170. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  171. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  172. xinference/thirdparty/matcha/utils/utils.py +2 -2
  173. xinference/thirdparty/melo/api.py +135 -0
  174. xinference/thirdparty/melo/app.py +61 -0
  175. xinference/thirdparty/melo/attentions.py +459 -0
  176. xinference/thirdparty/melo/commons.py +160 -0
  177. xinference/thirdparty/melo/configs/config.json +94 -0
  178. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  179. xinference/thirdparty/melo/data_utils.py +413 -0
  180. xinference/thirdparty/melo/download_utils.py +67 -0
  181. xinference/thirdparty/melo/infer.py +25 -0
  182. xinference/thirdparty/melo/init_downloads.py +14 -0
  183. xinference/thirdparty/melo/losses.py +58 -0
  184. xinference/thirdparty/melo/main.py +36 -0
  185. xinference/thirdparty/melo/mel_processing.py +174 -0
  186. xinference/thirdparty/melo/models.py +1030 -0
  187. xinference/thirdparty/melo/modules.py +598 -0
  188. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  189. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  190. xinference/thirdparty/melo/preprocess_text.py +135 -0
  191. xinference/thirdparty/melo/split_utils.py +174 -0
  192. xinference/thirdparty/melo/text/__init__.py +35 -0
  193. xinference/thirdparty/melo/text/chinese.py +199 -0
  194. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  195. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  196. xinference/thirdparty/melo/text/cleaner.py +36 -0
  197. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  198. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  199. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  200. xinference/thirdparty/melo/text/english.py +284 -0
  201. xinference/thirdparty/melo/text/english_bert.py +39 -0
  202. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  203. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  204. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  205. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  206. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  207. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  208. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  209. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  210. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  211. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  212. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  213. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  214. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  215. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  216. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  217. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  218. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  219. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  220. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  221. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  222. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  223. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  224. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  225. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  226. xinference/thirdparty/melo/text/french.py +94 -0
  227. xinference/thirdparty/melo/text/french_bert.py +39 -0
  228. xinference/thirdparty/melo/text/japanese.py +647 -0
  229. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  230. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  231. xinference/thirdparty/melo/text/korean.py +192 -0
  232. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  233. xinference/thirdparty/melo/text/spanish.py +122 -0
  234. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  235. xinference/thirdparty/melo/text/symbols.py +290 -0
  236. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  237. xinference/thirdparty/melo/train.py +635 -0
  238. xinference/thirdparty/melo/train.sh +19 -0
  239. xinference/thirdparty/melo/transforms.py +209 -0
  240. xinference/thirdparty/melo/utils.py +424 -0
  241. xinference/types.py +15 -0
  242. xinference/web/ui/build/asset-manifest.json +6 -6
  243. xinference/web/ui/build/index.html +1 -1
  244. xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
  245. xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
  246. xinference/web/ui/build/static/js/main.b0936c54.js +3 -0
  247. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  248. xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
  249. xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
  250. xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
  251. xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
  252. xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
  253. xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
  254. xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
  255. xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
  256. xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
  257. xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
  258. xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
  259. xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
  260. xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
  261. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
  262. xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
  263. xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
  264. xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
  265. xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
  266. xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
  267. xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
  268. xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
  269. xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
  270. xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
  271. xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
  272. xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
  273. xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
  274. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  275. xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
  276. xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
  277. xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
  278. xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
  279. xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
  280. xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
  281. xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
  282. xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
  283. xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
  284. xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
  285. xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
  286. xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
  287. xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
  288. xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
  289. xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
  290. xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
  291. xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
  292. xinference/web/ui/node_modules/.package-lock.json +67 -3
  293. xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
  294. xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
  295. xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
  296. xinference/web/ui/node_modules/i18next/package.json +129 -0
  297. xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
  298. xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
  299. xinference/web/ui/node_modules/react-i18next/package.json +162 -0
  300. xinference/web/ui/node_modules/void-elements/package.json +34 -0
  301. xinference/web/ui/package-lock.json +69 -3
  302. xinference/web/ui/package.json +2 -0
  303. xinference/web/ui/src/locales/en.json +186 -0
  304. xinference/web/ui/src/locales/zh.json +186 -0
  305. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/METADATA +68 -32
  306. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/RECORD +316 -122
  307. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  308. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  309. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  310. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  311. xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
  312. xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
  313. xinference/web/ui/build/static/js/main.2f269bb3.js +0 -3
  314. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  315. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
  316. xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
  317. xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
  318. xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
  319. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
  320. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
  321. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
  322. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
  323. xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
  324. xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
  325. xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
  326. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  327. xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
  328. xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
  329. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
  330. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
  331. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
  332. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
  333. xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
  334. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  335. /xinference/thirdparty/{cosyvoice/flow → melo}/__init__.py +0 -0
  336. /xinference/thirdparty/{cosyvoice/hifigan → melo/text/english_utils}/__init__.py +0 -0
  337. /xinference/thirdparty/{cosyvoice/llm → melo/text/es_phonemizer}/__init__.py +0 -0
  338. /xinference/thirdparty/{fish_speech/tools → melo/text/fr_phonemizer}/__init__.py +0 -0
  339. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  340. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/LICENSE +0 -0
  341. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/WHEEL +0 -0
  342. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/entry_points.txt +0 -0
  343. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,152 @@
1
+ import base64
2
+ import json
3
+ import logging
4
+ from pathlib import Path
5
+
6
+ import tiktoken
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # This is a modified version of the default pattern from GPT-4o, that better handles punctuations.
11
+ FISH_TIKTOKEN_PATTERN = "|".join(
12
+ [
13
+ r"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
14
+ r"\p{P}",
15
+ r"[^\r\n\p{L}\p{N}]?\p{L}+",
16
+ r"\p{N}",
17
+ r" ?[^\s\p{L}\p{N}]+[\r\n]*",
18
+ r"\s*[\r\n]+",
19
+ r"\s+(\?!\S)",
20
+ r"\s+",
21
+ ]
22
+ )
23
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
24
+
25
+ BOS_TOKEN = "<|begin_of_text|>"
26
+ EOS_TOKEN = "<|end_of_text|>"
27
+ PAD_TOKEN = "<|pad|>"
28
+ IM_START_TOKEN = "<|im_start|>"
29
+ IM_END_TOKEN = "<|im_end|>"
30
+
31
+ MODALITY_TEXT_TOKEN = "<|text|>"
32
+ MODALITY_VOICE_TOKEN = "<|voice|>"
33
+ MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
34
+ MODALITY_TOKENS = {
35
+ "text": MODALITY_TEXT_TOKEN,
36
+ "voice": MODALITY_VOICE_TOKEN,
37
+ "interleave": MODALITY_INTERLEAVE_TOKEN,
38
+ }
39
+
40
+ PLACEHOLDER_TOKEN = [""] * 4
41
+ for i in range(4):
42
+ PLACEHOLDER_TOKEN[i] = f"<|placeholder:{i}|>"
43
+
44
+ SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
45
+ SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)]
46
+
47
+ # Warning: when you add a new special token, you should only add it to the end of the list.
48
+ ALL_SPECIAL_TOKENS = [
49
+ BOS_TOKEN,
50
+ EOS_TOKEN,
51
+ PAD_TOKEN,
52
+ IM_START_TOKEN,
53
+ IM_END_TOKEN,
54
+ PLACEHOLDER_TOKEN[0],
55
+ PLACEHOLDER_TOKEN[1],
56
+ PLACEHOLDER_TOKEN[2],
57
+ PLACEHOLDER_TOKEN[3],
58
+ MODALITY_TEXT_TOKEN,
59
+ MODALITY_VOICE_TOKEN,
60
+ MODALITY_INTERLEAVE_TOKEN,
61
+ *SEMANTIC_TOKENS,
62
+ ]
63
+
64
+
65
+ class FishTokenizer:
66
+ def __init__(self, model_path: str) -> None:
67
+ mergeable_ranks = self.load_tiktoken_bpe(model_path)
68
+ special_token_begin = len(mergeable_ranks)
69
+ self.all_special_tokens_with_ids = {
70
+ token: special_token_begin + i for i, token in enumerate(ALL_SPECIAL_TOKENS)
71
+ }
72
+ self.semantic_id_to_token_id = {
73
+ i: self.all_special_tokens_with_ids[token]
74
+ for i, token in enumerate(SEMANTIC_TOKENS)
75
+ }
76
+ self.semantic_begin_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[0]]
77
+ self.semantic_end_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[-1]]
78
+
79
+ self.tkt_model = tiktoken.core.Encoding(
80
+ name=Path(model_path).stem,
81
+ pat_str=FISH_TIKTOKEN_PATTERN,
82
+ mergeable_ranks=mergeable_ranks,
83
+ special_tokens=self.all_special_tokens_with_ids,
84
+ )
85
+
86
+ @staticmethod
87
+ def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
88
+ data = {}
89
+ for line in open(tiktoken_bpe_file).read().splitlines():
90
+ if not line:
91
+ continue
92
+ token, rank = line.split()
93
+ data[base64.b64decode(token)] = int(rank)
94
+ return data
95
+
96
+ def get_token_id(self, token: str) -> int:
97
+ return self.all_special_tokens_with_ids[token]
98
+
99
+ def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]:
100
+ assert isinstance(s, str)
101
+
102
+ subs = []
103
+ for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
104
+ subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
105
+
106
+ if allowed_special is True:
107
+ allowed_special = self.tkt_model.special_tokens_set
108
+ elif allowed_special is False:
109
+ allowed_special = set()
110
+
111
+ return sum(
112
+ self.tkt_model.encode_batch(
113
+ subs, allowed_special=allowed_special, disallowed_special=set()
114
+ ),
115
+ start=[],
116
+ )
117
+
118
+ def decode(self, tokens: list[int]) -> str:
119
+ return self.tkt_model.decode(tokens)
120
+
121
+ def save_pretrained(self, path: str):
122
+ path = Path(path)
123
+ path.mkdir(parents=True, exist_ok=True)
124
+
125
+ with open(path / "tokenizer.tiktoken", "w") as f:
126
+ for token, rank in self.tkt_model._mergeable_ranks.items():
127
+ f.write(f"{base64.b64encode(token).decode()} {rank}\n")
128
+
129
+ with open(path / "special_tokens.json", "w") as f:
130
+ json.dump(
131
+ self.all_special_tokens_with_ids,
132
+ f,
133
+ indent=2,
134
+ ensure_ascii=False,
135
+ )
136
+
137
+ @staticmethod
138
+ def from_pretrained(path: str):
139
+ return FishTokenizer(Path(path) / "tokenizer.tiktoken")
140
+
141
+
142
+ if __name__ == "__main__":
143
+ tokenizer = FishTokenizer("data/mpacks/v1.4-pretrain/tokenizer.all.tiktoken")
144
+ tokenizer.save_pretrained("checkpoints/fish-speech-0.5B")
145
+ tokenizer = FishTokenizer.from_pretrained("checkpoints/fish-speech-0.5B")
146
+
147
+ print(
148
+ [
149
+ tokenizer.decode([i])
150
+ for i in tokenizer.encode(f"{BOS_TOKEN}你好,世界!{EOS_TOKEN}")
151
+ ]
152
+ )
@@ -6,7 +6,7 @@ from typing import Optional
6
6
 
7
7
  import hydra
8
8
  import lightning as L
9
- # import pyrootutils
9
+ import pyrootutils
10
10
  import torch
11
11
  from lightning import Callback, LightningDataModule, LightningModule, Trainer
12
12
  from lightning.pytorch.loggers import Logger
@@ -18,7 +18,7 @@ os.environ.pop("SLURM_JOB_NAME", None)
18
18
  os.environ.pop("SLURM_NTASKS_PER_NODE", None)
19
19
 
20
20
  # register eval resolver and root
21
- # pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
21
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
22
22
 
23
23
  # Allow TF32 on Ampere GPUs
24
24
  torch.set_float32_matmul_precision("high")
@@ -176,7 +176,7 @@ def change_infer(
176
176
  p_infer = subprocess.Popen(
177
177
  [
178
178
  PYTHON,
179
- "tools/webui.py",
179
+ "tools/run_webui.py",
180
180
  "--decoder-checkpoint-path",
181
181
  infer_decoder_model,
182
182
  "--decoder-config-name",
@@ -69,10 +69,6 @@ def parse_args():
69
69
  parser.add_argument(
70
70
  "--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
71
71
  )
72
- parser.add_argument(
73
- "--mp3_bitrate", type=int, choices=[64, 128, 192], default=64, help="kHz"
74
- )
75
- parser.add_argument("--opus_bitrate", type=int, default=-1000)
76
72
  parser.add_argument(
77
73
  "--latency",
78
74
  type=str,
@@ -83,7 +79,7 @@ def parse_args():
83
79
  parser.add_argument(
84
80
  "--max_new_tokens",
85
81
  type=int,
86
- default=0,
82
+ default=1024,
87
83
  help="Maximum new tokens to generate. \n0 means no limit.",
88
84
  )
89
85
  parser.add_argument(
@@ -112,11 +108,9 @@ def parse_args():
112
108
  parser.add_argument(
113
109
  "--use_memory_cache",
114
110
  type=str,
115
- default="never",
116
- choices=["on-demand", "never"],
117
- help="Cache encoded references codes in memory.\n"
118
- "If `on-demand`, the server will use cached encodings\n "
119
- "instead of encoding reference audio again.",
111
+ default="off",
112
+ choices=["on", "off"],
113
+ help="Cache encoded references codes in memory.\n",
120
114
  )
121
115
  parser.add_argument(
122
116
  "--seed",
@@ -154,14 +148,14 @@ if __name__ == "__main__":
154
148
  data = {
155
149
  "text": args.text,
156
150
  "references": [
157
- ServeReferenceAudio(audio=ref_audio, text=ref_text)
151
+ ServeReferenceAudio(
152
+ audio=ref_audio if ref_audio is not None else b"", text=ref_text
153
+ )
158
154
  for ref_text, ref_audio in zip(ref_texts, byte_audios)
159
155
  ],
160
156
  "reference_id": idstr,
161
157
  "normalize": args.normalize,
162
158
  "format": args.format,
163
- "mp3_bitrate": args.mp3_bitrate,
164
- "opus_bitrate": args.opus_bitrate,
165
159
  "max_new_tokens": args.max_new_tokens,
166
160
  "chunk_length": args.chunk_length,
167
161
  "top_p": args.top_p,
@@ -0,0 +1,98 @@
1
+ from threading import Lock
2
+
3
+ import pyrootutils
4
+ import uvicorn
5
+ from kui.asgi import FactoryClass, HTTPException, HttpRoute, Kui, OpenAPI, Routes
6
+ from loguru import logger
7
+
8
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
9
+
10
+ from tools.server.api_utils import MsgPackRequest, parse_args
11
+ from tools.server.exception_handler import ExceptionHandler
12
+ from tools.server.model_manager import ModelManager
13
+ from tools.server.views import (
14
+ ASRView,
15
+ ChatView,
16
+ HealthView,
17
+ TTSView,
18
+ VQGANDecodeView,
19
+ VQGANEncodeView,
20
+ )
21
+
22
+
23
+ class API(ExceptionHandler):
24
+ def __init__(self):
25
+ self.args = parse_args()
26
+ self.routes = [
27
+ ("/v1/health", HealthView),
28
+ ("/v1/vqgan/encode", VQGANEncodeView),
29
+ ("/v1/vqgan/decode", VQGANDecodeView),
30
+ ("/v1/asr", ASRView),
31
+ ("/v1/tts", TTSView),
32
+ ("/v1/chat", ChatView),
33
+ ]
34
+ self.routes = Routes([HttpRoute(path, view) for path, view in self.routes])
35
+
36
+ self.openapi = OpenAPI(
37
+ {
38
+ "title": "Fish Speech API",
39
+ "version": "1.5.0",
40
+ },
41
+ ).routes
42
+
43
+ # Initialize the app
44
+ self.app = Kui(
45
+ routes=self.routes + self.openapi[1:], # Remove the default route
46
+ exception_handlers={
47
+ HTTPException: self.http_exception_handler,
48
+ Exception: self.other_exception_handler,
49
+ },
50
+ factory_class=FactoryClass(http=MsgPackRequest),
51
+ cors_config={},
52
+ )
53
+
54
+ # Add the state variables
55
+ self.app.state.lock = Lock()
56
+ self.app.state.device = self.args.device
57
+ self.app.state.max_text_length = self.args.max_text_length
58
+
59
+ # Associate the app with the model manager
60
+ self.app.on_startup(self.initialize_app)
61
+
62
+ async def initialize_app(self, app: Kui):
63
+ # Make the ModelManager available to the views
64
+ app.state.model_manager = ModelManager(
65
+ mode=self.args.mode,
66
+ device=self.args.device,
67
+ half=self.args.half,
68
+ compile=self.args.compile,
69
+ asr_enabled=self.args.load_asr_model,
70
+ llama_checkpoint_path=self.args.llama_checkpoint_path,
71
+ decoder_checkpoint_path=self.args.decoder_checkpoint_path,
72
+ decoder_config_name=self.args.decoder_config_name,
73
+ )
74
+
75
+ logger.info(f"Startup done, listening server at http://{self.args.listen}")
76
+
77
+
78
+ # Each worker process created by Uvicorn has its own memory space,
79
+ # meaning that models and variables are not shared between processes.
80
+ # Therefore, any variables (like `llama_queue` or `decoder_model`)
81
+ # will not be shared across workers.
82
+
83
+ # Multi-threading for deep learning can cause issues, such as inconsistent
84
+ # outputs if multiple threads access the same buffers simultaneously.
85
+ # Instead, it's better to use multiprocessing or independent models per thread.
86
+
87
+ if __name__ == "__main__":
88
+
89
+ api = API()
90
+ host, port = api.args.listen.split(":")
91
+
92
+ uvicorn.run(
93
+ api.app,
94
+ host=host,
95
+ port=int(port),
96
+ workers=api.args.workers,
97
+ log_level="info",
98
+ )
@@ -22,14 +22,14 @@ def check_and_download_files(repo_id, file_list, local_dir):
22
22
 
23
23
 
24
24
  # 1st
25
- repo_id_1 = "fishaudio/fish-speech-1.4"
26
- local_dir_1 = "./checkpoints/fish-speech-1.4"
25
+ repo_id_1 = "fishaudio/fish-speech-1.5"
26
+ local_dir_1 = "./checkpoints/fish-speech-1.5"
27
27
  files_1 = [
28
+ "gitattributes",
28
29
  "model.pth",
29
30
  "README.md",
30
- "special_tokens_map.json",
31
- "tokenizer_config.json",
32
- "tokenizer.json",
31
+ "special_tokens.json",
32
+ "tokenizer.tiktoken",
33
33
  "config.json",
34
34
  "firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
35
35
  ]
@@ -14,8 +14,8 @@ import ormsgpack
14
14
  import soundfile as sf
15
15
 
16
16
  from .schema import (
17
+ ServeChatRequest,
17
18
  ServeMessage,
18
- ServeRequest,
19
19
  ServeTextPart,
20
20
  ServeVQGANDecodeRequest,
21
21
  ServeVQGANEncodeRequest,
@@ -163,7 +163,7 @@ class FishE2EAgent:
163
163
  else:
164
164
  user_codes = None
165
165
 
166
- request = ServeRequest(
166
+ request = ServeChatRequest(
167
167
  messages=prev_messages
168
168
  + (
169
169
  [
@@ -0,0 +1,192 @@
1
+ import gc
2
+ import queue
3
+ from typing import Generator
4
+
5
+ import numpy as np
6
+ import torch
7
+ from loguru import logger
8
+
9
+ from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
10
+ from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
11
+ from fish_speech.utils import autocast_exclude_mps, set_seed
12
+ from tools.inference_engine.reference_loader import ReferenceLoader
13
+ from tools.inference_engine.utils import InferenceResult, wav_chunk_header
14
+ from tools.inference_engine.vq_manager import VQManager
15
+ from tools.llama.generate import (
16
+ GenerateRequest,
17
+ GenerateResponse,
18
+ WrappedGenerateResponse,
19
+ )
20
+ from tools.schema import ServeTTSRequest
21
+
22
+
23
+ class TTSInferenceEngine(ReferenceLoader, VQManager):
24
+
25
+ def __init__(
26
+ self,
27
+ llama_queue: queue.Queue,
28
+ decoder_model: FireflyArchitecture,
29
+ precision: torch.dtype,
30
+ compile: bool,
31
+ ) -> None:
32
+
33
+ super().__init__()
34
+
35
+ self.llama_queue = llama_queue
36
+ self.decoder_model = decoder_model
37
+ self.precision = precision
38
+ self.compile = compile
39
+
40
+ @torch.inference_mode()
41
+ def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
42
+ """
43
+ Main inference function:
44
+ - Loads the reference audio and text.
45
+ - Calls the LLAMA model for inference.
46
+ - Decodes the VQ tokens to audio.
47
+ """
48
+
49
+ ref_id: str | None = req.reference_id
50
+ prompt_tokens, prompt_texts = [], []
51
+ # Load the reference audio and text based on id or hash
52
+ if ref_id is not None:
53
+ prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
54
+
55
+ elif req.references:
56
+ prompt_tokens, prompt_texts = self.load_by_hash(
57
+ req.references, req.use_memory_cache
58
+ )
59
+
60
+ # Set the random seed if provided
61
+ if req.seed is not None:
62
+ set_seed(req.seed)
63
+ logger.warning(f"set seed: {req.seed}")
64
+
65
+ # Get the symbolic tokens from the LLAMA model
66
+ response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
67
+
68
+ # Get the sample rate from the decoder model
69
+ sample_rate = self.decoder_model.spec_transform.sample_rate
70
+
71
+ # If streaming, send the header
72
+ # if req.streaming:
73
+ # yield InferenceResult(
74
+ # code="header",
75
+ # audio=(sample_rate, wav_chunk_header(sample_rate=sample_rate)),
76
+ # error=None,
77
+ # )
78
+
79
+ segments = []
80
+
81
+ while True:
82
+ # Get the response from the LLAMA model
83
+ wrapped_result: WrappedGenerateResponse = response_queue.get()
84
+ if wrapped_result.status == "error":
85
+ yield InferenceResult(
86
+ code="error",
87
+ audio=None,
88
+ error=(
89
+ wrapped_result.response
90
+ if isinstance(wrapped_result.response, Exception)
91
+ else Exception("Unknown error")
92
+ ),
93
+ )
94
+ break
95
+
96
+ # Check the response type
97
+ if not isinstance(wrapped_result.response, GenerateResponse):
98
+ raise TypeError(
99
+ "Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
100
+ )
101
+
102
+ result: GenerateResponse = wrapped_result.response
103
+ if result.action != "next":
104
+ segment = self.get_audio_segment(result)
105
+
106
+ if req.streaming: # Used only by the API server
107
+ yield InferenceResult(
108
+ code="segment",
109
+ audio=(sample_rate, segment),
110
+ error=None,
111
+ )
112
+ segments.append(segment)
113
+ else:
114
+ break
115
+
116
+ # Clean up the memory
117
+ if torch.cuda.is_available():
118
+ torch.cuda.empty_cache()
119
+ gc.collect()
120
+
121
+ # Edge case: no audio generated
122
+ if len(segments) == 0:
123
+ yield InferenceResult(
124
+ code="error",
125
+ audio=None,
126
+ error=RuntimeError("No audio generated, please check the input text."),
127
+ )
128
+ else:
129
+ # Streaming or not, return the final audio
130
+ audio = np.concatenate(segments, axis=0)
131
+ yield InferenceResult(
132
+ code="final",
133
+ audio=(sample_rate, audio),
134
+ error=None,
135
+ )
136
+
137
+ return None
138
+
139
+ def send_Llama_request(
140
+ self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
141
+ ) -> queue.Queue:
142
+ """
143
+ Send a request to the LLAMA model to generate the symbolic tokens.
144
+ """
145
+
146
+ # Prepare the request
147
+ request = dict(
148
+ device=self.decoder_model.device,
149
+ max_new_tokens=req.max_new_tokens,
150
+ text=(
151
+ req.text
152
+ if not req.normalize
153
+ else ChnNormedText(raw_text=req.text).normalize()
154
+ ),
155
+ top_p=req.top_p,
156
+ repetition_penalty=req.repetition_penalty,
157
+ temperature=req.temperature,
158
+ compile=self.compile,
159
+ iterative_prompt=req.chunk_length > 0,
160
+ chunk_length=req.chunk_length,
161
+ max_length=4096,
162
+ prompt_tokens=prompt_tokens,
163
+ prompt_text=prompt_texts,
164
+ )
165
+
166
+ # Create a queue to get the response
167
+ response_queue = queue.Queue()
168
+
169
+ # Send the request to the LLAMA model
170
+ self.llama_queue.put(
171
+ GenerateRequest(
172
+ request=request,
173
+ response_queue=response_queue,
174
+ )
175
+ )
176
+
177
+ return response_queue
178
+
179
+ def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
180
+ """
181
+ Decode the VQ tokens to audio.
182
+ """
183
+
184
+ # Don't use autocast on MPS devices
185
+ with autocast_exclude_mps(
186
+ device_type=self.decoder_model.device.type, dtype=self.precision
187
+ ):
188
+ # Decode the symbolic tokens to audio
189
+ segment = self.decode_vq_tokens(codes=result.codes)
190
+
191
+ # Convert the audio to numpy
192
+ return segment.float().cpu().numpy()
@@ -0,0 +1,125 @@
1
+ import io
2
+ from hashlib import sha256
3
+ from pathlib import Path
4
+ from typing import Callable, Literal, Tuple
5
+
6
+ import torch
7
+ import torchaudio
8
+ from loguru import logger
9
+
10
+ from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
11
+ from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
12
+ from tools.schema import ServeReferenceAudio
13
+
14
+
15
+ class ReferenceLoader:
16
+
17
+ def __init__(self) -> None:
18
+ """
19
+ Component of the TTSInferenceEngine class.
20
+ Loads and manages the cache for the reference audio and text.
21
+ """
22
+ self.ref_by_id: dict = {}
23
+ self.ref_by_hash: dict = {}
24
+
25
+ # Make Pylance happy (attribut/method not defined...)
26
+ self.decoder_model: FireflyArchitecture
27
+ self.encode_reference: Callable
28
+
29
+ # Define the torchaudio backend
30
+ backends = torchaudio.list_audio_backends()
31
+ if "ffmpeg" in backends:
32
+ self.backend = "ffmpeg"
33
+ else:
34
+ self.backend = "soundfile"
35
+
36
+ def load_by_id(
37
+ self,
38
+ id: str,
39
+ use_cache: Literal["on", "off"],
40
+ ) -> Tuple:
41
+
42
+ # Load the references audio and text by id
43
+ ref_folder = Path("references") / id
44
+ ref_folder.mkdir(parents=True, exist_ok=True)
45
+ ref_audios = list_files(
46
+ ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
47
+ )
48
+
49
+ if use_cache == "off" or id not in self.ref_by_id:
50
+ # If the references are not already loaded, encode them
51
+ prompt_tokens = [
52
+ self.encode_reference(
53
+ # decoder_model=self.decoder_model,
54
+ reference_audio=audio_to_bytes(str(ref_audio)),
55
+ enable_reference_audio=True,
56
+ )
57
+ for ref_audio in ref_audios
58
+ ]
59
+ prompt_texts = [
60
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
61
+ for ref_audio in ref_audios
62
+ ]
63
+ self.ref_by_id[id] = (prompt_tokens, prompt_texts)
64
+
65
+ else:
66
+ # Reuse already encoded references
67
+ logger.info("Use same references")
68
+ prompt_tokens, prompt_texts = self.ref_by_id[id]
69
+
70
+ return prompt_tokens, prompt_texts
71
+
72
+ def load_by_hash(
73
+ self,
74
+ references: list[ServeReferenceAudio],
75
+ use_cache: Literal["on", "off"],
76
+ ) -> Tuple:
77
+
78
+ # Load the references audio and text by hash
79
+ audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
80
+
81
+ cache_used = False
82
+ prompt_tokens, prompt_texts = [], []
83
+ for i, ref in enumerate(references):
84
+ if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
85
+ # If the references are not already loaded, encode them
86
+ prompt_tokens.append(
87
+ self.encode_reference(
88
+ reference_audio=ref.audio,
89
+ enable_reference_audio=True,
90
+ )
91
+ )
92
+ prompt_texts.append(ref.text)
93
+ self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts)
94
+
95
+ else:
96
+ # Reuse already encoded references
97
+ prompt_tokens, prompt_texts = self.ref_by_hash[audio_hashes[i]]
98
+ cache_used = True
99
+
100
+ if cache_used:
101
+ logger.info("Use same references")
102
+
103
+ return prompt_tokens, prompt_texts
104
+
105
+ def load_audio(self, reference_audio, sr):
106
+ """
107
+ Load the audio data from a file or bytes.
108
+ """
109
+ if len(reference_audio) > 255 or not Path(reference_audio).exists():
110
+ audio_data = reference_audio
111
+ reference_audio = io.BytesIO(audio_data)
112
+
113
+ waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
114
+
115
+ if waveform.shape[0] > 1:
116
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
117
+
118
+ if original_sr != sr:
119
+ resampler = torchaudio.transforms.Resample(
120
+ orig_freq=original_sr, new_freq=sr
121
+ )
122
+ waveform = resampler(waveform)
123
+
124
+ audio = waveform.squeeze().numpy()
125
+ return audio