xinference 0.16.3__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (373) hide show
  1. xinference/_compat.py +24 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +219 -77
  4. xinference/client/restful/restful_client.py +47 -2
  5. xinference/constants.py +1 -0
  6. xinference/core/chat_interface.py +6 -1
  7. xinference/core/model.py +124 -34
  8. xinference/core/supervisor.py +180 -12
  9. xinference/core/utils.py +73 -4
  10. xinference/core/worker.py +102 -4
  11. xinference/deploy/cmdline.py +3 -1
  12. xinference/deploy/test/test_cmdline.py +56 -0
  13. xinference/isolation.py +24 -0
  14. xinference/model/audio/__init__.py +12 -0
  15. xinference/model/audio/core.py +37 -4
  16. xinference/model/audio/cosyvoice.py +39 -6
  17. xinference/model/audio/f5tts.py +200 -0
  18. xinference/model/audio/f5tts_mlx.py +260 -0
  19. xinference/model/audio/fish_speech.py +70 -110
  20. xinference/model/audio/melotts.py +110 -0
  21. xinference/model/audio/model_spec.json +179 -3
  22. xinference/model/audio/model_spec_modelscope.json +27 -0
  23. xinference/model/audio/utils.py +32 -0
  24. xinference/model/audio/whisper.py +35 -10
  25. xinference/model/audio/whisper_mlx.py +208 -0
  26. xinference/model/embedding/core.py +322 -6
  27. xinference/model/embedding/model_spec.json +8 -1
  28. xinference/model/embedding/model_spec_modelscope.json +9 -1
  29. xinference/model/image/core.py +69 -1
  30. xinference/model/image/model_spec.json +145 -4
  31. xinference/model/image/model_spec_modelscope.json +150 -4
  32. xinference/model/image/stable_diffusion/core.py +50 -15
  33. xinference/model/llm/__init__.py +6 -2
  34. xinference/model/llm/llm_family.json +1055 -93
  35. xinference/model/llm/llm_family.py +15 -36
  36. xinference/model/llm/llm_family_modelscope.json +1031 -78
  37. xinference/model/llm/memory.py +1 -1
  38. xinference/model/llm/mlx/core.py +285 -47
  39. xinference/model/llm/sglang/core.py +2 -0
  40. xinference/model/llm/transformers/chatglm.py +9 -5
  41. xinference/model/llm/transformers/cogagent.py +272 -0
  42. xinference/model/llm/transformers/core.py +3 -0
  43. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  44. xinference/model/llm/transformers/qwen2_vl.py +12 -1
  45. xinference/model/llm/transformers/utils.py +16 -8
  46. xinference/model/llm/utils.py +55 -4
  47. xinference/model/llm/vllm/core.py +137 -12
  48. xinference/model/llm/vllm/xavier/__init__.py +13 -0
  49. xinference/model/llm/vllm/xavier/allocator.py +74 -0
  50. xinference/model/llm/vllm/xavier/block.py +111 -0
  51. xinference/model/llm/vllm/xavier/block_manager.py +71 -0
  52. xinference/model/llm/vllm/xavier/block_tracker.py +129 -0
  53. xinference/model/llm/vllm/xavier/collective.py +74 -0
  54. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  55. xinference/model/llm/vllm/xavier/engine.py +247 -0
  56. xinference/model/llm/vllm/xavier/executor.py +134 -0
  57. xinference/model/llm/vllm/xavier/scheduler.py +438 -0
  58. xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
  59. xinference/model/llm/vllm/xavier/test/test_xavier.py +147 -0
  60. xinference/model/llm/vllm/xavier/transfer.py +319 -0
  61. xinference/model/rerank/core.py +11 -4
  62. xinference/model/video/diffusers.py +14 -0
  63. xinference/model/video/model_spec.json +15 -0
  64. xinference/model/video/model_spec_modelscope.json +16 -0
  65. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  66. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  67. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  68. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  69. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  70. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  71. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  72. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  73. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  74. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  75. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  76. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  77. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  78. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  79. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  80. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  81. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  82. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  83. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  84. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  85. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  86. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  87. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  88. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  89. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  90. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  91. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  92. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  93. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  94. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  95. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  96. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  97. xinference/thirdparty/f5_tts/api.py +166 -0
  98. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  99. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  100. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  101. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  102. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  103. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  104. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  105. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  106. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  107. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  108. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  109. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  110. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  111. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  112. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  113. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  114. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  115. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  116. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  117. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  118. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  119. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  120. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  121. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  122. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  123. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  124. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  125. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  126. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  127. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  128. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  129. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  130. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  131. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  132. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  133. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  134. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  135. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  136. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  137. xinference/thirdparty/f5_tts/train/README.md +77 -0
  138. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  139. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  140. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  141. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  142. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  143. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  144. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  145. xinference/thirdparty/f5_tts/train/train.py +75 -0
  146. xinference/thirdparty/fish_speech/fish_speech/conversation.py +266 -1
  147. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  148. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  149. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  150. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  151. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  152. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +137 -29
  153. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  154. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  155. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +17 -11
  156. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  157. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  158. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  159. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  160. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  161. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  162. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +2 -2
  163. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +34 -18
  164. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  165. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  166. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  167. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  168. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  169. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  170. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  171. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  172. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  173. xinference/thirdparty/fish_speech/tools/llama/generate.py +484 -72
  174. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  175. xinference/thirdparty/fish_speech/tools/schema.py +170 -0
  176. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  177. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  178. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  179. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  180. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  181. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  182. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  183. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  184. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  185. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  186. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  187. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  188. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  189. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  190. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  191. xinference/thirdparty/matcha/utils/utils.py +2 -2
  192. xinference/thirdparty/melo/api.py +135 -0
  193. xinference/thirdparty/melo/app.py +61 -0
  194. xinference/thirdparty/melo/attentions.py +459 -0
  195. xinference/thirdparty/melo/commons.py +160 -0
  196. xinference/thirdparty/melo/configs/config.json +94 -0
  197. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  198. xinference/thirdparty/melo/data_utils.py +413 -0
  199. xinference/thirdparty/melo/download_utils.py +67 -0
  200. xinference/thirdparty/melo/infer.py +25 -0
  201. xinference/thirdparty/melo/init_downloads.py +14 -0
  202. xinference/thirdparty/melo/losses.py +58 -0
  203. xinference/thirdparty/melo/main.py +36 -0
  204. xinference/thirdparty/melo/mel_processing.py +174 -0
  205. xinference/thirdparty/melo/models.py +1030 -0
  206. xinference/thirdparty/melo/modules.py +598 -0
  207. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  208. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  209. xinference/thirdparty/melo/preprocess_text.py +135 -0
  210. xinference/thirdparty/melo/split_utils.py +174 -0
  211. xinference/thirdparty/melo/text/__init__.py +35 -0
  212. xinference/thirdparty/melo/text/chinese.py +199 -0
  213. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  214. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  215. xinference/thirdparty/melo/text/cleaner.py +36 -0
  216. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  217. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  218. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  219. xinference/thirdparty/melo/text/english.py +284 -0
  220. xinference/thirdparty/melo/text/english_bert.py +39 -0
  221. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  222. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  223. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  224. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  225. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  226. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  227. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  228. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  229. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  230. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  231. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  232. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  233. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  234. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  235. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  236. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  237. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  238. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  239. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  240. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  241. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  242. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  243. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  244. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  245. xinference/thirdparty/melo/text/french.py +94 -0
  246. xinference/thirdparty/melo/text/french_bert.py +39 -0
  247. xinference/thirdparty/melo/text/japanese.py +647 -0
  248. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  249. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  250. xinference/thirdparty/melo/text/korean.py +192 -0
  251. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  252. xinference/thirdparty/melo/text/spanish.py +122 -0
  253. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  254. xinference/thirdparty/melo/text/symbols.py +290 -0
  255. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  256. xinference/thirdparty/melo/train.py +635 -0
  257. xinference/thirdparty/melo/train.sh +19 -0
  258. xinference/thirdparty/melo/transforms.py +209 -0
  259. xinference/thirdparty/melo/utils.py +424 -0
  260. xinference/types.py +17 -1
  261. xinference/web/ui/build/asset-manifest.json +6 -6
  262. xinference/web/ui/build/index.html +1 -1
  263. xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
  264. xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
  265. xinference/web/ui/build/static/js/main.b0936c54.js +3 -0
  266. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  267. xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
  268. xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
  269. xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
  270. xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
  271. xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
  272. xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
  273. xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
  274. xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
  275. xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
  276. xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
  277. xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
  278. xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
  279. xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
  280. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
  281. xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
  282. xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
  283. xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
  284. xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
  285. xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
  286. xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
  287. xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
  288. xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
  289. xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
  290. xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
  291. xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
  292. xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
  293. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  294. xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
  295. xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
  296. xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
  297. xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
  298. xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
  299. xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
  300. xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
  301. xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
  302. xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
  303. xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
  304. xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
  305. xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
  306. xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
  307. xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
  308. xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
  309. xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
  310. xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
  311. xinference/web/ui/node_modules/.package-lock.json +67 -3
  312. xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
  313. xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
  314. xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
  315. xinference/web/ui/node_modules/i18next/package.json +129 -0
  316. xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
  317. xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
  318. xinference/web/ui/node_modules/react-i18next/package.json +162 -0
  319. xinference/web/ui/node_modules/void-elements/package.json +34 -0
  320. xinference/web/ui/package-lock.json +69 -3
  321. xinference/web/ui/package.json +2 -0
  322. xinference/web/ui/src/locales/en.json +186 -0
  323. xinference/web/ui/src/locales/zh.json +186 -0
  324. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/METADATA +96 -36
  325. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/RECORD +335 -146
  326. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/WHEEL +1 -1
  327. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  328. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  329. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  330. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  331. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  332. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  333. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  334. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  335. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  336. xinference/thirdparty/fish_speech/tools/api.py +0 -440
  337. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  338. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  339. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -34
  340. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  341. xinference/thirdparty/fish_speech/tools/webui.py +0 -485
  342. xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
  343. xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
  344. xinference/web/ui/build/static/js/main.2f269bb3.js +0 -3
  345. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  346. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
  347. xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
  348. xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
  349. xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
  350. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
  351. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
  352. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
  353. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
  354. xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
  355. xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
  356. xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
  357. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  358. xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
  359. xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
  360. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
  361. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
  362. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
  363. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
  364. xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
  365. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  366. /xinference/thirdparty/{cosyvoice/flow → melo}/__init__.py +0 -0
  367. /xinference/thirdparty/{cosyvoice/hifigan → melo/text/english_utils}/__init__.py +0 -0
  368. /xinference/thirdparty/{cosyvoice/llm → melo/text/es_phonemizer}/__init__.py +0 -0
  369. /xinference/thirdparty/{fish_speech/fish_speech/configs → melo/text/fr_phonemizer}/__init__.py +0 -0
  370. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  371. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/LICENSE +0 -0
  372. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/entry_points.txt +0 -0
  373. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/top_level.txt +0 -0
