xinference 1.0.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (343) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +77 -71
  4. xinference/core/chat_interface.py +6 -1
  5. xinference/core/model.py +79 -19
  6. xinference/core/supervisor.py +172 -10
  7. xinference/core/utils.py +12 -8
  8. xinference/core/worker.py +102 -4
  9. xinference/deploy/cmdline.py +3 -1
  10. xinference/deploy/test/test_cmdline.py +56 -0
  11. xinference/isolation.py +24 -0
  12. xinference/model/audio/core.py +16 -0
  13. xinference/model/audio/cosyvoice.py +39 -6
  14. xinference/model/audio/f5tts.py +200 -0
  15. xinference/model/audio/f5tts_mlx.py +260 -0
  16. xinference/model/audio/fish_speech.py +36 -111
  17. xinference/model/audio/melotts.py +110 -0
  18. xinference/model/audio/model_spec.json +99 -3
  19. xinference/model/audio/model_spec_modelscope.json +27 -0
  20. xinference/model/audio/utils.py +32 -0
  21. xinference/model/audio/whisper.py +35 -10
  22. xinference/model/embedding/core.py +203 -142
  23. xinference/model/embedding/model_spec.json +7 -0
  24. xinference/model/embedding/model_spec_modelscope.json +8 -0
  25. xinference/model/image/core.py +69 -1
  26. xinference/model/image/model_spec.json +145 -4
  27. xinference/model/image/model_spec_modelscope.json +150 -4
  28. xinference/model/image/stable_diffusion/core.py +45 -13
  29. xinference/model/llm/__init__.py +4 -2
  30. xinference/model/llm/llm_family.json +536 -53
  31. xinference/model/llm/llm_family.py +15 -36
  32. xinference/model/llm/llm_family_modelscope.json +454 -20
  33. xinference/model/llm/memory.py +1 -1
  34. xinference/model/llm/mlx/core.py +248 -52
  35. xinference/model/llm/sglang/core.py +1 -0
  36. xinference/model/llm/transformers/chatglm.py +9 -5
  37. xinference/model/llm/transformers/cogagent.py +272 -0
  38. xinference/model/llm/transformers/core.py +2 -0
  39. xinference/model/llm/transformers/qwen2_vl.py +12 -1
  40. xinference/model/llm/transformers/utils.py +16 -8
  41. xinference/model/llm/utils.py +36 -4
  42. xinference/model/llm/vllm/core.py +53 -10
  43. xinference/model/llm/vllm/xavier/__init__.py +13 -0
  44. xinference/model/llm/vllm/xavier/allocator.py +74 -0
  45. xinference/model/llm/vllm/xavier/block.py +111 -0
  46. xinference/model/llm/vllm/xavier/block_manager.py +71 -0
  47. xinference/model/llm/vllm/xavier/block_tracker.py +129 -0
  48. xinference/model/llm/vllm/xavier/collective.py +74 -0
  49. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  50. xinference/model/llm/vllm/xavier/engine.py +247 -0
  51. xinference/model/llm/vllm/xavier/executor.py +134 -0
  52. xinference/model/llm/vllm/xavier/scheduler.py +438 -0
  53. xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
  54. xinference/model/llm/vllm/xavier/test/test_xavier.py +147 -0
  55. xinference/model/llm/vllm/xavier/transfer.py +319 -0
  56. xinference/model/video/diffusers.py +14 -0
  57. xinference/model/video/model_spec.json +15 -0
  58. xinference/model/video/model_spec_modelscope.json +16 -0
  59. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  60. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  61. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  62. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  63. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  64. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  65. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  66. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  67. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  68. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  69. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  70. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  71. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  72. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  73. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  74. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  75. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  76. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  77. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  78. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  79. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  80. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  81. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  82. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  83. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  84. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  85. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  86. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  87. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  88. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  89. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  90. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  91. xinference/thirdparty/f5_tts/api.py +166 -0
  92. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  93. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  94. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  95. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  96. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  97. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  98. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  99. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  100. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  101. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  102. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  103. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  104. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  105. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  106. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  107. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  108. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  109. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  110. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  111. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  112. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  113. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  114. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  115. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  116. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  117. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  118. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  119. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  120. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  121. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  122. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  123. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  124. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  125. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  126. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  127. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  128. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  129. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  130. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  131. xinference/thirdparty/f5_tts/train/README.md +77 -0
  132. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  133. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  134. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  135. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  136. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  137. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  138. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  139. xinference/thirdparty/f5_tts/train/train.py +75 -0
  140. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  141. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  142. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  143. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  144. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  145. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  146. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  147. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  148. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  149. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  150. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  151. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  152. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  153. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  154. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  155. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  156. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  157. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  158. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  159. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  160. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  161. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  162. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  163. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  164. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  165. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  166. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  167. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  168. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  169. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  170. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  171. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  172. xinference/thirdparty/matcha/utils/utils.py +2 -2
  173. xinference/thirdparty/melo/api.py +135 -0
  174. xinference/thirdparty/melo/app.py +61 -0
  175. xinference/thirdparty/melo/attentions.py +459 -0
  176. xinference/thirdparty/melo/commons.py +160 -0
  177. xinference/thirdparty/melo/configs/config.json +94 -0
  178. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  179. xinference/thirdparty/melo/data_utils.py +413 -0
  180. xinference/thirdparty/melo/download_utils.py +67 -0
  181. xinference/thirdparty/melo/infer.py +25 -0
  182. xinference/thirdparty/melo/init_downloads.py +14 -0
  183. xinference/thirdparty/melo/losses.py +58 -0
  184. xinference/thirdparty/melo/main.py +36 -0
  185. xinference/thirdparty/melo/mel_processing.py +174 -0
  186. xinference/thirdparty/melo/models.py +1030 -0
  187. xinference/thirdparty/melo/modules.py +598 -0
  188. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  189. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  190. xinference/thirdparty/melo/preprocess_text.py +135 -0
  191. xinference/thirdparty/melo/split_utils.py +174 -0
  192. xinference/thirdparty/melo/text/__init__.py +35 -0
  193. xinference/thirdparty/melo/text/chinese.py +199 -0
  194. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  195. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  196. xinference/thirdparty/melo/text/cleaner.py +36 -0
  197. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  198. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  199. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  200. xinference/thirdparty/melo/text/english.py +284 -0
  201. xinference/thirdparty/melo/text/english_bert.py +39 -0
  202. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  203. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  204. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  205. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  206. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  207. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  208. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  209. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  210. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  211. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  212. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  213. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  214. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  215. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  216. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  217. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  218. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  219. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  220. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  221. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  222. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  223. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  224. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  225. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  226. xinference/thirdparty/melo/text/french.py +94 -0
  227. xinference/thirdparty/melo/text/french_bert.py +39 -0
  228. xinference/thirdparty/melo/text/japanese.py +647 -0
  229. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  230. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  231. xinference/thirdparty/melo/text/korean.py +192 -0
  232. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  233. xinference/thirdparty/melo/text/spanish.py +122 -0
  234. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  235. xinference/thirdparty/melo/text/symbols.py +290 -0
  236. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  237. xinference/thirdparty/melo/train.py +635 -0
  238. xinference/thirdparty/melo/train.sh +19 -0
  239. xinference/thirdparty/melo/transforms.py +209 -0
  240. xinference/thirdparty/melo/utils.py +424 -0
  241. xinference/types.py +15 -0
  242. xinference/web/ui/build/asset-manifest.json +6 -6
  243. xinference/web/ui/build/index.html +1 -1
  244. xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
  245. xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
  246. xinference/web/ui/build/static/js/main.b0936c54.js +3 -0
  247. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  248. xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
  249. xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
  250. xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
  251. xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
  252. xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
  253. xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
  254. xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
  255. xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
  256. xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
  257. xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
  258. xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
  259. xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
  260. xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
  261. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
  262. xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
  263. xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
  264. xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
  265. xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
  266. xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
  267. xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
  268. xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
  269. xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
  270. xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
  271. xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
  272. xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
  273. xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
  274. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  275. xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
  276. xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
  277. xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
  278. xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
  279. xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
  280. xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
  281. xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
  282. xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
  283. xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
  284. xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
  285. xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
  286. xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
  287. xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
  288. xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
  289. xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
  290. xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
  291. xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
  292. xinference/web/ui/node_modules/.package-lock.json +67 -3
  293. xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
  294. xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
  295. xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
  296. xinference/web/ui/node_modules/i18next/package.json +129 -0
  297. xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
  298. xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
  299. xinference/web/ui/node_modules/react-i18next/package.json +162 -0
  300. xinference/web/ui/node_modules/void-elements/package.json +34 -0
  301. xinference/web/ui/package-lock.json +69 -3
  302. xinference/web/ui/package.json +2 -0
  303. xinference/web/ui/src/locales/en.json +186 -0
  304. xinference/web/ui/src/locales/zh.json +186 -0
  305. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/METADATA +68 -32
  306. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/RECORD +316 -122
  307. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  308. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  309. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  310. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  311. xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
  312. xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
  313. xinference/web/ui/build/static/js/main.2f269bb3.js +0 -3
  314. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  315. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
  316. xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
  317. xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
  318. xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
  319. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
  320. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
  321. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
  322. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
  323. xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
  324. xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
  325. xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
  326. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  327. xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
  328. xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
  329. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
  330. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
  331. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
  332. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
  333. xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
  334. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  335. /xinference/thirdparty/{cosyvoice/flow → melo}/__init__.py +0 -0
  336. /xinference/thirdparty/{cosyvoice/hifigan → melo/text/english_utils}/__init__.py +0 -0
  337. /xinference/thirdparty/{cosyvoice/llm → melo/text/es_phonemizer}/__init__.py +0 -0
  338. /xinference/thirdparty/{fish_speech/tools → melo/text/fr_phonemizer}/__init__.py +0 -0
  339. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  340. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/LICENSE +0 -0
  341. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/WHEEL +0 -0
  342. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/entry_points.txt +0 -0
  343. {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,75 @@
1
+ # training script.
2
+
3
+ import os
4
+ from importlib.resources import files
5
+
6
+ import hydra
7
+
8
+ from f5_tts.model import CFM, DiT, Trainer, UNetT
9
+ from f5_tts.model.dataset import load_dataset
10
+ from f5_tts.model.utils import get_tokenizer
11
+
12
+ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
13
+
14
+
15
+ @hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
16
+ def main(cfg):
17
+ tokenizer = cfg.model.tokenizer
18
+ mel_spec_type = cfg.model.mel_spec.mel_spec_type
19
+ exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
20
+
21
+ # set text tokenizer
22
+ if tokenizer != "custom":
23
+ tokenizer_path = cfg.datasets.name
24
+ else:
25
+ tokenizer_path = cfg.model.tokenizer_path
26
+ vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
27
+
28
+ # set model
29
+ if "F5TTS" in cfg.model.name:
30
+ model_cls = DiT
31
+ elif "E2TTS" in cfg.model.name:
32
+ model_cls = UNetT
33
+ wandb_resume_id = None
34
+
35
+ model = CFM(
36
+ transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
37
+ mel_spec_kwargs=cfg.model.mel_spec,
38
+ vocab_char_map=vocab_char_map,
39
+ )
40
+
41
+ # init trainer
42
+ trainer = Trainer(
43
+ model,
44
+ epochs=cfg.optim.epochs,
45
+ learning_rate=cfg.optim.learning_rate,
46
+ num_warmup_updates=cfg.optim.num_warmup_updates,
47
+ save_per_updates=cfg.ckpts.save_per_updates,
48
+ checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
49
+ batch_size=cfg.datasets.batch_size_per_gpu,
50
+ batch_size_type=cfg.datasets.batch_size_type,
51
+ max_samples=cfg.datasets.max_samples,
52
+ grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
53
+ max_grad_norm=cfg.optim.max_grad_norm,
54
+ logger=cfg.ckpts.logger,
55
+ wandb_project="CFM-TTS",
56
+ wandb_run_name=exp_name,
57
+ wandb_resume_id=wandb_resume_id,
58
+ last_per_steps=cfg.ckpts.last_per_steps,
59
+ log_samples=True,
60
+ bnb_optimizer=cfg.optim.bnb_optimizer,
61
+ mel_spec_type=mel_spec_type,
62
+ is_local_vocoder=cfg.model.vocoder.is_local,
63
+ local_vocoder_path=cfg.model.vocoder.local_path,
64
+ )
65
+
66
+ train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
67
+ trainer.train(
68
+ train_dataset,
69
+ num_workers=cfg.datasets.num_workers,
70
+ resumable_with_seed=666, # seed for shuffling dataset
71
+ )
72
+
73
+
74
+ if __name__ == "__main__":
75
+ main()
@@ -2,41 +2,10 @@ from dataclasses import dataclass, field
2
2
  from typing import Literal
3
3
 
4
4
  import torch
5
- from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast
6
-
7
- IM_START_TOKEN = "<|im_start|>"
8
- IM_END_TOKEN = "<|im_end|>"
9
- SEMANTIC_TOKEN = "<|semantic|>"
10
- MEL_TOKEN = "<|mel|>"
11
- PHONEME_START_TOKEN = "<|phoneme_start|>"
12
- PHONEME_END_TOKEN = "<|phoneme_end|>"
13
- ALL_SPECIAL_TOKENS = [
14
- IM_START_TOKEN,
15
- IM_END_TOKEN,
16
- SEMANTIC_TOKEN,
17
- MEL_TOKEN,
18
- PHONEME_START_TOKEN,
19
- PHONEME_END_TOKEN,
20
- ]
21
-
22
- CODEBOOK_PAD_TOKEN_ID = 0
23
-
24
-
25
- class FishTokenizerConfig(PretrainedConfig):
26
- share_codebook_embeddings: bool = True
27
- codebook_size: int = 1024
28
- num_codebooks: int = 8
29
5
 
6
+ from .tokenizer import MODALITY_TOKENS, FishTokenizer
30
7
 
31
- class FishTokenizerFast(PreTrainedTokenizerFast):
32
- def __init__(self, *args, **kwargs):
33
- super().__init__(*args, **kwargs)
34
- self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True)
35
- self.codebook_size = kwargs.pop("codebook_size", 1024)
36
- self.num_codebooks = kwargs.pop("num_codebooks", 8)
37
-
38
-
39
- AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast)
8
+ CODEBOOK_PAD_TOKEN_ID = 0
40
9
 
