xinference 1.6.0.post1__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +79 -2
- xinference/client/restful/restful_client.py +65 -3
- xinference/conftest.py +0 -7
- xinference/core/media_interface.py +132 -8
- xinference/core/model.py +44 -6
- xinference/core/scheduler.py +1 -10
- xinference/core/supervisor.py +8 -17
- xinference/core/worker.py +5 -27
- xinference/deploy/cmdline.py +6 -2
- xinference/model/audio/chattts.py +24 -39
- xinference/model/audio/cosyvoice.py +18 -30
- xinference/model/audio/funasr.py +42 -0
- xinference/model/audio/model_spec.json +71 -1
- xinference/model/audio/model_spec_modelscope.json +76 -2
- xinference/model/audio/utils.py +75 -0
- xinference/model/core.py +1 -0
- xinference/model/embedding/__init__.py +74 -18
- xinference/model/embedding/core.py +98 -589
- xinference/model/embedding/embed_family.py +133 -0
- xinference/{thirdparty/omnilmm/train → model/embedding/flag}/__init__.py +1 -1
- xinference/model/embedding/flag/core.py +282 -0
- xinference/model/embedding/model_spec.json +24 -0
- xinference/model/embedding/model_spec_modelscope.json +24 -0
- xinference/model/embedding/sentence_transformers/__init__.py +13 -0
- xinference/model/embedding/sentence_transformers/core.py +399 -0
- xinference/model/embedding/vllm/core.py +95 -0
- xinference/model/image/model_spec.json +30 -3
- xinference/model/image/model_spec_modelscope.json +41 -2
- xinference/model/image/stable_diffusion/core.py +144 -53
- xinference/model/llm/__init__.py +6 -54
- xinference/model/llm/core.py +19 -5
- xinference/model/llm/llama_cpp/core.py +59 -3
- xinference/model/llm/llama_cpp/memory.py +457 -0
- xinference/model/llm/llm_family.json +247 -402
- xinference/model/llm/llm_family.py +88 -16
- xinference/model/llm/llm_family_modelscope.json +260 -421
- xinference/model/llm/llm_family_openmind_hub.json +0 -34
- xinference/model/llm/sglang/core.py +8 -0
- xinference/model/llm/transformers/__init__.py +27 -6
- xinference/model/llm/transformers/chatglm.py +4 -2
- xinference/model/llm/transformers/core.py +49 -28
- xinference/model/llm/transformers/deepseek_v2.py +6 -49
- xinference/model/llm/transformers/gemma3.py +119 -164
- xinference/model/llm/transformers/multimodal/__init__.py +13 -0
- xinference/model/llm/transformers/{cogagent.py → multimodal/cogagent.py} +58 -95
- xinference/model/llm/transformers/multimodal/core.py +205 -0
- xinference/model/llm/transformers/{deepseek_vl2.py → multimodal/deepseek_vl2.py} +59 -120
- xinference/model/llm/transformers/multimodal/gemma3.py +117 -0
- xinference/model/llm/transformers/{glm4v.py → multimodal/glm4v.py} +57 -93
- xinference/model/llm/transformers/multimodal/intern_vl.py +412 -0
- xinference/model/llm/transformers/{minicpmv26.py → multimodal/minicpmv26.py} +55 -102
- xinference/model/llm/transformers/{ovis2.py → multimodal/ovis2.py} +114 -175
- xinference/model/llm/transformers/{qwen-omni.py → multimodal/qwen-omni.py} +82 -167
- xinference/model/llm/transformers/multimodal/qwen2_audio.py +131 -0
- xinference/model/llm/transformers/{qwen2_vl.py → multimodal/qwen2_vl.py} +224 -256
- xinference/model/llm/transformers/opt.py +4 -2
- xinference/model/llm/transformers/utils.py +6 -37
- xinference/model/llm/utils.py +11 -0
- xinference/model/llm/vllm/core.py +7 -0
- xinference/model/rerank/core.py +91 -3
- xinference/model/rerank/model_spec.json +24 -0
- xinference/model/rerank/model_spec_modelscope.json +24 -0
- xinference/model/rerank/utils.py +20 -2
- xinference/model/utils.py +38 -1
- xinference/model/video/diffusers.py +65 -3
- xinference/model/video/model_spec.json +31 -4
- xinference/model/video/model_spec_modelscope.json +32 -4
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.013f296b.css +2 -0
- xinference/web/ui/build/static/css/main.013f296b.css.map +1 -0
- xinference/web/ui/build/static/js/main.8a9e3ba0.js +3 -0
- xinference/web/ui/build/static/js/main.8a9e3ba0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/34cfbfb7836e136ba3261cfd411cc554bf99ba24b35dcceebeaa4f008cb3c9dc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/567e49df411efb24425d289bb484758cb57067ca54f8b5c67fe4505f698deb96.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6595880facebca7ceace6f17cf21c3a5a9219a2f52fb0ba9f3cf1131eddbcf6b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/aa998bc2d9c11853add6b8a2e08f50327f56d8824ccaaec92d6dde1b305f0d85.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c748246b1d7bcebc16153be69f37e955bb2145526c47dd425aeeff70d3004dbc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e31234e95d60a5a7883fbcd70de2475dc1c88c90705df1a530abb68f86f80a51.json +1 -0
- xinference/web/ui/src/locales/en.json +21 -8
- xinference/web/ui/src/locales/ja.json +224 -0
- xinference/web/ui/src/locales/ko.json +224 -0
- xinference/web/ui/src/locales/zh.json +21 -8
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/METADATA +14 -11
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/RECORD +93 -100
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/WHEEL +1 -1
- xinference/model/llm/transformers/cogvlm2.py +0 -442
- xinference/model/llm/transformers/cogvlm2_video.py +0 -333
- xinference/model/llm/transformers/deepseek_vl.py +0 -280
- xinference/model/llm/transformers/glm_edge_v.py +0 -213
- xinference/model/llm/transformers/intern_vl.py +0 -526
- xinference/model/llm/transformers/internlm2.py +0 -94
- xinference/model/llm/transformers/minicpmv25.py +0 -193
- xinference/model/llm/transformers/omnilmm.py +0 -132
- xinference/model/llm/transformers/qwen2_audio.py +0 -179
- xinference/model/llm/transformers/qwen_vl.py +0 -360
- xinference/thirdparty/omnilmm/LICENSE +0 -201
- xinference/thirdparty/omnilmm/chat.py +0 -218
- xinference/thirdparty/omnilmm/constants.py +0 -4
- xinference/thirdparty/omnilmm/conversation.py +0 -332
- xinference/thirdparty/omnilmm/model/__init__.py +0 -1
- xinference/thirdparty/omnilmm/model/omnilmm.py +0 -595
- xinference/thirdparty/omnilmm/model/resampler.py +0 -166
- xinference/thirdparty/omnilmm/model/utils.py +0 -578
- xinference/thirdparty/omnilmm/train/train_utils.py +0 -150
- xinference/thirdparty/omnilmm/utils.py +0 -134
- xinference/web/ui/build/static/css/main.337afe76.css +0 -2
- xinference/web/ui/build/static/css/main.337afe76.css.map +0 -1
- xinference/web/ui/build/static/js/main.ae579a97.js +0 -3
- xinference/web/ui/build/static/js/main.ae579a97.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +0 -1
- /xinference/{thirdparty/omnilmm → model/embedding/vllm}/__init__.py +0 -0
- /xinference/web/ui/build/static/js/{main.ae579a97.js.LICENSE.txt → main.8a9e3ba0.js.LICENSE.txt} +0 -0
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -785,40 +785,6 @@
|
|
|
785
785
|
"</s>"
|
|
786
786
|
]
|
|
787
787
|
},
|
|
788
|
-
{
|
|
789
|
-
"version": 1,
|
|
790
|
-
"context_length": 8192,
|
|
791
|
-
"model_name": "cogvlm2",
|
|
792
|
-
"model_lang": [
|
|
793
|
-
"en",
|
|
794
|
-
"zh"
|
|
795
|
-
],
|
|
796
|
-
"model_ability": [
|
|
797
|
-
"chat",
|
|
798
|
-
"vision"
|
|
799
|
-
],
|
|
800
|
-
"model_description": "CogVLM2 have achieved good results in many lists compared to the previous generation of CogVLM open source models. Its excellent performance can compete with some non-open source models.",
|
|
801
|
-
"model_specs": [
|
|
802
|
-
{
|
|
803
|
-
"model_format": "pytorch",
|
|
804
|
-
"model_size_in_billions": 20,
|
|
805
|
-
"quantizations": [
|
|
806
|
-
"none"
|
|
807
|
-
],
|
|
808
|
-
"model_id": "AI-Research/cogvlm2-llama3-chinese-chat-19b",
|
|
809
|
-
"model_hub": "openmind_hub"
|
|
810
|
-
}
|
|
811
|
-
],
|
|
812
|
-
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = '<|begin_of_text|>' + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% else %}{{ '<|end_of_text|>' }}{% endif %}",
|
|
813
|
-
"stop_token_ids": [
|
|
814
|
-
128001,
|
|
815
|
-
128009
|
|
816
|
-
],
|
|
817
|
-
"stop": [
|
|
818
|
-
"<|end_of_text|>",
|
|
819
|
-
"<|eot_id|>"
|
|
820
|
-
]
|
|
821
|
-
},
|
|
822
788
|
{
|
|
823
789
|
"version": 1,
|
|
824
790
|
"context_length": 8192,
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
import importlib.util
|
|
15
15
|
import json
|
|
16
16
|
import logging
|
|
17
|
+
import multiprocessing
|
|
17
18
|
import sys
|
|
18
19
|
import threading
|
|
19
20
|
import time
|
|
@@ -107,7 +108,11 @@ SGLANG_SUPPORTED_CHAT_MODELS = [
|
|
|
107
108
|
"deepseek-r1-distill-qwen",
|
|
108
109
|
"deepseek-r1-distill-llama",
|
|
109
110
|
"deepseek-v3",
|
|
111
|
+
"deepseek-v3-0324",
|
|
110
112
|
"deepseek-r1",
|
|
113
|
+
"deepseek-r1-0528",
|
|
114
|
+
"deepseek-r1-0528-qwen3",
|
|
115
|
+
"deepseek-prover-v2",
|
|
111
116
|
"DianJin-R1",
|
|
112
117
|
"qwen3",
|
|
113
118
|
"HuatuoGPT-o1-Qwen2.5",
|
|
@@ -184,6 +189,9 @@ class SGLANGModel(LLM):
|
|
|
184
189
|
if sgl_port is None:
|
|
185
190
|
raise ValueError("Failed to find a port for sglang")
|
|
186
191
|
|
|
192
|
+
# fork may cause sglang stuck, force set to spawn
|
|
193
|
+
multiprocessing.set_start_method("spawn")
|
|
194
|
+
|
|
187
195
|
if self._n_worker > 1:
|
|
188
196
|
# distributed inference
|
|
189
197
|
self._model_config["nnodes"] = self._n_worker
|
|
@@ -16,12 +16,33 @@
|
|
|
16
16
|
import importlib
|
|
17
17
|
import os
|
|
18
18
|
import pkgutil
|
|
19
|
+
from typing import Dict
|
|
19
20
|
|
|
20
|
-
|
|
21
|
+
|
|
22
|
+
def import_submodules(package_path: str, package_name: str, globals_dict: Dict) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Recursively import all classes in submodules and subpackages
|
|
25
|
+
"""
|
|
26
|
+
for _, module_name, is_pkg in pkgutil.iter_modules([package_path]):
|
|
27
|
+
full_module_name = f"{package_name}.{module_name}"
|
|
28
|
+
|
|
29
|
+
if module_name.startswith(
|
|
30
|
+
("_", "test_")
|
|
31
|
+
): # Skip the modules which start with "_" or "test_"
|
|
32
|
+
continue
|
|
33
|
+
|
|
34
|
+
module = importlib.import_module(full_module_name)
|
|
35
|
+
globals_dict[module_name] = module
|
|
36
|
+
|
|
37
|
+
# If it's a pkg, recursive processing
|
|
38
|
+
if is_pkg:
|
|
39
|
+
subpackage_path = os.path.join(package_path, module_name)
|
|
40
|
+
import_submodules(subpackage_path, full_module_name, globals_dict)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# Get the path and name of the current package
|
|
21
44
|
__path__ = [os.path.dirname(os.path.abspath(__file__))]
|
|
45
|
+
__package__ = __name__
|
|
22
46
|
|
|
23
|
-
#
|
|
24
|
-
|
|
25
|
-
if not module_name.startswith("_"): # Skip modules starting with underscore
|
|
26
|
-
module = importlib.import_module(f"{__name__}.{module_name}")
|
|
27
|
-
globals()[module_name] = module
|
|
47
|
+
# Automatic import of all sub-modules and sub-packages
|
|
48
|
+
import_submodules(__path__[0], __package__, globals())
|
|
@@ -22,17 +22,19 @@ import torch
|
|
|
22
22
|
|
|
23
23
|
from ....core.scheduler import InferenceRequest
|
|
24
24
|
from ....types import ChatCompletion, ChatCompletionChunk, LoRA, PytorchGenerateConfig
|
|
25
|
-
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
25
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
|
|
26
26
|
from ..utils import (
|
|
27
27
|
GLM4_TOOL_CALL_FAMILY,
|
|
28
28
|
generate_chat_completion,
|
|
29
29
|
generate_completion_chunk,
|
|
30
30
|
)
|
|
31
|
-
from .core import PytorchChatModel, PytorchModelConfig
|
|
31
|
+
from .core import PytorchChatModel, PytorchModelConfig, register_non_default_model
|
|
32
32
|
|
|
33
33
|
logger = logging.getLogger(__name__)
|
|
34
34
|
|
|
35
35
|
|
|
36
|
+
@register_transformer
|
|
37
|
+
@register_non_default_model("glm4-chat", "glm4-chat-1m")
|
|
36
38
|
class ChatglmPytorchChatModel(PytorchChatModel):
|
|
37
39
|
def __init__(
|
|
38
40
|
self,
|
|
@@ -16,7 +16,7 @@ import json
|
|
|
16
16
|
import logging
|
|
17
17
|
import os
|
|
18
18
|
from functools import lru_cache
|
|
19
|
-
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
|
|
19
|
+
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
|
|
20
20
|
|
|
21
21
|
import torch
|
|
22
22
|
|
|
@@ -45,36 +45,17 @@ from ..utils import (
|
|
|
45
45
|
QWEN_TOOL_CALL_FAMILY,
|
|
46
46
|
ChatModelMixin,
|
|
47
47
|
)
|
|
48
|
-
from .utils import
|
|
48
|
+
from .utils import (
|
|
49
|
+
_get_pad_param,
|
|
50
|
+
get_context_length,
|
|
51
|
+
get_max_src_len,
|
|
52
|
+
pad_prefill_tokens,
|
|
53
|
+
)
|
|
49
54
|
|
|
50
55
|
logger = logging.getLogger(__name__)
|
|
51
56
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
"glm4-chat",
|
|
55
|
-
"glm4-chat-1m",
|
|
56
|
-
"qwen-vl-chat",
|
|
57
|
-
"OmniLMM",
|
|
58
|
-
"deepseek-vl-chat",
|
|
59
|
-
"cogvlm2",
|
|
60
|
-
"cogvlm2-video-llama3-chat",
|
|
61
|
-
"MiniCPM-Llama3-V-2_5",
|
|
62
|
-
"MiniCPM-V-2.6",
|
|
63
|
-
"glm-4v",
|
|
64
|
-
"qwen2-audio",
|
|
65
|
-
"qwen2-audio-instruct",
|
|
66
|
-
"deepseek-v2",
|
|
67
|
-
"deepseek-v2-chat",
|
|
68
|
-
"deepseek-v2.5",
|
|
69
|
-
"deepseek-v2-chat-0628",
|
|
70
|
-
"glm-edge-v",
|
|
71
|
-
"QvQ-72B-Preview",
|
|
72
|
-
"cogagent",
|
|
73
|
-
"gemma-3-1b-it",
|
|
74
|
-
"gemma-3-it",
|
|
75
|
-
"Ovis2",
|
|
76
|
-
"deepseek-vl2",
|
|
77
|
-
]
|
|
57
|
+
# !!!!! Do not add model_name to this list, use `register_non_default_model` below instead!
|
|
58
|
+
NON_DEFAULT_MODEL_LIST: List[str] = []
|
|
78
59
|
|
|
79
60
|
|
|
80
61
|
# Define the decorator to support multiple names registration
|
|
@@ -551,6 +532,36 @@ class PytorchModel(LLM):
|
|
|
551
532
|
def prepare_sanitize_generate_config(self, req: InferenceRequest):
|
|
552
533
|
return self._sanitize_generate_config(req.generate_config)
|
|
553
534
|
|
|
535
|
+
def merge_kv_cache(self, past_cache, new_cache):
|
|
536
|
+
from torch.nn.functional import pad
|
|
537
|
+
from transformers import DynamicCache
|
|
538
|
+
|
|
539
|
+
_, seq_len_idx = self.get_batch_size_and_seq_len_indexes_from_kv()
|
|
540
|
+
past_seq_len = past_cache[0][0].shape[seq_len_idx]
|
|
541
|
+
new_seq_len = new_cache[0][0].shape[seq_len_idx]
|
|
542
|
+
if past_seq_len != new_seq_len:
|
|
543
|
+
padding_target = new_cache if past_seq_len > new_seq_len else past_cache
|
|
544
|
+
padding_len = abs(past_seq_len - new_seq_len)
|
|
545
|
+
pad_param = _get_pad_param(seq_len_idx, padding_len)
|
|
546
|
+
for idx in range(len(padding_target)):
|
|
547
|
+
k = padding_target.key_cache[idx]
|
|
548
|
+
v = padding_target.value_cache[idx]
|
|
549
|
+
_k = pad(k, pad_param)
|
|
550
|
+
_v = pad(v, pad_param)
|
|
551
|
+
padding_target.key_cache[idx] = _k
|
|
552
|
+
padding_target.value_cache[idx] = _v
|
|
553
|
+
|
|
554
|
+
ret_kv = DynamicCache()
|
|
555
|
+
for idx in range(len(past_cache)):
|
|
556
|
+
k1, k2 = new_cache.key_cache[idx], past_cache.key_cache[idx]
|
|
557
|
+
v1, v2 = new_cache.value_cache[idx], past_cache.value_cache[idx]
|
|
558
|
+
ret_kv.update(
|
|
559
|
+
torch.cat((k1, k2), 0).contiguous(),
|
|
560
|
+
torch.cat((v1, v2), 0).contiguous(),
|
|
561
|
+
idx,
|
|
562
|
+
)
|
|
563
|
+
return ret_kv
|
|
564
|
+
|
|
554
565
|
def prepare_batch_inference(self, req_list: List[InferenceRequest]):
|
|
555
566
|
# check some parameters
|
|
556
567
|
for r in req_list:
|
|
@@ -642,6 +653,16 @@ class PytorchModel(LLM):
|
|
|
642
653
|
)
|
|
643
654
|
self.handle_batch_inference_results(req_list)
|
|
644
655
|
|
|
656
|
+
def build_reduced_kv_cache(self, cache, skipped_indexes: Set[int]):
|
|
657
|
+
batch_size = cache.key_cache[0].shape[0]
|
|
658
|
+
batch_slices = [num for num in range(batch_size) if num not in skipped_indexes]
|
|
659
|
+
for idx in range(len(cache)):
|
|
660
|
+
cache.key_cache[idx] = cache.key_cache[idx][batch_slices, ::].contiguous()
|
|
661
|
+
cache.value_cache[idx] = cache.value_cache[idx][
|
|
662
|
+
batch_slices, ::
|
|
663
|
+
].contiguous()
|
|
664
|
+
return cache
|
|
665
|
+
|
|
645
666
|
|
|
646
667
|
class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
647
668
|
def __init__(
|
|
@@ -15,59 +15,16 @@ import logging
|
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
17
|
|
|
18
|
-
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
19
|
-
from .core import PytorchChatModel,
|
|
18
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
|
|
19
|
+
from .core import PytorchChatModel, register_non_default_model
|
|
20
20
|
|
|
21
21
|
logger = logging.getLogger(__name__)
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
AutoModelForCausalLM,
|
|
29
|
-
AutoTokenizer,
|
|
30
|
-
GenerationConfig,
|
|
31
|
-
)
|
|
32
|
-
except ImportError:
|
|
33
|
-
error_message = "Failed to import module 'transformers'"
|
|
34
|
-
installation_guide = [
|
|
35
|
-
"Please make sure 'transformers' is installed. ",
|
|
36
|
-
"You can install it by `pip install transformers`\n",
|
|
37
|
-
]
|
|
38
|
-
|
|
39
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
40
|
-
|
|
41
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
42
|
-
self.model_path,
|
|
43
|
-
trust_remote_code=kwargs["trust_remote_code"],
|
|
44
|
-
)
|
|
45
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
46
|
-
self.model_path,
|
|
47
|
-
attn_implementation="eager",
|
|
48
|
-
torch_dtype=torch.bfloat16,
|
|
49
|
-
trust_remote_code=True,
|
|
50
|
-
device_map="auto",
|
|
51
|
-
**kwargs,
|
|
52
|
-
)
|
|
53
|
-
model.generation_config = GenerationConfig.from_pretrained(self.model_path)
|
|
54
|
-
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
|
55
|
-
return model, tokenizer
|
|
56
|
-
|
|
57
|
-
@classmethod
|
|
58
|
-
def match_json(
|
|
59
|
-
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
60
|
-
) -> bool:
|
|
61
|
-
if llm_spec.model_format != "pytorch":
|
|
62
|
-
return False
|
|
63
|
-
model_family = llm_family.model_family or llm_family.model_name
|
|
64
|
-
if "deepseek-v2" not in model_family:
|
|
65
|
-
return False
|
|
66
|
-
if "generate" not in llm_family.model_ability:
|
|
67
|
-
return False
|
|
68
|
-
return True
|
|
69
|
-
|
|
70
|
-
|
|
24
|
+
@register_transformer
|
|
25
|
+
@register_non_default_model(
|
|
26
|
+
"deepseek-v2-chat", "deepseek-v2.5", "deepseek-v2-chat-0628"
|
|
27
|
+
)
|
|
71
28
|
class DeepSeekV2PytorchChatModel(PytorchChatModel):
|
|
72
29
|
def _load_model(self, **kwargs):
|
|
73
30
|
try:
|
|
@@ -11,29 +11,18 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
|
|
16
14
|
import logging
|
|
17
|
-
import
|
|
18
|
-
|
|
19
|
-
from
|
|
20
|
-
|
|
21
|
-
from
|
|
22
|
-
from ....types import (
|
|
23
|
-
ChatCompletion,
|
|
24
|
-
ChatCompletionChunk,
|
|
25
|
-
ChatCompletionMessage,
|
|
26
|
-
CompletionChunk,
|
|
27
|
-
PytorchModelConfig,
|
|
28
|
-
)
|
|
29
|
-
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
30
|
-
from ..utils import generate_chat_completion, generate_completion_chunk
|
|
31
|
-
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
32
|
-
from .utils import cache_clean
|
|
15
|
+
from typing import Dict, List, Set
|
|
16
|
+
|
|
17
|
+
from ....core.scheduler import InferenceRequest
|
|
18
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
|
|
19
|
+
from .core import PytorchChatModel, register_non_default_model
|
|
33
20
|
|
|
34
21
|
logger = logging.getLogger(__name__)
|
|
35
22
|
|
|
36
23
|
|
|
24
|
+
@register_transformer
|
|
25
|
+
@register_non_default_model("gemma-3-1b-it")
|
|
37
26
|
class Gemma3TextChatModel(PytorchChatModel):
|
|
38
27
|
@classmethod
|
|
39
28
|
def match_json(
|
|
@@ -46,163 +35,129 @@ class Gemma3TextChatModel(PytorchChatModel):
|
|
|
46
35
|
return True
|
|
47
36
|
return False
|
|
48
37
|
|
|
38
|
+
def _load_model(self, **kwargs):
|
|
39
|
+
import torch
|
|
40
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
49
41
|
|
|
50
|
-
|
|
51
|
-
def __init__(self, *args, **kwargs):
|
|
52
|
-
super().__init__(*args, **kwargs)
|
|
53
|
-
self._tokenizer = None
|
|
54
|
-
self._model = None
|
|
55
|
-
self._device = None
|
|
56
|
-
self._processor = None
|
|
57
|
-
|
|
58
|
-
@classmethod
|
|
59
|
-
def match_json(
|
|
60
|
-
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
61
|
-
) -> bool:
|
|
62
|
-
if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
63
|
-
return False
|
|
64
|
-
llm_family = model_family.model_family or model_family.model_name
|
|
65
|
-
if "gemma-3-it".lower() in llm_family.lower():
|
|
66
|
-
return True
|
|
67
|
-
return False
|
|
68
|
-
|
|
69
|
-
def _sanitize_model_config(
|
|
70
|
-
self, pytorch_model_config: Optional[PytorchModelConfig]
|
|
71
|
-
) -> PytorchModelConfig:
|
|
72
|
-
pytorch_model_config = super()._sanitize_model_config(pytorch_model_config)
|
|
73
|
-
assert pytorch_model_config is not None
|
|
74
|
-
pytorch_model_config.setdefault("min_pixels", 256 * 28 * 28)
|
|
75
|
-
pytorch_model_config.setdefault("max_pixels", 1280 * 28 * 28)
|
|
76
|
-
return pytorch_model_config
|
|
77
|
-
|
|
78
|
-
def load(self):
|
|
79
|
-
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
|
80
|
-
|
|
81
|
-
device = self._pytorch_model_config.get("device", "auto")
|
|
82
|
-
device = select_device(device)
|
|
83
|
-
self._device = device
|
|
84
|
-
# for multiple GPU, set back to auto to make multiple devices work
|
|
85
|
-
device = "auto" if device == "cuda" else device
|
|
86
|
-
min_pixels = self._pytorch_model_config.get("min_pixels")
|
|
87
|
-
max_pixels = self._pytorch_model_config.get("max_pixels")
|
|
88
|
-
kwargs = self.apply_bnb_quantization()
|
|
89
|
-
self._processor = AutoProcessor.from_pretrained(
|
|
42
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
90
43
|
self.model_path,
|
|
91
|
-
|
|
92
|
-
|
|
44
|
+
trust_remote_code=kwargs["trust_remote_code"],
|
|
45
|
+
revision=kwargs["revision"],
|
|
93
46
|
)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
self.model_path,
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
@cache_clean
|
|
100
|
-
def chat(
|
|
101
|
-
self,
|
|
102
|
-
messages: List[ChatCompletionMessage], # type: ignore
|
|
103
|
-
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
104
|
-
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
105
|
-
messages = self._transform_messages(messages)
|
|
106
|
-
|
|
107
|
-
generate_config = generate_config if generate_config else {}
|
|
108
|
-
|
|
109
|
-
stream = generate_config.get("stream", False) if generate_config else False
|
|
110
|
-
|
|
111
|
-
if stream:
|
|
112
|
-
it = self._generate_stream(messages, generate_config)
|
|
113
|
-
return self._to_chat_completion_chunks(it)
|
|
114
|
-
else:
|
|
115
|
-
c = self._generate(messages, generate_config)
|
|
116
|
-
return c
|
|
117
|
-
|
|
118
|
-
def _generate(
|
|
119
|
-
self, messages: List, config: PytorchGenerateConfig = {}
|
|
120
|
-
) -> ChatCompletion:
|
|
121
|
-
inputs = self._processor.apply_chat_template(
|
|
122
|
-
messages,
|
|
123
|
-
add_generation_prompt=True,
|
|
124
|
-
tokenize=True,
|
|
125
|
-
return_dict=True,
|
|
126
|
-
return_tensors="pt",
|
|
127
|
-
).to(self._device)
|
|
128
|
-
input_len = inputs["input_ids"].shape[-1]
|
|
129
|
-
|
|
130
|
-
generation = self._model.generate(
|
|
131
|
-
**inputs,
|
|
132
|
-
do_sample=False,
|
|
133
|
-
max_new_tokens=config.get("max_tokens", 512),
|
|
134
|
-
temperature=config.get("temperature", 1),
|
|
47
|
+
kwargs["torch_dtype"] = torch.bfloat16
|
|
48
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
49
|
+
self.model_path,
|
|
50
|
+
**kwargs,
|
|
135
51
|
)
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
decoded = self._processor.decode(generation, skip_special_tokens=True)
|
|
139
|
-
return generate_chat_completion(self.model_uid, decoded)
|
|
52
|
+
self._device = model.device
|
|
53
|
+
return model, tokenizer
|
|
140
54
|
|
|
141
|
-
def
|
|
142
|
-
|
|
143
|
-
) -> Iterator[CompletionChunk]:
|
|
144
|
-
from threading import Thread
|
|
55
|
+
def _get_full_prompt(self, messages: List[Dict], tools, generate_config: dict):
|
|
56
|
+
return messages
|
|
145
57
|
|
|
146
|
-
|
|
58
|
+
def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
|
|
59
|
+
"""
|
|
60
|
+
Note that it is important to prepare `past_key_values` for gemma3 prefill phase
|
|
61
|
+
"""
|
|
62
|
+
from transformers import HybridCache
|
|
147
63
|
|
|
148
|
-
inputs = self.
|
|
149
|
-
|
|
150
|
-
add_generation_prompt=True,
|
|
64
|
+
inputs = self._tokenizer.apply_chat_template(
|
|
65
|
+
prompts,
|
|
151
66
|
tokenize=True,
|
|
152
|
-
|
|
67
|
+
add_generation_prompt=True,
|
|
153
68
|
return_tensors="pt",
|
|
69
|
+
return_dict=True,
|
|
70
|
+
padding=True,
|
|
154
71
|
).to(self._device)
|
|
155
72
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
73
|
+
for i, r in enumerate(req_list):
|
|
74
|
+
r.prompt_tokens = inputs["input_ids"][i].tolist()
|
|
75
|
+
|
|
76
|
+
batch_size = len(prompts)
|
|
77
|
+
max_cache_len = self.get_context_len()
|
|
78
|
+
kv = HybridCache(
|
|
79
|
+
self._model.config,
|
|
80
|
+
max_batch_size=batch_size,
|
|
81
|
+
max_cache_len=max_cache_len,
|
|
82
|
+
dtype=self._model.dtype,
|
|
83
|
+
device=self._device,
|
|
84
|
+
)
|
|
85
|
+
return {**inputs, "past_key_values": kv}
|
|
86
|
+
|
|
87
|
+
def merge_kv_cache(self, past_cache, new_cache):
|
|
88
|
+
"""
|
|
89
|
+
Note that: DO NOT use the `update` func of `HybridCache`, that is unrelated to KV cache merging.
|
|
90
|
+
"""
|
|
91
|
+
import torch
|
|
92
|
+
from transformers import HybridCache
|
|
93
|
+
|
|
94
|
+
max_cache_len = new_cache.max_cache_len
|
|
95
|
+
batch_size = past_cache.max_batch_size + new_cache.max_batch_size
|
|
96
|
+
|
|
97
|
+
kv_batch = HybridCache(
|
|
98
|
+
self._model.config,
|
|
99
|
+
max_batch_size=batch_size,
|
|
100
|
+
max_cache_len=max_cache_len,
|
|
101
|
+
dtype=self._model.dtype,
|
|
102
|
+
device=self._device,
|
|
159
103
|
)
|
|
160
104
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
for
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
has_content=False,
|
|
105
|
+
new_ks = [
|
|
106
|
+
torch.cat([nk, pk], dim=0).contiguous()
|
|
107
|
+
for nk, pk in zip(new_cache.key_cache, past_cache.key_cache)
|
|
108
|
+
]
|
|
109
|
+
new_vs = [
|
|
110
|
+
torch.cat([nv, pv], dim=0).contiguous()
|
|
111
|
+
for nv, pv in zip(new_cache.value_cache, past_cache.value_cache)
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
kv_batch.key_cache.clear()
|
|
115
|
+
kv_batch.value_cache.clear()
|
|
116
|
+
kv_batch.key_cache.extend(new_ks)
|
|
117
|
+
kv_batch.value_cache.extend(new_vs)
|
|
118
|
+
|
|
119
|
+
return kv_batch
|
|
120
|
+
|
|
121
|
+
def build_decode_attention_mask(
|
|
122
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
123
|
+
):
|
|
124
|
+
"""
|
|
125
|
+
In Gemma3's inference script, attention_mask is handled internally for decode phase.
|
|
126
|
+
"""
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
def build_decode_position_ids(
|
|
130
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
131
|
+
):
|
|
132
|
+
"""
|
|
133
|
+
In Gemma3's inference script, position_ids is handled internally for decode phase.
|
|
134
|
+
"""
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
def build_reduced_kv_cache(self, cache, skipped_indexes: Set[int]):
|
|
138
|
+
from transformers import HybridCache
|
|
139
|
+
|
|
140
|
+
batch_slices = [
|
|
141
|
+
num for num in range(cache.max_batch_size) if num not in skipped_indexes
|
|
142
|
+
]
|
|
143
|
+
batch_size = len(batch_slices)
|
|
144
|
+
|
|
145
|
+
kv_batch = HybridCache(
|
|
146
|
+
self._model.config,
|
|
147
|
+
max_batch_size=batch_size,
|
|
148
|
+
max_cache_len=cache.max_cache_len,
|
|
149
|
+
dtype=self._model.dtype,
|
|
150
|
+
device=self._device,
|
|
208
151
|
)
|
|
152
|
+
|
|
153
|
+
ks = cache.key_cache
|
|
154
|
+
vs = cache.value_cache
|
|
155
|
+
|
|
156
|
+
new_ks = [_k[batch_slices, ::].contiguous() for _k in ks]
|
|
157
|
+
new_vs = [_v[batch_slices, ::].contiguous() for _v in vs]
|
|
158
|
+
kv_batch.key_cache.clear()
|
|
159
|
+
kv_batch.value_cache.clear()
|
|
160
|
+
kv_batch.key_cache.extend(new_ks)
|
|
161
|
+
kv_batch.value_cache.extend(new_vs)
|
|
162
|
+
|
|
163
|
+
return kv_batch
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|