xinference 0.16.3__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (373) hide show
  1. xinference/_compat.py +24 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +219 -77
  4. xinference/client/restful/restful_client.py +47 -2
  5. xinference/constants.py +1 -0
  6. xinference/core/chat_interface.py +6 -1
  7. xinference/core/model.py +124 -34
  8. xinference/core/supervisor.py +180 -12
  9. xinference/core/utils.py +73 -4
  10. xinference/core/worker.py +102 -4
  11. xinference/deploy/cmdline.py +3 -1
  12. xinference/deploy/test/test_cmdline.py +56 -0
  13. xinference/isolation.py +24 -0
  14. xinference/model/audio/__init__.py +12 -0
  15. xinference/model/audio/core.py +37 -4
  16. xinference/model/audio/cosyvoice.py +39 -6
  17. xinference/model/audio/f5tts.py +200 -0
  18. xinference/model/audio/f5tts_mlx.py +260 -0
  19. xinference/model/audio/fish_speech.py +70 -110
  20. xinference/model/audio/melotts.py +110 -0
  21. xinference/model/audio/model_spec.json +179 -3
  22. xinference/model/audio/model_spec_modelscope.json +27 -0
  23. xinference/model/audio/utils.py +32 -0
  24. xinference/model/audio/whisper.py +35 -10
  25. xinference/model/audio/whisper_mlx.py +208 -0
  26. xinference/model/embedding/core.py +322 -6
  27. xinference/model/embedding/model_spec.json +8 -1
  28. xinference/model/embedding/model_spec_modelscope.json +9 -1
  29. xinference/model/image/core.py +69 -1
  30. xinference/model/image/model_spec.json +145 -4
  31. xinference/model/image/model_spec_modelscope.json +150 -4
  32. xinference/model/image/stable_diffusion/core.py +50 -15
  33. xinference/model/llm/__init__.py +6 -2
  34. xinference/model/llm/llm_family.json +1055 -93
  35. xinference/model/llm/llm_family.py +15 -36
  36. xinference/model/llm/llm_family_modelscope.json +1031 -78
  37. xinference/model/llm/memory.py +1 -1
  38. xinference/model/llm/mlx/core.py +285 -47
  39. xinference/model/llm/sglang/core.py +2 -0
  40. xinference/model/llm/transformers/chatglm.py +9 -5
  41. xinference/model/llm/transformers/cogagent.py +272 -0
  42. xinference/model/llm/transformers/core.py +3 -0
  43. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  44. xinference/model/llm/transformers/qwen2_vl.py +12 -1
  45. xinference/model/llm/transformers/utils.py +16 -8
  46. xinference/model/llm/utils.py +55 -4
  47. xinference/model/llm/vllm/core.py +137 -12
  48. xinference/model/llm/vllm/xavier/__init__.py +13 -0
  49. xinference/model/llm/vllm/xavier/allocator.py +74 -0
  50. xinference/model/llm/vllm/xavier/block.py +111 -0
  51. xinference/model/llm/vllm/xavier/block_manager.py +71 -0
  52. xinference/model/llm/vllm/xavier/block_tracker.py +129 -0
  53. xinference/model/llm/vllm/xavier/collective.py +74 -0
  54. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  55. xinference/model/llm/vllm/xavier/engine.py +247 -0
  56. xinference/model/llm/vllm/xavier/executor.py +134 -0
  57. xinference/model/llm/vllm/xavier/scheduler.py +438 -0
  58. xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
  59. xinference/model/llm/vllm/xavier/test/test_xavier.py +147 -0
  60. xinference/model/llm/vllm/xavier/transfer.py +319 -0
  61. xinference/model/rerank/core.py +11 -4
  62. xinference/model/video/diffusers.py +14 -0
  63. xinference/model/video/model_spec.json +15 -0
  64. xinference/model/video/model_spec_modelscope.json +16 -0
  65. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  66. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  67. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  68. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  69. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  70. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  71. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  72. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  73. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  74. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  75. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  76. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  77. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  78. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  79. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  80. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  81. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  82. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  83. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  84. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  85. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  86. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  87. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  88. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  89. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  90. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  91. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  92. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  93. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  94. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  95. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  96. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  97. xinference/thirdparty/f5_tts/api.py +166 -0
  98. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  99. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  100. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  101. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  102. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  103. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  104. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  105. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  106. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  107. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  108. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  109. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  110. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  111. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  112. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  113. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  114. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  115. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  116. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  117. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  118. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  119. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  120. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  121. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  122. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  123. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  124. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  125. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  126. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  127. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  128. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  129. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  130. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  131. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  132. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  133. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  134. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  135. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  136. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  137. xinference/thirdparty/f5_tts/train/README.md +77 -0
  138. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  139. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  140. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  141. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  142. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  143. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  144. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  145. xinference/thirdparty/f5_tts/train/train.py +75 -0
  146. xinference/thirdparty/fish_speech/fish_speech/conversation.py +266 -1
  147. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  148. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  149. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  150. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  151. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  152. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +137 -29
  153. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  154. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  155. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +17 -11
  156. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  157. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  158. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  159. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  160. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  161. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  162. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +2 -2
  163. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +34 -18
  164. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  165. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  166. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  167. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  168. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  169. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  170. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  171. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  172. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  173. xinference/thirdparty/fish_speech/tools/llama/generate.py +484 -72
  174. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  175. xinference/thirdparty/fish_speech/tools/schema.py +170 -0
  176. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  177. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  178. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  179. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  180. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  181. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  182. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  183. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  184. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  185. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  186. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  187. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  188. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  189. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  190. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  191. xinference/thirdparty/matcha/utils/utils.py +2 -2
  192. xinference/thirdparty/melo/api.py +135 -0
  193. xinference/thirdparty/melo/app.py +61 -0
  194. xinference/thirdparty/melo/attentions.py +459 -0
  195. xinference/thirdparty/melo/commons.py +160 -0
  196. xinference/thirdparty/melo/configs/config.json +94 -0
  197. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  198. xinference/thirdparty/melo/data_utils.py +413 -0
  199. xinference/thirdparty/melo/download_utils.py +67 -0
  200. xinference/thirdparty/melo/infer.py +25 -0
  201. xinference/thirdparty/melo/init_downloads.py +14 -0
  202. xinference/thirdparty/melo/losses.py +58 -0
  203. xinference/thirdparty/melo/main.py +36 -0
  204. xinference/thirdparty/melo/mel_processing.py +174 -0
  205. xinference/thirdparty/melo/models.py +1030 -0
  206. xinference/thirdparty/melo/modules.py +598 -0
  207. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  208. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  209. xinference/thirdparty/melo/preprocess_text.py +135 -0
  210. xinference/thirdparty/melo/split_utils.py +174 -0
  211. xinference/thirdparty/melo/text/__init__.py +35 -0
  212. xinference/thirdparty/melo/text/chinese.py +199 -0
  213. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  214. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  215. xinference/thirdparty/melo/text/cleaner.py +36 -0
  216. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  217. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  218. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  219. xinference/thirdparty/melo/text/english.py +284 -0
  220. xinference/thirdparty/melo/text/english_bert.py +39 -0
  221. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  222. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  223. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  224. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  225. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  226. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  227. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  228. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  229. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  230. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  231. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  232. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  233. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  234. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  235. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  236. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  237. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  238. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  239. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  240. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  241. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  242. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  243. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  244. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  245. xinference/thirdparty/melo/text/french.py +94 -0
  246. xinference/thirdparty/melo/text/french_bert.py +39 -0
  247. xinference/thirdparty/melo/text/japanese.py +647 -0
  248. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  249. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  250. xinference/thirdparty/melo/text/korean.py +192 -0
  251. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  252. xinference/thirdparty/melo/text/spanish.py +122 -0
  253. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  254. xinference/thirdparty/melo/text/symbols.py +290 -0
  255. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  256. xinference/thirdparty/melo/train.py +635 -0
  257. xinference/thirdparty/melo/train.sh +19 -0
  258. xinference/thirdparty/melo/transforms.py +209 -0
  259. xinference/thirdparty/melo/utils.py +424 -0
  260. xinference/types.py +17 -1
  261. xinference/web/ui/build/asset-manifest.json +6 -6
  262. xinference/web/ui/build/index.html +1 -1
  263. xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
  264. xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
  265. xinference/web/ui/build/static/js/main.b0936c54.js +3 -0
  266. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  267. xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
  268. xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
  269. xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
  270. xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
  271. xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
  272. xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
  273. xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
  274. xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
  275. xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
  276. xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
  277. xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
  278. xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
  279. xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
  280. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
  281. xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
  282. xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
  283. xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
  284. xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
  285. xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
  286. xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
  287. xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
  288. xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
  289. xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
  290. xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
  291. xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
  292. xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
  293. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  294. xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
  295. xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
  296. xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
  297. xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
  298. xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
  299. xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
  300. xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
  301. xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
  302. xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
  303. xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
  304. xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
  305. xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
  306. xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
  307. xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
  308. xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
  309. xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
  310. xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
  311. xinference/web/ui/node_modules/.package-lock.json +67 -3
  312. xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
  313. xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
  314. xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
  315. xinference/web/ui/node_modules/i18next/package.json +129 -0
  316. xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
  317. xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
  318. xinference/web/ui/node_modules/react-i18next/package.json +162 -0
  319. xinference/web/ui/node_modules/void-elements/package.json +34 -0
  320. xinference/web/ui/package-lock.json +69 -3
  321. xinference/web/ui/package.json +2 -0
  322. xinference/web/ui/src/locales/en.json +186 -0
  323. xinference/web/ui/src/locales/zh.json +186 -0
  324. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/METADATA +96 -36
  325. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/RECORD +335 -146
  326. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/WHEEL +1 -1
  327. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  328. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  329. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  330. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  331. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  332. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  333. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  334. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  335. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  336. xinference/thirdparty/fish_speech/tools/api.py +0 -440
  337. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  338. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  339. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -34
  340. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  341. xinference/thirdparty/fish_speech/tools/webui.py +0 -485
  342. xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
  343. xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
  344. xinference/web/ui/build/static/js/main.2f269bb3.js +0 -3
  345. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  346. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
  347. xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
  348. xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
  349. xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
  350. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
  351. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
  352. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
  353. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
  354. xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
  355. xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
  356. xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
  357. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  358. xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
  359. xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
  360. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
  361. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
  362. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
  363. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
  364. xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
  365. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  366. /xinference/thirdparty/{cosyvoice/flow → melo}/__init__.py +0 -0
  367. /xinference/thirdparty/{cosyvoice/hifigan → melo/text/english_utils}/__init__.py +0 -0
  368. /xinference/thirdparty/{cosyvoice/llm → melo/text/es_phonemizer}/__init__.py +0 -0
  369. /xinference/thirdparty/{fish_speech/fish_speech/configs → melo/text/fr_phonemizer}/__init__.py +0 -0
  370. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  371. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/LICENSE +0 -0
  372. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/entry_points.txt +0 -0
  373. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,104 @@
