xinference 0.16.3__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +24 -2
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +219 -77
- xinference/client/restful/restful_client.py +47 -2
- xinference/constants.py +1 -0
- xinference/core/chat_interface.py +6 -1
- xinference/core/model.py +124 -34
- xinference/core/supervisor.py +180 -12
- xinference/core/utils.py +73 -4
- xinference/core/worker.py +102 -4
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/__init__.py +12 -0
- xinference/model/audio/core.py +37 -4
- xinference/model/audio/cosyvoice.py +39 -6
- xinference/model/audio/f5tts.py +200 -0
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +70 -110
- xinference/model/audio/melotts.py +110 -0
- xinference/model/audio/model_spec.json +179 -3
- xinference/model/audio/model_spec_modelscope.json +27 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/audio/whisper.py +35 -10
- xinference/model/audio/whisper_mlx.py +208 -0
- xinference/model/embedding/core.py +322 -6
- xinference/model/embedding/model_spec.json +8 -1
- xinference/model/embedding/model_spec_modelscope.json +9 -1
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +145 -4
- xinference/model/image/model_spec_modelscope.json +150 -4
- xinference/model/image/stable_diffusion/core.py +50 -15
- xinference/model/llm/__init__.py +6 -2
- xinference/model/llm/llm_family.json +1055 -93
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +1031 -78
- xinference/model/llm/memory.py +1 -1
- xinference/model/llm/mlx/core.py +285 -47
- xinference/model/llm/sglang/core.py +2 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/cogagent.py +272 -0
- xinference/model/llm/transformers/core.py +3 -0
- xinference/model/llm/transformers/glm_edge_v.py +230 -0
- xinference/model/llm/transformers/qwen2_vl.py +12 -1
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +55 -4
- xinference/model/llm/vllm/core.py +137 -12
- xinference/model/llm/vllm/xavier/__init__.py +13 -0
- xinference/model/llm/vllm/xavier/allocator.py +74 -0
- xinference/model/llm/vllm/xavier/block.py +111 -0
- xinference/model/llm/vllm/xavier/block_manager.py +71 -0
- xinference/model/llm/vllm/xavier/block_tracker.py +129 -0
- xinference/model/llm/vllm/xavier/collective.py +74 -0
- xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
- xinference/model/llm/vllm/xavier/engine.py +247 -0
- xinference/model/llm/vllm/xavier/executor.py +134 -0
- xinference/model/llm/vllm/xavier/scheduler.py +438 -0
- xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
- xinference/model/llm/vllm/xavier/test/test_xavier.py +147 -0
- xinference/model/llm/vllm/xavier/transfer.py +319 -0
- xinference/model/rerank/core.py +11 -4
- xinference/model/video/diffusers.py +14 -0
- xinference/model/video/model_spec.json +15 -0
- xinference/model/video/model_spec_modelscope.json +16 -0
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +266 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +137 -29
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +17 -11
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +2 -2
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +34 -18
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +484 -72
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +170 -0
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- xinference/thirdparty/melo/api.py +135 -0
- xinference/thirdparty/melo/app.py +61 -0
- xinference/thirdparty/melo/attentions.py +459 -0
- xinference/thirdparty/melo/commons.py +160 -0
- xinference/thirdparty/melo/configs/config.json +94 -0
- xinference/thirdparty/melo/data/example/metadata.list +20 -0
- xinference/thirdparty/melo/data_utils.py +413 -0
- xinference/thirdparty/melo/download_utils.py +67 -0
- xinference/thirdparty/melo/infer.py +25 -0
- xinference/thirdparty/melo/init_downloads.py +14 -0
- xinference/thirdparty/melo/losses.py +58 -0
- xinference/thirdparty/melo/main.py +36 -0
- xinference/thirdparty/melo/mel_processing.py +174 -0
- xinference/thirdparty/melo/models.py +1030 -0
- xinference/thirdparty/melo/modules.py +598 -0
- xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
- xinference/thirdparty/melo/monotonic_align/core.py +46 -0
- xinference/thirdparty/melo/preprocess_text.py +135 -0
- xinference/thirdparty/melo/split_utils.py +174 -0
- xinference/thirdparty/melo/text/__init__.py +35 -0
- xinference/thirdparty/melo/text/chinese.py +199 -0
- xinference/thirdparty/melo/text/chinese_bert.py +107 -0
- xinference/thirdparty/melo/text/chinese_mix.py +253 -0
- xinference/thirdparty/melo/text/cleaner.py +36 -0
- xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
- xinference/thirdparty/melo/text/cmudict.rep +129530 -0
- xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
- xinference/thirdparty/melo/text/english.py +284 -0
- xinference/thirdparty/melo/text/english_bert.py +39 -0
- xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
- xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
- xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
- xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
- xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
- xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
- xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
- xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
- xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
- xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
- xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
- xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
- xinference/thirdparty/melo/text/french.py +94 -0
- xinference/thirdparty/melo/text/french_bert.py +39 -0
- xinference/thirdparty/melo/text/japanese.py +647 -0
- xinference/thirdparty/melo/text/japanese_bert.py +49 -0
- xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
- xinference/thirdparty/melo/text/korean.py +192 -0
- xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
- xinference/thirdparty/melo/text/spanish.py +122 -0
- xinference/thirdparty/melo/text/spanish_bert.py +39 -0
- xinference/thirdparty/melo/text/symbols.py +290 -0
- xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
- xinference/thirdparty/melo/train.py +635 -0
- xinference/thirdparty/melo/train.sh +19 -0
- xinference/thirdparty/melo/transforms.py +209 -0
- xinference/thirdparty/melo/utils.py +424 -0
- xinference/types.py +17 -1
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
- xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
- xinference/web/ui/build/static/js/main.b0936c54.js +3 -0
- xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +67 -3
- xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
- xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
- xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
- xinference/web/ui/node_modules/i18next/package.json +129 -0
- xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
- xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
- xinference/web/ui/node_modules/react-i18next/package.json +162 -0
- xinference/web/ui/node_modules/void-elements/package.json +34 -0
- xinference/web/ui/package-lock.json +69 -3
- xinference/web/ui/package.json +2 -0
- xinference/web/ui/src/locales/en.json +186 -0
- xinference/web/ui/src/locales/zh.json +186 -0
- {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/METADATA +96 -36
- {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/RECORD +335 -146
- {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/WHEEL +1 -1
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -440
- xinference/thirdparty/fish_speech/tools/commons.py +0 -35
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -34
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/webui.py +0 -485
- xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
- xinference/web/ui/build/static/js/main.2f269bb3.js +0 -3
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
- /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
- /xinference/thirdparty/{cosyvoice/flow → melo}/__init__.py +0 -0
- /xinference/thirdparty/{cosyvoice/hifigan → melo/text/english_utils}/__init__.py +0 -0
- /xinference/thirdparty/{cosyvoice/llm → melo/text/es_phonemizer}/__init__.py +0 -0
- /xinference/thirdparty/{fish_speech/fish_speech/configs → melo/text/fr_phonemizer}/__init__.py +0 -0
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
- {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/LICENSE +0 -0
- {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from argparse import ArgumentParser
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import pyrootutils
|
|
6
|
+
import torch
|
|
7
|
+
from loguru import logger
|
|
8
|
+
|
|
9
|
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
10
|
+
|
|
11
|
+
from tools.inference_engine import TTSInferenceEngine
|
|
12
|
+
from tools.llama.generate import launch_thread_safe_queue
|
|
13
|
+
from tools.schema import ServeTTSRequest
|
|
14
|
+
from tools.vqgan.inference import load_model as load_decoder_model
|
|
15
|
+
from tools.webui import build_app
|
|
16
|
+
from tools.webui.inference import get_inference_wrapper
|
|
17
|
+
|
|
18
|
+
# Make einx happy
|
|
19
|
+
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def parse_args():
|
|
23
|
+
parser = ArgumentParser()
|
|
24
|
+
parser.add_argument(
|
|
25
|
+
"--llama-checkpoint-path",
|
|
26
|
+
type=Path,
|
|
27
|
+
default="checkpoints/fish-speech-1.5",
|
|
28
|
+
)
|
|
29
|
+
parser.add_argument(
|
|
30
|
+
"--decoder-checkpoint-path",
|
|
31
|
+
type=Path,
|
|
32
|
+
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
|
33
|
+
)
|
|
34
|
+
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
|
35
|
+
parser.add_argument("--device", type=str, default="cuda")
|
|
36
|
+
parser.add_argument("--half", action="store_true")
|
|
37
|
+
parser.add_argument("--compile", action="store_true")
|
|
38
|
+
parser.add_argument("--max-gradio-length", type=int, default=0)
|
|
39
|
+
parser.add_argument("--theme", type=str, default="light")
|
|
40
|
+
|
|
41
|
+
return parser.parse_args()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
if __name__ == "__main__":
|
|
45
|
+
args = parse_args()
|
|
46
|
+
args.precision = torch.half if args.half else torch.bfloat16
|
|
47
|
+
|
|
48
|
+
# Check if MPS or CUDA is available
|
|
49
|
+
if torch.backends.mps.is_available():
|
|
50
|
+
args.device = "mps"
|
|
51
|
+
logger.info("mps is available, running on mps.")
|
|
52
|
+
elif not torch.cuda.is_available():
|
|
53
|
+
logger.info("CUDA is not available, running on CPU.")
|
|
54
|
+
args.device = "cpu"
|
|
55
|
+
|
|
56
|
+
logger.info("Loading Llama model...")
|
|
57
|
+
llama_queue = launch_thread_safe_queue(
|
|
58
|
+
checkpoint_path=args.llama_checkpoint_path,
|
|
59
|
+
device=args.device,
|
|
60
|
+
precision=args.precision,
|
|
61
|
+
compile=args.compile,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
logger.info("Loading VQ-GAN model...")
|
|
65
|
+
decoder_model = load_decoder_model(
|
|
66
|
+
config_name=args.decoder_config_name,
|
|
67
|
+
checkpoint_path=args.decoder_checkpoint_path,
|
|
68
|
+
device=args.device,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
logger.info("Decoder model loaded, warming up...")
|
|
72
|
+
|
|
73
|
+
# Create the inference engine
|
|
74
|
+
inference_engine = TTSInferenceEngine(
|
|
75
|
+
llama_queue=llama_queue,
|
|
76
|
+
decoder_model=decoder_model,
|
|
77
|
+
compile=args.compile,
|
|
78
|
+
precision=args.precision,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
|
82
|
+
list(
|
|
83
|
+
inference_engine.inference(
|
|
84
|
+
ServeTTSRequest(
|
|
85
|
+
text="Hello world.",
|
|
86
|
+
references=[],
|
|
87
|
+
reference_id=None,
|
|
88
|
+
max_new_tokens=1024,
|
|
89
|
+
chunk_length=200,
|
|
90
|
+
top_p=0.7,
|
|
91
|
+
repetition_penalty=1.5,
|
|
92
|
+
temperature=0.7,
|
|
93
|
+
format="wav",
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
logger.info("Warming up done, launching the web UI...")
|
|
99
|
+
|
|
100
|
+
# Get the inference function with the immutable arguments
|
|
101
|
+
inference_fct = get_inference_wrapper(inference_engine)
|
|
102
|
+
|
|
103
|
+
app = build_app(inference_fct, args.theme)
|
|
104
|
+
app.launch(show_api=True)
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import queue
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Annotated, Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from pydantic import BaseModel, Field, conint, conlist
|
|
8
|
+
from pydantic.functional_validators import SkipValidation
|
|
9
|
+
|
|
10
|
+
from fish_speech.conversation import Message, TextPart, VQPart
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ServeVQPart(BaseModel):
|
|
14
|
+
type: Literal["vq"] = "vq"
|
|
15
|
+
codes: SkipValidation[list[list[int]]]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ServeTextPart(BaseModel):
|
|
19
|
+
type: Literal["text"] = "text"
|
|
20
|
+
text: str
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ServeAudioPart(BaseModel):
|
|
24
|
+
type: Literal["audio"] = "audio"
|
|
25
|
+
audio: bytes
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class ASRPackRequest:
|
|
30
|
+
audio: torch.Tensor
|
|
31
|
+
result_queue: queue.Queue
|
|
32
|
+
language: str
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ServeASRRequest(BaseModel):
|
|
36
|
+
# The audio should be an uncompressed PCM float16 audio
|
|
37
|
+
audios: list[bytes]
|
|
38
|
+
sample_rate: int = 44100
|
|
39
|
+
language: Literal["zh", "en", "ja", "auto"] = "auto"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ServeASRTranscription(BaseModel):
|
|
43
|
+
text: str
|
|
44
|
+
duration: float
|
|
45
|
+
huge_gap: bool
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ServeASRSegment(BaseModel):
|
|
49
|
+
text: str
|
|
50
|
+
start: float
|
|
51
|
+
end: float
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ServeTimedASRResponse(BaseModel):
|
|
55
|
+
text: str
|
|
56
|
+
segments: list[ServeASRSegment]
|
|
57
|
+
duration: float
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ServeASRResponse(BaseModel):
|
|
61
|
+
transcriptions: list[ServeASRTranscription]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class ServeMessage(BaseModel):
|
|
65
|
+
role: Literal["system", "assistant", "user"]
|
|
66
|
+
parts: list[ServeVQPart | ServeTextPart]
|
|
67
|
+
|
|
68
|
+
def to_conversation_message(self):
|
|
69
|
+
new_message = Message(role=self.role, parts=[])
|
|
70
|
+
if self.role == "assistant":
|
|
71
|
+
new_message.modality = "voice"
|
|
72
|
+
|
|
73
|
+
for part in self.parts:
|
|
74
|
+
if isinstance(part, ServeTextPart):
|
|
75
|
+
new_message.parts.append(TextPart(text=part.text))
|
|
76
|
+
elif isinstance(part, ServeVQPart):
|
|
77
|
+
new_message.parts.append(
|
|
78
|
+
VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
raise ValueError(f"Unsupported part type: {part}")
|
|
82
|
+
|
|
83
|
+
return new_message
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ServeChatRequest(BaseModel):
|
|
87
|
+
messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
|
|
88
|
+
max_new_tokens: int = 1024
|
|
89
|
+
top_p: float = 0.7
|
|
90
|
+
repetition_penalty: float = 1.2
|
|
91
|
+
temperature: float = 0.7
|
|
92
|
+
streaming: bool = False
|
|
93
|
+
num_samples: int = 1
|
|
94
|
+
early_stop_threshold: float = 1.0
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class ServeVQGANEncodeRequest(BaseModel):
|
|
98
|
+
# The audio here should be in wav, mp3, etc
|
|
99
|
+
audios: list[bytes]
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class ServeVQGANEncodeResponse(BaseModel):
|
|
103
|
+
tokens: SkipValidation[list[list[list[int]]]]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class ServeVQGANDecodeRequest(BaseModel):
|
|
107
|
+
tokens: SkipValidation[list[list[list[int]]]]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class ServeVQGANDecodeResponse(BaseModel):
|
|
111
|
+
# The audio here should be in PCM float16 format
|
|
112
|
+
audios: list[bytes]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class ServeForwardMessage(BaseModel):
|
|
116
|
+
role: str
|
|
117
|
+
content: str
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class ServeResponse(BaseModel):
|
|
121
|
+
messages: list[ServeMessage]
|
|
122
|
+
finish_reason: Literal["stop", "error"] | None = None
|
|
123
|
+
stats: dict[str, int | float | str] = {}
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class ServeStreamDelta(BaseModel):
|
|
127
|
+
role: Literal["system", "assistant", "user"] | None = None
|
|
128
|
+
part: ServeVQPart | ServeTextPart | None = None
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class ServeStreamResponse(BaseModel):
|
|
132
|
+
sample_id: int = 0
|
|
133
|
+
delta: ServeStreamDelta | None = None
|
|
134
|
+
finish_reason: Literal["stop", "error"] | None = None
|
|
135
|
+
stats: dict[str, int | float | str] | None = None
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class ServeReferenceAudio(BaseModel):
|
|
139
|
+
audio: bytes
|
|
140
|
+
text: str
|
|
141
|
+
|
|
142
|
+
def __repr__(self) -> str:
|
|
143
|
+
return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class ServeTTSRequest(BaseModel):
|
|
147
|
+
text: str
|
|
148
|
+
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
|
|
149
|
+
# Audio format
|
|
150
|
+
format: Literal["wav", "pcm", "mp3"] = "wav"
|
|
151
|
+
# References audios for in-context learning
|
|
152
|
+
references: list[ServeReferenceAudio] = []
|
|
153
|
+
# Reference id
|
|
154
|
+
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
|
|
155
|
+
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
|
|
156
|
+
reference_id: str | None = None
|
|
157
|
+
seed: int | None = None
|
|
158
|
+
use_memory_cache: Literal["on", "off"] = "off"
|
|
159
|
+
# Normalize text for en & zh, this increase stability for numbers
|
|
160
|
+
normalize: bool = True
|
|
161
|
+
# not usually used below
|
|
162
|
+
streaming: bool = False
|
|
163
|
+
max_new_tokens: int = 1024
|
|
164
|
+
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
165
|
+
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
|
166
|
+
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
167
|
+
|
|
168
|
+
class Config:
|
|
169
|
+
# Allow arbitrary types for pytorch related types
|
|
170
|
+
arbitrary_types_allowed = True
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import struct
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import ormsgpack
|
|
5
|
+
|
|
6
|
+
from tools.server.agent.generate import generate_responses
|
|
7
|
+
from tools.server.agent.pre_generation_utils import prepare_messages
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def execute_request(input_queue, tokenizer, config, request, device):
|
|
11
|
+
"""
|
|
12
|
+
This function prepares the conversation, encodes the request,
|
|
13
|
+
sends the generation request, and handles decoding/streaming.
|
|
14
|
+
It returns a response generator (ServeResponse or ServeStreamResponse).
|
|
15
|
+
"""
|
|
16
|
+
prompt, im_end_id = prepare_messages(request, tokenizer, config)
|
|
17
|
+
yield from generate_responses(
|
|
18
|
+
input_queue, tokenizer, config, request, prompt, im_end_id, device
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def response_generator(req, llama_queue, tokenizer, config, device):
|
|
23
|
+
"""
|
|
24
|
+
Non-streaming response wrapper for the chat endpoint.
|
|
25
|
+
Only returns the final result.
|
|
26
|
+
"""
|
|
27
|
+
generator = execute_request(llama_queue, tokenizer, config, req, device)
|
|
28
|
+
return next(generator)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
async def streaming_generator(req, llama_queue, tokenizer, config, device, json_mode):
|
|
32
|
+
"""
|
|
33
|
+
Streaming response wrapper for the chat endpoint.
|
|
34
|
+
Returns the response in chunks.
|
|
35
|
+
"""
|
|
36
|
+
generator = execute_request(llama_queue, tokenizer, config, req, device)
|
|
37
|
+
for i in generator:
|
|
38
|
+
if json_mode:
|
|
39
|
+
body = i.model_dump_json().encode("utf-8")
|
|
40
|
+
yield b"data: " + body + b"\n\n"
|
|
41
|
+
else:
|
|
42
|
+
body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
|
43
|
+
yield struct.pack("I", len(body)) + body
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_response_generator(
|
|
47
|
+
llama_queue, tokenizer, config, req, device, json_mode
|
|
48
|
+
) -> partial:
|
|
49
|
+
"""
|
|
50
|
+
Get the correct response generator based on the request.
|
|
51
|
+
"""
|
|
52
|
+
if not req.streaming:
|
|
53
|
+
return partial(response_generator, req, llama_queue, tokenizer, config, device)
|
|
54
|
+
else:
|
|
55
|
+
return partial(
|
|
56
|
+
streaming_generator, req, llama_queue, tokenizer, config, device, json_mode
|
|
57
|
+
)
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
import time
|
|
2
|
+
|
|
3
|
+
from tools.schema import ServeMessage, ServeResponse, ServeStreamResponse
|
|
4
|
+
from tools.server.agent.generation_utils import (
|
|
5
|
+
initialize_decode_buffers,
|
|
6
|
+
process_response_tokens,
|
|
7
|
+
send_reset_buffer,
|
|
8
|
+
)
|
|
9
|
+
from tools.server.agent.pre_generation_utils import (
|
|
10
|
+
create_generation_request,
|
|
11
|
+
send_generation_request,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def generate_responses(
|
|
16
|
+
input_queue, tokenizer, config, request, prompt, im_end_id, device
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Main generation function that handles the conversation, encodes the request,
|
|
20
|
+
sends the generation request, and handles decoding/streaming.
|
|
21
|
+
It returns a response generator (ServeResponse or ServeStreamResponse).
|
|
22
|
+
"""
|
|
23
|
+
stats = {}
|
|
24
|
+
start = time.time()
|
|
25
|
+
stats["start_time"] = start
|
|
26
|
+
stats["tokens_count"] = 0
|
|
27
|
+
|
|
28
|
+
# Prepare and send the generation request
|
|
29
|
+
req = create_generation_request(prompt, request, im_end_id, device)
|
|
30
|
+
response_queue = send_generation_request(input_queue, req)
|
|
31
|
+
decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples)
|
|
32
|
+
|
|
33
|
+
while True:
|
|
34
|
+
response = response_queue.get()
|
|
35
|
+
|
|
36
|
+
# Handle abnormal finish or error
|
|
37
|
+
if response in ["stop", "error"]:
|
|
38
|
+
finish_reason = response
|
|
39
|
+
break
|
|
40
|
+
|
|
41
|
+
# Process the response tokens
|
|
42
|
+
is_first_token = stats["tokens_count"] == 0
|
|
43
|
+
responses = process_response_tokens(
|
|
44
|
+
response,
|
|
45
|
+
tokenizer,
|
|
46
|
+
config,
|
|
47
|
+
request,
|
|
48
|
+
decode_buffer,
|
|
49
|
+
parts,
|
|
50
|
+
finished,
|
|
51
|
+
im_end_id,
|
|
52
|
+
stats,
|
|
53
|
+
start,
|
|
54
|
+
is_first_token,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Yield the responses if streaming
|
|
58
|
+
if request.streaming and responses:
|
|
59
|
+
for r in responses:
|
|
60
|
+
yield r
|
|
61
|
+
|
|
62
|
+
stats["tokens_count"] += 1
|
|
63
|
+
|
|
64
|
+
# Check if all samples are finished
|
|
65
|
+
if all(finished):
|
|
66
|
+
finish_reason = "stop"
|
|
67
|
+
break
|
|
68
|
+
|
|
69
|
+
# Finalize the response
|
|
70
|
+
final_responses = finalize_response(
|
|
71
|
+
request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
|
|
72
|
+
)
|
|
73
|
+
for fr in final_responses:
|
|
74
|
+
yield fr
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def finalize_response(
|
|
78
|
+
request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
|
|
79
|
+
):
|
|
80
|
+
"""
|
|
81
|
+
Finalize the response by sending the remaining text buffers.
|
|
82
|
+
"""
|
|
83
|
+
responses = []
|
|
84
|
+
|
|
85
|
+
# Send the remaining text buffers
|
|
86
|
+
for sample_id in range(request.num_samples):
|
|
87
|
+
responses.extend(
|
|
88
|
+
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Calculate the final stats
|
|
92
|
+
stats["total_time"] = (time.time() - stats["start_time"]) * 1000
|
|
93
|
+
stats["total_tokens"] = stats["tokens_count"]
|
|
94
|
+
|
|
95
|
+
# If streaming, send the final chunks for each sample
|
|
96
|
+
if request.streaming:
|
|
97
|
+
for sample_id in range(request.num_samples):
|
|
98
|
+
if finished[sample_id]:
|
|
99
|
+
continue
|
|
100
|
+
responses.append(
|
|
101
|
+
ServeStreamResponse(
|
|
102
|
+
finish_reason=finish_reason, stats=stats, sample_id=sample_id
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
# If not streaming, send the full messages for each sample
|
|
107
|
+
full_messages = [
|
|
108
|
+
ServeMessage(role="assistant", parts=parts[i])
|
|
109
|
+
for i in range(request.num_samples)
|
|
110
|
+
]
|
|
111
|
+
responses.append(
|
|
112
|
+
ServeResponse(
|
|
113
|
+
messages=full_messages,
|
|
114
|
+
finish_reason=finish_reason,
|
|
115
|
+
stats=stats,
|
|
116
|
+
)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
return responses
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import time
|
|
2
|
+
|
|
3
|
+
from tools.schema import (
|
|
4
|
+
ServeStreamDelta,
|
|
5
|
+
ServeStreamResponse,
|
|
6
|
+
ServeTextPart,
|
|
7
|
+
ServeVQPart,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def initialize_decode_buffers(num_samples):
|
|
12
|
+
"""Initialise the decode buffers for each sample."""
|
|
13
|
+
decode_buffer = [[] for _ in range(num_samples)]
|
|
14
|
+
parts = [[] for _ in range(num_samples)]
|
|
15
|
+
finished = [False for _ in range(num_samples)]
|
|
16
|
+
return decode_buffer, parts, finished
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request):
|
|
20
|
+
"""Send the remaining text buffer for a sample."""
|
|
21
|
+
if len(decode_buffer[sample_id]) == 0:
|
|
22
|
+
return []
|
|
23
|
+
|
|
24
|
+
decoded = tokenizer.decode(decode_buffer[sample_id])
|
|
25
|
+
part = ServeTextPart(text=decoded)
|
|
26
|
+
|
|
27
|
+
responses = []
|
|
28
|
+
if request.streaming:
|
|
29
|
+
responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part)))
|
|
30
|
+
else:
|
|
31
|
+
parts[sample_id].append(part)
|
|
32
|
+
|
|
33
|
+
decode_buffer[sample_id] = []
|
|
34
|
+
return responses
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def handle_semantic_tokens(tokens, config, sample_id, parts, request):
|
|
38
|
+
"""Handle the semantic tokens returned by the model."""
|
|
39
|
+
responses = []
|
|
40
|
+
_tokens = tokens[1:].clone()
|
|
41
|
+
|
|
42
|
+
if not config.share_codebook_embeddings:
|
|
43
|
+
for i in range(len(_tokens)):
|
|
44
|
+
_tokens[i] -= config.codebook_size * i
|
|
45
|
+
|
|
46
|
+
# If streaming, send the VQ parts directly
|
|
47
|
+
if request.streaming:
|
|
48
|
+
responses.append(
|
|
49
|
+
ServeStreamResponse(
|
|
50
|
+
sample_id=sample_id,
|
|
51
|
+
delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
|
|
52
|
+
)
|
|
53
|
+
)
|
|
54
|
+
else:
|
|
55
|
+
# If not streaming, accumulate the VQ parts
|
|
56
|
+
if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart):
|
|
57
|
+
parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
|
|
58
|
+
else:
|
|
59
|
+
# Accumulate the codes
|
|
60
|
+
for codebook_id, value in enumerate(_tokens):
|
|
61
|
+
parts[sample_id][-1].codes[codebook_id].append(value.item())
|
|
62
|
+
|
|
63
|
+
return responses
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def process_response_tokens(
|
|
67
|
+
response,
|
|
68
|
+
tokenizer,
|
|
69
|
+
config,
|
|
70
|
+
request,
|
|
71
|
+
decode_buffer,
|
|
72
|
+
parts,
|
|
73
|
+
finished,
|
|
74
|
+
im_end_id,
|
|
75
|
+
stats,
|
|
76
|
+
start,
|
|
77
|
+
is_first_token,
|
|
78
|
+
):
|
|
79
|
+
"""Process the response tokens returned by the model."""
|
|
80
|
+
responses = []
|
|
81
|
+
for sample_id, tokens in enumerate(response):
|
|
82
|
+
if finished[sample_id]:
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
# End of the conversation
|
|
86
|
+
if tokens[0] == im_end_id:
|
|
87
|
+
finished[sample_id] = True
|
|
88
|
+
# Send the remaining text buffer
|
|
89
|
+
responses.extend(
|
|
90
|
+
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
|
|
91
|
+
)
|
|
92
|
+
if request.streaming:
|
|
93
|
+
responses.append(
|
|
94
|
+
ServeStreamResponse(
|
|
95
|
+
sample_id=sample_id,
|
|
96
|
+
finish_reason="stop",
|
|
97
|
+
stats=stats,
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
# Check if the token is semantic
|
|
103
|
+
is_semantic = (
|
|
104
|
+
tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if is_semantic:
|
|
108
|
+
# Before the semantic tokens, send the remaining text buffer
|
|
109
|
+
responses.extend(
|
|
110
|
+
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
|
|
111
|
+
)
|
|
112
|
+
responses.extend(
|
|
113
|
+
handle_semantic_tokens(tokens, config, sample_id, parts, request)
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
# Accumulate the text tokens (not implemented?)
|
|
117
|
+
decode_buffer[sample_id].append(tokens[0, 0])
|
|
118
|
+
|
|
119
|
+
if is_first_token:
|
|
120
|
+
stats["time_to_first_token"] = (time.time() - start) * 1000
|
|
121
|
+
|
|
122
|
+
return responses
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import queue
|
|
2
|
+
|
|
3
|
+
from fish_speech.conversation import Conversation, Message
|
|
4
|
+
from fish_speech.tokenizer import IM_END_TOKEN
|
|
5
|
+
from tools.llama.generate import GenerateRequest
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def prepare_messages(request, tokenizer, config):
|
|
9
|
+
"""
|
|
10
|
+
Reorganise the provided list of messages into a conversation.
|
|
11
|
+
Encode the conversation for inference.
|
|
12
|
+
"""
|
|
13
|
+
# Convert the messages to ConversationMessage objects
|
|
14
|
+
messages = [msg.to_conversation_message() for msg in request.messages]
|
|
15
|
+
|
|
16
|
+
if len(messages) < 1:
|
|
17
|
+
raise ValueError("At least one message is required")
|
|
18
|
+
|
|
19
|
+
# Check the last message to determine the next step
|
|
20
|
+
last_role = messages[-1].role
|
|
21
|
+
match last_role:
|
|
22
|
+
case "user":
|
|
23
|
+
# The last message is from the user, ask the assistant to respond with a new message
|
|
24
|
+
messages.append(
|
|
25
|
+
Message(role="assistant", parts=[], add_im_end=False, modality="voice")
|
|
26
|
+
)
|
|
27
|
+
case "raw":
|
|
28
|
+
# The last message is raw text, ask the assistant to complete it
|
|
29
|
+
messages[-1].add_im_start = False
|
|
30
|
+
messages[-1].add_im_end = False
|
|
31
|
+
messages[-1].modality = "voice"
|
|
32
|
+
case "assistant":
|
|
33
|
+
# The last message is from the assistant, ask the assistant to continue
|
|
34
|
+
messages[-1].add_im_end = False
|
|
35
|
+
case _:
|
|
36
|
+
# We expect it to be assistant if not user or raw
|
|
37
|
+
raise ValueError("The last message must be from the assistant, user or raw")
|
|
38
|
+
|
|
39
|
+
# Create a conversation object and encode it for inference
|
|
40
|
+
conv = Conversation(messages=messages)
|
|
41
|
+
prompt = conv.encode_for_inference(
|
|
42
|
+
tokenizer=tokenizer, num_codebooks=config.num_codebooks
|
|
43
|
+
)
|
|
44
|
+
im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
|
|
45
|
+
|
|
46
|
+
return prompt, im_end_id
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def create_generation_request(prompt, request, im_end_id, device):
|
|
50
|
+
"""
|
|
51
|
+
Convert the request into a dictionary that can be sent to the model for generation.
|
|
52
|
+
"""
|
|
53
|
+
req = {
|
|
54
|
+
"prompt": prompt.to(device),
|
|
55
|
+
"max_new_tokens": request.max_new_tokens,
|
|
56
|
+
"im_end_id": im_end_id,
|
|
57
|
+
"temperature": request.temperature,
|
|
58
|
+
"top_p": request.top_p,
|
|
59
|
+
"repetition_penalty": request.repetition_penalty,
|
|
60
|
+
"num_samples": request.num_samples,
|
|
61
|
+
"early_stop_threshold": request.early_stop_threshold,
|
|
62
|
+
}
|
|
63
|
+
return req
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def send_generation_request(input_queue, req):
|
|
67
|
+
"""
|
|
68
|
+
Send the generation request to the model and return a queue to get the response.
|
|
69
|
+
"""
|
|
70
|
+
response_queue = queue.Queue()
|
|
71
|
+
input_queue.put(GenerateRequest(req, response_queue))
|
|
72
|
+
return response_queue
|