xinference 0.16.3__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +24 -2
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +219 -77
- xinference/client/restful/restful_client.py +47 -2
- xinference/constants.py +1 -0
- xinference/core/chat_interface.py +6 -1
- xinference/core/model.py +124 -34
- xinference/core/supervisor.py +180 -12
- xinference/core/utils.py +73 -4
- xinference/core/worker.py +102 -4
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/__init__.py +12 -0
- xinference/model/audio/core.py +37 -4
- xinference/model/audio/cosyvoice.py +39 -6
- xinference/model/audio/f5tts.py +200 -0
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +70 -110
- xinference/model/audio/melotts.py +110 -0
- xinference/model/audio/model_spec.json +179 -3
- xinference/model/audio/model_spec_modelscope.json +27 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/audio/whisper.py +35 -10
- xinference/model/audio/whisper_mlx.py +208 -0
- xinference/model/embedding/core.py +322 -6
- xinference/model/embedding/model_spec.json +8 -1
- xinference/model/embedding/model_spec_modelscope.json +9 -1
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +145 -4
- xinference/model/image/model_spec_modelscope.json +150 -4
- xinference/model/image/stable_diffusion/core.py +50 -15
- xinference/model/llm/__init__.py +6 -2
- xinference/model/llm/llm_family.json +1055 -93
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +1031 -78
- xinference/model/llm/memory.py +1 -1
- xinference/model/llm/mlx/core.py +285 -47
- xinference/model/llm/sglang/core.py +2 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/cogagent.py +272 -0
- xinference/model/llm/transformers/core.py +3 -0
- xinference/model/llm/transformers/glm_edge_v.py +230 -0
- xinference/model/llm/transformers/qwen2_vl.py +12 -1
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +55 -4
- xinference/model/llm/vllm/core.py +137 -12
- xinference/model/llm/vllm/xavier/__init__.py +13 -0
- xinference/model/llm/vllm/xavier/allocator.py +74 -0
- xinference/model/llm/vllm/xavier/block.py +111 -0
- xinference/model/llm/vllm/xavier/block_manager.py +71 -0
- xinference/model/llm/vllm/xavier/block_tracker.py +129 -0
- xinference/model/llm/vllm/xavier/collective.py +74 -0
- xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
- xinference/model/llm/vllm/xavier/engine.py +247 -0
- xinference/model/llm/vllm/xavier/executor.py +134 -0
- xinference/model/llm/vllm/xavier/scheduler.py +438 -0
- xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
- xinference/model/llm/vllm/xavier/test/test_xavier.py +147 -0
- xinference/model/llm/vllm/xavier/transfer.py +319 -0
- xinference/model/rerank/core.py +11 -4
- xinference/model/video/diffusers.py +14 -0
- xinference/model/video/model_spec.json +15 -0
- xinference/model/video/model_spec_modelscope.json +16 -0
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +266 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +137 -29
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +17 -11
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +2 -2
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +34 -18
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +484 -72
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +170 -0
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- xinference/thirdparty/melo/api.py +135 -0
- xinference/thirdparty/melo/app.py +61 -0
- xinference/thirdparty/melo/attentions.py +459 -0
- xinference/thirdparty/melo/commons.py +160 -0
- xinference/thirdparty/melo/configs/config.json +94 -0
- xinference/thirdparty/melo/data/example/metadata.list +20 -0
- xinference/thirdparty/melo/data_utils.py +413 -0
- xinference/thirdparty/melo/download_utils.py +67 -0
- xinference/thirdparty/melo/infer.py +25 -0
- xinference/thirdparty/melo/init_downloads.py +14 -0
- xinference/thirdparty/melo/losses.py +58 -0
- xinference/thirdparty/melo/main.py +36 -0
- xinference/thirdparty/melo/mel_processing.py +174 -0
- xinference/thirdparty/melo/models.py +1030 -0
- xinference/thirdparty/melo/modules.py +598 -0
- xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
- xinference/thirdparty/melo/monotonic_align/core.py +46 -0
- xinference/thirdparty/melo/preprocess_text.py +135 -0
- xinference/thirdparty/melo/split_utils.py +174 -0
- xinference/thirdparty/melo/text/__init__.py +35 -0
- xinference/thirdparty/melo/text/chinese.py +199 -0
- xinference/thirdparty/melo/text/chinese_bert.py +107 -0
- xinference/thirdparty/melo/text/chinese_mix.py +253 -0
- xinference/thirdparty/melo/text/cleaner.py +36 -0
- xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
- xinference/thirdparty/melo/text/cmudict.rep +129530 -0
- xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
- xinference/thirdparty/melo/text/english.py +284 -0
- xinference/thirdparty/melo/text/english_bert.py +39 -0
- xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
- xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
- xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
- xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
- xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
- xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
- xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
- xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
- xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
- xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
- xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
- xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
- xinference/thirdparty/melo/text/french.py +94 -0
- xinference/thirdparty/melo/text/french_bert.py +39 -0
- xinference/thirdparty/melo/text/japanese.py +647 -0
- xinference/thirdparty/melo/text/japanese_bert.py +49 -0
- xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
- xinference/thirdparty/melo/text/korean.py +192 -0
- xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
- xinference/thirdparty/melo/text/spanish.py +122 -0
- xinference/thirdparty/melo/text/spanish_bert.py +39 -0
- xinference/thirdparty/melo/text/symbols.py +290 -0
- xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
- xinference/thirdparty/melo/train.py +635 -0
- xinference/thirdparty/melo/train.sh +19 -0
- xinference/thirdparty/melo/transforms.py +209 -0
- xinference/thirdparty/melo/utils.py +424 -0
- xinference/types.py +17 -1
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
- xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
- xinference/web/ui/build/static/js/main.b0936c54.js +3 -0
- xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +67 -3
- xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
- xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
- xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
- xinference/web/ui/node_modules/i18next/package.json +129 -0
- xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
- xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
- xinference/web/ui/node_modules/react-i18next/package.json +162 -0
- xinference/web/ui/node_modules/void-elements/package.json +34 -0
- xinference/web/ui/package-lock.json +69 -3
- xinference/web/ui/package.json +2 -0
- xinference/web/ui/src/locales/en.json +186 -0
- xinference/web/ui/src/locales/zh.json +186 -0
- {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/METADATA +96 -36
- {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/RECORD +335 -146
- {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/WHEEL +1 -1
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -440
- xinference/thirdparty/fish_speech/tools/commons.py +0 -35
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -34
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/webui.py +0 -485
- xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
- xinference/web/ui/build/static/js/main.2f269bb3.js +0 -3
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
- /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
- /xinference/thirdparty/{cosyvoice/flow → melo}/__init__.py +0 -0
- /xinference/thirdparty/{cosyvoice/hifigan → melo/text/english_utils}/__init__.py +0 -0
- /xinference/thirdparty/{cosyvoice/llm → melo/text/es_phonemizer}/__init__.py +0 -0
- /xinference/thirdparty/{fish_speech/fish_speech/configs → melo/text/fr_phonemizer}/__init__.py +0 -0
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
- {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/LICENSE +0 -0
- {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.16.3.dist-info → xinference-1.2.1.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import dataclasses
|
|
1
2
|
import json
|
|
2
3
|
import math
|
|
3
4
|
from collections import OrderedDict
|
|
@@ -15,7 +16,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
|
15
16
|
from torch.utils.checkpoint import checkpoint
|
|
16
17
|
from transformers import AutoTokenizer
|
|
17
18
|
|
|
18
|
-
from fish_speech.
|
|
19
|
+
from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
|
|
19
20
|
from fish_speech.utils import RankedLogger
|
|
20
21
|
|
|
21
22
|
from .lora import LoraConfig, setup_lora
|
|
@@ -57,6 +58,11 @@ class BaseModelArgs:
|
|
|
57
58
|
# Initialize the model
|
|
58
59
|
initializer_range: float = 0.02
|
|
59
60
|
|
|
61
|
+
# Dummy vars
|
|
62
|
+
is_reward_model: bool = False
|
|
63
|
+
share_codebook_embeddings: bool = True
|
|
64
|
+
scale_codebook_embeddings: bool = False
|
|
65
|
+
|
|
60
66
|
def __post_init__(self):
|
|
61
67
|
if self.n_local_heads == -1:
|
|
62
68
|
self.n_local_heads = self.n_head
|
|
@@ -100,6 +106,28 @@ class NaiveModelArgs(BaseModelArgs):
|
|
|
100
106
|
class DualARModelArgs(BaseModelArgs):
|
|
101
107
|
model_type: str = "dual_ar"
|
|
102
108
|
n_fast_layer: int = 4
|
|
109
|
+
fast_dim: int | None = None
|
|
110
|
+
fast_n_head: int | None = None
|
|
111
|
+
fast_n_local_heads: int | None = None
|
|
112
|
+
fast_head_dim: int | None = None
|
|
113
|
+
fast_intermediate_size: int | None = None
|
|
114
|
+
fast_attention_qkv_bias: bool | None = None
|
|
115
|
+
|
|
116
|
+
def __post_init__(self):
|
|
117
|
+
super().__post_init__()
|
|
118
|
+
|
|
119
|
+
self.fast_dim = self.fast_dim or self.dim
|
|
120
|
+
self.fast_n_head = self.fast_n_head or self.n_head
|
|
121
|
+
self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
|
|
122
|
+
self.fast_head_dim = self.fast_head_dim or self.head_dim
|
|
123
|
+
self.fast_intermediate_size = (
|
|
124
|
+
self.fast_intermediate_size or self.intermediate_size
|
|
125
|
+
)
|
|
126
|
+
self.fast_attention_qkv_bias = (
|
|
127
|
+
self.fast_attention_qkv_bias
|
|
128
|
+
if self.fast_attention_qkv_bias is not None
|
|
129
|
+
else self.attention_qkv_bias
|
|
130
|
+
)
|
|
103
131
|
|
|
104
132
|
|
|
105
133
|
class KVCache(nn.Module):
|
|
@@ -137,13 +165,17 @@ class BaseTransformerForwardResult:
|
|
|
137
165
|
|
|
138
166
|
class BaseTransformer(nn.Module):
|
|
139
167
|
def __init__(
|
|
140
|
-
self,
|
|
168
|
+
self,
|
|
169
|
+
config: BaseModelArgs,
|
|
170
|
+
tokenizer: FishTokenizer | AutoTokenizer,
|
|
171
|
+
init_weights: bool = True,
|
|
141
172
|
) -> None:
|
|
142
173
|
super().__init__()
|
|
143
174
|
self.config = config
|
|
144
175
|
self.tokenizer = tokenizer
|
|
145
|
-
|
|
146
|
-
|
|
176
|
+
self.semantic_token_ids = [
|
|
177
|
+
tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS
|
|
178
|
+
]
|
|
147
179
|
|
|
148
180
|
# Slow transformer
|
|
149
181
|
self.embeddings = nn.Embedding(
|
|
@@ -218,8 +250,10 @@ class BaseTransformer(nn.Module):
|
|
|
218
250
|
vocab_embeds = [self.embeddings(x[:, 0])]
|
|
219
251
|
for i in range(self.config.num_codebooks):
|
|
220
252
|
emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
|
|
221
|
-
|
|
222
|
-
|
|
253
|
+
semantic_token_ids_tensor = torch.tensor(
|
|
254
|
+
self.semantic_token_ids, device=x.device
|
|
255
|
+
)
|
|
256
|
+
emb[~torch.isin(x[:, 0], semantic_token_ids_tensor)] = 0
|
|
223
257
|
|
|
224
258
|
x = torch.stack(vocab_embeds, dim=3)
|
|
225
259
|
x = x.sum(dim=3)
|
|
@@ -267,20 +301,45 @@ class BaseTransformer(nn.Module):
|
|
|
267
301
|
|
|
268
302
|
def forward_generate(
|
|
269
303
|
self,
|
|
270
|
-
|
|
304
|
+
inp: Tensor,
|
|
271
305
|
input_pos: Optional[Tensor] = None,
|
|
306
|
+
vq_masks: Optional[Tensor] = None, # this is not used in fact
|
|
272
307
|
return_all: bool = False,
|
|
273
308
|
) -> BaseTransformerForwardResult:
|
|
274
309
|
# This is used for generation, optimized for torch compile
|
|
275
|
-
assert (
|
|
276
|
-
|
|
277
|
-
), "Please call setup_caches before forward_generate"
|
|
310
|
+
# assert (
|
|
311
|
+
# self.max_seq_len != -1 and self.max_batch_size != -1
|
|
312
|
+
# ), "Please call setup_caches before forward_generate"
|
|
313
|
+
|
|
314
|
+
embeds = []
|
|
315
|
+
for i in range(self.config.num_codebooks):
|
|
316
|
+
if self.config.share_codebook_embeddings:
|
|
317
|
+
_tokens = inp[:, i + 1] + i * self.config.codebook_size
|
|
318
|
+
else:
|
|
319
|
+
_tokens = inp[:, i + 1]
|
|
278
320
|
|
|
279
|
-
|
|
321
|
+
emb = self.codebook_embeddings(_tokens)
|
|
322
|
+
embeds.append(emb)
|
|
280
323
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
324
|
+
vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
|
|
325
|
+
# if self.config.use_codebook_mlp:
|
|
326
|
+
# vq_embeds_sum = vq_embeds_sum / self.config.num_codebooks
|
|
327
|
+
# vq_embeds_sum = self.codebook_mlp(vq_embeds_sum)
|
|
328
|
+
|
|
329
|
+
vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & (
|
|
330
|
+
inp[:, 0] <= self.tokenizer.semantic_end_id
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
vq_embeds_sum[~vq_masks] = 0
|
|
334
|
+
x = self.embeddings(inp[:, 0]) + vq_embeds_sum
|
|
335
|
+
|
|
336
|
+
if input_pos is None:
|
|
337
|
+
input_pos = torch.arange(inp.shape[-1], device=x.device)
|
|
338
|
+
max_seq_len = inp.shape[-1]
|
|
339
|
+
else:
|
|
340
|
+
max_seq_len = self.max_seq_len
|
|
341
|
+
|
|
342
|
+
mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
|
|
284
343
|
freqs_cis = self.freqs_cis[input_pos]
|
|
285
344
|
|
|
286
345
|
for layer in self.layers:
|
|
@@ -293,7 +352,9 @@ class BaseTransformer(nn.Module):
|
|
|
293
352
|
# We got slow_out here
|
|
294
353
|
slow_out = self.norm(x)
|
|
295
354
|
|
|
296
|
-
if self.config.
|
|
355
|
+
if self.config.is_reward_model:
|
|
356
|
+
token_logits = self.score_output(slow_out)
|
|
357
|
+
elif self.config.tie_word_embeddings:
|
|
297
358
|
token_logits = F.linear(slow_out, self.embeddings.weight)
|
|
298
359
|
else:
|
|
299
360
|
token_logits = self.output(slow_out)
|
|
@@ -321,6 +382,7 @@ class BaseTransformer(nn.Module):
|
|
|
321
382
|
max_length: int | None = None,
|
|
322
383
|
lora_config: LoraConfig | None = None,
|
|
323
384
|
rope_base: int | None = None,
|
|
385
|
+
is_agent: bool = False,
|
|
324
386
|
) -> "BaseTransformer":
|
|
325
387
|
config = BaseModelArgs.from_pretrained(str(path))
|
|
326
388
|
if max_length is not None:
|
|
@@ -339,7 +401,12 @@ class BaseTransformer(nn.Module):
|
|
|
339
401
|
case _:
|
|
340
402
|
raise ValueError(f"Unknown model type: {config.model_type}")
|
|
341
403
|
|
|
342
|
-
|
|
404
|
+
if is_agent:
|
|
405
|
+
tokenizer = AutoTokenizer.from_pretrained(str(path))
|
|
406
|
+
else:
|
|
407
|
+
tokenizer_path = str(path) + "/tokenizer.tiktoken"
|
|
408
|
+
tokenizer = FishTokenizer(tokenizer_path)
|
|
409
|
+
|
|
343
410
|
log.info(f"Loading model from {path}, config: {config}")
|
|
344
411
|
model = model_cls(config, tokenizer=tokenizer)
|
|
345
412
|
|
|
@@ -369,7 +436,10 @@ class BaseTransformer(nn.Module):
|
|
|
369
436
|
model = simple_quantizer.convert_for_runtime()
|
|
370
437
|
|
|
371
438
|
weights = torch.load(
|
|
372
|
-
Path(path) / "model.pth",
|
|
439
|
+
Path(path) / "model.pth",
|
|
440
|
+
map_location="cpu",
|
|
441
|
+
mmap=True,
|
|
442
|
+
weights_only=True,
|
|
373
443
|
)
|
|
374
444
|
|
|
375
445
|
if "state_dict" in weights:
|
|
@@ -422,7 +492,7 @@ class BaseTransformer(nn.Module):
|
|
|
422
492
|
|
|
423
493
|
|
|
424
494
|
class NaiveTransformer(BaseTransformer):
|
|
425
|
-
def __init__(self, config: NaiveModelArgs, tokenizer:
|
|
495
|
+
def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
|
|
426
496
|
super().__init__(config, init_weights=False, tokenizer=tokenizer)
|
|
427
497
|
|
|
428
498
|
self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
@@ -468,23 +538,49 @@ class NaiveTransformer(BaseTransformer):
|
|
|
468
538
|
|
|
469
539
|
|
|
470
540
|
class DualARTransformer(BaseTransformer):
|
|
471
|
-
def __init__(self, config: NaiveModelArgs, tokenizer:
|
|
541
|
+
def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
|
|
472
542
|
super().__init__(config, init_weights=False, tokenizer=tokenizer)
|
|
473
543
|
|
|
544
|
+
# Project to fast dim if needed
|
|
545
|
+
if config.fast_dim is not None and config.fast_dim != config.dim:
|
|
546
|
+
self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
|
|
547
|
+
else:
|
|
548
|
+
self.fast_project_in = nn.Identity()
|
|
549
|
+
|
|
474
550
|
# Fast transformer
|
|
475
|
-
self.fast_embeddings = nn.Embedding(config.codebook_size, config.
|
|
551
|
+
self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
|
|
476
552
|
|
|
477
553
|
# The equivalent bs is so large that sdpa doesn't work
|
|
554
|
+
override_config = dataclasses.replace(
|
|
555
|
+
config,
|
|
556
|
+
dim=config.fast_dim,
|
|
557
|
+
n_head=config.fast_n_head,
|
|
558
|
+
n_local_heads=config.fast_n_local_heads,
|
|
559
|
+
head_dim=config.fast_head_dim,
|
|
560
|
+
intermediate_size=config.fast_intermediate_size,
|
|
561
|
+
attention_qkv_bias=config.fast_attention_qkv_bias,
|
|
562
|
+
)
|
|
563
|
+
|
|
478
564
|
self.fast_layers = nn.ModuleList(
|
|
479
|
-
TransformerBlock(
|
|
565
|
+
TransformerBlock(override_config, use_sdpa=False)
|
|
566
|
+
for _ in range(config.n_fast_layer)
|
|
480
567
|
)
|
|
481
|
-
self.fast_norm = RMSNorm(config.
|
|
568
|
+
self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
|
|
482
569
|
self.fast_output = nn.Linear(
|
|
483
|
-
config.
|
|
570
|
+
config.fast_dim,
|
|
484
571
|
config.codebook_size,
|
|
485
572
|
bias=False,
|
|
486
573
|
)
|
|
487
574
|
|
|
575
|
+
self.register_buffer(
|
|
576
|
+
"fast_freqs_cis",
|
|
577
|
+
precompute_freqs_cis(
|
|
578
|
+
config.num_codebooks,
|
|
579
|
+
config.fast_dim // config.fast_n_head,
|
|
580
|
+
config.rope_base,
|
|
581
|
+
),
|
|
582
|
+
persistent=False,
|
|
583
|
+
)
|
|
488
584
|
self.apply(self._init_weights)
|
|
489
585
|
|
|
490
586
|
def setup_caches(
|
|
@@ -492,7 +588,7 @@ class DualARTransformer(BaseTransformer):
|
|
|
492
588
|
):
|
|
493
589
|
super().setup_caches(max_batch_size, max_seq_len, dtype)
|
|
494
590
|
|
|
495
|
-
head_dim = self.config.
|
|
591
|
+
head_dim = self.config.fast_dim // self.config.fast_n_head
|
|
496
592
|
|
|
497
593
|
# Fast transformer
|
|
498
594
|
# The max seq len here is the number of codebooks
|
|
@@ -500,7 +596,7 @@ class DualARTransformer(BaseTransformer):
|
|
|
500
596
|
b.attention.kv_cache = KVCache(
|
|
501
597
|
max_batch_size,
|
|
502
598
|
self.config.num_codebooks,
|
|
503
|
-
self.config.
|
|
599
|
+
self.config.fast_n_local_heads,
|
|
504
600
|
head_dim,
|
|
505
601
|
dtype=dtype,
|
|
506
602
|
)
|
|
@@ -513,13 +609,13 @@ class DualARTransformer(BaseTransformer):
|
|
|
513
609
|
parent_result = super().forward(inp, key_padding_mask)
|
|
514
610
|
token_logits = parent_result.logits
|
|
515
611
|
x = parent_result.hidden_states
|
|
612
|
+
x = self.fast_project_in(x)
|
|
516
613
|
|
|
517
614
|
# Fast transformer
|
|
518
615
|
fast_seq_len = self.config.num_codebooks
|
|
519
616
|
fast_mask = self.causal_mask[
|
|
520
617
|
None, None, :fast_seq_len, :fast_seq_len
|
|
521
618
|
] # (B, N, Q, K)
|
|
522
|
-
fast_freqs_cis = self.freqs_cis[:fast_seq_len]
|
|
523
619
|
|
|
524
620
|
# Drop the last token and rotate left
|
|
525
621
|
codebooks = inp[:, 1:-1, 1:]
|
|
@@ -542,9 +638,11 @@ class DualARTransformer(BaseTransformer):
|
|
|
542
638
|
|
|
543
639
|
for layer in self.fast_layers:
|
|
544
640
|
if self.config.use_gradient_checkpointing and self.training:
|
|
545
|
-
x = checkpoint(
|
|
641
|
+
x = checkpoint(
|
|
642
|
+
layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
|
|
643
|
+
)
|
|
546
644
|
else:
|
|
547
|
-
x = layer(x, fast_freqs_cis, fast_mask)
|
|
645
|
+
x = layer(x, self.fast_freqs_cis, fast_mask)
|
|
548
646
|
|
|
549
647
|
# unflatten the batch and num_codebooks
|
|
550
648
|
fast_out = self.fast_norm(x)
|
|
@@ -584,7 +682,7 @@ class DualARTransformer(BaseTransformer):
|
|
|
584
682
|
fast_mask = self.causal_mask[
|
|
585
683
|
None, None, input_pos, : self.config.num_codebooks
|
|
586
684
|
] # (B, N, Q, K)
|
|
587
|
-
fast_freqs_cis = self.
|
|
685
|
+
fast_freqs_cis = self.fast_freqs_cis[input_pos]
|
|
588
686
|
|
|
589
687
|
for layer in self.fast_layers:
|
|
590
688
|
x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
|
|
@@ -595,6 +693,16 @@ class DualARTransformer(BaseTransformer):
|
|
|
595
693
|
|
|
596
694
|
return codebook_logits
|
|
597
695
|
|
|
696
|
+
def forward_generate(
|
|
697
|
+
self,
|
|
698
|
+
x: Tensor,
|
|
699
|
+
input_pos: Optional[Tensor] = None,
|
|
700
|
+
vq_masks: Optional[Tensor] = None,
|
|
701
|
+
) -> TransformerForwardResult:
|
|
702
|
+
x = super().forward_generate(x, input_pos, vq_masks)
|
|
703
|
+
x.hidden_states = self.fast_project_in(x.hidden_states)
|
|
704
|
+
return x
|
|
705
|
+
|
|
598
706
|
|
|
599
707
|
class TransformerBlock(nn.Module):
|
|
600
708
|
def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
|
|
@@ -102,8 +102,8 @@ class FishConvNet(nn.Module):
|
|
|
102
102
|
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
|
103
103
|
return self
|
|
104
104
|
|
|
105
|
-
def
|
|
106
|
-
self.conv = remove_parametrizations(self.conv)
|
|
105
|
+
def remove_parametrizations(self, name="weight"):
|
|
106
|
+
self.conv = remove_parametrizations(self.conv, name)
|
|
107
107
|
return self
|
|
108
108
|
|
|
109
109
|
|
|
@@ -128,8 +128,8 @@ class FishTransConvNet(nn.Module):
|
|
|
128
128
|
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
|
129
129
|
return self
|
|
130
130
|
|
|
131
|
-
def
|
|
132
|
-
self.conv = remove_parametrizations(self.conv)
|
|
131
|
+
def remove_parametrizations(self, name="weight"):
|
|
132
|
+
self.conv = remove_parametrizations(self.conv, name)
|
|
133
133
|
return self
|
|
134
134
|
|
|
135
135
|
|
|
@@ -178,9 +178,9 @@ class ResBlock1(torch.nn.Module):
|
|
|
178
178
|
|
|
179
179
|
def remove_parametrizations(self):
|
|
180
180
|
for conv in self.convs1:
|
|
181
|
-
remove_parametrizations(
|
|
181
|
+
conv.remove_parametrizations()
|
|
182
182
|
for conv in self.convs2:
|
|
183
|
-
remove_parametrizations(
|
|
183
|
+
conv.remove_parametrizations()
|
|
184
184
|
|
|
185
185
|
|
|
186
186
|
class ParallelBlock(nn.Module):
|
|
@@ -288,11 +288,11 @@ class HiFiGANGenerator(nn.Module):
|
|
|
288
288
|
|
|
289
289
|
def remove_parametrizations(self):
|
|
290
290
|
for up in self.ups:
|
|
291
|
-
remove_parametrizations(
|
|
291
|
+
up.remove_parametrizations()
|
|
292
292
|
for block in self.resblocks:
|
|
293
293
|
block.remove_parametrizations()
|
|
294
|
-
|
|
295
|
-
|
|
294
|
+
self.conv_pre.remove_parametrizations()
|
|
295
|
+
self.conv_post.remove_parametrizations()
|
|
296
296
|
|
|
297
297
|
|
|
298
298
|
# DropPath copied from timm library
|
|
@@ -1,19 +1,8 @@
|
|
|
1
1
|
import re
|
|
2
2
|
|
|
3
3
|
SYMBOLS_MAPPING = {
|
|
4
|
-
"“": "'",
|
|
5
|
-
"”": "'",
|
|
6
4
|
"‘": "'",
|
|
7
5
|
"’": "'",
|
|
8
|
-
"【": "",
|
|
9
|
-
"】": "",
|
|
10
|
-
"[": "",
|
|
11
|
-
"]": "",
|
|
12
|
-
"(": "",
|
|
13
|
-
")": "",
|
|
14
|
-
"(": "",
|
|
15
|
-
")": "",
|
|
16
|
-
"・": "·",
|
|
17
6
|
}
|
|
18
7
|
|
|
19
8
|
REPLACE_SYMBOL_REGEX = re.compile(
|
|
@@ -21,6 +10,17 @@ REPLACE_SYMBOL_REGEX = re.compile(
|
|
|
21
10
|
)
|
|
22
11
|
|
|
23
12
|
|
|
13
|
+
EMOJI_REGEX = re.compile(
|
|
14
|
+
"["
|
|
15
|
+
"\U0001F600-\U0001F64F" # emoticons
|
|
16
|
+
"\U0001F300-\U0001F5FF" # symbols & pictographs
|
|
17
|
+
"\U0001F680-\U0001F6FF" # transport & map symbols
|
|
18
|
+
"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
|
19
|
+
"]+",
|
|
20
|
+
flags=re.UNICODE,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
24
|
def clean_text(text):
|
|
25
25
|
# Clean the text
|
|
26
26
|
text = text.strip()
|
|
@@ -28,4 +28,10 @@ def clean_text(text):
|
|
|
28
28
|
# Replace all chinese symbols with their english counterparts
|
|
29
29
|
text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
|
|
30
30
|
|
|
31
|
+
# Remove emojis
|
|
32
|
+
text = EMOJI_REGEX.sub(r"", text)
|
|
33
|
+
|
|
34
|
+
# Remove continuous periods (...) and commas (,,,)
|
|
35
|
+
text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text)
|
|
36
|
+
|
|
31
37
|
return text
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import tiktoken
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
# This is a modified version of the default pattern from GPT-4o, that better handles punctuations.
|
|
11
|
+
FISH_TIKTOKEN_PATTERN = "|".join(
|
|
12
|
+
[
|
|
13
|
+
r"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
|
|
14
|
+
r"\p{P}",
|
|
15
|
+
r"[^\r\n\p{L}\p{N}]?\p{L}+",
|
|
16
|
+
r"\p{N}",
|
|
17
|
+
r" ?[^\s\p{L}\p{N}]+[\r\n]*",
|
|
18
|
+
r"\s*[\r\n]+",
|
|
19
|
+
r"\s+(\?!\S)",
|
|
20
|
+
r"\s+",
|
|
21
|
+
]
|
|
22
|
+
)
|
|
23
|
+
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
|
24
|
+
|
|
25
|
+
BOS_TOKEN = "<|begin_of_text|>"
|
|
26
|
+
EOS_TOKEN = "<|end_of_text|>"
|
|
27
|
+
PAD_TOKEN = "<|pad|>"
|
|
28
|
+
IM_START_TOKEN = "<|im_start|>"
|
|
29
|
+
IM_END_TOKEN = "<|im_end|>"
|
|
30
|
+
|
|
31
|
+
MODALITY_TEXT_TOKEN = "<|text|>"
|
|
32
|
+
MODALITY_VOICE_TOKEN = "<|voice|>"
|
|
33
|
+
MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
|
|
34
|
+
MODALITY_TOKENS = {
|
|
35
|
+
"text": MODALITY_TEXT_TOKEN,
|
|
36
|
+
"voice": MODALITY_VOICE_TOKEN,
|
|
37
|
+
"interleave": MODALITY_INTERLEAVE_TOKEN,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
PLACEHOLDER_TOKEN = [""] * 4
|
|
41
|
+
for i in range(4):
|
|
42
|
+
PLACEHOLDER_TOKEN[i] = f"<|placeholder:{i}|>"
|
|
43
|
+
|
|
44
|
+
SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
|
|
45
|
+
SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)]
|
|
46
|
+
|
|
47
|
+
# Warning: when you add a new special token, you should only add it to the end of the list.
|
|
48
|
+
ALL_SPECIAL_TOKENS = [
|
|
49
|
+
BOS_TOKEN,
|
|
50
|
+
EOS_TOKEN,
|
|
51
|
+
PAD_TOKEN,
|
|
52
|
+
IM_START_TOKEN,
|
|
53
|
+
IM_END_TOKEN,
|
|
54
|
+
PLACEHOLDER_TOKEN[0],
|
|
55
|
+
PLACEHOLDER_TOKEN[1],
|
|
56
|
+
PLACEHOLDER_TOKEN[2],
|
|
57
|
+
PLACEHOLDER_TOKEN[3],
|
|
58
|
+
MODALITY_TEXT_TOKEN,
|
|
59
|
+
MODALITY_VOICE_TOKEN,
|
|
60
|
+
MODALITY_INTERLEAVE_TOKEN,
|
|
61
|
+
*SEMANTIC_TOKENS,
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class FishTokenizer:
|
|
66
|
+
def __init__(self, model_path: str) -> None:
|
|
67
|
+
mergeable_ranks = self.load_tiktoken_bpe(model_path)
|
|
68
|
+
special_token_begin = len(mergeable_ranks)
|
|
69
|
+
self.all_special_tokens_with_ids = {
|
|
70
|
+
token: special_token_begin + i for i, token in enumerate(ALL_SPECIAL_TOKENS)
|
|
71
|
+
}
|
|
72
|
+
self.semantic_id_to_token_id = {
|
|
73
|
+
i: self.all_special_tokens_with_ids[token]
|
|
74
|
+
for i, token in enumerate(SEMANTIC_TOKENS)
|
|
75
|
+
}
|
|
76
|
+
self.semantic_begin_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[0]]
|
|
77
|
+
self.semantic_end_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[-1]]
|
|
78
|
+
|
|
79
|
+
self.tkt_model = tiktoken.core.Encoding(
|
|
80
|
+
name=Path(model_path).stem,
|
|
81
|
+
pat_str=FISH_TIKTOKEN_PATTERN,
|
|
82
|
+
mergeable_ranks=mergeable_ranks,
|
|
83
|
+
special_tokens=self.all_special_tokens_with_ids,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
|
|
88
|
+
data = {}
|
|
89
|
+
for line in open(tiktoken_bpe_file).read().splitlines():
|
|
90
|
+
if not line:
|
|
91
|
+
continue
|
|
92
|
+
token, rank = line.split()
|
|
93
|
+
data[base64.b64decode(token)] = int(rank)
|
|
94
|
+
return data
|
|
95
|
+
|
|
96
|
+
def get_token_id(self, token: str) -> int:
|
|
97
|
+
return self.all_special_tokens_with_ids[token]
|
|
98
|
+
|
|
99
|
+
def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]:
|
|
100
|
+
assert isinstance(s, str)
|
|
101
|
+
|
|
102
|
+
subs = []
|
|
103
|
+
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
|
|
104
|
+
subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
|
|
105
|
+
|
|
106
|
+
if allowed_special is True:
|
|
107
|
+
allowed_special = self.tkt_model.special_tokens_set
|
|
108
|
+
elif allowed_special is False:
|
|
109
|
+
allowed_special = set()
|
|
110
|
+
|
|
111
|
+
return sum(
|
|
112
|
+
self.tkt_model.encode_batch(
|
|
113
|
+
subs, allowed_special=allowed_special, disallowed_special=set()
|
|
114
|
+
),
|
|
115
|
+
start=[],
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def decode(self, tokens: list[int]) -> str:
|
|
119
|
+
return self.tkt_model.decode(tokens)
|
|
120
|
+
|
|
121
|
+
def save_pretrained(self, path: str):
|
|
122
|
+
path = Path(path)
|
|
123
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
124
|
+
|
|
125
|
+
with open(path / "tokenizer.tiktoken", "w") as f:
|
|
126
|
+
for token, rank in self.tkt_model._mergeable_ranks.items():
|
|
127
|
+
f.write(f"{base64.b64encode(token).decode()} {rank}\n")
|
|
128
|
+
|
|
129
|
+
with open(path / "special_tokens.json", "w") as f:
|
|
130
|
+
json.dump(
|
|
131
|
+
self.all_special_tokens_with_ids,
|
|
132
|
+
f,
|
|
133
|
+
indent=2,
|
|
134
|
+
ensure_ascii=False,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
@staticmethod
|
|
138
|
+
def from_pretrained(path: str):
|
|
139
|
+
return FishTokenizer(Path(path) / "tokenizer.tiktoken")
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
if __name__ == "__main__":
|
|
143
|
+
tokenizer = FishTokenizer("data/mpacks/v1.4-pretrain/tokenizer.all.tiktoken")
|
|
144
|
+
tokenizer.save_pretrained("checkpoints/fish-speech-0.5B")
|
|
145
|
+
tokenizer = FishTokenizer.from_pretrained("checkpoints/fish-speech-0.5B")
|
|
146
|
+
|
|
147
|
+
print(
|
|
148
|
+
[
|
|
149
|
+
tokenizer.decode([i])
|
|
150
|
+
for i in tokenizer.encode(f"{BOS_TOKEN}你好,世界!{EOS_TOKEN}")
|
|
151
|
+
]
|
|
152
|
+
)
|
|
@@ -6,7 +6,7 @@ from typing import Optional
|
|
|
6
6
|
|
|
7
7
|
import hydra
|
|
8
8
|
import lightning as L
|
|
9
|
-
|
|
9
|
+
import pyrootutils
|
|
10
10
|
import torch
|
|
11
11
|
from lightning import Callback, LightningDataModule, LightningModule, Trainer
|
|
12
12
|
from lightning.pytorch.loggers import Logger
|
|
@@ -18,7 +18,7 @@ os.environ.pop("SLURM_JOB_NAME", None)
|
|
|
18
18
|
os.environ.pop("SLURM_NTASKS_PER_NODE", None)
|
|
19
19
|
|
|
20
20
|
# register eval resolver and root
|
|
21
|
-
|
|
21
|
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
22
22
|
|
|
23
23
|
# Allow TF32 on Ampere GPUs
|
|
24
24
|
torch.set_float32_matmul_precision("high")
|
|
@@ -5,7 +5,7 @@ from .instantiators import instantiate_callbacks, instantiate_loggers
|
|
|
5
5
|
from .logger import RankedLogger
|
|
6
6
|
from .logging_utils import log_hyperparameters
|
|
7
7
|
from .rich_utils import enforce_tags, print_config_tree
|
|
8
|
-
from .utils import extras, get_metric_value, task_wrapper
|
|
8
|
+
from .utils import extras, get_metric_value, set_seed, task_wrapper
|
|
9
9
|
|
|
10
10
|
__all__ = [
|
|
11
11
|
"enforce_tags",
|
|
@@ -20,4 +20,5 @@ __all__ = [
|
|
|
20
20
|
"braceexpand",
|
|
21
21
|
"get_latest_checkpoint",
|
|
22
22
|
"autocast_exclude_mps",
|
|
23
|
+
"set_seed",
|
|
23
24
|
]
|
|
@@ -1,7 +1,10 @@
|
|
|
1
|
+
import random
|
|
1
2
|
import warnings
|
|
2
3
|
from importlib.util import find_spec
|
|
3
4
|
from typing import Callable
|
|
4
5
|
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
5
8
|
from omegaconf import DictConfig
|
|
6
9
|
|
|
7
10
|
from .logger import RankedLogger
|
|
@@ -112,3 +115,22 @@ def get_metric_value(metric_dict: dict, metric_name: str) -> float:
|
|
|
112
115
|
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
|
|
113
116
|
|
|
114
117
|
return metric_value
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def set_seed(seed: int):
|
|
121
|
+
if seed < 0:
|
|
122
|
+
seed = -seed
|
|
123
|
+
if seed > (1 << 31):
|
|
124
|
+
seed = 1 << 31
|
|
125
|
+
|
|
126
|
+
random.seed(seed)
|
|
127
|
+
np.random.seed(seed)
|
|
128
|
+
torch.manual_seed(seed)
|
|
129
|
+
|
|
130
|
+
if torch.cuda.is_available():
|
|
131
|
+
torch.cuda.manual_seed(seed)
|
|
132
|
+
torch.cuda.manual_seed_all(seed)
|
|
133
|
+
|
|
134
|
+
if torch.backends.cudnn.is_available():
|
|
135
|
+
torch.backends.cudnn.deterministic = True
|
|
136
|
+
torch.backends.cudnn.benchmark = False
|
|
@@ -114,7 +114,7 @@ class Seafoam(Base):
|
|
|
114
114
|
block_title_text_weight="600",
|
|
115
115
|
block_border_width="3px",
|
|
116
116
|
block_shadow="*shadow_drop_lg",
|
|
117
|
-
button_shadow="*shadow_drop_lg",
|
|
117
|
+
# button_shadow="*shadow_drop_lg",
|
|
118
118
|
button_small_padding="0px",
|
|
119
119
|
button_large_padding="3px",
|
|
120
120
|
)
|
|
@@ -176,7 +176,7 @@ def change_infer(
|
|
|
176
176
|
p_infer = subprocess.Popen(
|
|
177
177
|
[
|
|
178
178
|
PYTHON,
|
|
179
|
-
"tools/
|
|
179
|
+
"tools/run_webui.py",
|
|
180
180
|
"--decoder-checkpoint-path",
|
|
181
181
|
infer_decoder_model,
|
|
182
182
|
"--decoder-config-name",
|
|
@@ -794,7 +794,7 @@ with gr.Blocks(
|
|
|
794
794
|
value="VQGAN",
|
|
795
795
|
)
|
|
796
796
|
with gr.Row():
|
|
797
|
-
with gr.
|
|
797
|
+
with gr.Column():
|
|
798
798
|
with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
|
|
799
799
|
gr.HTML("You don't need to train this model!")
|
|
800
800
|
|