41
10
 
42
11
  @dataclass(kw_only=True)
@@ -54,77 +23,72 @@ class TextPart(BasePart):
54
23
  text: str
55
24
 
56
25
 
57
- @dataclass(kw_only=True)
58
- class MelPart(BasePart):
59
- mels: torch.Tensor
60
-
61
-
62
26
  @dataclass(kw_only=True)
63
27
  class EncodedMessage:
64
28
  tokens: torch.Tensor
65
29
  labels: torch.Tensor
30
+ vq_mask_tokens: torch.Tensor | None = None
31
+ vq_mask_labels: torch.Tensor | None = None
66
32
  vq_parts: list[torch.Tensor]
67
- mel_parts: list[torch.Tensor]
68
33
  vq_require_losses: torch.Tensor | None = None
69
34
 
70
35
 
71
36
  @dataclass(kw_only=True)
72
37
  class Message:
73
38
  role: Literal["system", "user", "assistant"]
74
- parts: list[VQPart | TextPart | MelPart] = field(default_factory=list)
39
+ parts: list[VQPart | TextPart] = field(default_factory=list)
75
40
  add_im_start: bool = True
76
41
  add_im_end: bool = True
77
42
  cal_loss: bool = False
43
+ modality: Literal["text", "voice", "interleave"] | None = None
78
44
 
