xinference 0.16.3__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (373) hide show
  1. xinference/_compat.py +24 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +219 -77
  4. xinference/client/restful/restful_client.py +47 -2
  5. xinference/constants.py +1 -0
  6. xinference/core/chat_interface.py +6 -1
  7. xinference/core/model.py +124 -34
  8. xinference/core/supervisor.py +180 -12
  9. xinference/core/utils.py +73 -4
  10. xinference/core/worker.py +102 -4
  11. xinference/deploy/cmdline.py +3 -1
  12. xinference/deploy/test/test_cmdline.py +56 -0
  13. xinference/isolation.py +24 -0
  14. xinference/model/audio/__init__.py +12 -0
  15. xinference/model/audio/core.py +37 -4
  16. xinference/model/audio/cosyvoice.py +39 -6
  17. xinference/model/audio/f5tts.py +200 -0
  18. xinference/model/audio/f5tts_mlx.py +260 -0
  19. xinference/model/audio/fish_speech.py +70 -110
  20. xinference/model/audio/melotts.py +110 -0
  21. xinference/model/audio/model_spec.json +179 -3
  22. xinference/model/audio/model_spec_modelscope.json +27 -0
  23. xinference/model/audio/utils.py +32 -0
  24. xinference/model/audio/whisper.py +35 -10
  25. xinference/model/audio/whisper_mlx.py +208 -0
  26. xinference/model/embedding/core.py +322 -6
  27. xinference/model/embedding/model_spec.json +8 -1
  28. xinference/model/embedding/model_spec_modelscope.json +9 -1
  29. xinference/model/image/core.py +69 -1
  30. xinference/model/image/model_spec.json +145 -4
  31. xinference/model/image/model_spec_modelscope.json +150 -4
  32. xinference/model/image/stable_diffusion/core.py +50 -15
  33. xinference/model/llm/__init__.py +6 -2
  34. xinference/model/llm/llm_family.json +1055 -93
  35. xinference/model/llm/llm_family.py +15 -36
  36. xinference/model/llm/llm_family_modelscope.json +1031 -78
  37. xinference/model/llm/memory.py +1 -1
  38. xinference/model/llm/mlx/core.py +285 -47
  39. xinference/model/llm/sglang/core.py +2 -0
  40. xinference/model/llm/transformers/chatglm.py +9 -5
  41. xinference/model/llm/transformers/cogagent.py +272 -0
  42. xinference/model/llm/transformers/core.py +3 -0
  43. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  44. xinference/model/llm/transformers/qwen2_vl.py +12 -1
  45. xinference/model/llm/transformers/utils.py +16 -8
  46. xinference/model/llm/utils.py +55 -4
  47. xinference/model/llm/vllm/core.py +137 -12
  48. xinference/model/llm/vllm/xavier/__init__.py +13 -0
  49. xinference/model/llm/vllm/xavier/allocator.py +74 -0
  50. xinference/model/llm/vllm/xavier/block.py +111 -0
  51. xinference/model/llm/vllm/xavier/block_manager.py +71 -0
  52. xinference/model/llm/vllm/xavier/block_tracker.py +129 -0
  53. xinference/model/llm/vllm/xavier/collective.py +74 -0
  54. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  55. xinference/model/llm/vllm/xavier/engine.py +247 -0
  56. xinference/model/llm/vllm/xavier/executor.py +134 -0
  57. xinference/model/llm/vllm/xavier/scheduler.py +438 -0
  58. xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
  59. xinference/model/llm/vllm/xavier/test/test_xavier.py +147 -0
  60. xinference/model/llm/vllm/xavier/transfer.py +319 -0
  61. xinference/model/rerank/core.py +11 -4
  62. xinference/model/video/diffusers.py +14 -0
  63. xinference/model/video/model_spec.json +15 -0
  64. xinference/model/video/model_spec_modelscope.json +16 -0
  65. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  66. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  67. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  68. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  69. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  70. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  71. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  72. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  73. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  74. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  75. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  76. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  77. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  78. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  79. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  80. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  81. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  82. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  83. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  84. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  85. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  86. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  87. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  88. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  89. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  90. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  91. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  92. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  93. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  94. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  95. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  96. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  97. xinference/thirdparty/f5_tts/api.py +166 -0
  98. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  99. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  100. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  101. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  102. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  103. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  104. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  105. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  106. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  107. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  108. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  109. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  110. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  111. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  112. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  113. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  114. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  115. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  116. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  117. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  118. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  119. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  120. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  121. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  122. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  123. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  124. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  125. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  126. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  127. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  128. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  129. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  130. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  131. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  132. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  133. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  134. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  135. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  136. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  137. xinference/thirdparty/f5_tts/train/README.md +77 -0
  138. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  139. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  140. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  141. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  142. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  143. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  144. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  145. xinference/thirdparty/f5_tts/train/train.py +75 -0
  146. xinference/thirdparty/fish_speech/fish_speech/conversation.py +266 -1
  147. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  148. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  149. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  150. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  151. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  152. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +137 -29
  153. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  154. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  155. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +17 -11
  156. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  157. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  158. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  159. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  160. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  161. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  162. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +2 -2
  163. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +34 -18
  164. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  165. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  166. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  167. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  168. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  169. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  170. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  171. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  172. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  173. xinference/thirdparty/fish_speech/tools/llama/generate.py +484 -72
  174. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  175. xinference/thirdparty/fish_speech/tools/schema.py +170 -0
  176. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  177. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  178. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  179. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  180. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  181. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  182. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  183. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  184. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  185. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  186. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  187. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  188. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  189. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  190. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  191. xinference/thirdparty/matcha/utils/utils.py +2 -2
  192. xinference/thirdparty/melo/api.py +135 -0
  193. xinference/thirdparty/melo/app.py +61 -0
  194. xinference/thirdparty/melo/attentions.py +459 -0
  195. xinference/thirdparty/melo/commons.py +160 -0
  196. xinference/thirdparty/melo/configs/config.json +94 -0
  197. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  198. xinference/thirdparty/melo/data_utils.py +413 -0
  199. xinference/thirdparty/melo/download_utils.py +67 -0
  200. xinference/thirdparty/melo/infer.py +25 -0
  201. xinference/thirdparty/melo/init_downloads.py +14 -0
  202. xinference/thirdparty/melo/losses.py +58 -0
  203. xinference/thirdparty/melo/main.py +36 -0
  204. xinference/thirdparty/melo/mel_processing.py +174 -0
  205. xinference/thirdparty/melo/models.py +1030 -0
  206. xinference/thirdparty/melo/modules.py +598 -0
  207. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  208. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  209. xinference/thirdparty/melo/preprocess_text.py +135 -0
  210. xinference/thirdparty/melo/split_utils.py +174 -0
  211. xinference/thirdparty/melo/text/__init__.py +35 -0
  212. xinference/thirdparty/melo/text/chinese.py +199 -0
  213. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  214. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  215. xinference/thirdparty/melo/text/cleaner.py +36 -0
  216. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  217. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  218. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  219. xinference/thirdparty/melo/text/english.py +284 -0
  220. xinference/thirdparty/melo/text/english_bert.py +39 -0
  221. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  222. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  223. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  224. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  225. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  226. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  227. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  228. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  229. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  230. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  231. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  232. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  233. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  234. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  235. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  236. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  237. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  238. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  239. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  240. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  241. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  242. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  243. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  244. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  245. xinference/thirdparty/melo/text/french.py +94 -0
  246. xinference/thirdparty/melo/text/french_bert.py +39 -0
  247. xinference/thirdparty/melo/text/japanese.py +647 -0
  248. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  249. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  250. xinference/thirdparty/melo/text/korean.py +192 -0
  251. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  252. xinference/thirdparty/melo/text/spanish.py +122 -0
  253. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  254. xinference/thirdparty/melo/text/symbols.py +290 -0
  255. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  256. xinference/thirdparty/melo/train.py +635 -0
  257. xinference/thirdparty/melo/train.sh +19 -0
  258. xinference/thirdparty/melo/transforms.py +209 -0
  259. xinference/thirdparty/melo/utils.py +424 -0
  260. xinference/types.py +17 -1
  261. xinference/web/ui/build/asset-manifest.json +6 -6
  262. xinference/web/ui/build/index.html +1 -1
  263. xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
  264. xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
  265. xinference/web/ui/build/static/js/main.b0936c54.js +3 -0
  266. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  267. xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
  268. xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
  269. xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
  270. xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
  271. xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
  272. xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
  273. xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
  274. xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
  275. xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
  276. xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
  277. xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
  278. xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
  279. xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
  280. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
  281. xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
  282. xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
  283. xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
  284. xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
  285. xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
  286. xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
  287. xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
  288. xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
  289. xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
  290. xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
  291. xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
  292. xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
  293. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  294. xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
  295. xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
  296. xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
  297. xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
  298. xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
  299. xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
  300. xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
  301. xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
  302. xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
  303. xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
  304. xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
  305. xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
  306. xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
  307. xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
  308. xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
  309. xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
  310. xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
  311. xinference/web/ui/node_modules/.package-lock.json +67 -3
  312. xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
  313. xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
  314. xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
  315. xinference/web/ui/node_modules/i18next/package.json +129 -0
  316. xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
  317. xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
  318. xinference/web/ui/node_modules/react-i18next/package.json +162 -0
  319. xinference/web/ui/node_modules/void-elements/package.json +34 -0
  320. xinference/web/ui/package-lock.json +69 -3
  321. xinference/web/ui/package.json +2 -0
  322. xinference/web/ui/src/locales/en.json +186 -0
  323. xinference/web/ui/src/locales/zh.json +186 -0
  324. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/METADATA +96 -36
  325. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/RECORD +335 -146
  326. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/WHEEL +1 -1
  327. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  328. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  329. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  330. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  331. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  332. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  333. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  334. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  335. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  336. xinference/thirdparty/fish_speech/tools/api.py +0 -440
  337. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  338. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  339. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -34
  340. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  341. xinference/thirdparty/fish_speech/tools/webui.py +0 -485
  342. xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
  343. xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
  344. xinference/web/ui/build/static/js/main.2f269bb3.js +0 -3
  345. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  346. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
  347. xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
  348. xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
  349. xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
  350. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
  351. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
  352. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
  353. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
  354. xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
  355. xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
  356. xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
  357. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  358. xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
  359. xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
  360. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
  361. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
  362. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
  363. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
  364. xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
  365. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  366. /xinference/thirdparty/{cosyvoice/flow → melo}/__init__.py +0 -0
  367. /xinference/thirdparty/{cosyvoice/hifigan → melo/text/english_utils}/__init__.py +0 -0
  368. /xinference/thirdparty/{cosyvoice/llm → melo/text/es_phonemizer}/__init__.py +0 -0
  369. /xinference/thirdparty/{fish_speech/fish_speech/configs → melo/text/fr_phonemizer}/__init__.py +0 -0
  370. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  371. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/LICENSE +0 -0
  372. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/entry_points.txt +0 -0
  373. {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import dataclasses
1
2
  import json
2
3
  import math
3
4
  from collections import OrderedDict
@@ -15,7 +16,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel
15
16
  from torch.utils.checkpoint import checkpoint
16
17
  from transformers import AutoTokenizer
17
18
 
18
- from fish_speech.conversation import SEMANTIC_TOKEN
19
+ from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
19
20
  from fish_speech.utils import RankedLogger
20
21
 
21
22
  from .lora import LoraConfig, setup_lora
@@ -57,6 +58,11 @@ class BaseModelArgs:
57
58
  # Initialize the model
58
59
  initializer_range: float = 0.02
59
60
 
61
+ # Dummy vars
62
+ is_reward_model: bool = False
63
+ share_codebook_embeddings: bool = True
64
+ scale_codebook_embeddings: bool = False
65
+
60
66
  def __post_init__(self):
61
67
  if self.n_local_heads == -1:
62
68
  self.n_local_heads = self.n_head
@@ -100,6 +106,28 @@ class NaiveModelArgs(BaseModelArgs):
100
106
  class DualARModelArgs(BaseModelArgs):
101
107
  model_type: str = "dual_ar"
102
108
  n_fast_layer: int = 4
109
+ fast_dim: int | None = None
110
+ fast_n_head: int | None = None
111
+ fast_n_local_heads: int | None = None
112
+ fast_head_dim: int | None = None
113
+ fast_intermediate_size: int | None = None
114
+ fast_attention_qkv_bias: bool | None = None
115
+
116
+ def __post_init__(self):
117
+ super().__post_init__()
118
+
119
+ self.fast_dim = self.fast_dim or self.dim
120
+ self.fast_n_head = self.fast_n_head or self.n_head
121
+ self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
122
+ self.fast_head_dim = self.fast_head_dim or self.head_dim
123
+ self.fast_intermediate_size = (
124
+ self.fast_intermediate_size or self.intermediate_size
125
+ )
126
+ self.fast_attention_qkv_bias = (
127
+ self.fast_attention_qkv_bias
128
+ if self.fast_attention_qkv_bias is not None
129
+ else self.attention_qkv_bias
130
+ )
103
131
 
104
132
 
105
133
  class KVCache(nn.Module):
@@ -137,13 +165,17 @@ class BaseTransformerForwardResult:
137
165
 
138
166
  class BaseTransformer(nn.Module):
139
167
  def __init__(
140
- self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
168
+ self,
169
+ config: BaseModelArgs,
170
+ tokenizer: FishTokenizer | AutoTokenizer,
171
+ init_weights: bool = True,
141
172
  ) -> None:
142
173
  super().__init__()
143
174
  self.config = config
144
175
  self.tokenizer = tokenizer
145
-
146
- self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
176
+ self.semantic_token_ids = [
177
+ tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS
178
+ ]
147
179
 
148
180
  # Slow transformer
149
181
  self.embeddings = nn.Embedding(
@@ -218,8 +250,10 @@ class BaseTransformer(nn.Module):
218
250
  vocab_embeds = [self.embeddings(x[:, 0])]
219
251
  for i in range(self.config.num_codebooks):
220
252
  emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
221
- emb[x[:, 0] != self.semantic_token_id] = 0
222
- vocab_embeds.append(emb)
253
+ semantic_token_ids_tensor = torch.tensor(
254
+ self.semantic_token_ids, device=x.device
255
+ )
256
+ emb[~torch.isin(x[:, 0], semantic_token_ids_tensor)] = 0
223
257
 
224
258
  x = torch.stack(vocab_embeds, dim=3)
225
259
  x = x.sum(dim=3)
@@ -267,20 +301,45 @@ class BaseTransformer(nn.Module):
267
301
 
268
302
  def forward_generate(
269
303
  self,
270
- x: Tensor,
304
+ inp: Tensor,
271
305
  input_pos: Optional[Tensor] = None,
306
+ vq_masks: Optional[Tensor] = None, # this is not used in fact
272
307
  return_all: bool = False,
273
308
  ) -> BaseTransformerForwardResult:
274
309
  # This is used for generation, optimized for torch compile
275
- assert (
276
- self.max_seq_len != -1 and self.max_batch_size != -1
277
- ), "Please call setup_caches before forward_generate"
310
+ # assert (
311
+ # self.max_seq_len != -1 and self.max_batch_size != -1
312
+ # ), "Please call setup_caches before forward_generate"
313
+
314
+ embeds = []
315
+ for i in range(self.config.num_codebooks):
316
+ if self.config.share_codebook_embeddings:
317
+ _tokens = inp[:, i + 1] + i * self.config.codebook_size
318
+ else:
319
+ _tokens = inp[:, i + 1]
278
320
 
279
- x = self.embed(x)
321
+ emb = self.codebook_embeddings(_tokens)
322
+ embeds.append(emb)
280
323
 
281
- mask = self.causal_mask[
282
- None, None, input_pos, : self.max_seq_len
283
- ] # (B, N, Q, K)
324
+ vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
325
+ # if self.config.use_codebook_mlp:
326
+ # vq_embeds_sum = vq_embeds_sum / self.config.num_codebooks
327
+ # vq_embeds_sum = self.codebook_mlp(vq_embeds_sum)
328
+
329
+ vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & (
330
+ inp[:, 0] <= self.tokenizer.semantic_end_id
331
+ )
332
+
333
+ vq_embeds_sum[~vq_masks] = 0
334
+ x = self.embeddings(inp[:, 0]) + vq_embeds_sum
335
+
336
+ if input_pos is None:
337
+ input_pos = torch.arange(inp.shape[-1], device=x.device)
338
+ max_seq_len = inp.shape[-1]
339
+ else:
340
+ max_seq_len = self.max_seq_len
341
+
342
+ mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
284
343
  freqs_cis = self.freqs_cis[input_pos]
285
344
 
286
345
  for layer in self.layers:
@@ -293,7 +352,9 @@ class BaseTransformer(nn.Module):
293
352
  # We got slow_out here
294
353
  slow_out = self.norm(x)
295
354
 
296
- if self.config.tie_word_embeddings:
355
+ if self.config.is_reward_model:
356
+ token_logits = self.score_output(slow_out)
357
+ elif self.config.tie_word_embeddings:
297
358
  token_logits = F.linear(slow_out, self.embeddings.weight)
298
359
  else:
299
360
  token_logits = self.output(slow_out)
@@ -321,6 +382,7 @@ class BaseTransformer(nn.Module):
321
382
  max_length: int | None = None,
322
383
  lora_config: LoraConfig | None = None,
323
384
  rope_base: int | None = None,
385
+ is_agent: bool = False,
324
386
  ) -> "BaseTransformer":
325
387
  config = BaseModelArgs.from_pretrained(str(path))
326
388
  if max_length is not None:
@@ -339,7 +401,12 @@ class BaseTransformer(nn.Module):
339
401
  case _:
340
402
  raise ValueError(f"Unknown model type: {config.model_type}")
341
403
 
342
- tokenizer = AutoTokenizer.from_pretrained(str(path))
404
+ if is_agent:
405
+ tokenizer = AutoTokenizer.from_pretrained(str(path))
406
+ else:
407
+ tokenizer_path = str(path) + "/tokenizer.tiktoken"
408
+ tokenizer = FishTokenizer(tokenizer_path)
409
+
343
410
  log.info(f"Loading model from {path}, config: {config}")
344
411
  model = model_cls(config, tokenizer=tokenizer)
345
412
 
@@ -369,7 +436,10 @@ class BaseTransformer(nn.Module):
369
436
  model = simple_quantizer.convert_for_runtime()
370
437
 
371
438
  weights = torch.load(
372
- Path(path) / "model.pth", map_location="cpu", mmap=True
439
+ Path(path) / "model.pth",
440
+ map_location="cpu",
441
+ mmap=True,
442
+ weights_only=True,
373
443
  )
374
444
 
375
445
  if "state_dict" in weights:
@@ -422,7 +492,7 @@ class BaseTransformer(nn.Module):
422
492
 
423
493
 
424
494
  class NaiveTransformer(BaseTransformer):
425
- def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
495
+ def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
426
496
  super().__init__(config, init_weights=False, tokenizer=tokenizer)
427
497
 
428
498
  self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
@@ -468,23 +538,49 @@ class NaiveTransformer(BaseTransformer):
468
538
 
469
539
 
470
540
  class DualARTransformer(BaseTransformer):
471
- def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
541
+ def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
472
542
  super().__init__(config, init_weights=False, tokenizer=tokenizer)
473
543
 
544
+ # Project to fast dim if needed
545
+ if config.fast_dim is not None and config.fast_dim != config.dim:
546
+ self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
547
+ else:
548
+ self.fast_project_in = nn.Identity()
549
+
474
550
  # Fast transformer
475
- self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
551
+ self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
476
552
 
477
553
  # The equivalent bs is so large that sdpa doesn't work
554
+ override_config = dataclasses.replace(
555
+ config,
556
+ dim=config.fast_dim,
557
+ n_head=config.fast_n_head,
558
+ n_local_heads=config.fast_n_local_heads,
559
+ head_dim=config.fast_head_dim,
560
+ intermediate_size=config.fast_intermediate_size,
561
+ attention_qkv_bias=config.fast_attention_qkv_bias,
562
+ )
563
+
478
564
  self.fast_layers = nn.ModuleList(
479
- TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
565
+ TransformerBlock(override_config, use_sdpa=False)
566
+ for _ in range(config.n_fast_layer)
480
567
  )
481
- self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
568
+ self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
482
569
  self.fast_output = nn.Linear(
483
- config.dim,
570
+ config.fast_dim,
484
571
  config.codebook_size,
485
572
  bias=False,
486
573
  )
487
574
 
575
+ self.register_buffer(
576
+ "fast_freqs_cis",
577
+ precompute_freqs_cis(
578
+ config.num_codebooks,
579
+ config.fast_dim // config.fast_n_head,
580
+ config.rope_base,
581
+ ),
582
+ persistent=False,
583
+ )
488
584
  self.apply(self._init_weights)
489
585
 
490
586
  def setup_caches(
@@ -492,7 +588,7 @@ class DualARTransformer(BaseTransformer):
492
588
  ):
493
589
  super().setup_caches(max_batch_size, max_seq_len, dtype)
494
590
 
495
- head_dim = self.config.dim // self.config.n_head
591
+ head_dim = self.config.fast_dim // self.config.fast_n_head
496
592
 
497
593
  # Fast transformer
498
594
  # The max seq len here is the number of codebooks
@@ -500,7 +596,7 @@ class DualARTransformer(BaseTransformer):
500
596
  b.attention.kv_cache = KVCache(
501
597
  max_batch_size,
502
598
  self.config.num_codebooks,
503
- self.config.n_local_heads,
599
+ self.config.fast_n_local_heads,
504
600
  head_dim,
505
601
  dtype=dtype,
506
602
  )
@@ -513,13 +609,13 @@ class DualARTransformer(BaseTransformer):
513
609
  parent_result = super().forward(inp, key_padding_mask)
514
610
  token_logits = parent_result.logits
515
611
  x = parent_result.hidden_states
612
+ x = self.fast_project_in(x)
516
613
 
517
614
  # Fast transformer
518
615
  fast_seq_len = self.config.num_codebooks
519
616
  fast_mask = self.causal_mask[
520
617
  None, None, :fast_seq_len, :fast_seq_len
521
618
  ] # (B, N, Q, K)
522
- fast_freqs_cis = self.freqs_cis[:fast_seq_len]
523
619
 
524
620
  # Drop the last token and rotate left
525
621
  codebooks = inp[:, 1:-1, 1:]
@@ -542,9 +638,11 @@ class DualARTransformer(BaseTransformer):
542
638
 
543
639
  for layer in self.fast_layers:
544
640
  if self.config.use_gradient_checkpointing and self.training:
545
- x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
641
+ x = checkpoint(
642
+ layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
643
+ )
546
644
  else:
547
- x = layer(x, fast_freqs_cis, fast_mask)
645
+ x = layer(x, self.fast_freqs_cis, fast_mask)
548
646
 
549
647
  # unflatten the batch and num_codebooks
550
648
  fast_out = self.fast_norm(x)
@@ -584,7 +682,7 @@ class DualARTransformer(BaseTransformer):
584
682
  fast_mask = self.causal_mask[
585
683
  None, None, input_pos, : self.config.num_codebooks
586
684
  ] # (B, N, Q, K)
587
- fast_freqs_cis = self.freqs_cis[input_pos]
685
+ fast_freqs_cis = self.fast_freqs_cis[input_pos]
588
686
 
589
687
  for layer in self.fast_layers:
590
688
  x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
@@ -595,6 +693,16 @@ class DualARTransformer(BaseTransformer):
595
693
 
596
694
  return codebook_logits
597
695
 
696
+ def forward_generate(
697
+ self,
698
+ x: Tensor,
699
+ input_pos: Optional[Tensor] = None,
700
+ vq_masks: Optional[Tensor] = None,
701
+ ) -> TransformerForwardResult:
702
+ x = super().forward_generate(x, input_pos, vq_masks)
703
+ x.hidden_states = self.fast_project_in(x.hidden_states)
704
+ return x
705
+
598
706
 
599
707
  class TransformerBlock(nn.Module):
600
708
  def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
@@ -102,8 +102,8 @@ class FishConvNet(nn.Module):
102
102
  self.conv = weight_norm(self.conv, name=name, dim=dim)
103
103
  return self
104
104
 
105
- def remove_weight_norm(self):
106
- self.conv = remove_parametrizations(self.conv)
105
+ def remove_parametrizations(self, name="weight"):
106
+ self.conv = remove_parametrizations(self.conv, name)
107
107
  return self
108
108
 
109
109
 
@@ -128,8 +128,8 @@ class FishTransConvNet(nn.Module):
128
128
  self.conv = weight_norm(self.conv, name=name, dim=dim)
129
129
  return self
130
130
 
131
- def remove_weight_norm(self):
132
- self.conv = remove_parametrizations(self.conv)
131
+ def remove_parametrizations(self, name="weight"):
132
+ self.conv = remove_parametrizations(self.conv, name)
133
133
  return self
134
134
 
135
135
 
@@ -178,9 +178,9 @@ class ResBlock1(torch.nn.Module):
178
178
 
179
179
  def remove_parametrizations(self):
180
180
  for conv in self.convs1:
181
- remove_parametrizations(conv, tensor_name="weight")
181
+ conv.remove_parametrizations()
182
182
  for conv in self.convs2:
183
- remove_parametrizations(conv, tensor_name="weight")
183
+ conv.remove_parametrizations()
184
184
 
185
185
 
186
186
  class ParallelBlock(nn.Module):
@@ -288,11 +288,11 @@ class HiFiGANGenerator(nn.Module):
288
288
 
289
289
  def remove_parametrizations(self):
290
290
  for up in self.ups:
291
- remove_parametrizations(up, tensor_name="weight")
291
+ up.remove_parametrizations()
292
292
  for block in self.resblocks:
293
293
  block.remove_parametrizations()
294
- remove_parametrizations(self.conv_pre, tensor_name="weight")
295
- remove_parametrizations(self.conv_post, tensor_name="weight")
294
+ self.conv_pre.remove_parametrizations()
295
+ self.conv_post.remove_parametrizations()
296
296
 
297
297
 
298
298
  # DropPath copied from timm library
@@ -99,7 +99,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
99
99
  if diff > 0:
100
100
  result.z = F.pad(result.z, (left, right))
101
101
  elif diff < 0:
102
- result.z = result.z[..., left:-right]
102
+ result.z = result.z[..., -left:right]
103
103
 
104
104
  return result
105
105
 
@@ -1,19 +1,8 @@
1
1
  import re
2
2
 
3
3
  SYMBOLS_MAPPING = {
4
- "“": "'",
5
- "”": "'",
6
4
  "‘": "'",
7
5
  "’": "'",
8
- "【": "",
9
- "】": "",
10
- "[": "",
11
- "]": "",
12
- "(": "",
13
- ")": "",
14
- "(": "",
15
- ")": "",
16
- "・": "·",
17
6
  }
18
7
 
19
8
  REPLACE_SYMBOL_REGEX = re.compile(
@@ -21,6 +10,17 @@ REPLACE_SYMBOL_REGEX = re.compile(
21
10
  )
22
11
 
23
12
 
13
+ EMOJI_REGEX = re.compile(
14
+ "["
15
+ "\U0001F600-\U0001F64F" # emoticons
16
+ "\U0001F300-\U0001F5FF" # symbols & pictographs
17
+ "\U0001F680-\U0001F6FF" # transport & map symbols
18
+ "\U0001F1E0-\U0001F1FF" # flags (iOS)
19
+ "]+",
20
+ flags=re.UNICODE,
21
+ )
22
+
23
+
24
24
  def clean_text(text):
25
25
  # Clean the text
26
26
  text = text.strip()
@@ -28,4 +28,10 @@ def clean_text(text):
28
28
  # Replace all chinese symbols with their english counterparts
29
29
  text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
30
30
 
31
+ # Remove emojis
32
+ text = EMOJI_REGEX.sub(r"", text)
33
+
34
+ # Remove continuous periods (...) and commas (,,,)
35
+ text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text)
36
+
31
37
  return text
@@ -4,7 +4,7 @@ import string
4
4
  from fish_speech.text.clean import clean_text
5
5
 
6
6
 
7
- def utf_8_len(text):
7
+ def utf_8_len(text: str):
8
8
  return len(text.encode("utf-8"))
9
9
 
10
10
 
@@ -0,0 +1,152 @@
1
+ import base64
2
+ import json
3
+ import logging
4
+ from pathlib import Path
5
+
6
+ import tiktoken
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # This is a modified version of the default pattern from GPT-4o, that better handles punctuations.
11
+ FISH_TIKTOKEN_PATTERN = "|".join(
12
+ [
13
+ r"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
14
+ r"\p{P}",
15
+ r"[^\r\n\p{L}\p{N}]?\p{L}+",
16
+ r"\p{N}",
17
+ r" ?[^\s\p{L}\p{N}]+[\r\n]*",
18
+ r"\s*[\r\n]+",
19
+ r"\s+(\?!\S)",
20
+ r"\s+",
21
+ ]
22
+ )
23
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
24
+
25
+ BOS_TOKEN = "<|begin_of_text|>"
26
+ EOS_TOKEN = "<|end_of_text|>"
27
+ PAD_TOKEN = "<|pad|>"
28
+ IM_START_TOKEN = "<|im_start|>"
29
+ IM_END_TOKEN = "<|im_end|>"
30
+
31
+ MODALITY_TEXT_TOKEN = "<|text|>"
32
+ MODALITY_VOICE_TOKEN = "<|voice|>"
33
+ MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
34
+ MODALITY_TOKENS = {
35
+ "text": MODALITY_TEXT_TOKEN,
36
+ "voice": MODALITY_VOICE_TOKEN,
37
+ "interleave": MODALITY_INTERLEAVE_TOKEN,
38
+ }
39
+
40
+ PLACEHOLDER_TOKEN = [""] * 4
41
+ for i in range(4):
42
+ PLACEHOLDER_TOKEN[i] = f"<|placeholder:{i}|>"
43
+
44
+ SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
45
+ SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)]
46
+
47
+ # Warning: when you add a new special token, you should only add it to the end of the list.
48
+ ALL_SPECIAL_TOKENS = [
49
+ BOS_TOKEN,
50
+ EOS_TOKEN,
51
+ PAD_TOKEN,
52
+ IM_START_TOKEN,
53
+ IM_END_TOKEN,
54
+ PLACEHOLDER_TOKEN[0],
55
+ PLACEHOLDER_TOKEN[1],
56
+ PLACEHOLDER_TOKEN[2],
57
+ PLACEHOLDER_TOKEN[3],
58
+ MODALITY_TEXT_TOKEN,
59
+ MODALITY_VOICE_TOKEN,
60
+ MODALITY_INTERLEAVE_TOKEN,
61
+ *SEMANTIC_TOKENS,
62
+ ]
63
+
64
+
65
+ class FishTokenizer:
66
+ def __init__(self, model_path: str) -> None:
67
+ mergeable_ranks = self.load_tiktoken_bpe(model_path)
68
+ special_token_begin = len(mergeable_ranks)
69
+ self.all_special_tokens_with_ids = {
70
+ token: special_token_begin + i for i, token in enumerate(ALL_SPECIAL_TOKENS)
71
+ }
72
+ self.semantic_id_to_token_id = {
73
+ i: self.all_special_tokens_with_ids[token]
74
+ for i, token in enumerate(SEMANTIC_TOKENS)
75
+ }
76
+ self.semantic_begin_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[0]]
77
+ self.semantic_end_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[-1]]
78
+
79
+ self.tkt_model = tiktoken.core.Encoding(
80
+ name=Path(model_path).stem,
81
+ pat_str=FISH_TIKTOKEN_PATTERN,
82
+ mergeable_ranks=mergeable_ranks,
83
+ special_tokens=self.all_special_tokens_with_ids,
84
+ )
85
+
86
+ @staticmethod
87
+ def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
88
+ data = {}
89
+ for line in open(tiktoken_bpe_file).read().splitlines():
90
+ if not line:
91
+ continue
92
+ token, rank = line.split()
93
+ data[base64.b64decode(token)] = int(rank)
94
+ return data
95
+
96
+ def get_token_id(self, token: str) -> int:
97
+ return self.all_special_tokens_with_ids[token]
98
+
99
+ def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]:
100
+ assert isinstance(s, str)
101
+
102
+ subs = []
103
+ for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
104
+ subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
105
+
106
+ if allowed_special is True:
107
+ allowed_special = self.tkt_model.special_tokens_set
108
+ elif allowed_special is False:
109
+ allowed_special = set()
110
+
111
+ return sum(
112
+ self.tkt_model.encode_batch(
113
+ subs, allowed_special=allowed_special, disallowed_special=set()
114
+ ),
115
+ start=[],
116
+ )
117
+
118
+ def decode(self, tokens: list[int]) -> str:
119
+ return self.tkt_model.decode(tokens)
120
+
121
+ def save_pretrained(self, path: str):
122
+ path = Path(path)
123
+ path.mkdir(parents=True, exist_ok=True)
124
+
125
+ with open(path / "tokenizer.tiktoken", "w") as f:
126
+ for token, rank in self.tkt_model._mergeable_ranks.items():
127
+ f.write(f"{base64.b64encode(token).decode()} {rank}\n")
128
+
129
+ with open(path / "special_tokens.json", "w") as f:
130
+ json.dump(
131
+ self.all_special_tokens_with_ids,
132
+ f,
133
+ indent=2,
134
+ ensure_ascii=False,
135
+ )
136
+
137
+ @staticmethod
138
+ def from_pretrained(path: str):
139
+ return FishTokenizer(Path(path) / "tokenizer.tiktoken")
140
+
141
+
142
+ if __name__ == "__main__":
143
+ tokenizer = FishTokenizer("data/mpacks/v1.4-pretrain/tokenizer.all.tiktoken")
144
+ tokenizer.save_pretrained("checkpoints/fish-speech-0.5B")
145
+ tokenizer = FishTokenizer.from_pretrained("checkpoints/fish-speech-0.5B")
146
+
147
+ print(
148
+ [
149
+ tokenizer.decode([i])
150
+ for i in tokenizer.encode(f"{BOS_TOKEN}你好,世界!{EOS_TOKEN}")
151
+ ]
152
+ )
@@ -6,7 +6,7 @@ from typing import Optional
6
6
 