1
+ import os
2
+ from argparse import ArgumentParser
3
+ from pathlib import Path
4
+
5
+ import pyrootutils
6
+ import torch
7
+ from loguru import logger
8
+
9
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
10
+
11
+ from tools.inference_engine import TTSInferenceEngine
12
+ from tools.llama.generate import launch_thread_safe_queue
13
+ from tools.schema import ServeTTSRequest
14
+ from tools.vqgan.inference import load_model as load_decoder_model
15
+ from tools.webui import build_app
16
+ from tools.webui.inference import get_inference_wrapper
17
+
18
+ # Make einx happy
19
+ os.environ["EINX_FILTER_TRACEBACK"] = "false"
20
+
21
+
22
+ def parse_args():
23
+ parser = ArgumentParser()
24
+ parser.add_argument(
25
+ "--llama-checkpoint-path",
26
+ type=Path,
27
+ default="checkpoints/fish-speech-1.5",
28
+ )
29
+ parser.add_argument(
30
+ "--decoder-checkpoint-path",
31
+ type=Path,
32
+ default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
33
+ )
34
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
35
+ parser.add_argument("--device", type=str, default="cuda")
36
+ parser.add_argument("--half", action="store_true")
37
+ parser.add_argument("--compile", action="store_true")
38
+ parser.add_argument("--max-gradio-length", type=int, default=0)
39
+ parser.add_argument("--theme", type=str, default="light")
40
+
41
+ return parser.parse_args()
42
+
43
+
44
+ if __name__ == "__main__":
45
+ args = parse_args()
46
+ args.precision = torch.half if args.half else torch.bfloat16
47
+
48
+ # Check if MPS or CUDA is available
49
+ if torch.backends.mps.is_available():
50
+ args.device = "mps"
51
+ logger.info("mps is available, running on mps.")
52
+ elif not torch.cuda.is_available():
53
+ logger.info("CUDA is not available, running on CPU.")
54
+ args.device = "cpu"
55
+
56
+ logger.info("Loading Llama model...")
57
+ llama_queue = launch_thread_safe_queue(
58
+ checkpoint_path=args.llama_checkpoint_path,
59
+ device=args.device,
60
+ precision=args.precision,
61
+ compile=args.compile,
62
+ )
63
+
64
+ logger.info("Loading VQ-GAN model...")
65
+ decoder_model = load_decoder_model(
66
+ config_name=args.decoder_config_name,
67
+ checkpoint_path=args.decoder_checkpoint_path,
68
+ device=args.device,
69
+ )
70
+
71
+ logger.info("Decoder model loaded, warming up...")
72
+
73
+ # Create the inference engine
74
+ inference_engine = TTSInferenceEngine(
75
+ llama_queue=llama_queue,
76
+ decoder_model=decoder_model,
77
+ compile=args.compile,
78
+ precision=args.precision,
79
+ )
80
+
81
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
82
+ list(
83
+ inference_engine.inference(
84
+ ServeTTSRequest(
85
+ text="Hello world.",
86
+ references=[],
87
+ reference_id=None,
88
+ max_new_tokens=1024,
89
+ chunk_length=200,
90
+ top_p=0.7,
91
+ repetition_penalty=1.5,
92
+ temperature=0.7,
93
+ format="wav",
94
+ )
95
+ )
96
+ )
97
+
98
+ logger.info("Warming up done, launching the web UI...")
99
+
100
+ # Get the inference function with the immutable arguments
101
+ inference_fct = get_inference_wrapper(inference_engine)
102
+
103
+ app = build_app(inference_fct, args.theme)
104
+ app.launch(show_api=True)
@@ -0,0 +1,170 @@
1
+ import os
2
+ import queue
3
+ from dataclasses import dataclass
4
+ from typing import Annotated, Literal
5
+
6
+ import torch
7
+ from pydantic import BaseModel, Field, conint, conlist
8
+ from pydantic.functional_validators import SkipValidation
9
+
10
+ from fish_speech.conversation import Message, TextPart, VQPart
11
+
12
+
13
+ class ServeVQPart(BaseModel):
14
+ type: Literal["vq"] = "vq"
15
+ codes: SkipValidation[list[list[int]]]
16
+
17
+
18
+ class ServeTextPart(BaseModel):
19
+ type: Literal["text"] = "text"
20
+ text: str
21
+
22
+
23
+ class ServeAudioPart(BaseModel):
24
+ type: Literal["audio"] = "audio"
25
+ audio: bytes
26
+
27
+
28
+ @dataclass
29
+ class ASRPackRequest:
30
+ audio: torch.Tensor
31
+ result_queue: queue.Queue
32
+ language: str
33
+
34
+
35
+ class ServeASRRequest(BaseModel):
36
+ # The audio should be an uncompressed PCM float16 audio
37
+ audios: list[bytes]
38
+ sample_rate: int = 44100
39
+ language: Literal["zh", "en", "ja", "auto"] = "auto"
40
+
41
+
42
+ class ServeASRTranscription(BaseModel):
43
+ text: str
44
+ duration: float
45
+ huge_gap: bool
46
+
47
+
48
+ class ServeASRSegment(BaseModel):
49
+ text: str
50
+ start: float
51
+ end: float
52
+
53
+
54
+ class ServeTimedASRResponse(BaseModel):
55
+ text: str
56
+ segments: list[ServeASRSegment]
57
+ duration: float
58
+
59
+
60
+ class ServeASRResponse(BaseModel):
61
+ transcriptions: list[ServeASRTranscription]
62
+
63
+
64
+ class ServeMessage(BaseModel):
65
+ role: Literal["system", "assistant", "user"]
66
+ parts: list[ServeVQPart | ServeTextPart]
67
+
68
+ def to_conversation_message(self):
69
+ new_message = Message(role=self.role, parts=[])
70
+ if self.role == "assistant":
71
+ new_message.modality = "voice"
72
+
73
+ for part in self.parts:
74
+ if isinstance(part, ServeTextPart):
75
+ new_message.parts.append(TextPart(text=part.text))
76
+ elif isinstance(part, ServeVQPart):
77
+ new_message.parts.append(
78
+ VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
79
+ )
80
+ else:
81
+ raise ValueError(f"Unsupported part type: {part}")
82
+
83
+ return new_message
84
+
85
+
86
+ class ServeChatRequest(BaseModel):
87
+ messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
88
+ max_new_tokens: int = 1024
89
+ top_p: float = 0.7
90
+ repetition_penalty: float = 1.2
91
+ temperature: float = 0.7
92
+ streaming: bool = False
93
+ num_samples: int = 1
94
+ early_stop_threshold: float = 1.0
95
+
96
+
97
+ class ServeVQGANEncodeRequest(BaseModel):
98
+ # The audio here should be in wav, mp3, etc
99
+ audios: list[bytes]
100
+
101
+
102
+ class ServeVQGANEncodeResponse(BaseModel):
103
+ tokens: SkipValidation[list[list[list[int]]]]
104
+
105
+
106
+ class ServeVQGANDecodeRequest(BaseModel):
107
+ tokens: SkipValidation[list[list[list[int]]]]
108
+
109
+
110
+ class ServeVQGANDecodeResponse(BaseModel):
111
+ # The audio here should be in PCM float16 format
112
+ audios: list[bytes]
113
+
114
+
115
+ class ServeForwardMessage(BaseModel):
116
+ role: str
117
+ content: str
118
+
119
+
120
+ class ServeResponse(BaseModel):
121
+ messages: list[ServeMessage]
122
+ finish_reason: Literal["stop", "error"] | None = None
123
+ stats: dict[str, int | float | str] = {}
124
+
125
+
126
+ class ServeStreamDelta(BaseModel):
127
+ role: Literal["system", "assistant", "user"] | None = None
128
+ part: ServeVQPart | ServeTextPart | None = None
129
+
130
+
131
+ class ServeStreamResponse(BaseModel):
132
+ sample_id: int = 0
133
+ delta: ServeStreamDelta | None = None
134
+ finish_reason: Literal["stop", "error"] | None = None
135
+ stats: dict[str, int | float | str] | None = None
136
+
137
+
138
+ class ServeReferenceAudio(BaseModel):
139
+ audio: bytes
140
+ text: str
141
+
142
+ def __repr__(self) -> str:
143
+ return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
144
+
145
+
146
+ class ServeTTSRequest(BaseModel):
147
+ text: str
148
+ chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
149
+ # Audio format
150
+ format: Literal["wav", "pcm", "mp3"] = "wav"
151
+ # References audios for in-context learning
152
+ references: list[ServeReferenceAudio] = []
153
+ # Reference id
154
+ # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
155
+ # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
156
+ reference_id: str | None = None
157
+ seed: int | None = None
158
+ use_memory_cache: Literal["on", "off"] = "off"
159
+ # Normalize text for en & zh, this increase stability for numbers
160
+ normalize: bool = True
161
+ # not usually used below
162
+ streaming: bool = False
163
+ max_new_tokens: int = 1024
164
+ top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
165
+ repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
166
+ temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
167
+
168
+ class Config:
169
+ # Allow arbitrary types for pytorch related types
170
+ arbitrary_types_allowed = True
@@ -0,0 +1,57 @@
1
+ import struct
2
+ from functools import partial
3
+
4
+ import ormsgpack
5
+
6
+ from tools.server.agent.generate import generate_responses
7
+ from tools.server.agent.pre_generation_utils import prepare_messages
8
+
9
+
10
+ def execute_request(input_queue, tokenizer, config, request, device):
11
+ """
12
+ This function prepares the conversation, encodes the request,
13
+ sends the generation request, and handles decoding/streaming.
14
+ It returns a response generator (ServeResponse or ServeStreamResponse).
15
+ """
16
+ prompt, im_end_id = prepare_messages(request, tokenizer, config)
17
+ yield from generate_responses(
18
+ input_queue, tokenizer, config, request, prompt, im_end_id, device
19
+ )
20
+
21
+
22
+ def response_generator(req, llama_queue, tokenizer, config, device):
23
+ """
24
+ Non-streaming response wrapper for the chat endpoint.
25
+ Only returns the final result.
26
+ """
27
+ generator = execute_request(llama_queue, tokenizer, config, req, device)
28
+ return next(generator)
29
+
30
+
31
+ async def streaming_generator(req, llama_queue, tokenizer, config, device, json_mode):
32
+ """
33
+ Streaming response wrapper for the chat endpoint.
34
+ Returns the response in chunks.
35
+ """
36
+ generator = execute_request(llama_queue, tokenizer, config, req, device)
37
+ for i in generator:
38
+ if json_mode:
39
+ body = i.model_dump_json().encode("utf-8")
40
+ yield b"data: " + body + b"\n\n"
41
+ else:
42
+ body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
43
+ yield struct.pack("I", len(body)) + body
44
+
45
+
46
+ def get_response_generator(
47
+ llama_queue, tokenizer, config, req, device, json_mode
48
+ ) -> partial:
49
+ """
50
+ Get the correct response generator based on the request.
51
+ """
52
+ if not req.streaming:
53
+ return partial(response_generator, req, llama_queue, tokenizer, config, device)
54
+ else:
55
+ return partial(
56
+ streaming_generator, req, llama_queue, tokenizer, config, device, json_mode
57
+ )
@@ -0,0 +1,119 @@
1
+ import time
2
+
3
+ from tools.schema import ServeMessage, ServeResponse, ServeStreamResponse
4
+ from tools.server.agent.generation_utils import (
5
+ initialize_decode_buffers,
6
+ process_response_tokens,
7
+ send_reset_buffer,
8
+ )
9
+ from tools.server.agent.pre_generation_utils import (
10
+ create_generation_request,
11
+ send_generation_request,
12
+ )
13
+
14
+
15
+ def generate_responses(
16
+ input_queue, tokenizer, config, request, prompt, im_end_id, device
17
+ ):
18
+ """
19
+ Main generation function that handles the conversation, encodes the request,
20
+ sends the generation request, and handles decoding/streaming.
21
+ It returns a response generator (ServeResponse or ServeStreamResponse).
22
+ """
23
+ stats = {}
24
+ start = time.time()
25
+ stats["start_time"] = start
26
+ stats["tokens_count"] = 0
27
+
28
+ # Prepare and send the generation request
29
+ req = create_generation_request(prompt, request, im_end_id, device)
30
+ response_queue = send_generation_request(input_queue, req)
31
+ decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples)
32
+
33
+ while True:
34
+ response = response_queue.get()
35
+
36
+ # Handle abnormal finish or error
37
+ if response in ["stop", "error"]:
38
+ finish_reason = response
39
+ break
40
+
41
+ # Process the response tokens
42
+ is_first_token = stats["tokens_count"] == 0
43
+ responses = process_response_tokens(
44
+ response,
45
+ tokenizer,
46
+ config,
47
+ request,
48
+ decode_buffer,
49
+ parts,
50
+ finished,
51
+ im_end_id,
52
+ stats,
53
+ start,
54
+ is_first_token,
55
+ )
56
+
57
+ # Yield the responses if streaming
58
+ if request.streaming and responses:
59
+ for r in responses:
60
+ yield r
61
+
62
+ stats["tokens_count"] += 1
63
+
64
+ # Check if all samples are finished
65
+ if all(finished):
66
+ finish_reason = "stop"
67
+ break
68
+
69
+ # Finalize the response
70
+ final_responses = finalize_response(
71
+ request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
72
+ )
73
+ for fr in final_responses:
74
+ yield fr
75
+
76
+
77
+ def finalize_response(
78
+ request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
79
+ ):
80
+ """
81
+ Finalize the response by sending the remaining text buffers.
82
+ """
83
+ responses = []
84
+
85
+ # Send the remaining text buffers
86
+ for sample_id in range(request.num_samples):
87
+ responses.extend(
88
+ send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
89
+ )
90
+
91
+ # Calculate the final stats
92
+ stats["total_time"] = (time.time() - stats["start_time"]) * 1000
93
+ stats["total_tokens"] = stats["tokens_count"]
94
+
95
+ # If streaming, send the final chunks for each sample
96
+ if request.streaming:
97
+ for sample_id in range(request.num_samples):
98
+ if finished[sample_id]:
99
+ continue
100
+ responses.append(
101
+ ServeStreamResponse(
102
+ finish_reason=finish_reason, stats=stats, sample_id=sample_id
103
+ )
104
+ )
105
+ else:
106
+ # If not streaming, send the full messages for each sample
107
+ full_messages = [
108
+ ServeMessage(role="assistant", parts=parts[i])
109
+ for i in range(request.num_samples)
110
+ ]
111
+ responses.append(
112
+ ServeResponse(
113
+ messages=full_messages,
114
+ finish_reason=finish_reason,
115
+ stats=stats,
116
+ )
117
+ )
118
+
119
+ return responses
@@ -0,0 +1,122 @@
1
+ import time
2
+
3
+ from tools.schema import (
4
+ ServeStreamDelta,
5
+ ServeStreamResponse,
6
+ ServeTextPart,
7
+ ServeVQPart,
8
+ )
9
+
10
+
11
+ def initialize_decode_buffers(num_samples):
12
+ """Initialise the decode buffers for each sample."""
13
+ decode_buffer = [[] for _ in range(num_samples)]
14
+ parts = [[] for _ in range(num_samples)]
15
+ finished = [False for _ in range(num_samples)]
16
+ return decode_buffer, parts, finished
17
+
18
+
19
+ def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request):
20
+ """Send the remaining text buffer for a sample."""
21
+ if len(decode_buffer[sample_id]) == 0:
22
+ return []
23
+
24
+ decoded = tokenizer.decode(decode_buffer[sample_id])
25
+ part = ServeTextPart(text=decoded)
26
+
27
+ responses = []
28
+ if request.streaming:
29
+ responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part)))
30
+ else:
31
+ parts[sample_id].append(part)
32
+
33
+ decode_buffer[sample_id] = []
34
+ return responses
35
+
36
+
37
+ def handle_semantic_tokens(tokens, config, sample_id, parts, request):
38
+ """Handle the semantic tokens returned by the model."""
39
+ responses = []
40
+ _tokens = tokens[1:].clone()
41
+
42
+ if not config.share_codebook_embeddings:
43
+ for i in range(len(_tokens)):
44
+ _tokens[i] -= config.codebook_size * i
45
+
46
+ # If streaming, send the VQ parts directly
47
+ if request.streaming:
48
+ responses.append(
49
+ ServeStreamResponse(
50
+ sample_id=sample_id,
51
+ delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
52
+ )
53
+ )
54
+ else:
55
+ # If not streaming, accumulate the VQ parts
56
+ if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart):
57
+ parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
58
+ else:
59
+ # Accumulate the codes
60
+ for codebook_id, value in enumerate(_tokens):
61
+ parts[sample_id][-1].codes[codebook_id].append(value.item())
62
+
63
+ return responses
64
+
65
+
66
+ def process_response_tokens(
67
+ response,
68
+ tokenizer,
69
+ config,
70
+ request,
71
+ decode_buffer,
72
+ parts,
73
+ finished,
74
+ im_end_id,
75
+ stats,
76
+ start,
77
+ is_first_token,
78
+ ):
79
+ """Process the response tokens returned by the model."""
80
+ responses = []
81
+ for sample_id, tokens in enumerate(response):
82
+ if finished[sample_id]:
83
+ continue
84
+
85
+ # End of the conversation
86
+ if tokens[0] == im_end_id:
87
+ finished[sample_id] = True
88
+ # Send the remaining text buffer
89
+ responses.extend(
90
+ send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
91
+ )
92
+ if request.streaming:
93
+ responses.append(
94
+ ServeStreamResponse(
95
+ sample_id=sample_id,
96
+ finish_reason="stop",
97
+ stats=stats,
98
+ )
99
+ )
100
+ continue
101
+
102
+ # Check if the token is semantic
103
+ is_semantic = (
104
+ tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
105
+ )
106
+
107
+ if is_semantic:
108
+ # Before the semantic tokens, send the remaining text buffer
109
+ responses.extend(
110
+ send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
111
+ )
112
+ responses.extend(
113
+ handle_semantic_tokens(tokens, config, sample_id, parts, request)
114
+ )
115
+ else:
116
+ # Accumulate the text tokens (not implemented?)
117
+ decode_buffer[sample_id].append(tokens[0, 0])
118
+
119
+ if is_first_token:
120
+ stats["time_to_first_token"] = (time.time() - start) * 1000
121
+
122
+ return responses
@@ -0,0 +1,72 @@
1
+ import queue
2
+
3
+ from fish_speech.conversation import Conversation, Message
4
+ from fish_speech.tokenizer import IM_END_TOKEN
5
+ from tools.llama.generate import GenerateRequest
6
+
7
+
8
+ def prepare_messages(request, tokenizer, config):
9
+ """
10
+ Reorganise the provided list of messages into a conversation.
11
+ Encode the conversation for inference.
12
+ """
13
+ # Convert the messages to ConversationMessage objects
14
+ messages = [msg.to_conversation_message() for msg in request.messages]
15
+
16
+ if len(messages) < 1:
17
+ raise ValueError("At least one message is required")
18
+
19
+ # Check the last message to determine the next step
20
+ last_role = messages[-1].role
21
+ match last_role:
22
+ case "user":
23
+ # The last message is from the user, ask the assistant to respond with a new message
24
+ messages.append(
25
+ Message(role="assistant", parts=[], add_im_end=False, modality="voice")
26
+ )
27
+ case "raw":
28
+ # The last message is raw text, ask the assistant to complete it
29
+ messages[-1].add_im_start = False
30
+ messages[-1].add_im_end = False
31
+ messages[-1].modality = "voice"
32
+ case "assistant":
33
+ # The last message is from the assistant, ask the assistant to continue
34
+ messages[-1].add_im_end = False
35
+ case _:
36
+ # We expect it to be assistant if not user or raw
37
+ raise ValueError("The last message must be from the assistant, user or raw")
38
+
39
+ # Create a conversation object and encode it for inference
40
+ conv = Conversation(messages=messages)
41
+ prompt = conv.encode_for_inference(
42
+ tokenizer=tokenizer, num_codebooks=config.num_codebooks
43
+ )
44
+ im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
45
+
46
+ return prompt, im_end_id
47
+
48
+
49
+ def create_generation_request(prompt, request, im_end_id, device):
50
+ """
51
+ Convert the request into a dictionary that can be sent to the model for generation.
52
+ """
53
+ req = {
54
+ "prompt": prompt.to(device),
55
+ "max_new_tokens": request.max_new_tokens,
56
+ "im_end_id": im_end_id,
57
+ "temperature": request.temperature,
58
+ "top_p": request.top_p,
59
+ "repetition_penalty": request.repetition_penalty,
60
+ "num_samples": request.num_samples,
61
+ "early_stop_threshold": request.early_stop_threshold,
62
+ }
63
+ return req
64
+
65
+
66
+ def send_generation_request(input_queue, req):
67
+ """
68
+ Send the generation request to the model and return a queue to get the response.
69
+ """
70
+ response_queue = queue.Queue()
71
+ input_queue.put(GenerateRequest(req, response_queue))
72
+ return response_queue