79
45
  # By default, ignore the loss of the auto-generated im_start token
80
46
  ignore_im_start_loss: bool = True
81
47
 
82
48
  def encode(
83
49
  self: "Message",
84
- tokenizer: AutoTokenizer,
50
+ tokenizer: FishTokenizer,
85
51
  ) -> EncodedMessage:
86
52
  all_tokens = []
87
53
  all_labels = []
88
54
 
89
55
  # Multi-modal tokens
90
56
  vq_parts = []
91
- mel_parts = []
92
-
93
- semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
94
- [SEMANTIC_TOKEN, MEL_TOKEN]
95
- )
57
+ vq_masks = []
96
58
 
97
59
  parts = self.parts.copy()
98
60
  if self.add_im_start:
99
- parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n"))
61
+ modality_token = MODALITY_TOKENS[self.modality] if self.modality else ""
62
+ parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n{modality_token}"))
100
63
 
101
64
  if self.add_im_end:
102
65
  parts.append(TextPart(text="<|im_end|>"))
103
66
 
104
67
  for part in parts:
105
68
  if isinstance(part, TextPart):
106
- tokens = tokenizer.encode(
107
- part.text,
108
- add_special_tokens=False,
109
- truncation=False,
110
- return_tensors="pt",
111
- ).int()[0]
69
+ tokens = torch.tensor(
70
+ tokenizer.encode(part.text),
71
+ dtype=torch.int,
72
+ )
112
73
  elif isinstance(part, VQPart):