@@ -15,9 +15,18 @@ import torch._dynamo.config
15
15
  import torch._inductor.config
16
16
  from loguru import logger
17
17
  from tqdm import tqdm
18
-
19
- from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
18
+ from transformers import AutoTokenizer
19
+
20
+ from fish_speech.conversation import (
21
+ CODEBOOK_PAD_TOKEN_ID,
22
+ Conversation,
23
+ Message,
24
+ TextPart,
25
+ VQPart,
26
+ )
27
+ from fish_speech.models.text2semantic.llama import BaseModelArgs
20
28
  from fish_speech.text import clean_text, split_text
29
+ from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
21
30
 
22
31
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
32
  torch._inductor.config.coordinate_descent_tuning = True
@@ -28,6 +37,8 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
28
37
  torch._inductor.config.fx_graph_cache = True
29
38
 
30
39
 
40
+ from torch.nn.attention import SDPBackend, sdpa_kernel
41
+
31
42
  from fish_speech.models.text2semantic.llama import (
32
43
  BaseTransformer,
33
44
  DualARTransformer,
@@ -74,6 +85,45 @@ def logits_to_probs(
74
85
  return probs
75
86
 
76
87
 
88
+ def multinomial_sample_one_no_sync_agent(
89
+ probs_sort,
90
+ ): # Does multinomial sampling without a cuda synchronization
91
+ q = torch.empty_like(probs_sort).exponential_(1)
92
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
93
+
94
+
95
+ def logits_to_probs_agent(
96
+ logits,
97
+ previous_tokens: Optional[torch.Tensor] = None,
98
+ temperature: torch.Tensor = 1.0,
99
+ top_p: torch.Tensor = 1.0,
100
+ repetition_penalty: torch.Tensor = 1.0,
101
+ ) -> torch.Tensor:
102
+ # Apply repetition penalty
103
+ if previous_tokens is not None:
104
+ previous_tokens = previous_tokens.long()
105
+ score = torch.gather(logits, dim=-1, index=previous_tokens)
106
+ score = torch.where(
107
+ score < 0, score * repetition_penalty, score / repetition_penalty
108
+ )
109
+ logits.scatter_(dim=-1, index=previous_tokens, src=score)
110
+
111
+ # Apply top-p sampling
112
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
113
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
114
+ sorted_indices_to_remove = cum_probs > top_p
115
+ sorted_indices_to_remove[..., 0] = False # keep at least one option
116
+ indices_to_remove = sorted_indices_to_remove.scatter(
117
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
118
+ )
119
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
120
+
121
+ logits = logits / max(temperature, 1e-5)
122
+
123
+ probs = torch.nn.functional.softmax(logits, dim=-1)
124
+ return probs
125
+
126
+
77
127
  def sample(
78
128
  logits,
79
129
  previous_tokens: Optional[torch.Tensor] = None,
@@ -86,38 +136,161 @@ def sample(
86
136
  return idx_next, probs
87
137
 
88
138
 
89
- def decode_one_token_ar(
139
+ def sample_agent(
140
+ logits,
141
+ previous_tokens: Optional[torch.Tensor] = None,
142
+ **sampling_kwargs,
143
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
144
+ probs = logits_to_probs_agent(
145
+ logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
146
+ )
147
+ idx_next = multinomial_sample_one_no_sync_agent(probs)
148
+ return idx_next, probs
149
+
150
+
151
+ def decode_one_token_ar_agent(
90
152
  model: DualARTransformer,
91
153
  x: torch.Tensor,
92
154
  input_pos: torch.Tensor,
155
+ semantic_ids: list,
93
156
  previous_tokens: torch.Tensor = None,
94
157
  **sampling_kwargs,
95
158
  ) -> torch.Tensor:
159
+ # print(x, input_pos)
96
160
  x = model.forward_generate(x, input_pos)
161
+ logits = x.logits # [:, -1:]
162
+ hidden_states = x.hidden_states # [:, -1:]
97
163
 
98
164
  sampling_kwargs_main = sampling_kwargs.copy()
99
165
  sampling_kwargs_main["temperature"] = 0.1
100
166
  sampling_kwargs_main["top_p"] = 0.1
101
167
  sampling_kwargs_main["repetition_penalty"] = 1.0
102
168
 
169
+ codebooks = [
170
+ sample_agent(
171
+ logits,
172
+ previous_tokens=None, # Disable repetition penalty for the token codebook
173
+ **sampling_kwargs_main,
174
+ )[0]
175
+ ]
176
+
177
+ # Cleanup the cache
178
+ for layer in model.fast_layers:
179
+ layer.attention.kv_cache.k_cache.fill_(0)
180
+ layer.attention.kv_cache.v_cache.fill_(0)
181
+
182
+ for codebook_idx in range(model.config.num_codebooks):
183
+ input_pos = torch.tensor(
184
+ [codebook_idx], device=hidden_states.device, dtype=torch.long
185
+ )
186
+ logits = model.forward_generate_fast(hidden_states, input_pos)
187
+ a = sample_agent(
188
+ logits,
189
+ previous_tokens=(
190
+ previous_tokens[:, codebook_idx + 1]
191
+ if previous_tokens is not None
192
+ else None
193
+ ),
194
+ **sampling_kwargs,
195
+ )[0]
196
+ hidden_states = model.fast_embeddings(a)
197
+ codebooks.append(a)
198
+
199
+ codebooks = torch.stack(codebooks, dim=1)
200
+ semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
201
+ codebooks[:, 1:, :] = torch.masked_fill(
202
+ codebooks[:, 1:, :],
203
+ ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
204
+ CODEBOOK_PAD_TOKEN_ID,
205
+ )
206
+
207
+ return codebooks
208
+
209
+
210
+ def decode_one_token_naive_agent(
211
+ model: NaiveTransformer,
212
+ x: torch.Tensor,
213
+ input_pos: torch.Tensor,
214
+ semantic_ids: list,
215
+ previous_tokens: torch.Tensor = None,
216
+ **sampling_kwargs,
217
+ ) -> torch.Tensor:
218
+ x = model.forward_generate(x, input_pos)
219
+
103
220
  codebooks = [
104
221
  sample(
105
- x.logits,
222
+ x.token_logits,
106
223
  previous_tokens=None, # Disable repetition penalty for the token codebook
224
+ **sampling_kwargs,
225
+ )[0]
226
+ ]
227
+
228
+ for i in range(model.config.num_codebooks):
229
+ codebooks.append(
230
+ sample_agent(
231
+ x.codebook_logits[:, :, i],
232
+ previous_tokens=(
233
+ previous_tokens[:, i + 1] if previous_tokens is not None else None
234
+ ),
235
+ **sampling_kwargs,
236
+ )[0]
237
+ )
238
+
239
+ codebooks = torch.stack(codebooks, dim=1)
240
+ semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
241
+ codebooks[:, 1:, :] = torch.masked_fill(
242
+ codebooks[:, 1:, :],
243
+ ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
244
+ CODEBOOK_PAD_TOKEN_ID,
245
+ )
246
+
247
+ return codebooks
248
+
249
+
250
+ def decode_one_token_ar(
251
+ model: DualARTransformer,
252
+ x: torch.Tensor,
253
+ input_pos: torch.Tensor,
254
+ semantic_ids: list,
255
+ previous_tokens: torch.Tensor = None,
256
+ **sampling_kwargs,
257
+ ) -> torch.Tensor:
258
+ x = model.forward_generate(x, input_pos)
259
+
260
+ sampling_kwargs_main = sampling_kwargs.copy()
261
+ # sampling_kwargs_main["temperature"] = 0.1
262
+ # sampling_kwargs_main["top_p"] = 0.1
263
+ # sampling_kwargs_main["repetition_penalty"] = 1.0
264
+
265
+ codebooks = [
266
+ sample(
267
+ x.logits,
268
+ previous_tokens=(
269
+ previous_tokens[0] if previous_tokens is not None else None
270
+ ), # Disable repetition penalty for the token codebook
107
271
  **sampling_kwargs_main,
108
272
  )[0]
109
273
  ]
110
274
 
111
- x = x.hidden_states
275
+ hidden_states = x.hidden_states
112
276
 
113
277
  # Cleanup the cache
114
278
  for layer in model.fast_layers:
115
279
  layer.attention.kv_cache.k_cache.fill_(0)
116
280
  layer.attention.kv_cache.v_cache.fill_(0)
117
281
 
118
- for codebook_idx in range(model.config.num_codebooks):
119
- input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
120
- logits = model.forward_generate_fast(x, input_pos)
282
+ input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
283
+ model.forward_generate_fast(hidden_states, input_pos)
284
+ a = codebooks[0] - model.tokenizer.semantic_begin_id
285
+ a[a < 0] = 0
286
+ hidden_states = model.fast_embeddings(a)
287
+ codebooks.append(a)
288
+
289
+ for codebook_idx in range(1, model.config.num_codebooks):
290
+ input_pos = torch.tensor(
291
+ [codebook_idx], device=hidden_states.device, dtype=torch.long
292
+ )
293
+ logits = model.forward_generate_fast(hidden_states, input_pos)
121
294
  a = sample(
122
295
  logits,
123
296
  previous_tokens=(
@@ -127,10 +300,17 @@ def decode_one_token_ar(
127
300
  ),
128
301
  **sampling_kwargs,
129
302
  )[0]
130
- x = model.fast_embeddings(a)
303
+ hidden_states = model.fast_embeddings(a)
131
304
  codebooks.append(a)
132
305
 
133
- return torch.stack(codebooks, dim=0)
306
+ codebooks = torch.stack(codebooks, dim=0)
307
+ # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
308
+ # codebooks[1:, :] = torch.masked_fill(
309
+ # codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
310
+ # )
311
+
312
+ # print(codebooks)
313
+ return codebooks
134
314
 
135
315
 
136
316
  def decode_one_token_naive(
@@ -174,7 +354,7 @@ def decode_n_tokens(
174
354
  cur_token: torch.Tensor,
175
355
  input_pos: torch.Tensor,
176
356
  num_new_tokens: int,
177
- im_end_id: int = 4,
357
+ semantic_ids: list,
178
358
  decode_one_token=decode_one_token_naive,
179
359
  **sampling_kwargs,
180
360
  ):
@@ -204,6 +384,7 @@ def decode_n_tokens(
204
384
  x=cur_token,
205
385
  input_pos=input_pos,
206
386
  previous_tokens=window,
387
+ semantic_ids=semantic_ids,
207
388
  **sampling_kwargs,
208
389
  )
209
390
 
@@ -213,7 +394,7 @@ def decode_n_tokens(
213
394
  model.config.num_codebooks + 1, -1
214
395
  )
215
396
 
216
- if cur_token[0, 0, -1] == im_end_id:
397
+ if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
217
398
  break
218
399
 
219
400
  return previous_tokens[:, : i + 1]
@@ -226,7 +407,6 @@ def generate(
226
407
  model: NaiveTransformer,
227
408
  prompt: torch.Tensor,
228
409
  max_new_tokens: int,
229
- im_end_id: int = 4,
230
410
  decode_one_token=decode_one_token_naive,
231
411
  **sampling_kwargs,
232
412
  ) -> torch.Tensor:
@@ -236,12 +416,28 @@ def generate(
236
416
 
237
417
  # create an empty tensor of the expected final shape and fill in the current tokens
238
418
  T = prompt.size(1)
419
+ # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
420
+ semantic_ids = [
421
+ model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
422
+ ]
423
+
424
+ if max_new_tokens:
425
+ if T + max_new_tokens > model.config.max_seq_len:
426
+ max_new_tokens = model.config.max_seq_len - T
427
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
428
+
429
+ T_new = T + max_new_tokens
430
+ else:
431
+ T_new = model.config.max_seq_len
432
+ max_new_tokens = T_new - T
239
433
 
240
434
  device, dtype = prompt.device, prompt.dtype
241
435
 
242
436
  codebook_dim = 1 + model.config.num_codebooks
243
437
  # create an empty tensor of the expected final shape and fill in the current tokens
244
- empty = torch.empty((codebook_dim, max_new_tokens), dtype=dtype, device=device)
438
+ empty = torch.empty(
439
+ (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
440
+ )
245
441
  empty[:, :T] = prompt
246
442
  seq = empty
247
443
  input_pos = torch.arange(0, T, device=device)
@@ -254,7 +450,11 @@ def generate(
254
450
  )
255
451
 
256
452
  next_token = prefill_decode(
257
- model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
453
+ model,
454
+ prompt.view(1, codebook_dim, -1),
455
+ input_pos,
456
+ semantic_ids=semantic_ids,
457
+ **sampling_kwargs,
258
458
  )
259
459
  seq[:, T : T + 1] = next_token
260
460
 
@@ -264,8 +464,8 @@ def generate(
264
464
  next_token.view(1, codebook_dim, -1),
265
465
  input_pos,
266
466
  max_new_tokens - 1,
267
- im_end_id=im_end_id,
268
467
  decode_one_token=decode_one_token,
468
+ semantic_ids=semantic_ids,
269
469
  **sampling_kwargs,
270
470
  )
271
471
  # x = torch.cat(generated_tokens, dim=1)
@@ -275,6 +475,142 @@ def generate(
275
475
  return seq
276
476
 
277
477
 
478
+ def decode_n_tokens_agent(
479
+ model: NaiveTransformer,
480
+ cur_token: torch.Tensor,
481
+ input_pos: torch.Tensor,
482
+ num_new_tokens: int,
483
+ semantic_ids: list,
484
+ im_end_id: int = 4,
485
+ decode_one_token=decode_one_token_naive_agent,
486
+ early_stop_threshold: float = 0.6,
487
+ **sampling_kwargs,
488
+ ):
489
+ batch_size = cur_token.size(0)
490
+ previous_tokens = torch.zeros(
491
+ (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
492
+ dtype=torch.int,
493
+ device=cur_token.device,
494
+ )
495
+ finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
496
+ finished = finished | (cur_token[:, 0, -1] == im_end_id)
497
+ start_time = time.time()
498
+
499
+ for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
500
+ # We need to get windowed repeat penalty
501
+ win_size = 16
502
+ if i < win_size:
503
+ window = previous_tokens[:, :, :win_size]
504
+ else:
505
+ window = previous_tokens[:, :, i - win_size : i]
506
+
507
+ with sdpa_kernel(
508
+ SDPBackend.MATH
509
+ ): # Actually better for Inductor to codegen attention here
510
+ next_token = decode_one_token(
511
+ model=model,
512
+ x=cur_token,
513
+ input_pos=input_pos,
514
+ previous_tokens=window,
515
+ semantic_ids=semantic_ids,
516
+ **sampling_kwargs,
517
+ )
518
+
519
+ input_pos += 1
520
+ cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
521
+ previous_tokens[:, :, i : i + 1] = next_token.view(
522
+ batch_size, model.config.num_codebooks + 1, -1
523
+ )
524
+
525
+ yield cur_token.cpu()
526
+
527
+ finished = finished | (cur_token[:, 0, -1] == im_end_id)
528
+ if finished.all() or (
529
+ 0 < early_stop_threshold < 1
530
+ and finished.sum() >= round(batch_size * early_stop_threshold)
531
+ ):
532
+ break
533
+
534
+ total_time = time.time() - start_time
535
+ generated_tokens = i + 1
536
+ tokens_per_second = (generated_tokens / total_time) * batch_size
537
+ logger.info(
538
+ f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
539
+ )
540
+
541
+
542
+ @torch.no_grad()
543
+ @torch.inference_mode()
544
+ def generate_agent(
545
+ *,
546
+ model: BaseTransformer,
547
+ prompt: torch.Tensor,
548
+ max_new_tokens: int,
549
+ semantic_ids: list,
550
+ im_end_id: int = 4,
551
+ decode_one_token=decode_one_token_naive_agent,
552
+ num_samples: int = 1,
553
+ early_stop_threshold: float = 0.6,
554
+ **sampling_kwargs,
555
+ ):
556
+ """
557
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
558
+ """
559
+
560
+ # create an empty tensor of the expected final shape and fill in the current tokens
561
+ T = prompt.size(1)
562
+ prompt = prompt[None].repeat(num_samples, 1, 1)
563
+
564
+ if T >= model.config.max_seq_len:
565
+ raise ValueError(
566
+ f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
567
+ )
568
+
569
+ if max_new_tokens:
570
+ if T + max_new_tokens > model.config.max_seq_len:
571
+ max_new_tokens = model.config.max_seq_len - T
572
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
573
+
574
+ T_new = T + max_new_tokens
575
+ else:
576
+ T_new = model.config.max_seq_len
577
+ max_new_tokens = T_new - T
578
+
579
+ device, dtype = prompt.device, prompt.dtype
580
+
581
+ codebook_dim = 1 + model.config.num_codebooks
582
+ input_pos = torch.arange(0, T, device=device)
583
+
584
+ # Use non-accelerated version for now, to avoid compilation overhead
585
+ prefill_decode = (
586
+ decode_one_token_naive_agent
587
+ if isinstance(model, NaiveTransformer)
588
+ else decode_one_token_ar_agent
589
+ )
590
+ next_token = prefill_decode(
591
+ model,
592
+ prompt,
593
+ input_pos,
594
+ semantic_ids=semantic_ids,
595
+ **sampling_kwargs,
596
+ ).view(num_samples, codebook_dim, -1)
597
+ yield next_token.cpu()
598
+
599
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
600
+
601
+ yield from decode_n_tokens_agent(
602
+ model,
603
+ next_token,
604
+ input_pos,
605
+ max_new_tokens - 1,
606
+ im_end_id=im_end_id,
607
+ semantic_ids=semantic_ids,
608
+ decode_one_token=decode_one_token,
609
+ early_stop_threshold=early_stop_threshold,
610
+ **sampling_kwargs,
611
+ )
612
+
613
+
278
614
  def encode_tokens(
279
615
  tokenizer,
280
616
  string,
@@ -283,75 +619,77 @@ def encode_tokens(
283
619
  num_codebooks=4,
284
620
  ):
285
621
  string = clean_text(string)
286
- string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
287
622
 
288
- new_tokens = tokenizer.encode(
289
- string,
290
- add_special_tokens=False,
291
- max_length=10**6,
292
- truncation=False,
623
+ messages = []
624
+ messages.append(
625
+ Message(
626
+ role="user",
627
+ parts=[TextPart(text=string)],
628
+ cal_loss=False,
629
+ )
293
630
  )
294
- tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
295
631
 
296
- # Codebooks
297
- zeros = (
298
- torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
299
- * CODEBOOK_PAD_TOKEN_ID
300
- )
301
- prompt = torch.cat((tokens, zeros), dim=0)
632
+ if prompt_tokens is not None:
633
+ if prompt_tokens.ndim == 3:
634
+ assert (
635
+ prompt_tokens.shape[0] == 1
636
+ ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
637
+ prompt_tokens = prompt_tokens[0]
302
638
 
303
- if prompt_tokens is None:
304
- return prompt
639
+ assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
305
640
 
306
- # Get prompt tokens
307
- if prompt_tokens.ndim == 3:
308
- assert (
309
- prompt_tokens.shape[0] == 1
310
- ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
311
- prompt_tokens = prompt_tokens[0]
641
+ if prompt_tokens.shape[0] > num_codebooks:
642
+ logger.warning(
643
+ f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
644
+ )
645
+ prompt_tokens = prompt_tokens[:num_codebooks]
312
646
 
313
- assert prompt_tokens.ndim == 2
314
- data = prompt_tokens + 1
647
+ vq_part = VQPart(codes=prompt_tokens.to(device))
315
648
 
316
- if prompt_tokens.shape[0] > num_codebooks:
317
- logger.warning(
318
- f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
649
+ messages.append(
650
+ Message(
651
+ role="assistant",
652
+ parts=[TextPart(text="<|voice|>"), vq_part],
653
+ cal_loss=False,
654
+ )
655
+ )
656
+ else:
657
+ messages.append(
658
+ Message(
659
+ role="assistant",
660
+ parts=[TextPart(text="<|voice|>")],
661
+ cal_loss=False,
662
+ add_im_end=False,
663
+ )
319
664
  )
320
- data = data[:num_codebooks]
321
-
322
- # Add pad token for each codebook
323
- data = torch.cat(
324
- (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
325
- dim=1,
326
- )
327
665
 
328
- # Since 1.0, we use <|semantic|>
329
- s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
330
- end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
331
- main_token_ids = (
332
- torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
666
+ conversation = Conversation(messages=messages)
667
+ # conversation.visualize(tokenizer)
668
+ encoded = conversation.encode_for_inference(
669
+ tokenizer=tokenizer,
670
+ num_codebooks=num_codebooks,
333
671
  )
334
- main_token_ids[0, -1] = end_token_id
335
-
336
- data = torch.cat((main_token_ids, data), dim=0)
337
- prompt = torch.cat((prompt, data), dim=1)
338
672
 
339
- return prompt
673
+ return encoded.to(device)
340
674
 
341
675
 
342
- def load_model(checkpoint_path, device, precision, compile=False):
676
+ def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
343
677
  model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
344
- checkpoint_path, load_weights=True
678
+ checkpoint_path, load_weights=True, is_agent=is_agent
345
679
  )
346
680
 
347
681
  model = model.to(device=device, dtype=precision)
348
682
  logger.info(f"Restored model from checkpoint")
349
683
 
350
684
  if isinstance(model, DualARTransformer):
351
- decode_one_token = decode_one_token_ar
685
+ decode_one_token = (
686
+ decode_one_token_ar_agent if is_agent else decode_one_token_ar
687
+ )
352
688
  logger.info("Using DualARTransformer")
353
689
  else:
354
- decode_one_token = decode_one_token_naive
690
+ decode_one_token = (
691
+ decode_one_token_naive_agent if is_agent else decode_one_token_naive
692
+ )
355
693
  logger.info("Using NaiveTransformer")
356
694
 
357
695
  if compile:
@@ -406,11 +744,26 @@ def generate_long(
406
744
 
407
745
  model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
408
746
  tokenizer = model.tokenizer
409
- im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
747
+ im_end_id = tokenizer.get_token_id("<|im_end|>")
410
748
 
411
749
  encoded = []
412
750
  texts = split_text(text, chunk_length) if iterative_prompt else [text]
413
- encoded_prompts = []
751
+ encoded_prompts = [
752
+ Conversation(
753
+ messages=[
754
+ Message(
755
+ role="system",
756
+ parts=[TextPart(text="Speak out the provided text.")],
757
+ cal_loss=False,
758
+ )
759
+ ]
760
+ )
761
+ .encode_for_inference(
762
+ tokenizer=tokenizer,
763
+ num_codebooks=model.config.num_codebooks,
764
+ )
765
+ .to(device)
766
+ ]
414
767
 
415
768
  if use_prompt:
416
769
  for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
@@ -489,7 +842,6 @@ def generate_long(
489
842
  model=model,
490
843
  prompt=cat_encoded,
491
844
  max_new_tokens=max_new_tokens,
492
- im_end_id=im_end_id,
493
845
  decode_one_token=decode_one_token,
494
846
  temperature=temperature,
495
847
  top_p=top_p,
@@ -519,12 +871,11 @@ def generate_long(
519
871
  )
520
872
 
521
873
  # Put the generated tokens
522
- # since there is <im_end> and <eos> tokens, we remove last 2 tokens
523
- codes = y[1:, prompt_length:-1].clone()
524
- codes = codes - 1
874
+ # since there is <im_end>, we remove last token
875
+ codes = y[1:, prompt_length + 1 :].clone()
525
876
  assert (codes >= 0).all(), f"Negative code found"
526
877
 
527
- decoded = y[:, prompt_length:-1].clone()
878
+ decoded = y[:, prompt_length:].clone()
528
879
  # But for global encoding, we should keep the <im_end> token
529
880
 
530
881
  global_encoded.append(decoded)
@@ -563,7 +914,9 @@ def launch_thread_safe_queue(
563
914
  )
564
915
  with torch.device(device):
565
916
  model.setup_caches(
566
- max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
917
+ max_batch_size=1,
918
+ max_seq_len=model.config.max_seq_len,
919
+ dtype=next(model.parameters()).dtype,
567
920
  )
568
921
  init_event.set()
569
922
 
@@ -591,6 +944,60 @@ def launch_thread_safe_queue(
591
944
  return input_queue
592
945
 
593
946
 
947
+ def launch_thread_safe_queue_agent(
948
+ checkpoint_path,
949
+ device,
950
+ precision,
951
+ compile: bool = False,
952
+ ):
953
+ input_queue = queue.Queue()
954
+ init_event = threading.Event()
955
+
956
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
957
+ config = BaseModelArgs.from_pretrained(checkpoint_path)
958
+
959
+ def worker():
960
+ model, decode_one_token = load_model(
961
+ checkpoint_path, device, precision, compile=compile, is_agent=True
962
+ )
963
+
964
+ with torch.device(device):
965
+ model.setup_caches(
966
+ max_batch_size=1,
967
+ max_seq_len=model.config.max_seq_len,
968
+ dtype=next(model.parameters()).dtype,
969
+ )
970
+ init_event.set()
971
+
972
+ while True:
973
+ item: GenerateRequest | None = input_queue.get()
974
+ if item is None:
975
+ break
976
+
977
+ kwargs = item.request
978
+ response_queue = item.response_queue
979
+
980
+ try:
981
+ for token in generate_agent(
982
+ model=model,
983
+ decode_one_token=decode_one_token,
984
+ **kwargs,
985
+ ):
986
+ response_queue.put(token)
987
+
988
+ response_queue.put("stop")
989
+ except Exception as e:
990
+ import traceback
991
+
992
+ logger.exception(f"Error in worker: {traceback.format_exc()}")
993
+ response_queue.put("error")
994
+
995
+ threading.Thread(target=worker, daemon=True).start()
996
+ init_event.wait()
997
+
998
+ return input_queue, tokenizer, config
999
+
1000
+
594
1001
  @click.command()
595
1002
  @click.option(
596
1003
  "--text",
@@ -650,7 +1057,12 @@ def main(
650
1057
  model, decode_one_token = load_model(
651
1058
  checkpoint_path, device, precision, compile=compile
652
1059
  )
653
-
1060
+ with torch.device(device):
1061
+ model.setup_caches(
1062
+ max_batch_size=1,
1063
+ max_seq_len=model.config.max_seq_len,
1064
+ dtype=next(model.parameters()).dtype,
1065
+ )
654
1066
  if torch.cuda.is_available():
655
1067
  torch.cuda.synchronize()
656
1068