xinference 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +72 -66
- xinference/core/model.py +78 -25
- xinference/core/supervisor.py +81 -10
- xinference/core/utils.py +12 -8
- xinference/core/worker.py +32 -0
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +15 -10
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +35 -111
- xinference/model/audio/model_spec.json +19 -3
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +145 -4
- xinference/model/image/model_spec_modelscope.json +150 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +143 -0
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +148 -0
- xinference/model/llm/mlx/core.py +37 -32
- xinference/model/llm/transformers/cogagent.py +272 -0
- xinference/model/llm/transformers/core.py +2 -0
- xinference/model/llm/transformers/qwen2_vl.py +12 -1
- xinference/model/llm/utils.py +28 -3
- xinference/model/llm/vllm/core.py +48 -9
- xinference/model/llm/vllm/xavier/__init__.py +13 -0
- xinference/model/llm/vllm/xavier/allocator.py +74 -0
- xinference/model/llm/vllm/xavier/block.py +112 -0
- xinference/model/llm/vllm/xavier/block_manager.py +71 -0
- xinference/model/llm/vllm/xavier/block_tracker.py +116 -0
- xinference/model/llm/vllm/xavier/engine.py +247 -0
- xinference/model/llm/vllm/xavier/executor.py +132 -0
- xinference/model/llm/vllm/xavier/scheduler.py +422 -0
- xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
- xinference/model/llm/vllm/xavier/test/test_xavier.py +122 -0
- xinference/model/llm/vllm/xavier/transfer.py +298 -0
- xinference/model/video/diffusers.py +14 -0
- xinference/model/video/model_spec.json +15 -0
- xinference/model/video/model_spec_modelscope.json +16 -0
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- xinference/types.py +13 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
- xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
- xinference/web/ui/build/static/js/main.1eb206d1.js +3 -0
- xinference/web/ui/build/static/js/main.1eb206d1.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +67 -3
- xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
- xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
- xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
- xinference/web/ui/node_modules/i18next/package.json +129 -0
- xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
- xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
- xinference/web/ui/node_modules/react-i18next/package.json +162 -0
- xinference/web/ui/node_modules/void-elements/package.json +34 -0
- xinference/web/ui/package-lock.json +69 -3
- xinference/web/ui/package.json +2 -0
- xinference/web/ui/src/locales/en.json +186 -0
- xinference/web/ui/src/locales/zh.json +186 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/METADATA +19 -11
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/RECORD +178 -111
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
- xinference/web/ui/build/static/js/main.4eb4ee80.js +0 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
- /xinference/web/ui/build/static/js/{main.4eb4ee80.js.LICENSE.txt → main.1eb206d1.js.LICENSE.txt} +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/LICENSE +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/WHEEL +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/top_level.txt +0 -0
xinference/model/llm/mlx/core.py
CHANGED
|
@@ -173,7 +173,9 @@ class MLXModel(LLM):
|
|
|
173
173
|
return False
|
|
174
174
|
return True
|
|
175
175
|
|
|
176
|
-
def _get_prompt_cache(
|
|
176
|
+
def _get_prompt_cache(
|
|
177
|
+
self, prompt, lora_name: Optional[str] = None, model: Any = None
|
|
178
|
+
):
|
|
177
179
|
from mlx_lm.models.cache import make_prompt_cache
|
|
178
180
|
|
|
179
181
|
assert self._prompt_cache is not None
|
|
@@ -185,7 +187,9 @@ class MLXModel(LLM):
|
|
|
185
187
|
or self._prompt_cache.tokens != prompt[:cache_len]
|
|
186
188
|
):
|
|
187
189
|
self._prompt_cache.model_key = model_key
|
|
188
|
-
self._prompt_cache.cache = make_prompt_cache(
|
|
190
|
+
self._prompt_cache.cache = make_prompt_cache(
|
|
191
|
+
model or self._model, self._max_kv_size
|
|
192
|
+
)
|
|
189
193
|
self._prompt_cache.tokens = []
|
|
190
194
|
logger.debug("Making new prompt cache for %s", self.model_uid)
|
|
191
195
|
else:
|
|
@@ -458,6 +462,8 @@ class MLXVisionModel(MLXModel, ChatModelMixin):
|
|
|
458
462
|
|
|
459
463
|
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
460
464
|
|
|
465
|
+
self._prompt_cache = PromptCache()
|
|
466
|
+
|
|
461
467
|
return load(self.model_path)
|
|
462
468
|
|
|
463
469
|
def load(self):
|
|
@@ -476,23 +482,10 @@ class MLXVisionModel(MLXModel, ChatModelMixin):
|
|
|
476
482
|
from mlx_lm.utils import GenerationResponse
|
|
477
483
|
from mlx_vlm.utils import generate_step
|
|
478
484
|
|
|
479
|
-
max_tokens = kwargs.pop("max_tokens")
|
|
480
485
|
inputs = kwargs["prompt_token_ids"]
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
kwargs =
|
|
484
|
-
k: v
|
|
485
|
-
for k, v in zip(
|
|
486
|
-
[
|
|
487
|
-
"image_grid_thw",
|
|
488
|
-
"image_sizes",
|
|
489
|
-
"aspect_ratio_ids",
|
|
490
|
-
"aspect_ratio_mask",
|
|
491
|
-
"cross_attention_mask",
|
|
492
|
-
],
|
|
493
|
-
inputs[3:],
|
|
494
|
-
)
|
|
495
|
-
}
|
|
486
|
+
|
|
487
|
+
max_tokens = kwargs.pop("max_tokens")
|
|
488
|
+
input_ids, pixel_values, mask, kwargs = inputs
|
|
496
489
|
|
|
497
490
|
tokenizer = self._processor.tokenizer
|
|
498
491
|
detokenizer = self._processor.detokenizer
|
|
@@ -538,27 +531,39 @@ class MLXVisionModel(MLXModel, ChatModelMixin):
|
|
|
538
531
|
def _prepare_inputs(
|
|
539
532
|
self, prompt: Union[str, Dict[str, Any]], kwargs
|
|
540
533
|
) -> Tuple[Any, int]:
|
|
534
|
+
import mlx.core as mx
|
|
541
535
|
from mlx_vlm import prepare_inputs
|
|
542
536
|
|
|
543
537
|
prompt_str = prompt.get("prompt") # type: ignore
|
|
544
538
|
images = prompt.get("multi_modal_data", {}).get("image") # type: ignore
|
|
545
539
|
if images and not isinstance(images, list):
|
|
546
540
|
images = [images]
|
|
547
|
-
|
|
548
|
-
|
|
541
|
+
resize_shape = kwargs.pop("resize_shape", None)
|
|
542
|
+
image_token_index = getattr(self._model.config, "image_token_index", None)
|
|
543
|
+
|
|
544
|
+
processor = self._processor
|
|
545
|
+
tokenizer = processor if hasattr(processor, "encode") else processor.tokenizer
|
|
546
|
+
prompt_tokens = mx.array(tokenizer.encode(prompt_str))
|
|
547
|
+
|
|
548
|
+
if not images:
|
|
549
|
+
input_ids = prompt_tokens[None, :]
|
|
550
|
+
pixel_values = mask = None
|
|
551
|
+
kwargs = {}
|
|
552
|
+
input_token_len = input_ids.size
|
|
549
553
|
else:
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
554
|
+
inputs = prepare_inputs(
|
|
555
|
+
processor, images, prompt_str, image_token_index, resize_shape
|
|
556
|
+
)
|
|
557
|
+
input_ids = inputs["input_ids"]
|
|
558
|
+
pixel_values = inputs["pixel_values"]
|
|
559
|
+
mask = inputs["attention_mask"]
|
|
560
|
+
kwargs = {
|
|
561
|
+
k: v
|
|
562
|
+
for k, v in inputs.items()
|
|
563
|
+
if k not in ["input_ids", "pixel_values", "attention_mask"]
|
|
564
|
+
}
|
|
565
|
+
input_token_len = int(mask.sum())
|
|
566
|
+
return (input_ids, pixel_values, mask, kwargs), input_token_len
|
|
562
567
|
|
|
563
568
|
def chat(
|
|
564
569
|
self,
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
import re
|
|
16
|
+
import uuid
|
|
17
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
18
|
+
from typing import Dict, Iterator, List, Literal, Optional, Union
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from ....model.utils import select_device
|
|
23
|
+
from ....types import (
|
|
24
|
+
ChatCompletion,
|
|
25
|
+
ChatCompletionChunk,
|
|
26
|
+
CogagentGenerateConfig,
|
|
27
|
+
CompletionChunk,
|
|
28
|
+
)
|
|
29
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
30
|
+
from ..utils import (
|
|
31
|
+
_decode_image,
|
|
32
|
+
generate_chat_completion,
|
|
33
|
+
generate_completion_chunk,
|
|
34
|
+
parse_messages,
|
|
35
|
+
)
|
|
36
|
+
from .core import PytorchChatModel
|
|
37
|
+
from .utils import cache_clean
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class CogAgentChatModel(PytorchChatModel):
|
|
43
|
+
def __init__(self, *args, **kwargs):
|
|
44
|
+
super().__init__(*args, **kwargs)
|
|
45
|
+
self._torch_type = None
|
|
46
|
+
self._device = None
|
|
47
|
+
self._tokenizer = None
|
|
48
|
+
self._model = None
|
|
49
|
+
self._platform: Literal["Mac", "WIN", "Mobile"] | None = "Mac"
|
|
50
|
+
self._format: Literal[
|
|
51
|
+
"(Answer in Action-Operation-Sensitive format.)",
|
|
52
|
+
"(Answer in Status-Plan-Action-Operation format.)",
|
|
53
|
+
"(Answer in Status-Action-Operation-Sensitive format.)",
|
|
54
|
+
"(Answer in Status-Action-Operation format.)",
|
|
55
|
+
"(Answer in Action-Operation format.)",
|
|
56
|
+
] | None = "(Answer in Action-Operation-Sensitive format.)"
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def match(
|
|
60
|
+
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
61
|
+
) -> bool:
|
|
62
|
+
family = model_family.model_family or model_family.model_name
|
|
63
|
+
if "cogagent" in family.lower():
|
|
64
|
+
return True
|
|
65
|
+
return False
|
|
66
|
+
|
|
67
|
+
def load(self, **kwargs):
|
|
68
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
69
|
+
|
|
70
|
+
device = self._pytorch_model_config.get("device", "auto")
|
|
71
|
+
self._device = select_device(device)
|
|
72
|
+
|
|
73
|
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
74
|
+
self.model_path, trust_remote_code=True
|
|
75
|
+
)
|
|
76
|
+
if self.quantization == "4-bit":
|
|
77
|
+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
|
78
|
+
elif self.quantization == "8-bit":
|
|
79
|
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
80
|
+
else:
|
|
81
|
+
quantization_config = None
|
|
82
|
+
|
|
83
|
+
self._model = AutoModelForCausalLM.from_pretrained(
|
|
84
|
+
self.model_path,
|
|
85
|
+
torch_dtype=torch.bfloat16,
|
|
86
|
+
trust_remote_code=True,
|
|
87
|
+
device_map=self._device,
|
|
88
|
+
quantization_config=quantization_config,
|
|
89
|
+
).eval()
|
|
90
|
+
|
|
91
|
+
def _message_content_to_cogagent(self, content):
|
|
92
|
+
assert isinstance(content, list)
|
|
93
|
+
texts = []
|
|
94
|
+
image_urls = []
|
|
95
|
+
for c in content:
|
|
96
|
+
c_type = c.get("type")
|
|
97
|
+
if c_type == "text":
|
|
98
|
+
texts.append(c["text"])
|
|
99
|
+
elif c_type == "image_url":
|
|
100
|
+
image_urls.append(c["image_url"]["url"])
|
|
101
|
+
image_futures = []
|
|
102
|
+
with ThreadPoolExecutor() as executor:
|
|
103
|
+
for image_url in image_urls:
|
|
104
|
+
fut = executor.submit(_decode_image, image_url)
|
|
105
|
+
image_futures.append(fut)
|
|
106
|
+
images = [fut.result() for fut in image_futures]
|
|
107
|
+
text = " ".join(texts)
|
|
108
|
+
if len(images) == 0:
|
|
109
|
+
raise RuntimeError(
|
|
110
|
+
"CogAgent requires image input to perform GUI Agent tasks. Pure text-based interaction cannot execute such tasks."
|
|
111
|
+
)
|
|
112
|
+
elif len(images) == 1:
|
|
113
|
+
return text, images[-1]
|
|
114
|
+
else:
|
|
115
|
+
logger.warning(
|
|
116
|
+
"There are multiple images in the prompt, CogAgent will automatically use the most recently provided image as the input."
|
|
117
|
+
)
|
|
118
|
+
return text, images[-1]
|
|
119
|
+
|
|
120
|
+
def _history_content_to_cogagent(self, chat_history: List[Dict]):
|
|
121
|
+
grounded_pattern = r"Grounded Operation:\s*(.*)"
|
|
122
|
+
action_pattern = r"Action:\s*(.*)"
|
|
123
|
+
|
|
124
|
+
def extract_operations(_content: str):
|
|
125
|
+
"""extract grounded operation and action operation"""
|
|
126
|
+
_history_step = []
|
|
127
|
+
_history_action = []
|
|
128
|
+
|
|
129
|
+
matches_history = re.search(grounded_pattern, _content)
|
|
130
|
+
matches_actions = re.search(action_pattern, _content)
|
|
131
|
+
|
|
132
|
+
if matches_history:
|
|
133
|
+
grounded_operation = matches_history.group(1)
|
|
134
|
+
_history_step.append(grounded_operation)
|
|
135
|
+
if matches_actions:
|
|
136
|
+
action_operation = matches_actions.group(1)
|
|
137
|
+
_history_action.append(action_operation)
|
|
138
|
+
|
|
139
|
+
return _history_step, _history_action
|
|
140
|
+
|
|
141
|
+
history_step = []
|
|
142
|
+
history_action = []
|
|
143
|
+
|
|
144
|
+
for i in range(0, len(chat_history) - 1, 2):
|
|
145
|
+
content = chat_history[i + 1].get("content")
|
|
146
|
+
if isinstance(content, str): # 如果内容是字符串
|
|
147
|
+
steps, actions = extract_operations(content)
|
|
148
|
+
history_step.extend(steps)
|
|
149
|
+
history_action.extend(actions)
|
|
150
|
+
|
|
151
|
+
elif isinstance(content, list): # 如果内容是列表
|
|
152
|
+
for c in content:
|
|
153
|
+
c_content = c.get("content")
|
|
154
|
+
if isinstance(c_content, str): # 确保是字符串类型
|
|
155
|
+
steps, actions = extract_operations(c_content)
|
|
156
|
+
history_step.extend(steps)
|
|
157
|
+
history_action.extend(actions)
|
|
158
|
+
|
|
159
|
+
return history_step, history_action
|
|
160
|
+
|
|
161
|
+
def get_query_and_history(
|
|
162
|
+
self,
|
|
163
|
+
prompt: Union[str, List[Dict]],
|
|
164
|
+
chat_history: Optional[List[Dict]] = None,
|
|
165
|
+
):
|
|
166
|
+
task, image = self._message_content_to_cogagent(prompt)
|
|
167
|
+
|
|
168
|
+
history_step, history_action = [], []
|
|
169
|
+
|
|
170
|
+
if chat_history:
|
|
171
|
+
history_step, history_action = self._history_content_to_cogagent(
|
|
172
|
+
chat_history
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Verify history lengths match
|
|
176
|
+
if len(history_step) != len(history_action):
|
|
177
|
+
raise ValueError("Mismatch in lengths of history_step and history_action.")
|
|
178
|
+
|
|
179
|
+
# Format history steps for output
|
|
180
|
+
history_str = "\nHistory steps: "
|
|
181
|
+
for index, (step, action) in enumerate(zip(history_step, history_action)):
|
|
182
|
+
history_str += f"\n{index}. {step}\t{action}"
|
|
183
|
+
|
|
184
|
+
# Compose the query with task, platform, and selected format instructions
|
|
185
|
+
query = f"Task: {task}{history_str}\n{self._platform}{self._format}"
|
|
186
|
+
logger.info(f"query:{query}")
|
|
187
|
+
return query, image
|
|
188
|
+
|
|
189
|
+
@cache_clean
|
|
190
|
+
def chat(
|
|
191
|
+
self,
|
|
192
|
+
messages: List[Dict],
|
|
193
|
+
generate_config: Optional[CogagentGenerateConfig] = None,
|
|
194
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
195
|
+
if generate_config is not None:
|
|
196
|
+
self._platform = generate_config.pop("platform", self._platform)
|
|
197
|
+
self._format = generate_config.pop("format", self._format)
|
|
198
|
+
|
|
199
|
+
sanitize_generate_config = self._sanitize_generate_config(generate_config)
|
|
200
|
+
stream = sanitize_generate_config.get("stream")
|
|
201
|
+
sanitized_config = {
|
|
202
|
+
"max_length": sanitize_generate_config.get("max_tokens", 512),
|
|
203
|
+
"top_k": sanitize_generate_config.get("top_k", 1),
|
|
204
|
+
"do_sample": True,
|
|
205
|
+
}
|
|
206
|
+
prompt, _, chat_history = parse_messages(messages)
|
|
207
|
+
|
|
208
|
+
query, image = self.get_query_and_history(prompt, chat_history)
|
|
209
|
+
|
|
210
|
+
full_context_kwargs = {
|
|
211
|
+
"return_tensors": "pt",
|
|
212
|
+
"return_dict": True,
|
|
213
|
+
}
|
|
214
|
+
assert self.model_family.chat_template is not None
|
|
215
|
+
inputs = self.get_full_context(
|
|
216
|
+
[{"role": "user", "image": image, "content": query}],
|
|
217
|
+
self.model_family.chat_template,
|
|
218
|
+
self._tokenizer,
|
|
219
|
+
tokenize=True,
|
|
220
|
+
**full_context_kwargs,
|
|
221
|
+
)
|
|
222
|
+
inputs.to(self._model.device)
|
|
223
|
+
|
|
224
|
+
if stream:
|
|
225
|
+
it = self._streaming_chat_response(inputs, sanitized_config)
|
|
226
|
+
return self._to_chat_completion_chunks(it)
|
|
227
|
+
else:
|
|
228
|
+
# Generate response
|
|
229
|
+
with torch.no_grad():
|
|
230
|
+
outputs = self._model.generate(**inputs, **sanitized_config)
|
|
231
|
+
outputs = outputs[:, inputs["input_ids"].shape[1] :]
|
|
232
|
+
response = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
233
|
+
|
|
234
|
+
return generate_chat_completion(self.model_uid, response)
|
|
235
|
+
|
|
236
|
+
def _streaming_chat_response(
|
|
237
|
+
self, inputs: Dict, config: Dict
|
|
238
|
+
) -> Iterator[CompletionChunk]:
|
|
239
|
+
from threading import Thread
|
|
240
|
+
|
|
241
|
+
from transformers import TextIteratorStreamer
|
|
242
|
+
|
|
243
|
+
streamer = TextIteratorStreamer(
|
|
244
|
+
self._tokenizer, skip_prompt=True, skip_special_tokens=True
|
|
245
|
+
)
|
|
246
|
+
generation_kwargs = {**inputs, **config}
|
|
247
|
+
|
|
248
|
+
thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
|
|
249
|
+
thread.start()
|
|
250
|
+
|
|
251
|
+
completion_id = str(uuid.uuid1())
|
|
252
|
+
for new_text in streamer:
|
|
253
|
+
yield generate_completion_chunk(
|
|
254
|
+
chunk_text=new_text,
|
|
255
|
+
finish_reason=None,
|
|
256
|
+
chunk_id=completion_id,
|
|
257
|
+
model_uid=self.model_uid,
|
|
258
|
+
prompt_tokens=-1,
|
|
259
|
+
completion_tokens=-1,
|
|
260
|
+
total_tokens=-1,
|
|
261
|
+
)
|
|
262
|
+
yield generate_completion_chunk(
|
|
263
|
+
chunk_text=None,
|
|
264
|
+
finish_reason="stop",
|
|
265
|
+
chunk_id=completion_id,
|
|
266
|
+
model_uid=self.model_uid,
|
|
267
|
+
prompt_tokens=-1,
|
|
268
|
+
completion_tokens=-1,
|
|
269
|
+
total_tokens=-1,
|
|
270
|
+
has_choice=True,
|
|
271
|
+
has_content=False,
|
|
272
|
+
)
|
|
@@ -17,6 +17,7 @@ import sys
|
|
|
17
17
|
import uuid
|
|
18
18
|
from typing import Iterator, List, Optional, Union
|
|
19
19
|
|
|
20
|
+
from ....device_utils import is_npu_available
|
|
20
21
|
from ....model.utils import select_device
|
|
21
22
|
from ....types import (
|
|
22
23
|
ChatCompletion,
|
|
@@ -47,6 +48,8 @@ class Qwen2VLChatModel(PytorchChatModel):
|
|
|
47
48
|
llm_family = model_family.model_family or model_family.model_name
|
|
48
49
|
if "qwen2-vl-instruct".lower() in llm_family.lower():
|
|
49
50
|
return True
|
|
51
|
+
if "qvq-72b-preview".lower() in llm_family.lower():
|
|
52
|
+
return True
|
|
50
53
|
return False
|
|
51
54
|
|
|
52
55
|
def load(self):
|
|
@@ -71,6 +74,14 @@ class Qwen2VLChatModel(PytorchChatModel):
|
|
|
71
74
|
attn_implementation="flash_attention_2",
|
|
72
75
|
trust_remote_code=True,
|
|
73
76
|
).eval()
|
|
77
|
+
elif is_npu_available():
|
|
78
|
+
# Ascend do not support bf16
|
|
79
|
+
self._model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
80
|
+
self.model_path,
|
|
81
|
+
device_map="auto",
|
|
82
|
+
trust_remote_code=True,
|
|
83
|
+
torch_dtype="float16",
|
|
84
|
+
).eval()
|
|
74
85
|
else:
|
|
75
86
|
self._model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
76
87
|
self.model_path, device_map=device, trust_remote_code=True
|
|
@@ -112,7 +123,7 @@ class Qwen2VLChatModel(PytorchChatModel):
|
|
|
112
123
|
padding=True,
|
|
113
124
|
return_tensors="pt",
|
|
114
125
|
)
|
|
115
|
-
inputs = inputs.to(
|
|
126
|
+
inputs = inputs.to(self._device)
|
|
116
127
|
|
|
117
128
|
# Inference: Generation of the output
|
|
118
129
|
generated_ids = self._model.generate(
|
xinference/model/llm/utils.py
CHANGED
|
@@ -52,6 +52,7 @@ QWEN_TOOL_CALL_FAMILY = [
|
|
|
52
52
|
"qwen2-instruct",
|
|
53
53
|
"qwen2-moe-instruct",
|
|
54
54
|
"qwen2.5-instruct",
|
|
55
|
+
"qwen2.5-coder-instruct",
|
|
55
56
|
]
|
|
56
57
|
|
|
57
58
|
GLM4_TOOL_CALL_FAMILY = [
|
|
@@ -96,13 +97,18 @@ class ChatModelMixin:
|
|
|
96
97
|
return rendered
|
|
97
98
|
|
|
98
99
|
def get_full_context(
|
|
99
|
-
self,
|
|
100
|
-
|
|
100
|
+
self,
|
|
101
|
+
messages: List,
|
|
102
|
+
chat_template: str,
|
|
103
|
+
tokenizer=None,
|
|
104
|
+
tokenize=False,
|
|
105
|
+
**kwargs,
|
|
106
|
+
):
|
|
101
107
|
if tokenizer is not None:
|
|
102
108
|
try:
|
|
103
109
|
full_context = tokenizer.apply_chat_template(
|
|
104
110
|
messages,
|
|
105
|
-
tokenize=
|
|
111
|
+
tokenize=tokenize,
|
|
106
112
|
chat_template=chat_template,
|
|
107
113
|
add_generation_prompt=True,
|
|
108
114
|
**kwargs,
|
|
@@ -118,6 +124,25 @@ class ChatModelMixin:
|
|
|
118
124
|
# Compilation function uses a cache to avoid recompiling the same template
|
|
119
125
|
return self._build_from_raw_template(messages, chat_template, **kwargs)
|
|
120
126
|
|
|
127
|
+
@staticmethod
|
|
128
|
+
def convert_messages_with_content_list_to_str_conversion(
|
|
129
|
+
messages: List[Dict],
|
|
130
|
+
) -> List[Dict]:
|
|
131
|
+
"""
|
|
132
|
+
Handles messages with content list conversion, in order to support Cline, see GH#2659 .
|
|
133
|
+
"""
|
|
134
|
+
for message in messages:
|
|
135
|
+
texts = ""
|
|
136
|
+
msg_content = message.get("content")
|
|
137
|
+
if msg_content:
|
|
138
|
+
if isinstance(msg_content, str):
|
|
139
|
+
texts = msg_content
|
|
140
|
+
elif isinstance(msg_content, list):
|
|
141
|
+
texts = "\n".join(item.get("text", "") for item in msg_content)
|
|
142
|
+
if texts:
|
|
143
|
+
message["content"] = texts
|
|
144
|
+
return messages
|
|
145
|
+
|
|
121
146
|
@staticmethod
|
|
122
147
|
def get_specific_prompt(model_family: str, messages: List[ChatCompletionMessage]):
|
|
123
148
|
"""
|
|
@@ -70,6 +70,7 @@ class VLLMModelConfig(TypedDict, total=False):
|
|
|
70
70
|
max_model_len: Optional[int]
|
|
71
71
|
limit_mm_per_prompt: Optional[Dict[str, int]]
|
|
72
72
|
guided_decoding_backend: Optional[str]
|
|
73
|
+
scheduling_policy: Optional[str]
|
|
73
74
|
|
|
74
75
|
|
|
75
76
|
class VLLMGenerateConfig(TypedDict, total=False):
|
|
@@ -155,6 +156,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
|
|
|
155
156
|
VLLM_SUPPORTED_MODELS.append("qwen2.5-coder")
|
|
156
157
|
VLLM_SUPPORTED_CHAT_MODELS.append("qwen2.5-coder-instruct")
|
|
157
158
|
VLLM_SUPPORTED_CHAT_MODELS.append("QwQ-32B-Preview")
|
|
159
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("marco-o1")
|
|
158
160
|
|
|
159
161
|
|
|
160
162
|
if VLLM_INSTALLED and vllm.__version__ >= "0.3.2":
|
|
@@ -187,10 +189,14 @@ if VLLM_INSTALLED and vllm.__version__ > "0.5.3":
|
|
|
187
189
|
if VLLM_INSTALLED and vllm.__version__ >= "0.6.1":
|
|
188
190
|
VLLM_SUPPORTED_VISION_MODEL_LIST.append("internvl2")
|
|
189
191
|
|
|
192
|
+
if VLLM_INSTALLED and vllm.__version__ >= "0.6.2":
|
|
193
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("minicpm3-4b")
|
|
194
|
+
|
|
190
195
|
if VLLM_INSTALLED and vllm.__version__ >= "0.6.3":
|
|
191
196
|
VLLM_SUPPORTED_MODELS.append("llama-3.2-vision")
|
|
192
197
|
VLLM_SUPPORTED_VISION_MODEL_LIST.append("llama-3.2-vision-instruct")
|
|
193
198
|
VLLM_SUPPORTED_VISION_MODEL_LIST.append("qwen2-vl-instruct")
|
|
199
|
+
VLLM_SUPPORTED_VISION_MODEL_LIST.append("QvQ-72B-Preview")
|
|
194
200
|
|
|
195
201
|
|
|
196
202
|
class VLLMModel(LLM):
|
|
@@ -219,6 +225,10 @@ class VLLMModel(LLM):
|
|
|
219
225
|
self._engine = None
|
|
220
226
|
self.lora_modules = peft_model
|
|
221
227
|
self.lora_requests: List[LoRARequest] = []
|
|
228
|
+
self._xavier_config = None
|
|
229
|
+
|
|
230
|
+
def set_xavier_config(self, value: Optional[Dict]):
|
|
231
|
+
self._xavier_config = value # type: ignore
|
|
222
232
|
|
|
223
233
|
def load(self):
|
|
224
234
|
try:
|
|
@@ -244,7 +254,6 @@ class VLLMModel(LLM):
|
|
|
244
254
|
multiprocessing.set_start_method("fork", force=True)
|
|
245
255
|
|
|
246
256
|
self._model_config = self._sanitize_model_config(self._model_config)
|
|
247
|
-
|
|
248
257
|
if self.lora_modules is None:
|
|
249
258
|
self.lora_requests = []
|
|
250
259
|
else:
|
|
@@ -265,13 +274,34 @@ class VLLMModel(LLM):
|
|
|
265
274
|
f"Enable lora: {enable_lora}. Lora count: {max_loras}."
|
|
266
275
|
)
|
|
267
276
|
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
277
|
+
if self._xavier_config is not None:
|
|
278
|
+
from .xavier.engine import XavierEngine
|
|
279
|
+
|
|
280
|
+
# Enabling Xavier means that `enable_prefix_caching` is enabled by default.
|
|
281
|
+
self._model_config.setdefault("enable_prefix_caching", True)
|
|
282
|
+
xavier_transfer_block_num = self._model_config.pop(
|
|
283
|
+
"xavier_transfer_block_num", 512
|
|
284
|
+
)
|
|
285
|
+
self._xavier_config["transfer_block_num"] = xavier_transfer_block_num
|
|
286
|
+
engine_args = AsyncEngineArgs(
|
|
287
|
+
model=self.model_path,
|
|
288
|
+
enable_lora=enable_lora,
|
|
289
|
+
max_loras=max_loras,
|
|
290
|
+
**self._model_config,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
logger.debug(f"Start xavier for vllm with config: {self._xavier_config}")
|
|
294
|
+
self._engine = XavierEngine.from_engine_args(
|
|
295
|
+
engine_args, xavier_config=self._xavier_config
|
|
296
|
+
)
|
|
297
|
+
else:
|
|
298
|
+
engine_args = AsyncEngineArgs(
|
|
299
|
+
model=self.model_path,
|
|
300
|
+
enable_lora=enable_lora,
|
|
301
|
+
max_loras=max_loras,
|
|
302
|
+
**self._model_config,
|
|
303
|
+
)
|
|
304
|
+
self._engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
275
305
|
|
|
276
306
|
self._check_health_task = None
|
|
277
307
|
if hasattr(self._engine, "check_health"):
|
|
@@ -289,6 +319,9 @@ class VLLMModel(LLM):
|
|
|
289
319
|
model_executor.shutdown()
|
|
290
320
|
self._engine = None
|
|
291
321
|
|
|
322
|
+
async def init_xavier(self):
|
|
323
|
+
await self._engine.init_xavier()
|
|
324
|
+
|
|
292
325
|
async def _check_healthy(self, interval: int = 30):
|
|
293
326
|
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
|
294
327
|
|
|
@@ -327,7 +360,9 @@ class VLLMModel(LLM):
|
|
|
327
360
|
model_config.setdefault("quantization", None)
|
|
328
361
|
model_config.setdefault("max_model_len", None)
|
|
329
362
|
model_config.setdefault("guided_decoding_backend", "outlines")
|
|
330
|
-
|
|
363
|
+
# Add scheduling policy if vLLM version is 0.6.3 or higher
|
|
364
|
+
if vllm.__version__ >= "0.6.3":
|
|
365
|
+
model_config.setdefault("scheduling_policy", "fcfs")
|
|
331
366
|
return model_config
|
|
332
367
|
|
|
333
368
|
@staticmethod
|
|
@@ -769,6 +804,7 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
769
804
|
generate_config: Optional[Dict] = None,
|
|
770
805
|
request_id: Optional[str] = None,
|
|
771
806
|
) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
|
|
807
|
+
messages = self.convert_messages_with_content_list_to_str_conversion(messages)
|
|
772
808
|
tools = generate_config.pop("tools", []) if generate_config else None
|
|
773
809
|
model_family = self.model_family.model_family or self.model_family.model_name
|
|
774
810
|
full_context_kwargs = {}
|
|
@@ -859,6 +895,9 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
|
|
|
859
895
|
"image": 2, # default 2 images all chat
|
|
860
896
|
}
|
|
861
897
|
)
|
|
898
|
+
# Add scheduling policy if vLLM version is 0.6.3 or higher
|
|
899
|
+
if vllm.__version__ >= "0.6.3":
|
|
900
|
+
model_config.setdefault("scheduling_policy", "fcfs")
|
|
862
901
|
|
|
863
902
|
return model_config
|
|
864
903
|
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|