113
- tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id
114
- codes = part.codes.clone() + 1
115
-
116
- if getattr(tokenizer, "share_codebook_embeddings", True) is False:
117
- for i in range(len(codes)):
118
- codes[i] += tokenizer.codebook_size * i
119
-
120
- vq_parts.append(codes)
121
- elif isinstance(part, MelPart):
122
- tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id
123
- mel_parts.append(part.mels)
74
+ curr_codes = part.codes.clone()
75
+ tokens = torch.tensor(
76
+ [
77
+ tokenizer.semantic_id_to_token_id[i.item()]
78
+ for i in curr_codes[0].int()
79
+ ],
80
+ dtype=torch.int,
81
+ )
82
+ vq_parts.append(curr_codes)
124
83
  else:
125
84
  raise ValueError(f"Unsupported part type: {type(part)}")
126
85
 
127
86
  all_tokens.append(tokens)
87
+ if isinstance(part, VQPart):
88
+ vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
89
+ else:
90
+ vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
91
+
128
92
  if self.cal_loss:
129
93
  all_labels.append(tokens.clone())
130
94
  else:
@@ -132,7 +96,9 @@ class Message:
132
96
 
133
97
  tokens = torch.cat(all_tokens, dim=0)
134
98
  labels = torch.cat(all_labels, dim=0)