7
7
  import hydra
8
8
  import lightning as L
9
- # import pyrootutils
9
+ import pyrootutils
10
10
  import torch
11
11
  from lightning import Callback, LightningDataModule, LightningModule, Trainer
12
12
  from lightning.pytorch.loggers import Logger
@@ -18,7 +18,7 @@ os.environ.pop("SLURM_JOB_NAME", None)
18
18
  os.environ.pop("SLURM_NTASKS_PER_NODE", None)
19
19
 
20
20
  # register eval resolver and root
21
- # pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
21
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
22
22
 
23
23
  # Allow TF32 on Ampere GPUs
24
24
  torch.set_float32_matmul_precision("high")
@@ -5,7 +5,7 @@ from .instantiators import instantiate_callbacks, instantiate_loggers
5
5
  from .logger import RankedLogger
6
6
  from .logging_utils import log_hyperparameters
7
7
  from .rich_utils import enforce_tags, print_config_tree
8
- from .utils import extras, get_metric_value, task_wrapper
8
+ from .utils import extras, get_metric_value, set_seed, task_wrapper
9
9
 
10
10
  __all__ = [
11
11
  "enforce_tags",
@@ -20,4 +20,5 @@ __all__ = [
20
20
  "braceexpand",
21
21
  "get_latest_checkpoint",
22
22
  "autocast_exclude_mps",
23
+ "set_seed",
23
24
  ]
@@ -1,7 +1,10 @@
1
+ import random
1
2
  import warnings
2
3
  from importlib.util import find_spec
3
4
  from typing import Callable
4
5
 
6
+ import numpy as np
7
+ import torch
5
8
  from omegaconf import DictConfig
6
9
 
7
10
  from .logger import RankedLogger
@@ -112,3 +115,22 @@ def get_metric_value(metric_dict: dict, metric_name: str) -> float:
112
115
  log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
113
116
 
114
117
  return metric_value
118
+
119
+
120
+ def set_seed(seed: int):
121
+ if seed < 0:
122
+ seed = -seed
123
+ if seed > (1 << 31):
124
+ seed = 1 << 31
125
+
126
+ random.seed(seed)
127
+ np.random.seed(seed)
128
+ torch.manual_seed(seed)
129
+
130
+ if torch.cuda.is_available():
131
+ torch.cuda.manual_seed(seed)
132
+ torch.cuda.manual_seed_all(seed)
133
+
134
+ if torch.backends.cudnn.is_available():
135
+ torch.backends.cudnn.deterministic = True
136
+ torch.backends.cudnn.benchmark = False
@@ -114,7 +114,7 @@ class Seafoam(Base):
114
114
  block_title_text_weight="600",
115
115
  block_border_width="3px",
116
116
  block_shadow="*shadow_drop_lg",
117
- button_shadow="*shadow_drop_lg",
117
+ # button_shadow="*shadow_drop_lg",
118
118
  button_small_padding="0px",
119
119
  button_large_padding="3px",
120
120
  )
@@ -176,7 +176,7 @@ def change_infer(
176
176
  p_infer = subprocess.Popen(
177
177
  [
178
178
  PYTHON,
179
- "tools/webui.py",
179
+ "tools/run_webui.py",
180
180
  "--decoder-checkpoint-path",
181
181
  infer_decoder_model,
182
182
  "--decoder-config-name",
@@ -794,7 +794,7 @@ with gr.Blocks(
794
794
  value="VQGAN",
795
795
  )
796
796
  with gr.Row():
797
- with gr.Tabs():
797
+ with gr.Column():
798
798
  with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
799
799
  gr.HTML("You don't need to train this model!")
800
800