xinference 1.0.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +77 -71
- xinference/core/chat_interface.py +6 -1
- xinference/core/model.py +79 -19
- xinference/core/supervisor.py +172 -10
- xinference/core/utils.py +12 -8
- xinference/core/worker.py +102 -4
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/core.py +16 -0
- xinference/model/audio/cosyvoice.py +39 -6
- xinference/model/audio/f5tts.py +200 -0
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +36 -111
- xinference/model/audio/melotts.py +110 -0
- xinference/model/audio/model_spec.json +99 -3
- xinference/model/audio/model_spec_modelscope.json +27 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/audio/whisper.py +35 -10
- xinference/model/embedding/core.py +203 -142
- xinference/model/embedding/model_spec.json +7 -0
- xinference/model/embedding/model_spec_modelscope.json +8 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +145 -4
- xinference/model/image/model_spec_modelscope.json +150 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/__init__.py +4 -2
- xinference/model/llm/llm_family.json +536 -53
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +454 -20
- xinference/model/llm/memory.py +1 -1
- xinference/model/llm/mlx/core.py +248 -52
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/cogagent.py +272 -0
- xinference/model/llm/transformers/core.py +2 -0
- xinference/model/llm/transformers/qwen2_vl.py +12 -1
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +36 -4
- xinference/model/llm/vllm/core.py +53 -10
- xinference/model/llm/vllm/xavier/__init__.py +13 -0
- xinference/model/llm/vllm/xavier/allocator.py +74 -0
- xinference/model/llm/vllm/xavier/block.py +111 -0
- xinference/model/llm/vllm/xavier/block_manager.py +71 -0
- xinference/model/llm/vllm/xavier/block_tracker.py +129 -0
- xinference/model/llm/vllm/xavier/collective.py +74 -0
- xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
- xinference/model/llm/vllm/xavier/engine.py +247 -0
- xinference/model/llm/vllm/xavier/executor.py +134 -0
- xinference/model/llm/vllm/xavier/scheduler.py +438 -0
- xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
- xinference/model/llm/vllm/xavier/test/test_xavier.py +147 -0
- xinference/model/llm/vllm/xavier/transfer.py +319 -0
- xinference/model/video/diffusers.py +14 -0
- xinference/model/video/model_spec.json +15 -0
- xinference/model/video/model_spec_modelscope.json +16 -0
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- xinference/thirdparty/melo/api.py +135 -0
- xinference/thirdparty/melo/app.py +61 -0
- xinference/thirdparty/melo/attentions.py +459 -0
- xinference/thirdparty/melo/commons.py +160 -0
- xinference/thirdparty/melo/configs/config.json +94 -0
- xinference/thirdparty/melo/data/example/metadata.list +20 -0
- xinference/thirdparty/melo/data_utils.py +413 -0
- xinference/thirdparty/melo/download_utils.py +67 -0
- xinference/thirdparty/melo/infer.py +25 -0
- xinference/thirdparty/melo/init_downloads.py +14 -0
- xinference/thirdparty/melo/losses.py +58 -0
- xinference/thirdparty/melo/main.py +36 -0
- xinference/thirdparty/melo/mel_processing.py +174 -0
- xinference/thirdparty/melo/models.py +1030 -0
- xinference/thirdparty/melo/modules.py +598 -0
- xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
- xinference/thirdparty/melo/monotonic_align/core.py +46 -0
- xinference/thirdparty/melo/preprocess_text.py +135 -0
- xinference/thirdparty/melo/split_utils.py +174 -0
- xinference/thirdparty/melo/text/__init__.py +35 -0
- xinference/thirdparty/melo/text/chinese.py +199 -0
- xinference/thirdparty/melo/text/chinese_bert.py +107 -0
- xinference/thirdparty/melo/text/chinese_mix.py +253 -0
- xinference/thirdparty/melo/text/cleaner.py +36 -0
- xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
- xinference/thirdparty/melo/text/cmudict.rep +129530 -0
- xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
- xinference/thirdparty/melo/text/english.py +284 -0
- xinference/thirdparty/melo/text/english_bert.py +39 -0
- xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
- xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
- xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
- xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
- xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
- xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
- xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
- xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
- xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
- xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
- xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
- xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
- xinference/thirdparty/melo/text/french.py +94 -0
- xinference/thirdparty/melo/text/french_bert.py +39 -0
- xinference/thirdparty/melo/text/japanese.py +647 -0
- xinference/thirdparty/melo/text/japanese_bert.py +49 -0
- xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
- xinference/thirdparty/melo/text/korean.py +192 -0
- xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
- xinference/thirdparty/melo/text/spanish.py +122 -0
- xinference/thirdparty/melo/text/spanish_bert.py +39 -0
- xinference/thirdparty/melo/text/symbols.py +290 -0
- xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
- xinference/thirdparty/melo/train.py +635 -0
- xinference/thirdparty/melo/train.sh +19 -0
- xinference/thirdparty/melo/transforms.py +209 -0
- xinference/thirdparty/melo/utils.py +424 -0
- xinference/types.py +15 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
- xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
- xinference/web/ui/build/static/js/main.b0936c54.js +3 -0
- xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +67 -3
- xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
- xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
- xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
- xinference/web/ui/node_modules/i18next/package.json +129 -0
- xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
- xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
- xinference/web/ui/node_modules/react-i18next/package.json +162 -0
- xinference/web/ui/node_modules/void-elements/package.json +34 -0
- xinference/web/ui/package-lock.json +69 -3
- xinference/web/ui/package.json +2 -0
- xinference/web/ui/src/locales/en.json +186 -0
- xinference/web/ui/src/locales/zh.json +186 -0
- {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/METADATA +68 -32
- {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/RECORD +316 -122
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
- xinference/web/ui/build/static/js/main.2f269bb3.js +0 -3
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
- /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
- /xinference/thirdparty/{cosyvoice/flow → melo}/__init__.py +0 -0
- /xinference/thirdparty/{cosyvoice/hifigan → melo/text/english_utils}/__init__.py +0 -0
- /xinference/thirdparty/{cosyvoice/llm → melo/text/es_phonemizer}/__init__.py +0 -0
- /xinference/thirdparty/{fish_speech/tools → melo/text/fr_phonemizer}/__init__.py +0 -0
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
- {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/LICENSE +0 -0
- {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/WHEEL +0 -0
- {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.1.dist-info → xinference-1.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import wave
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Literal, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class InferenceResult:
|
|
13
|
+
code: Literal["header", "segment", "error", "final"]
|
|
14
|
+
audio: Optional[Tuple[int, np.ndarray | bytes]]
|
|
15
|
+
error: Optional[Exception]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def normalize_text(user_input: str, use_normalization: bool) -> str:
|
|
19
|
+
"""Normalize user input text if needed."""
|
|
20
|
+
if use_normalization:
|
|
21
|
+
return ChnNormedText(raw_text=user_input).normalize()
|
|
22
|
+
else:
|
|
23
|
+
return user_input
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def wav_chunk_header(
|
|
27
|
+
sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
|
|
28
|
+
) -> bytes:
|
|
29
|
+
buffer = io.BytesIO()
|
|
30
|
+
|
|
31
|
+
with wave.open(buffer, "wb") as wav_file:
|
|
32
|
+
wav_file.setnchannels(channels)
|
|
33
|
+
wav_file.setsampwidth(bit_depth // 8)
|
|
34
|
+
wav_file.setframerate(sample_rate)
|
|
35
|
+
|
|
36
|
+
wav_header_bytes = buffer.getvalue()
|
|
37
|
+
buffer.close()
|
|
38
|
+
|
|
39
|
+
return wav_header_bytes
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from loguru import logger
|
|
5
|
+
|
|
6
|
+
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class VQManager:
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
# Make Pylance happy (attribut/method not defined...)
|
|
13
|
+
self.decoder_model: FireflyArchitecture
|
|
14
|
+
self.load_audio: Callable
|
|
15
|
+
|
|
16
|
+
def decode_vq_tokens(self, codes):
|
|
17
|
+
feature_lengths = torch.tensor(
|
|
18
|
+
[codes.shape[1]], device=self.decoder_model.device
|
|
19
|
+
)
|
|
20
|
+
logger.info(f"VQ features: {codes.shape}")
|
|
21
|
+
|
|
22
|
+
if isinstance(self.decoder_model, FireflyArchitecture):
|
|
23
|
+
return self.decoder_model.decode(
|
|
24
|
+
indices=codes[None],
|
|
25
|
+
feature_lengths=feature_lengths,
|
|
26
|
+
)[0].squeeze()
|
|
27
|
+
|
|
28
|
+
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
|
|
29
|
+
|
|
30
|
+
def encode_reference(self, reference_audio, enable_reference_audio):
|
|
31
|
+
if enable_reference_audio and reference_audio is not None:
|
|
32
|
+
# Load audios, and prepare basic info here
|
|
33
|
+
reference_audio_content = self.load_audio(
|
|
34
|
+
reference_audio, self.decoder_model.spec_transform.sample_rate
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
audios = torch.from_numpy(reference_audio_content).to(
|
|
38
|
+
self.decoder_model.device
|
|
39
|
+
)[None, None, :]
|
|
40
|
+
audio_lengths = torch.tensor(
|
|
41
|
+
[audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
|
|
42
|
+
)
|
|
43
|
+
logger.info(
|
|
44
|
+
f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# VQ Encoder
|
|
48
|
+
if isinstance(self.decoder_model, FireflyArchitecture):
|
|
49
|
+
prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
|
|
50
|
+
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
|
|
53
|
+
else:
|
|
54
|
+
prompt_tokens = None
|
|
55
|
+
logger.info("No reference audio provided")
|
|
56
|
+
|
|
57
|
+
return prompt_tokens
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
|
|
1
|
+
import pyrootutils
|
|
2
2
|
import torch
|
|
3
3
|
import torch.nn.functional as F
|
|
4
4
|
from matplotlib import pyplot as plt
|
|
5
5
|
from transformers import AutoTokenizer
|
|
6
6
|
|
|
7
7
|
# register eval resolver and root
|
|
8
|
-
|
|
8
|
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
9
9
|
|
|
10
10
|
from torch.utils.data import DataLoader
|
|
11
11
|
|
|
@@ -17,9 +17,16 @@ from loguru import logger
|
|
|
17
17
|
from tqdm import tqdm
|
|
18
18
|
from transformers import AutoTokenizer
|
|
19
19
|
|
|
20
|
-
from fish_speech.conversation import
|
|
20
|
+
from fish_speech.conversation import (
|
|
21
|
+
CODEBOOK_PAD_TOKEN_ID,
|
|
22
|
+
Conversation,
|
|
23
|
+
Message,
|
|
24
|
+
TextPart,
|
|
25
|
+
VQPart,
|
|
26
|
+
)
|
|
21
27
|
from fish_speech.models.text2semantic.llama import BaseModelArgs
|
|
22
28
|
from fish_speech.text import clean_text, split_text
|
|
29
|
+
from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
|
|
23
30
|
|
|
24
31
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
25
32
|
torch._inductor.config.coordinate_descent_tuning = True
|
|
@@ -145,8 +152,8 @@ def decode_one_token_ar_agent(
|
|
|
145
152
|
model: DualARTransformer,
|
|
146
153
|
x: torch.Tensor,
|
|
147
154
|
input_pos: torch.Tensor,
|
|
155
|
+
semantic_ids: list,
|
|
148
156
|
previous_tokens: torch.Tensor = None,
|
|
149
|
-
semantic_id: int = 32003,
|
|
150
157
|
**sampling_kwargs,
|
|
151
158
|
) -> torch.Tensor:
|
|
152
159
|
# print(x, input_pos)
|
|
@@ -190,19 +197,13 @@ def decode_one_token_ar_agent(
|
|
|
190
197
|
codebooks.append(a)
|
|
191
198
|
|
|
192
199
|
codebooks = torch.stack(codebooks, dim=1)
|
|
200
|
+
semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
|
|
193
201
|
codebooks[:, 1:, :] = torch.masked_fill(
|
|
194
|
-
codebooks[:, 1:, :],
|
|
202
|
+
codebooks[:, 1:, :],
|
|
203
|
+
~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
|
|
204
|
+
CODEBOOK_PAD_TOKEN_ID,
|
|
195
205
|
)
|
|
196
206
|
|
|
197
|
-
# for i in range(codebooks.size(1) - 1):
|
|
198
|
-
# codebooks[:, i + 1, :] = torch.masked_fill(
|
|
199
|
-
# codebooks[:, i + 1, :],
|
|
200
|
-
# codebooks[:, :1, :] != semantic_id,
|
|
201
|
-
# CODEBOOK_PAD_TOKEN_ID + i * 1024,
|
|
202
|
-
# )
|
|
203
|
-
|
|
204
|
-
# print(codebooks)
|
|
205
|
-
|
|
206
207
|
return codebooks
|
|
207
208
|
|
|
208
209
|
|
|
@@ -210,8 +211,8 @@ def decode_one_token_naive_agent(
|
|
|
210
211
|
model: NaiveTransformer,
|
|
211
212
|
x: torch.Tensor,
|
|
212
213
|
input_pos: torch.Tensor,
|
|
214
|
+
semantic_ids: list,
|
|
213
215
|
previous_tokens: torch.Tensor = None,
|
|
214
|
-
semantic_id: int = 32003,
|
|
215
216
|
**sampling_kwargs,
|
|
216
217
|
) -> torch.Tensor:
|
|
217
218
|
x = model.forward_generate(x, input_pos)
|
|
@@ -236,8 +237,11 @@ def decode_one_token_naive_agent(
|
|
|
236
237
|
)
|
|
237
238
|
|
|
238
239
|
codebooks = torch.stack(codebooks, dim=1)
|
|
240
|
+
semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
|
|
239
241
|
codebooks[:, 1:, :] = torch.masked_fill(
|
|
240
|
-
codebooks[:, 1:, :],
|
|
242
|
+
codebooks[:, 1:, :],
|
|
243
|
+
~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
|
|
244
|
+
CODEBOOK_PAD_TOKEN_ID,
|
|
241
245
|
)
|
|
242
246
|
|
|
243
247
|
return codebooks
|
|
@@ -247,8 +251,8 @@ def decode_one_token_ar(
|
|
|
247
251
|
model: DualARTransformer,
|
|
248
252
|
x: torch.Tensor,
|
|
249
253
|
input_pos: torch.Tensor,
|
|
254
|
+
semantic_ids: list,
|
|
250
255
|
previous_tokens: torch.Tensor = None,
|
|
251
|
-
semantic_id: int = 0,
|
|
252
256
|
**sampling_kwargs,
|
|
253
257
|
) -> torch.Tensor:
|
|
254
258
|
x = model.forward_generate(x, input_pos)
|
|
@@ -261,21 +265,32 @@ def decode_one_token_ar(
|
|
|
261
265
|
codebooks = [
|
|
262
266
|
sample(
|
|
263
267
|
x.logits,
|
|
264
|
-
previous_tokens=
|
|
268
|
+
previous_tokens=(
|
|
269
|
+
previous_tokens[0] if previous_tokens is not None else None
|
|
270
|
+
), # Disable repetition penalty for the token codebook
|
|
265
271
|
**sampling_kwargs_main,
|
|
266
272
|
)[0]
|
|
267
273
|
]
|
|
268
274
|
|
|
269
|
-
|
|
275
|
+
hidden_states = x.hidden_states
|
|
270
276
|
|
|
271
277
|
# Cleanup the cache
|
|
272
278
|
for layer in model.fast_layers:
|
|
273
279
|
layer.attention.kv_cache.k_cache.fill_(0)
|
|
274
280
|
layer.attention.kv_cache.v_cache.fill_(0)
|
|
275
281
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
282
|
+
input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
|
|
283
|
+
model.forward_generate_fast(hidden_states, input_pos)
|
|
284
|
+
a = codebooks[0] - model.tokenizer.semantic_begin_id
|
|
285
|
+
a[a < 0] = 0
|
|
286
|
+
hidden_states = model.fast_embeddings(a)
|
|
287
|
+
codebooks.append(a)
|
|
288
|
+
|
|
289
|
+
for codebook_idx in range(1, model.config.num_codebooks):
|
|
290
|
+
input_pos = torch.tensor(
|
|
291
|
+
[codebook_idx], device=hidden_states.device, dtype=torch.long
|
|
292
|
+
)
|
|
293
|
+
logits = model.forward_generate_fast(hidden_states, input_pos)
|
|
279
294
|
a = sample(
|
|
280
295
|
logits,
|
|
281
296
|
previous_tokens=(
|
|
@@ -285,14 +300,16 @@ def decode_one_token_ar(
|
|
|
285
300
|
),
|
|
286
301
|
**sampling_kwargs,
|
|
287
302
|
)[0]
|
|
288
|
-
|
|
303
|
+
hidden_states = model.fast_embeddings(a)
|
|
289
304
|
codebooks.append(a)
|
|
290
305
|
|
|
291
306
|
codebooks = torch.stack(codebooks, dim=0)
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
)
|
|
307
|
+
# semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
|
|
308
|
+
# codebooks[1:, :] = torch.masked_fill(
|
|
309
|
+
# codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
|
|
310
|
+
# )
|
|
295
311
|
|
|
312
|
+
# print(codebooks)
|
|
296
313
|
return codebooks
|
|
297
314
|
|
|
298
315
|
|
|
@@ -337,9 +354,8 @@ def decode_n_tokens(
|
|
|
337
354
|
cur_token: torch.Tensor,
|
|
338
355
|
input_pos: torch.Tensor,
|
|
339
356
|
num_new_tokens: int,
|
|
340
|
-
|
|
357
|
+
semantic_ids: list,
|
|
341
358
|
decode_one_token=decode_one_token_naive,
|
|
342
|
-
semantic_id: int = 0,
|
|
343
359
|
**sampling_kwargs,
|
|
344
360
|
):
|
|
345
361
|
previous_tokens = torch.zeros(
|
|
@@ -368,7 +384,7 @@ def decode_n_tokens(
|
|
|
368
384
|
x=cur_token,
|
|
369
385
|
input_pos=input_pos,
|
|
370
386
|
previous_tokens=window,
|
|
371
|
-
|
|
387
|
+
semantic_ids=semantic_ids,
|
|
372
388
|
**sampling_kwargs,
|
|
373
389
|
)
|
|
374
390
|
|
|
@@ -378,7 +394,7 @@ def decode_n_tokens(
|
|
|
378
394
|
model.config.num_codebooks + 1, -1
|
|
379
395
|
)
|
|
380
396
|
|
|
381
|
-
if cur_token[0, 0, -1] ==
|
|
397
|
+
if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
|
|
382
398
|
break
|
|
383
399
|
|
|
384
400
|
return previous_tokens[:, : i + 1]
|
|
@@ -391,7 +407,6 @@ def generate(
|
|
|
391
407
|
model: NaiveTransformer,
|
|
392
408
|
prompt: torch.Tensor,
|
|
393
409
|
max_new_tokens: int,
|
|
394
|
-
im_end_id: int = 4,
|
|
395
410
|
decode_one_token=decode_one_token_naive,
|
|
396
411
|
**sampling_kwargs,
|
|
397
412
|
) -> torch.Tensor:
|
|
@@ -401,7 +416,10 @@ def generate(
|
|
|
401
416
|
|
|
402
417
|
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
403
418
|
T = prompt.size(1)
|
|
404
|
-
semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
|
419
|
+
# semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
|
420
|
+
semantic_ids = [
|
|
421
|
+
model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
|
|
422
|
+
]
|
|
405
423
|
|
|
406
424
|
if max_new_tokens:
|
|
407
425
|
if T + max_new_tokens > model.config.max_seq_len:
|
|
@@ -435,7 +453,7 @@ def generate(
|
|
|
435
453
|
model,
|
|
436
454
|
prompt.view(1, codebook_dim, -1),
|
|
437
455
|
input_pos,
|
|
438
|
-
|
|
456
|
+
semantic_ids=semantic_ids,
|
|
439
457
|
**sampling_kwargs,
|
|
440
458
|
)
|
|
441
459
|
seq[:, T : T + 1] = next_token
|
|
@@ -446,9 +464,8 @@ def generate(
|
|
|
446
464
|
next_token.view(1, codebook_dim, -1),
|
|
447
465
|
input_pos,
|
|
448
466
|
max_new_tokens - 1,
|
|
449
|
-
im_end_id=im_end_id,
|
|
450
467
|
decode_one_token=decode_one_token,
|
|
451
|
-
|
|
468
|
+
semantic_ids=semantic_ids,
|
|
452
469
|
**sampling_kwargs,
|
|
453
470
|
)
|
|
454
471
|
# x = torch.cat(generated_tokens, dim=1)
|
|
@@ -463,8 +480,8 @@ def decode_n_tokens_agent(
|
|
|
463
480
|
cur_token: torch.Tensor,
|
|
464
481
|
input_pos: torch.Tensor,
|
|
465
482
|
num_new_tokens: int,
|
|
483
|
+
semantic_ids: list,
|
|
466
484
|
im_end_id: int = 4,
|
|
467
|
-
semantic_id: int = 32003,
|
|
468
485
|
decode_one_token=decode_one_token_naive_agent,
|
|
469
486
|
early_stop_threshold: float = 0.6,
|
|
470
487
|
**sampling_kwargs,
|
|
@@ -495,7 +512,7 @@ def decode_n_tokens_agent(
|
|
|
495
512
|
x=cur_token,
|
|
496
513
|
input_pos=input_pos,
|
|
497
514
|
previous_tokens=window,
|
|
498
|
-
|
|
515
|
+
semantic_ids=semantic_ids,
|
|
499
516
|
**sampling_kwargs,
|
|
500
517
|
)
|
|
501
518
|
|
|
@@ -529,8 +546,8 @@ def generate_agent(
|
|
|
529
546
|
model: BaseTransformer,
|
|
530
547
|
prompt: torch.Tensor,
|
|
531
548
|
max_new_tokens: int,
|
|
549
|
+
semantic_ids: list,
|
|
532
550
|
im_end_id: int = 4,
|
|
533
|
-
semantic_id: int = 32003,
|
|
534
551
|
decode_one_token=decode_one_token_naive_agent,
|
|
535
552
|
num_samples: int = 1,
|
|
536
553
|
early_stop_threshold: float = 0.6,
|
|
@@ -574,7 +591,7 @@ def generate_agent(
|
|
|
574
591
|
model,
|
|
575
592
|
prompt,
|
|
576
593
|
input_pos,
|
|
577
|
-
|
|
594
|
+
semantic_ids=semantic_ids,
|
|
578
595
|
**sampling_kwargs,
|
|
579
596
|
).view(num_samples, codebook_dim, -1)
|
|
580
597
|
yield next_token.cpu()
|
|
@@ -587,7 +604,7 @@ def generate_agent(
|
|
|
587
604
|
input_pos,
|
|
588
605
|
max_new_tokens - 1,
|
|
589
606
|
im_end_id=im_end_id,
|
|
590
|
-
|
|
607
|
+
semantic_ids=semantic_ids,
|
|
591
608
|
decode_one_token=decode_one_token,
|
|
592
609
|
early_stop_threshold=early_stop_threshold,
|
|
593
610
|
**sampling_kwargs,
|
|
@@ -602,65 +619,63 @@ def encode_tokens(
|
|
|
602
619
|
num_codebooks=4,
|
|
603
620
|
):
|
|
604
621
|
string = clean_text(string)
|
|
605
|
-
string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
|
|
606
622
|
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
623
|
+
messages = []
|
|
624
|
+
messages.append(
|
|
625
|
+
Message(
|
|
626
|
+
role="user",
|
|
627
|
+
parts=[TextPart(text=string)],
|
|
628
|
+
cal_loss=False,
|
|
629
|
+
)
|
|
612
630
|
)
|
|
613
|
-
tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
|
|
614
631
|
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
632
|
+
if prompt_tokens is not None:
|
|
633
|
+
if prompt_tokens.ndim == 3:
|
|
634
|
+
assert (
|
|
635
|
+
prompt_tokens.shape[0] == 1
|
|
636
|
+
), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
|
|
637
|
+
prompt_tokens = prompt_tokens[0]
|
|
621
638
|
|
|
622
|
-
|
|
623
|
-
return prompt
|
|
639
|
+
assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
|
|
624
640
|
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
prompt_tokens = prompt_tokens[0]
|
|
641
|
+
if prompt_tokens.shape[0] > num_codebooks:
|
|
642
|
+
logger.warning(
|
|
643
|
+
f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
|
|
644
|
+
)
|
|
645
|
+
prompt_tokens = prompt_tokens[:num_codebooks]
|
|
631
646
|
|
|
632
|
-
|
|
633
|
-
data = prompt_tokens + 1
|
|
647
|
+
vq_part = VQPart(codes=prompt_tokens.to(device))
|
|
634
648
|
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
649
|
+
messages.append(
|
|
650
|
+
Message(
|
|
651
|
+
role="assistant",
|
|
652
|
+
parts=[TextPart(text="<|voice|>"), vq_part],
|
|
653
|
+
cal_loss=False,
|
|
654
|
+
)
|
|
655
|
+
)
|
|
656
|
+
else:
|
|
657
|
+
messages.append(
|
|
658
|
+
Message(
|
|
659
|
+
role="assistant",
|
|
660
|
+
parts=[TextPart(text="<|voice|>")],
|
|
661
|
+
cal_loss=False,
|
|
662
|
+
add_im_end=False,
|
|
663
|
+
)
|
|
638
664
|
)
|
|
639
|
-
data = data[:num_codebooks]
|
|
640
|
-
|
|
641
|
-
# Add pad token for each codebook
|
|
642
|
-
data = torch.cat(
|
|
643
|
-
(data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
|
|
644
|
-
dim=1,
|
|
645
|
-
)
|
|
646
665
|
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
666
|
+
conversation = Conversation(messages=messages)
|
|
667
|
+
# conversation.visualize(tokenizer)
|
|
668
|
+
encoded = conversation.encode_for_inference(
|
|
669
|
+
tokenizer=tokenizer,
|
|
670
|
+
num_codebooks=num_codebooks,
|
|
652
671
|
)
|
|
653
|
-
main_token_ids[0, -1] = end_token_id
|
|
654
|
-
|
|
655
|
-
data = torch.cat((main_token_ids, data), dim=0)
|
|
656
|
-
prompt = torch.cat((prompt, data), dim=1)
|
|
657
672
|
|
|
658
|
-
return
|
|
673
|
+
return encoded.to(device)
|
|
659
674
|
|
|
660
675
|
|
|
661
676
|
def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
|
|
662
677
|
model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
|
|
663
|
-
checkpoint_path, load_weights=True
|
|
678
|
+
checkpoint_path, load_weights=True, is_agent=is_agent
|
|
664
679
|
)
|
|
665
680
|
|
|
666
681
|
model = model.to(device=device, dtype=precision)
|
|
@@ -729,11 +744,26 @@ def generate_long(
|
|
|
729
744
|
|
|
730
745
|
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
731
746
|
tokenizer = model.tokenizer
|
|
732
|
-
im_end_id = tokenizer.
|
|
747
|
+
im_end_id = tokenizer.get_token_id("<|im_end|>")
|
|
733
748
|
|
|
734
749
|
encoded = []
|
|
735
750
|
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
|
736
|
-
encoded_prompts = [
|
|
751
|
+
encoded_prompts = [
|
|
752
|
+
Conversation(
|
|
753
|
+
messages=[
|
|
754
|
+
Message(
|
|
755
|
+
role="system",
|
|
756
|
+
parts=[TextPart(text="Speak out the provided text.")],
|
|
757
|
+
cal_loss=False,
|
|
758
|
+
)
|
|
759
|
+
]
|
|
760
|
+
)
|
|
761
|
+
.encode_for_inference(
|
|
762
|
+
tokenizer=tokenizer,
|
|
763
|
+
num_codebooks=model.config.num_codebooks,
|
|
764
|
+
)
|
|
765
|
+
.to(device)
|
|
766
|
+
]
|
|
737
767
|
|
|
738
768
|
if use_prompt:
|
|
739
769
|
for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
|
|
@@ -812,7 +842,6 @@ def generate_long(
|
|
|
812
842
|
model=model,
|
|
813
843
|
prompt=cat_encoded,
|
|
814
844
|
max_new_tokens=max_new_tokens,
|
|
815
|
-
im_end_id=im_end_id,
|
|
816
845
|
decode_one_token=decode_one_token,
|
|
817
846
|
temperature=temperature,
|
|
818
847
|
top_p=top_p,
|
|
@@ -842,12 +871,11 @@ def generate_long(
|
|
|
842
871
|
)
|
|
843
872
|
|
|
844
873
|
# Put the generated tokens
|
|
845
|
-
# since there is <im_end
|
|
846
|
-
codes = y[1:, prompt_length
|
|
847
|
-
codes = codes - 1
|
|
874
|
+
# since there is <im_end>, we remove last token
|
|
875
|
+
codes = y[1:, prompt_length + 1 :].clone()
|
|
848
876
|
assert (codes >= 0).all(), f"Negative code found"
|
|
849
877
|
|
|
850
|
-
decoded = y[:, prompt_length
|
|
878
|
+
decoded = y[:, prompt_length:].clone()
|
|
851
879
|
# But for global encoding, we should keep the <im_end> token
|
|
852
880
|
|
|
853
881
|
global_encoded.append(decoded)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from argparse import ArgumentParser
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import pyrootutils
|
|
6
|
+
import torch
|
|
7
|
+
from loguru import logger
|
|
8
|
+
|
|
9
|
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
10
|
+
|
|
11
|
+
from tools.inference_engine import TTSInferenceEngine
|
|
12
|
+
from tools.llama.generate import launch_thread_safe_queue
|
|
13
|
+
from tools.schema import ServeTTSRequest
|
|
14
|
+
from tools.vqgan.inference import load_model as load_decoder_model
|
|
15
|
+
from tools.webui import build_app
|
|
16
|
+
from tools.webui.inference import get_inference_wrapper
|
|
17
|
+
|
|
18
|
+
# Make einx happy
|
|
19
|
+
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def parse_args():
|
|
23
|
+
parser = ArgumentParser()
|
|
24
|
+
parser.add_argument(
|
|
25
|
+
"--llama-checkpoint-path",
|
|
26
|
+
type=Path,
|
|
27
|
+
default="checkpoints/fish-speech-1.5",
|
|
28
|
+
)
|
|
29
|
+
parser.add_argument(
|
|
30
|
+
"--decoder-checkpoint-path",
|
|
31
|
+
type=Path,
|
|
32
|
+
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
|
33
|
+
)
|
|
34
|
+
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
|
35
|
+
parser.add_argument("--device", type=str, default="cuda")
|
|
36
|
+
parser.add_argument("--half", action="store_true")
|
|
37
|
+
parser.add_argument("--compile", action="store_true")
|
|
38
|
+
parser.add_argument("--max-gradio-length", type=int, default=0)
|
|
39
|
+
parser.add_argument("--theme", type=str, default="light")
|
|
40
|
+
|
|
41
|
+
return parser.parse_args()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
if __name__ == "__main__":
|
|
45
|
+
args = parse_args()
|
|
46
|
+
args.precision = torch.half if args.half else torch.bfloat16
|
|
47
|
+
|
|
48
|
+
# Check if MPS or CUDA is available
|
|
49
|
+
if torch.backends.mps.is_available():
|
|
50
|
+
args.device = "mps"
|
|
51
|
+
logger.info("mps is available, running on mps.")
|
|
52
|
+
elif not torch.cuda.is_available():
|
|
53
|
+
logger.info("CUDA is not available, running on CPU.")
|
|
54
|
+
args.device = "cpu"
|
|
55
|
+
|
|
56
|
+
logger.info("Loading Llama model...")
|
|
57
|
+
llama_queue = launch_thread_safe_queue(
|
|
58
|
+
checkpoint_path=args.llama_checkpoint_path,
|
|
59
|
+
device=args.device,
|
|
60
|
+
precision=args.precision,
|
|
61
|
+
compile=args.compile,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
logger.info("Loading VQ-GAN model...")
|
|
65
|
+
decoder_model = load_decoder_model(
|
|
66
|
+
config_name=args.decoder_config_name,
|
|
67
|
+
checkpoint_path=args.decoder_checkpoint_path,
|
|
68
|
+
device=args.device,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
logger.info("Decoder model loaded, warming up...")
|
|
72
|
+
|
|
73
|
+
# Create the inference engine
|
|
74
|
+
inference_engine = TTSInferenceEngine(
|
|
75
|
+
llama_queue=llama_queue,
|
|
76
|
+
decoder_model=decoder_model,
|
|
77
|
+
compile=args.compile,
|
|
78
|
+
precision=args.precision,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
|
82
|
+
list(
|
|
83
|
+
inference_engine.inference(
|
|
84
|
+
ServeTTSRequest(
|
|
85
|
+
text="Hello world.",
|
|
86
|
+
references=[],
|
|
87
|
+
reference_id=None,
|
|
88
|
+
max_new_tokens=1024,
|
|
89
|
+
chunk_length=200,
|
|
90
|
+
top_p=0.7,
|
|
91
|
+
repetition_penalty=1.5,
|
|
92
|
+
temperature=0.7,
|
|
93
|
+
format="wav",
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
logger.info("Warming up done, launching the web UI...")
|
|
99
|
+
|
|
100
|
+
# Get the inference function with the immutable arguments
|
|
101
|
+
inference_fct = get_inference_wrapper(inference_engine)
|
|
102
|
+
|
|
103
|
+
app = build_app(inference_fct, args.theme)
|
|
104
|
+
app.launch(show_api=True)
|