135
- assert tokens.shape == labels.shape
99
+ vq_masks = torch.cat(vq_masks, dim=0)
100
+
101
+ assert tokens.shape == labels.shape == vq_masks.shape
136
102
 
137
103
  if self.ignore_im_start_loss and self.add_im_start:
138
104
  labels[: len(all_tokens[0])] = -100
@@ -141,7 +107,8 @@ class Message:
141
107
  tokens=tokens,
142
108
  labels=labels,
143
109
  vq_parts=vq_parts,
144
- mel_parts=mel_parts,
110
+ vq_mask_tokens=vq_masks,
111
+ vq_mask_labels=vq_masks,
145
112
  )
146
113
 
147
114
 
@@ -149,17 +116,23 @@ class Message:
149
116
  class Conversation:
150
117
  messages: list[Message]
151
118
 
119
+ def __init__(self: "Conversation", messages: list[Message] | None = None):
120
+ self.messages = messages or []
121
+
152
122
  def encode(
153
123
  self: "Conversation",
154
- tokenizer: AutoTokenizer,
124
+ tokenizer: FishTokenizer,
155
125
  add_shift: bool = True,
126
+ ignore_loss_tokens: list[str] = [],
156
127
  ) -> EncodedMessage:
157
128
  # Build the input_ids and labels
158
129
  tokens = []
159
130
  labels = []
160
131
  vq_parts = []
161
- mel_parts = []
132
+ vq_mask_tokens = []
133
+ vq_mask_labels = []
162
134
  vq_require_losses = []
135
+ ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens]
163
136
 
164
137
  for message in self.messages:
165
138
  encoded = message.encode(
@@ -168,16 +141,25 @@ class Conversation:
168
141
  tokens.append(encoded.tokens)
169
142
  labels.append(encoded.labels)
170
143
  vq_parts.extend(encoded.vq_parts)
171
- mel_parts.extend(encoded.mel_parts)
144
+ vq_mask_tokens.append(encoded.vq_mask_tokens)
145
+ vq_mask_labels.append(encoded.vq_mask_labels)
172
146
  vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
173
147
 
174
148
  tokens = torch.cat(tokens, dim=0)
175
149
  labels = torch.cat(labels, dim=0)
150
+ vq_mask_tokens = torch.cat(vq_mask_tokens, dim=0)
151
+ vq_mask_labels = torch.cat(vq_mask_labels, dim=0)
176
152
  vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
177
153
 
178
154
  if add_shift:
179
155
  tokens = tokens[:-1]
180
156
  labels = labels[1:]
157
+ vq_mask_tokens = vq_mask_tokens[:-1]
158
+ vq_mask_labels = vq_mask_labels[1:]
159
+
160
+ for i in ignore_loss_token_ids:
161
+ assert i != -100 and i is not None
162
+ labels[labels == i] = -100
181
163
 
182
164
  assert tokens.dtype in [
183
165
  torch.int,
@@ -188,15 +170,18 @@ class Conversation:
188
170
  tokens=tokens,
189
171
  labels=labels,
190
172
  vq_parts=vq_parts,
191
- mel_parts=mel_parts,
173
+ vq_mask_tokens=vq_mask_tokens,
174
+ vq_mask_labels=vq_mask_labels,
192
175
  vq_require_losses=vq_require_losses,
193
176
  )
194
177
 
195
178
  def encode_for_inference(
196
179
  self: "Conversation",
197
- tokenizer: AutoTokenizer,
180
+ tokenizer: FishTokenizer,
198
181
  num_codebooks: int,
199
182
  ) -> EncodedMessage:
183
+ # self.visualize(tokenizer)
184
+
200
185
  encoded = self.encode(tokenizer, add_shift=False)
201
186
  tokens = encoded.tokens
202
187
  values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
@@ -205,24 +190,47 @@ class Conversation:
205
190
  if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
206
191
  return values
207
192
 
208
- semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
209
- [SEMANTIC_TOKEN, MEL_TOKEN]
210
- )
211
193
  vq_parts = encoded.vq_parts
194
+ vq_parts = [part.to(values.device) for part in vq_parts]
212
195
  vq_parts = torch.cat(vq_parts, dim=1)
213
- values[1:, tokens == semantic_id] = vq_parts
196
+ values[0, encoded.vq_mask_tokens] = vq_parts[0] + tokenizer.semantic_begin_id
197
+ values[1:, encoded.vq_mask_tokens] = vq_parts
198
+
214
199
  return values
215
200
 
216
- def visualize(self: "Conversation", tokenizer: AutoTokenizer):
217
- encoded = self.encode(tokenizer, add_shift=False)
201
+ def visualize(
202
+ self: "Conversation",
203
+ tokenizer: FishTokenizer,
204
+ ignore_loss_tokens: list[str] = [],
205
+ ):
206
+ encoded = self.encode(
207
+ tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
208
+ )
218
209
 
219
- print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="")
220
- print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="")
210
+ # Colors for alternating tokens
211
+ colors = {
212
+ "blue": "\033[94m", # Light blue
213
+ "cyan": "\033[96m", # Cyan
214
+ "green": "\033[92m", # Light green
215
+ "dark_green": "\033[32m", # Dark green
216
+ }
217
+ blue_idx = 0
218
+ green_idx = 0
219
+
220
+ def print_in_blue(x):
221
+ nonlocal blue_idx
222
+ color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
223
+ print(f"{color}{x}\033[0m", end="")
224
+ blue_idx += 1
225
+
226
+ def print_in_green(x):
227
+ nonlocal green_idx
228
+ color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
229
+ print(f"{color}{x}\033[0m", end="")
230
+ green_idx += 1
221
231
 
222
232
  for tok, lab in zip(encoded.tokens, encoded.labels):
223
- val = tokenizer.decode(tok, skip_special_tokens=False)
224
- if val == "\n":
225
- val = "\\n\n"
233
+ val = tokenizer.decode([tok])
226
234
 
227
235
  if lab == -100:
228
236
  print_in_green(val)
@@ -231,6 +239,9 @@ class Conversation:
231
239
 
232
240
  print()
233
241
 
242
+ def append(self: "Conversation", message: Message):
243
+ self.messages.append(message)
244
+
234
245
 
235
246
  if __name__ == "__main__":
236
247
  message0 = Message(
@@ -248,7 +259,7 @@ if __name__ == "__main__":
248
259
  cal_loss=True,
249
260
  )
250
261
  conversation = Conversation([message0, message1])
251
- tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
262
+ tokenizer = FishTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
252
263
  conversation.visualize(tokenizer)
253
264
 
254
265
  encoded = conversation.encode(tokenizer)
@@ -16,7 +16,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel
16
16
  from torch.utils.checkpoint import checkpoint
17
17
  from transformers import AutoTokenizer
18
18
 
19
- from fish_speech.conversation import SEMANTIC_TOKEN
19
+ from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
20
20
  from fish_speech.utils import RankedLogger
21
21
 
22
22
  from .lora import LoraConfig, setup_lora
@@ -61,6 +61,7 @@ class BaseModelArgs:
61
61
  # Dummy vars
62
62
  is_reward_model: bool = False
63
63
  share_codebook_embeddings: bool = True
64
+ scale_codebook_embeddings: bool = False
64
65
 
65
66
  def __post_init__(self):
66
67
  if self.n_local_heads == -1:
@@ -164,13 +165,17 @@ class BaseTransformerForwardResult:
164
165
 
165
166
  class BaseTransformer(nn.Module):
166
167
  def __init__(
167
- self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
168
+ self,
169
+ config: BaseModelArgs,
170
+ tokenizer: FishTokenizer | AutoTokenizer,
171
+ init_weights: bool = True,
168
172
  ) -> None:
169
173
  super().__init__()
170
174
  self.config = config
171
175
  self.tokenizer = tokenizer
172
-
173
- self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
176
+ self.semantic_token_ids = [
177
+ tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS
178
+ ]
174
179
 
175
180
  # Slow transformer
176
181
  self.embeddings = nn.Embedding(
@@ -245,8 +250,10 @@ class BaseTransformer(nn.Module):
245
250
  vocab_embeds = [self.embeddings(x[:, 0])]
246
251
  for i in range(self.config.num_codebooks):
247
252
  emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
248
- emb[x[:, 0] != self.semantic_token_id] = 0
249
- vocab_embeds.append(emb)
253
+ semantic_token_ids_tensor = torch.tensor(
254
+ self.semantic_token_ids, device=x.device
255
+ )
256
+ emb[~torch.isin(x[:, 0], semantic_token_ids_tensor)] = 0
250
257
 
251
258
  x = torch.stack(vocab_embeds, dim=3)
252
259
  x = x.sum(dim=3)
@@ -294,20 +301,45 @@ class BaseTransformer(nn.Module):
294
301
 
295
302
  def forward_generate(
296
303
  self,
297
- x: Tensor,
304
+ inp: Tensor,
298
305
  input_pos: Optional[Tensor] = None,
306
+ vq_masks: Optional[Tensor] = None, # this is not used in fact
299
307
  return_all: bool = False,
300
308
  ) -> BaseTransformerForwardResult:
301
309
  # This is used for generation, optimized for torch compile
302
- assert (
303
- self.max_seq_len != -1 and self.max_batch_size != -1
304
- ), "Please call setup_caches before forward_generate"
310
+ # assert (
311
+ # self.max_seq_len != -1 and self.max_batch_size != -1
312
+ # ), "Please call setup_caches before forward_generate"
305
313
 
306
- x = self.embed(x)
314
+ embeds = []
315
+ for i in range(self.config.num_codebooks):
316
+ if self.config.share_codebook_embeddings:
317
+ _tokens = inp[:, i + 1] + i * self.config.codebook_size
318
+ else:
319
+ _tokens = inp[:, i + 1]
307
320
 
308
- mask = self.causal_mask[
309
- None, None, input_pos, : self.max_seq_len
310
- ] # (B, N, Q, K)
321
+ emb = self.codebook_embeddings(_tokens)
322
+ embeds.append(emb)
323
+
324
+ vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
325
+ # if self.config.use_codebook_mlp:
326
+ # vq_embeds_sum = vq_embeds_sum / self.config.num_codebooks
327
+ # vq_embeds_sum = self.codebook_mlp(vq_embeds_sum)
328
+
329
+ vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & (
330
+ inp[:, 0] <= self.tokenizer.semantic_end_id
331
+ )
332
+
333
+ vq_embeds_sum[~vq_masks] = 0
334
+ x = self.embeddings(inp[:, 0]) + vq_embeds_sum
335
+
336
+ if input_pos is None:
337
+ input_pos = torch.arange(inp.shape[-1], device=x.device)
338
+ max_seq_len = inp.shape[-1]
339
+ else:
340
+ max_seq_len = self.max_seq_len
341
+
342
+ mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
311
343
  freqs_cis = self.freqs_cis[input_pos]
312
344
 
313
345
  for layer in self.layers:
@@ -320,7 +352,9 @@ class BaseTransformer(nn.Module):
320
352
  # We got slow_out here
321
353
  slow_out = self.norm(x)
322
354
 
323
- if self.config.tie_word_embeddings:
355
+ if self.config.is_reward_model:
356
+ token_logits = self.score_output(slow_out)
357
+ elif self.config.tie_word_embeddings:
324
358
  token_logits = F.linear(slow_out, self.embeddings.weight)
325
359
  else:
326
360
  token_logits = self.output(slow_out)
@@ -348,6 +382,7 @@ class BaseTransformer(nn.Module):
348
382
  max_length: int | None = None,
349
383
  lora_config: LoraConfig | None = None,
350
384
  rope_base: int | None = None,
385
+ is_agent: bool = False,
351
386
  ) -> "BaseTransformer":
352
387
  config = BaseModelArgs.from_pretrained(str(path))
353
388
  if max_length is not None:
@@ -366,7 +401,12 @@ class BaseTransformer(nn.Module):
366
401
  case _:
367
402
  raise ValueError(f"Unknown model type: {config.model_type}")
368
403
 
369
- tokenizer = AutoTokenizer.from_pretrained(str(path))
404
+ if is_agent:
405
+ tokenizer = AutoTokenizer.from_pretrained(str(path))
406
+ else:
407
+ tokenizer_path = str(path) + "/tokenizer.tiktoken"
408
+ tokenizer = FishTokenizer(tokenizer_path)
409
+
370
410
  log.info(f"Loading model from {path}, config: {config}")
371
411
  model = model_cls(config, tokenizer=tokenizer)
372
412
 
@@ -452,7 +492,7 @@ class BaseTransformer(nn.Module):
452
492
 
453
493
 
454
494
  class NaiveTransformer(BaseTransformer):
455
- def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
495
+ def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
456
496
  super().__init__(config, init_weights=False, tokenizer=tokenizer)
457
497
 
458
498
  self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
@@ -498,7 +538,7 @@ class NaiveTransformer(BaseTransformer):
498
538
 
499
539
 
500
540
  class DualARTransformer(BaseTransformer):
501
- def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
541
+ def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
502
542
  super().__init__(config, init_weights=False, tokenizer=tokenizer)
503
543
 
504
544
  # Project to fast dim if needed
@@ -654,9 +694,12 @@ class DualARTransformer(BaseTransformer):
654
694
  return codebook_logits
655
695
 
656
696
  def forward_generate(
657
- self, x: Tensor, input_pos: Optional[Tensor] = None
697
+ self,
698
+ x: Tensor,
699
+ input_pos: Optional[Tensor] = None,
700
+ vq_masks: Optional[Tensor] = None,
658
701
  ) -> TransformerForwardResult:
659
- x = super().forward_generate(x, input_pos)
702
+ x = super().forward_generate(x, input_pos, vq_masks)
660
703
  x.hidden_states = self.fast_project_in(x.hidden_states)
661
704
  return x
662
705
 
@@ -1,33 +1,8 @@
1
1
  import re
2
2
 
3
3
  SYMBOLS_MAPPING = {
4
- "\n": "",
5
- "…": ".",
6
- "“": "'",
7
- "”": "'",
8
4
  "‘": "'",
9
5
  "’": "'",
10
- "【": "",
11
- "】": "",
12
- "[": "",
13
- "]": "",
14
- "(": "",
15
- ")": "",
16
- "(": "",
17
- ")": "",
18
- "・": "",
19
- "·": "",
20
- "「": "'",
21
- "」": "'",
22
- "《": "'",
23
- "》": "'",
24
- "—": "",
25
- "~": "",
26
- "~": "",
27
- ":": ",",
28
- ";": ",",
29
- ";": ",",
30
- ":": ",",
31
6
  }
32
7
 
33
8
  REPLACE_SYMBOL_REGEX = re.compile(
@@ -57,6 +32,6 @@ def clean_text(text):
57
32
  text = EMOJI_REGEX.sub(r"", text)
58
33
 
59
34
  # Remove continuous periods (...) and commas (,,,)
60
- text = re.sub(r"[.,]{2,}", lambda m: m.group()[0], text)
35
+ text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text)
61
36
 
62
37
  return text
@@ -4,7 +4,7 @@ import string
4
4
  from fish_speech.text.clean import clean_text
5
5
 
6
6
 
7
- def utf_8_len(text):
7
+ def utf_8_len(text: str):
8
8
  return len(text.encode("utf-8"))
9
9
 
10
10