xinference 1.5.0.post2__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +107 -11
- xinference/client/restful/restful_client.py +51 -11
- xinference/constants.py +5 -1
- xinference/core/media_interface.py +758 -0
- xinference/core/model.py +49 -9
- xinference/core/supervisor.py +1 -1
- xinference/core/utils.py +1 -1
- xinference/core/worker.py +33 -39
- xinference/deploy/cmdline.py +17 -0
- xinference/deploy/utils.py +0 -3
- xinference/model/audio/__init__.py +16 -27
- xinference/model/audio/core.py +2 -1
- xinference/model/audio/cosyvoice.py +4 -2
- xinference/model/audio/model_spec.json +63 -46
- xinference/model/audio/model_spec_modelscope.json +31 -14
- xinference/model/embedding/__init__.py +16 -24
- xinference/model/image/__init__.py +15 -25
- xinference/model/llm/__init__.py +40 -115
- xinference/model/llm/core.py +29 -6
- xinference/model/llm/llama_cpp/core.py +30 -347
- xinference/model/llm/llm_family.json +1674 -2203
- xinference/model/llm/llm_family.py +71 -7
- xinference/model/llm/llm_family_csghub.json +0 -32
- xinference/model/llm/llm_family_modelscope.json +1838 -2016
- xinference/model/llm/llm_family_openmind_hub.json +19 -325
- xinference/model/llm/lmdeploy/core.py +7 -2
- xinference/model/llm/mlx/core.py +23 -7
- xinference/model/llm/reasoning_parser.py +281 -5
- xinference/model/llm/sglang/core.py +39 -11
- xinference/model/llm/transformers/chatglm.py +9 -2
- xinference/model/llm/transformers/cogagent.py +10 -12
- xinference/model/llm/transformers/cogvlm2.py +6 -3
- xinference/model/llm/transformers/cogvlm2_video.py +3 -6
- xinference/model/llm/transformers/core.py +58 -60
- xinference/model/llm/transformers/deepseek_v2.py +4 -2
- xinference/model/llm/transformers/deepseek_vl.py +10 -4
- xinference/model/llm/transformers/deepseek_vl2.py +9 -4
- xinference/model/llm/transformers/gemma3.py +4 -5
- xinference/model/llm/transformers/glm4v.py +3 -21
- xinference/model/llm/transformers/glm_edge_v.py +3 -20
- xinference/model/llm/transformers/intern_vl.py +3 -6
- xinference/model/llm/transformers/internlm2.py +1 -1
- xinference/model/llm/transformers/minicpmv25.py +4 -2
- xinference/model/llm/transformers/minicpmv26.py +5 -3
- xinference/model/llm/transformers/omnilmm.py +1 -1
- xinference/model/llm/transformers/opt.py +1 -1
- xinference/model/llm/transformers/ovis2.py +302 -0
- xinference/model/llm/transformers/qwen-omni.py +8 -1
- xinference/model/llm/transformers/qwen2_audio.py +3 -1
- xinference/model/llm/transformers/qwen2_vl.py +5 -1
- xinference/model/llm/transformers/qwen_vl.py +5 -2
- xinference/model/llm/utils.py +96 -45
- xinference/model/llm/vllm/core.py +108 -24
- xinference/model/llm/vllm/distributed_executor.py +8 -7
- xinference/model/llm/vllm/xavier/allocator.py +1 -1
- xinference/model/llm/vllm/xavier/block_manager.py +1 -1
- xinference/model/llm/vllm/xavier/block_tracker.py +3 -3
- xinference/model/llm/vllm/xavier/executor.py +1 -1
- xinference/model/llm/vllm/xavier/test/test_xavier.py +2 -11
- xinference/model/rerank/__init__.py +13 -24
- xinference/model/video/__init__.py +15 -25
- xinference/model/video/core.py +3 -3
- xinference/model/video/diffusers.py +157 -13
- xinference/model/video/model_spec.json +100 -0
- xinference/model/video/model_spec_modelscope.json +104 -0
- xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
- xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
- xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
- xinference/thirdparty/cosyvoice/bin/train.py +7 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
- xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
- xinference/thirdparty/cosyvoice/cli/model.py +140 -155
- xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
- xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
- xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
- xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
- xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
- xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
- xinference/thirdparty/cosyvoice/utils/common.py +1 -1
- xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
- xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
- xinference/types.py +2 -71
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.0f6523be.css → main.337afe76.css} +2 -2
- xinference/web/ui/build/static/css/main.337afe76.css.map +1 -0
- xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
- xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6798e126f3bc5f95a4c16a9c2ad52ffe77970c62406d83e20604dfda7ffd2247.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b617f7d21a95045fc57b26a9373551740f1978a826134cbf705c3a1bf8714a93.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c1506cb142151366074975f30fa1ff9cd6e5e978b62a4b074dfc16fe08d70d75.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c5c7c2cd1b863ce41adff2c4737bba06eef3a1acf28288cb83d992060f6b8923.json +1 -0
- xinference/web/ui/src/locales/en.json +7 -4
- xinference/web/ui/src/locales/zh.json +7 -4
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/METADATA +56 -36
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/RECORD +120 -121
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/WHEEL +1 -1
- xinference/core/image_interface.py +0 -377
- xinference/model/llm/transformers/compression.py +0 -258
- xinference/model/llm/transformers/yi_vl.py +0 -239
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
- xinference/web/ui/build/static/css/main.0f6523be.css.map +0 -1
- xinference/web/ui/build/static/js/main.4b67a723.js +0 -3
- xinference/web/ui/build/static/js/main.4b67a723.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8f9af2979e45d4648f0cfae108363e58ee421c29a9d4e7329b6f06d9adfd4133.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/9c8b1a86e7c65b2b2599a205e30920652d6c2105f926508ef5bcf29a3ef4ce76.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e4ba658c6b3b0490910acdae0c535a892257efb61539a24adf8038fc653bd22f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/efe7cd132c27a8f9fd5352a394c491fd5fb0da0348cf9fcbd923164a32365eab.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f199e8173f6409a5802ed44acb95f218388131136504b2e9132129e150c92f9a.json +0 -1
- /xinference/web/ui/build/static/js/{main.4b67a723.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/top_level.txt +0 -0
xinference/model/llm/mlx/core.py
CHANGED
|
@@ -11,7 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
14
|
+
import importlib.util
|
|
15
15
|
import logging
|
|
16
16
|
import platform
|
|
17
17
|
import sys
|
|
@@ -160,7 +160,10 @@ class MLXModel(LLM):
|
|
|
160
160
|
|
|
161
161
|
def load(self):
|
|
162
162
|
reasoning_content = self._model_config.pop("reasoning_content")
|
|
163
|
-
self.
|
|
163
|
+
enable_thinking = self._model_config.pop("enable_thinking", True)
|
|
164
|
+
self.prepare_parse_reasoning_content(
|
|
165
|
+
reasoning_content, enable_thinking=enable_thinking
|
|
166
|
+
)
|
|
164
167
|
|
|
165
168
|
kwargs = {}
|
|
166
169
|
kwargs["revision"] = self._model_config.get(
|
|
@@ -172,7 +175,11 @@ class MLXModel(LLM):
|
|
|
172
175
|
self._model, self._tokenizer = self._load_model(**kwargs)
|
|
173
176
|
|
|
174
177
|
@classmethod
|
|
175
|
-
def
|
|
178
|
+
def check_lib(cls) -> bool:
|
|
179
|
+
return importlib.util.find_spec("mlx_lm") is not None
|
|
180
|
+
|
|
181
|
+
@classmethod
|
|
182
|
+
def match_json(
|
|
176
183
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
177
184
|
) -> bool:
|
|
178
185
|
if llm_spec.model_format not in ["mlx"]:
|
|
@@ -423,7 +430,7 @@ class MLXChatModel(MLXModel, ChatModelMixin):
|
|
|
423
430
|
return generate_config
|
|
424
431
|
|
|
425
432
|
@classmethod
|
|
426
|
-
def
|
|
433
|
+
def match_json(
|
|
427
434
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
428
435
|
) -> bool:
|
|
429
436
|
if llm_spec.model_format not in ["mlx"]:
|
|
@@ -445,7 +452,9 @@ class MLXChatModel(MLXModel, ChatModelMixin):
|
|
|
445
452
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
446
453
|
model_family = self.model_family.model_family or self.model_family.model_name
|
|
447
454
|
tools = generate_config.pop("tools", []) if generate_config else None
|
|
448
|
-
full_context_kwargs =
|
|
455
|
+
full_context_kwargs = (
|
|
456
|
+
self._get_chat_template_kwargs_from_generate_config(generate_config, self.reasoning_parser) or {} # type: ignore
|
|
457
|
+
)
|
|
449
458
|
if tools:
|
|
450
459
|
if (
|
|
451
460
|
model_family in QWEN_TOOL_CALL_FAMILY
|
|
@@ -476,7 +485,11 @@ class MLXChatModel(MLXModel, ChatModelMixin):
|
|
|
476
485
|
|
|
477
486
|
class MLXVisionModel(MLXModel, ChatModelMixin):
|
|
478
487
|
@classmethod
|
|
479
|
-
def
|
|
488
|
+
def check_lib(cls) -> bool:
|
|
489
|
+
return importlib.util.find_spec("mlx_vlm") is not None
|
|
490
|
+
|
|
491
|
+
@classmethod
|
|
492
|
+
def match_json(
|
|
480
493
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
481
494
|
) -> bool:
|
|
482
495
|
if llm_spec.model_format not in ["mlx"]:
|
|
@@ -623,7 +636,10 @@ class MLXVisionModel(MLXModel, ChatModelMixin):
|
|
|
623
636
|
if "internvl2" not in model_family.lower():
|
|
624
637
|
from qwen_vl_utils import process_vision_info
|
|
625
638
|
|
|
626
|
-
full_context_kwargs =
|
|
639
|
+
full_context_kwargs = (
|
|
640
|
+
self._get_chat_template_kwargs_from_generate_config(generate_config, self.reasoning_parser) # type: ignore
|
|
641
|
+
or {}
|
|
642
|
+
)
|
|
627
643
|
if tools and model_family in QWEN_TOOL_CALL_FAMILY:
|
|
628
644
|
full_context_kwargs["tools"] = tools
|
|
629
645
|
assert self.model_family.chat_template is not None
|
|
@@ -1,20 +1,33 @@
|
|
|
1
1
|
import re
|
|
2
|
-
from typing import Optional, Tuple, Union
|
|
2
|
+
from typing import Any, AsyncGenerator, Dict, Iterator, List, Optional, Tuple, Union
|
|
3
3
|
|
|
4
|
-
from ...types import
|
|
4
|
+
from ...types import (
|
|
5
|
+
ChatCompletionChunk,
|
|
6
|
+
ChatCompletionChunkDelta,
|
|
7
|
+
CompletionChoice,
|
|
8
|
+
CompletionChunk,
|
|
9
|
+
)
|
|
5
10
|
|
|
6
11
|
|
|
7
12
|
class ReasoningParser:
|
|
8
13
|
"""Reasoning parser for reasoning model."""
|
|
9
14
|
|
|
10
15
|
def __init__(
|
|
11
|
-
self,
|
|
16
|
+
self,
|
|
17
|
+
reasoning_content: bool = False,
|
|
18
|
+
reasoning_start_tag: str = "",
|
|
19
|
+
reasoning_end_tag: str = "",
|
|
20
|
+
enable_thinking: bool = True,
|
|
12
21
|
):
|
|
22
|
+
self.reasoning_content = reasoning_content
|
|
13
23
|
self.reasoning_start_tag = reasoning_start_tag
|
|
14
24
|
self.reasoning_end_tag = reasoning_end_tag
|
|
15
25
|
self.reasoning_regex = re.compile(
|
|
16
26
|
rf"{self.reasoning_start_tag}(.*?){self.reasoning_end_tag}", re.DOTALL
|
|
17
27
|
)
|
|
28
|
+
# enable_thinking can be set to False only for hybrid model
|
|
29
|
+
# e.g. qwen3, which can support both thinking and non-thinking
|
|
30
|
+
self.enable_thinking = enable_thinking
|
|
18
31
|
|
|
19
32
|
def extract_reasoning_content_streaming(
|
|
20
33
|
self,
|
|
@@ -62,9 +75,9 @@ class ReasoningParser:
|
|
|
62
75
|
delta["content"] = None
|
|
63
76
|
return delta
|
|
64
77
|
elif self.reasoning_start_tag in delta_text:
|
|
78
|
+
start_idx = delta_text.find(self.reasoning_start_tag)
|
|
65
79
|
if self.reasoning_end_tag in delta_text:
|
|
66
80
|
# <think> in delta, </think> in delta, extract reasoning content
|
|
67
|
-
start_idx = delta_text.find(self.reasoning_start_tag)
|
|
68
81
|
end_idx = delta_text.find(self.reasoning_end_tag)
|
|
69
82
|
reasoning_content = delta_text[
|
|
70
83
|
start_idx + len(self.reasoning_start_tag) : end_idx
|
|
@@ -79,7 +92,10 @@ class ReasoningParser:
|
|
|
79
92
|
else:
|
|
80
93
|
# <think> in delta, no </think> in delta,
|
|
81
94
|
# reasoning content continues
|
|
82
|
-
|
|
95
|
+
reasoning_content = delta_text[
|
|
96
|
+
start_idx + len(self.reasoning_start_tag) :
|
|
97
|
+
]
|
|
98
|
+
delta["reasoning_content"] = reasoning_content
|
|
83
99
|
delta["content"] = None
|
|
84
100
|
return delta
|
|
85
101
|
else:
|
|
@@ -142,3 +158,263 @@ class ReasoningParser:
|
|
|
142
158
|
if len(final_output) == 0:
|
|
143
159
|
return reasoning_content, ""
|
|
144
160
|
return reasoning_content, final_output
|
|
161
|
+
|
|
162
|
+
def check_content_parser(self) -> bool:
|
|
163
|
+
"""Check if the parser should extract reasoning content.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
bool: True if reasoning content should be extracted, False otherwise
|
|
167
|
+
"""
|
|
168
|
+
return self.reasoning_content
|
|
169
|
+
|
|
170
|
+
def _create_chat_completion_chunk(
|
|
171
|
+
self, chunk: Union[Dict[str, Any], CompletionChunk], content: str
|
|
172
|
+
) -> ChatCompletionChunk:
|
|
173
|
+
"""Helper method to create a ChatCompletionChunk with specified content.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
chunk: The original chunk to copy metadata from
|
|
177
|
+
content: The content to include in the chunk
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
ChatCompletionChunk: A new chat completion chunk
|
|
181
|
+
"""
|
|
182
|
+
return ChatCompletionChunk(
|
|
183
|
+
id="chat" + chunk["id"],
|
|
184
|
+
model=chunk["model"],
|
|
185
|
+
created=chunk["created"],
|
|
186
|
+
object="chat.completion.chunk",
|
|
187
|
+
choices=[
|
|
188
|
+
{
|
|
189
|
+
"index": 0,
|
|
190
|
+
"delta": {
|
|
191
|
+
"content": content,
|
|
192
|
+
},
|
|
193
|
+
"finish_reason": None,
|
|
194
|
+
}
|
|
195
|
+
],
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def _create_completion_chunk(
|
|
199
|
+
self, chunk: Union[Dict[str, Any], CompletionChunk], text: str
|
|
200
|
+
) -> CompletionChunk:
|
|
201
|
+
"""Helper method to create a CompletionChunk with specified text.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
chunk: The original chunk to copy metadata from
|
|
205
|
+
text: The text to include in the chunk
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
CompletionChunk: A new completion chunk
|
|
209
|
+
"""
|
|
210
|
+
return CompletionChunk(
|
|
211
|
+
id=chunk["id"],
|
|
212
|
+
model=chunk["model"],
|
|
213
|
+
created=chunk["created"],
|
|
214
|
+
object="text_completion",
|
|
215
|
+
choices=[
|
|
216
|
+
{
|
|
217
|
+
"index": 0,
|
|
218
|
+
"text": text,
|
|
219
|
+
"logprobs": None,
|
|
220
|
+
"finish_reason": None,
|
|
221
|
+
}
|
|
222
|
+
],
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
async def prepare_reasoning_content_streaming(
|
|
226
|
+
self, chunks: AsyncGenerator[CompletionChunk, None]
|
|
227
|
+
):
|
|
228
|
+
"""Process the chunks from model output, check if the first chunk contains reasoning_start_tag,
|
|
229
|
+
if not, add a chunk with the tag at the beginning.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
chunks (AsyncGenerator[CompletionChunk, None]): Chunks from model output
|
|
233
|
+
|
|
234
|
+
Yields:
|
|
235
|
+
AsyncGenerator[CompletionChunk, None]: Processed chunks
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
# If reasoning_start_tag is not set, or disable thinking for hybrid model like qwen3,
|
|
239
|
+
# yield chunks as is
|
|
240
|
+
if not self.reasoning_start_tag or not self.enable_thinking:
|
|
241
|
+
async for chunk in chunks:
|
|
242
|
+
yield chunk
|
|
243
|
+
return
|
|
244
|
+
|
|
245
|
+
# If chunks is empty, return
|
|
246
|
+
if not chunks:
|
|
247
|
+
return
|
|
248
|
+
|
|
249
|
+
# Flag to identify the first chunk
|
|
250
|
+
is_first_chunk = True
|
|
251
|
+
|
|
252
|
+
async for chunk in chunks:
|
|
253
|
+
if is_first_chunk:
|
|
254
|
+
# Reset the flag after processing the first chunk
|
|
255
|
+
is_first_chunk = False
|
|
256
|
+
choices = chunk.get("choices")
|
|
257
|
+
if not choices or not choices[0]:
|
|
258
|
+
continue
|
|
259
|
+
if (
|
|
260
|
+
chunk.get("object") == "chat.completion.chunk"
|
|
261
|
+
and "delta" in choices[0]
|
|
262
|
+
):
|
|
263
|
+
# For chat completion chunks with delta format
|
|
264
|
+
delta = choices[0].get("delta")
|
|
265
|
+
if delta is None:
|
|
266
|
+
continue
|
|
267
|
+
assert isinstance(delta, dict)
|
|
268
|
+
text = delta.get("content")
|
|
269
|
+
if text is None:
|
|
270
|
+
continue
|
|
271
|
+
# If the first chunk doesn't contain the reasoning_start_tag
|
|
272
|
+
if self.reasoning_start_tag not in text:
|
|
273
|
+
# Create and yield chunks with reasoning_start_tag and newline
|
|
274
|
+
yield self._create_chat_completion_chunk(
|
|
275
|
+
chunk, f"{self.reasoning_start_tag}\n"
|
|
276
|
+
)
|
|
277
|
+
else:
|
|
278
|
+
# For standard completion chunks
|
|
279
|
+
text = choices[0].get("text")
|
|
280
|
+
if text is None:
|
|
281
|
+
continue
|
|
282
|
+
# If the first chunk doesn't contain the reasoning_start_tag
|
|
283
|
+
if self.reasoning_start_tag not in text:
|
|
284
|
+
# Create and yield chunks with reasoning_start_tag and newline
|
|
285
|
+
yield self._create_completion_chunk(
|
|
286
|
+
chunk, f"{self.reasoning_start_tag}\n"
|
|
287
|
+
)
|
|
288
|
+
# Yield the original first chunk
|
|
289
|
+
yield chunk
|
|
290
|
+
else:
|
|
291
|
+
# For non-first chunks, yield directly
|
|
292
|
+
yield chunk
|
|
293
|
+
|
|
294
|
+
def prepare_reasoning_content_sync(self, chunks: Iterator[CompletionChunk]):
|
|
295
|
+
"""Process the chunks from model output, check if the first chunk contains reasoning_start_tag,
|
|
296
|
+
if not, add a chunk with the tag at the beginning. This is a synchronous version of
|
|
297
|
+
prepare_reasoning_content_streaming.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
chunks (Iterator[CompletionChunk]): Chunks from model output
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
Iterator[CompletionChunk]: Processed chunks
|
|
304
|
+
"""
|
|
305
|
+
# If reasoning_start_tag is not set, or disable thinking for hybrid model like qwen3,
|
|
306
|
+
# yield chunks as is
|
|
307
|
+
if not self.reasoning_start_tag or not self.enable_thinking:
|
|
308
|
+
for chunk in chunks:
|
|
309
|
+
yield chunk
|
|
310
|
+
return
|
|
311
|
+
|
|
312
|
+
# Flag to identify the first chunk
|
|
313
|
+
is_first_chunk = True
|
|
314
|
+
|
|
315
|
+
for chunk in chunks:
|
|
316
|
+
if is_first_chunk:
|
|
317
|
+
# Reset the flag after processing the first chunk
|
|
318
|
+
is_first_chunk = False
|
|
319
|
+
choices = chunk.get("choices")
|
|
320
|
+
if not choices or not choices[0]:
|
|
321
|
+
continue
|
|
322
|
+
if (
|
|
323
|
+
chunk.get("object") == "chat.completion.chunk"
|
|
324
|
+
and "delta" in choices[0]
|
|
325
|
+
):
|
|
326
|
+
# For chat completion chunks with delta format
|
|
327
|
+
delta = choices[0].get("delta")
|
|
328
|
+
if delta is None:
|
|
329
|
+
continue
|
|
330
|
+
assert isinstance(delta, dict)
|
|
331
|
+
text = delta.get("content")
|
|
332
|
+
if text is None:
|
|
333
|
+
continue
|
|
334
|
+
# If the first chunk doesn't contain the reasoning_start_tag
|
|
335
|
+
if self.reasoning_start_tag not in text:
|
|
336
|
+
# Create and yield chunks with reasoning_start_tag and newline
|
|
337
|
+
yield self._create_chat_completion_chunk(
|
|
338
|
+
chunk, f"{self.reasoning_start_tag}\n"
|
|
339
|
+
)
|
|
340
|
+
else:
|
|
341
|
+
# For standard completion chunks
|
|
342
|
+
text = choices[0].get("text")
|
|
343
|
+
if text is None:
|
|
344
|
+
continue
|
|
345
|
+
# If the first chunk doesn't contain the reasoning_start_tag
|
|
346
|
+
if self.reasoning_start_tag not in text:
|
|
347
|
+
# Create and yield chunks with reasoning_start_tag and newline
|
|
348
|
+
yield self._create_completion_chunk(
|
|
349
|
+
chunk, f"{self.reasoning_start_tag}\n"
|
|
350
|
+
)
|
|
351
|
+
# Yield the original first chunk
|
|
352
|
+
yield chunk
|
|
353
|
+
else:
|
|
354
|
+
# For non-first chunks, yield directly
|
|
355
|
+
yield chunk
|
|
356
|
+
|
|
357
|
+
def prepare_reasoning_content(self, completion):
|
|
358
|
+
"""Ensures that the model output string starts with the reasoning_start_tag.
|
|
359
|
+
|
|
360
|
+
If the model_output is not a string (e.g., CompletionChoice), it extracts
|
|
361
|
+
the text content. If the reasoning_start_tag is not found in the text,
|
|
362
|
+
it prepends the tag to the text.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
completion: The completion object containing model output,
|
|
366
|
+
which can be either a chat completion or a standard completion.
|
|
367
|
+
"""
|
|
368
|
+
if not self.reasoning_start_tag or not self.enable_thinking:
|
|
369
|
+
return completion
|
|
370
|
+
|
|
371
|
+
if completion.get("object") == "chat.completion" and completion.get("choices"):
|
|
372
|
+
text = completion["choices"][0]["message"]["content"]
|
|
373
|
+
if self.reasoning_start_tag not in text:
|
|
374
|
+
text = f"{self.reasoning_start_tag}\n{text}"
|
|
375
|
+
completion["choices"][0]["message"]["content"] = text
|
|
376
|
+
return completion
|
|
377
|
+
|
|
378
|
+
text = completion["choices"][0]["text"]
|
|
379
|
+
if self.reasoning_start_tag not in text:
|
|
380
|
+
text = f"{self.reasoning_start_tag}\n{text}"
|
|
381
|
+
completion["choices"][0]["text"] = text
|
|
382
|
+
return completion
|
|
383
|
+
|
|
384
|
+
def prepare_first_reasoning_content_chunk(
|
|
385
|
+
self,
|
|
386
|
+
chunk: CompletionChunk,
|
|
387
|
+
) -> List[ChatCompletionChunk]:
|
|
388
|
+
"""Prepares the first chunk of a completion by adding reasoning_start_tag if needed.
|
|
389
|
+
|
|
390
|
+
This function checks if the first chunk contains the reasoning_start_tag. If not,
|
|
391
|
+
it creates two new chunks containing the reasoning_start_tag and a newline character
|
|
392
|
+
that will be inserted before the original chunk.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
chunk (CompletionChunk): The first chunk of a completion to check and possibly modify
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
List[ChatCompletionChunk]: A list of new chunks to insert before the original chunk,
|
|
399
|
+
or an empty list if no modification is needed
|
|
400
|
+
"""
|
|
401
|
+
chunks: List[ChatCompletionChunk] = []
|
|
402
|
+
if not self.reasoning_start_tag or not self.enable_thinking:
|
|
403
|
+
return chunks
|
|
404
|
+
|
|
405
|
+
choices = chunk.get("choices")
|
|
406
|
+
if not choices or not choices[0]:
|
|
407
|
+
return chunks
|
|
408
|
+
text = choices[0].get("text")
|
|
409
|
+
if not text:
|
|
410
|
+
return chunks
|
|
411
|
+
|
|
412
|
+
if self.reasoning_start_tag not in text:
|
|
413
|
+
# Create chunks with reasoning_start_tag and newline
|
|
414
|
+
chunks.append(
|
|
415
|
+
self._create_chat_completion_chunk(
|
|
416
|
+
chunk, f"{self.reasoning_start_tag}\n"
|
|
417
|
+
)
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
return chunks
|
|
@@ -11,7 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
14
|
+
import importlib.util
|
|
15
15
|
import json
|
|
16
16
|
import logging
|
|
17
17
|
import sys
|
|
@@ -101,12 +101,17 @@ SGLANG_SUPPORTED_CHAT_MODELS = [
|
|
|
101
101
|
"deepseek-v2-chat-0628",
|
|
102
102
|
"qwen2.5-instruct",
|
|
103
103
|
"qwen2.5-coder-instruct",
|
|
104
|
+
"XiYanSQL-QwenCoder-2504",
|
|
104
105
|
"QwQ-32B-Preview",
|
|
105
106
|
"QwQ-32B",
|
|
106
107
|
"deepseek-r1-distill-qwen",
|
|
107
108
|
"deepseek-r1-distill-llama",
|
|
108
109
|
"deepseek-v3",
|
|
109
110
|
"deepseek-r1",
|
|
111
|
+
"DianJin-R1",
|
|
112
|
+
"qwen3",
|
|
113
|
+
"HuatuoGPT-o1-Qwen2.5",
|
|
114
|
+
"HuatuoGPT-o1-LLaMA-3.1",
|
|
110
115
|
]
|
|
111
116
|
SGLANG_SUPPORTED_VISION_MODEL_LIST = [
|
|
112
117
|
"qwen2.5-vl-instruct",
|
|
@@ -154,7 +159,10 @@ class SGLANGModel(LLM):
|
|
|
154
159
|
|
|
155
160
|
self._model_config = self._sanitize_model_config(self._model_config)
|
|
156
161
|
reasoning_content = self._model_config.pop("reasoning_content")
|
|
157
|
-
self.
|
|
162
|
+
enable_thinking = self._model_config.pop("enable_thinking", False)
|
|
163
|
+
self.prepare_parse_reasoning_content(
|
|
164
|
+
reasoning_content, enable_thinking=enable_thinking
|
|
165
|
+
)
|
|
158
166
|
|
|
159
167
|
# Fix: GH#2169
|
|
160
168
|
if sgl.__version__ >= "0.2.14":
|
|
@@ -297,7 +305,11 @@ class SGLANGModel(LLM):
|
|
|
297
305
|
return generate_config
|
|
298
306
|
|
|
299
307
|
@classmethod
|
|
300
|
-
def
|
|
308
|
+
def check_lib(cls) -> bool:
|
|
309
|
+
return importlib.util.find_spec("sglang") is not None
|
|
310
|
+
|
|
311
|
+
@classmethod
|
|
312
|
+
def match_json(
|
|
301
313
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
302
314
|
) -> bool:
|
|
303
315
|
if not cls._has_cuda_device():
|
|
@@ -435,6 +447,7 @@ class SGLANGModel(LLM):
|
|
|
435
447
|
async def async_generate(
|
|
436
448
|
self,
|
|
437
449
|
prompt: str,
|
|
450
|
+
*,
|
|
438
451
|
image_data: Optional[Union[List[str], str]] = None,
|
|
439
452
|
generate_config: Optional[SGLANGGenerateConfig] = None,
|
|
440
453
|
request_id: Optional[str] = None,
|
|
@@ -524,7 +537,7 @@ class SGLANGModel(LLM):
|
|
|
524
537
|
|
|
525
538
|
class SGLANGChatModel(SGLANGModel, ChatModelMixin):
|
|
526
539
|
@classmethod
|
|
527
|
-
def
|
|
540
|
+
def match_json(
|
|
528
541
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
529
542
|
) -> bool:
|
|
530
543
|
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
|
|
@@ -551,6 +564,7 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
|
|
|
551
564
|
if self.model_family.stop:
|
|
552
565
|
if (not generate_config.get("stop")) and self.model_family.stop:
|
|
553
566
|
generate_config["stop"] = self.model_family.stop.copy()
|
|
567
|
+
generate_config.pop("chat_template_kwargs", None)
|
|
554
568
|
return generate_config
|
|
555
569
|
|
|
556
570
|
async def async_chat(
|
|
@@ -560,23 +574,31 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
|
|
|
560
574
|
request_id: Optional[str] = None,
|
|
561
575
|
) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
|
|
562
576
|
assert self.model_family.chat_template is not None
|
|
563
|
-
|
|
577
|
+
full_context_kwargs = (
|
|
578
|
+
self._get_chat_template_kwargs_from_generate_config(
|
|
579
|
+
generate_config, self.reasoning_parser
|
|
580
|
+
)
|
|
581
|
+
or {}
|
|
582
|
+
)
|
|
583
|
+
full_prompt = self.get_full_context(
|
|
584
|
+
messages, self.model_family.chat_template, **full_context_kwargs
|
|
585
|
+
)
|
|
564
586
|
|
|
565
587
|
generate_config = self._sanitize_chat_config(generate_config)
|
|
566
588
|
stream = generate_config.get("stream", None)
|
|
567
589
|
if stream:
|
|
568
|
-
agen = await self.async_generate(full_prompt, generate_config) # type: ignore
|
|
590
|
+
agen = await self.async_generate(full_prompt, generate_config=generate_config) # type: ignore
|
|
569
591
|
assert isinstance(agen, AsyncGenerator)
|
|
570
592
|
return self._async_to_chat_completion_chunks(agen, self.reasoning_parser)
|
|
571
593
|
else:
|
|
572
|
-
c = await self.async_generate(full_prompt, generate_config) # type: ignore
|
|
594
|
+
c = await self.async_generate(full_prompt, generate_config=generate_config) # type: ignore
|
|
573
595
|
assert not isinstance(c, AsyncGenerator)
|
|
574
596
|
return self._to_chat_completion(c, self.reasoning_parser)
|
|
575
597
|
|
|
576
598
|
|
|
577
599
|
class SGLANGVisionModel(SGLANGModel, ChatModelMixin):
|
|
578
600
|
@classmethod
|
|
579
|
-
def
|
|
601
|
+
def match_json(
|
|
580
602
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
581
603
|
) -> bool:
|
|
582
604
|
if not cls._has_cuda_device():
|
|
@@ -627,7 +649,13 @@ class SGLANGVisionModel(SGLANGModel, ChatModelMixin):
|
|
|
627
649
|
self.model_family.chat_template if self.model_family.chat_template else ""
|
|
628
650
|
)
|
|
629
651
|
|
|
630
|
-
|
|
652
|
+
full_context_kwargs = (
|
|
653
|
+
self._get_chat_template_kwargs_from_generate_config(
|
|
654
|
+
generate_config, self.reasoning_parser
|
|
655
|
+
)
|
|
656
|
+
or {}
|
|
657
|
+
)
|
|
658
|
+
prompt = self.get_full_context(messages, chat_template, **full_context_kwargs)
|
|
631
659
|
images, video_inputs = process_vision_info(messages)
|
|
632
660
|
if video_inputs:
|
|
633
661
|
raise ValueError("Not support video input now.")
|
|
@@ -650,10 +678,10 @@ class SGLANGVisionModel(SGLANGModel, ChatModelMixin):
|
|
|
650
678
|
generate_config = self._sanitize_chat_config(generate_config)
|
|
651
679
|
stream = generate_config.get("stream", None)
|
|
652
680
|
if stream:
|
|
653
|
-
agen = await self.async_generate(prompt, base64_images, generate_config) # type: ignore
|
|
681
|
+
agen = await self.async_generate(prompt, image_data=base64_images, generate_config=generate_config) # type: ignore
|
|
654
682
|
assert isinstance(agen, AsyncGenerator)
|
|
655
683
|
return self._async_to_chat_completion_chunks(agen, self.reasoning_parser)
|
|
656
684
|
else:
|
|
657
|
-
c = await self.async_generate(prompt, base64_images, generate_config) # type: ignore
|
|
685
|
+
c = await self.async_generate(prompt, image_data=base64_images, generate_config=generate_config) # type: ignore
|
|
658
686
|
assert not isinstance(c, AsyncGenerator)
|
|
659
687
|
return self._to_chat_completion(c, self.reasoning_parser)
|
|
@@ -84,7 +84,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
84
84
|
return model, tokenizer
|
|
85
85
|
|
|
86
86
|
@classmethod
|
|
87
|
-
def
|
|
87
|
+
def match_json(
|
|
88
88
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
89
89
|
) -> bool:
|
|
90
90
|
if llm_spec.model_format != "pytorch":
|
|
@@ -462,6 +462,12 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
462
462
|
tools = list(tools) if tools is not None else None
|
|
463
463
|
tool_choice = r.generate_config.get("tool_choice", "none")
|
|
464
464
|
|
|
465
|
+
full_context_kwargs = (
|
|
466
|
+
self._get_chat_template_kwargs_from_generate_config(
|
|
467
|
+
r.generate_config, self.reasoning_parser
|
|
468
|
+
)
|
|
469
|
+
or {}
|
|
470
|
+
)
|
|
465
471
|
r.prompt = self._process_messages(
|
|
466
472
|
r.prompt, tools=tools, tool_choice=tool_choice
|
|
467
473
|
)
|
|
@@ -469,6 +475,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
469
475
|
r.prompt,
|
|
470
476
|
self.model_family.chat_template, # type: ignore
|
|
471
477
|
tokenizer=self._tokenizer,
|
|
478
|
+
**full_context_kwargs,
|
|
472
479
|
)
|
|
473
480
|
if tools:
|
|
474
481
|
r.tools = tools
|
|
@@ -501,7 +508,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
501
508
|
|
|
502
509
|
if "<bos_stream>" in req.completion:
|
|
503
510
|
bos_pos = req.completion.index("<bos_stream>")
|
|
504
|
-
results.
|
|
511
|
+
results.extend(
|
|
505
512
|
self._get_first_chat_completion_chunk(req.completion[bos_pos + 1])
|
|
506
513
|
)
|
|
507
514
|
|
|
@@ -46,8 +46,8 @@ class CogAgentChatModel(PytorchChatModel):
|
|
|
46
46
|
self._device = None
|
|
47
47
|
self._tokenizer = None
|
|
48
48
|
self._model = None
|
|
49
|
-
self._platform: Literal["Mac", "WIN", "Mobile"] | None = "Mac"
|
|
50
|
-
self._format: Literal[
|
|
49
|
+
self._platform: Literal["Mac", "WIN", "Mobile"] | None = "Mac" # type: ignore
|
|
50
|
+
self._format: Literal[ # type: ignore
|
|
51
51
|
"(Answer in Action-Operation-Sensitive format.)",
|
|
52
52
|
"(Answer in Status-Plan-Action-Operation format.)",
|
|
53
53
|
"(Answer in Status-Action-Operation-Sensitive format.)",
|
|
@@ -56,7 +56,7 @@ class CogAgentChatModel(PytorchChatModel):
|
|
|
56
56
|
] | None = "(Answer in Action-Operation-Sensitive format.)"
|
|
57
57
|
|
|
58
58
|
@classmethod
|
|
59
|
-
def
|
|
59
|
+
def match_json(
|
|
60
60
|
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
61
61
|
) -> bool:
|
|
62
62
|
family = model_family.model_family or model_family.model_name
|
|
@@ -64,8 +64,8 @@ class CogAgentChatModel(PytorchChatModel):
|
|
|
64
64
|
return True
|
|
65
65
|
return False
|
|
66
66
|
|
|
67
|
-
def load(self
|
|
68
|
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
67
|
+
def load(self):
|
|
68
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
69
69
|
|
|
70
70
|
device = self._pytorch_model_config.get("device", "auto")
|
|
71
71
|
self._device = select_device(device)
|
|
@@ -73,19 +73,14 @@ class CogAgentChatModel(PytorchChatModel):
|
|
|
73
73
|
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
74
74
|
self.model_path, trust_remote_code=True
|
|
75
75
|
)
|
|
76
|
-
|
|
77
|
-
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
|
78
|
-
elif self.quantization == "8-bit":
|
|
79
|
-
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
80
|
-
else:
|
|
81
|
-
quantization_config = None
|
|
76
|
+
kwargs = self.apply_bnb_quantization()
|
|
82
77
|
|
|
83
78
|
self._model = AutoModelForCausalLM.from_pretrained(
|
|
84
79
|
self.model_path,
|
|
85
80
|
torch_dtype=torch.bfloat16,
|
|
86
81
|
trust_remote_code=True,
|
|
87
82
|
device_map=self._device,
|
|
88
|
-
|
|
83
|
+
**kwargs,
|
|
89
84
|
).eval()
|
|
90
85
|
|
|
91
86
|
def _message_content_to_cogagent(self, content):
|
|
@@ -211,6 +206,9 @@ class CogAgentChatModel(PytorchChatModel):
|
|
|
211
206
|
"return_tensors": "pt",
|
|
212
207
|
"return_dict": True,
|
|
213
208
|
}
|
|
209
|
+
full_context_kwargs.update(
|
|
210
|
+
self._get_chat_template_kwargs_from_generate_config(generate_config, self.reasoning_parser) or {} # type: ignore
|
|
211
|
+
)
|
|
214
212
|
assert self.model_family.chat_template is not None
|
|
215
213
|
inputs = self.get_full_context(
|
|
216
214
|
[{"role": "user", "image": image, "content": query}],
|
|
@@ -64,7 +64,7 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
64
64
|
self._model = None
|
|
65
65
|
|
|
66
66
|
@classmethod
|
|
67
|
-
def
|
|
67
|
+
def match_json(
|
|
68
68
|
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
69
69
|
) -> bool:
|
|
70
70
|
family = model_family.model_family or model_family.model_name
|
|
@@ -72,7 +72,7 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
72
72
|
return True
|
|
73
73
|
return False
|
|
74
74
|
|
|
75
|
-
def load(self
|
|
75
|
+
def load(self):
|
|
76
76
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
77
77
|
from transformers.generation import GenerationConfig
|
|
78
78
|
|
|
@@ -88,6 +88,8 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
88
88
|
self._model, self._tokenizer = self._load_tensorizer()
|
|
89
89
|
return
|
|
90
90
|
|
|
91
|
+
kwargs = self.apply_bnb_quantization()
|
|
92
|
+
|
|
91
93
|
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
92
94
|
self.model_path,
|
|
93
95
|
trust_remote_code=True,
|
|
@@ -99,6 +101,7 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
99
101
|
trust_remote_code=True,
|
|
100
102
|
low_cpu_mem_usage=True,
|
|
101
103
|
device_map="auto",
|
|
104
|
+
**kwargs
|
|
102
105
|
).eval()
|
|
103
106
|
|
|
104
107
|
# Specify hyperparameters for generation
|
|
@@ -313,7 +316,7 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
313
316
|
def get_dtype(self):
|
|
314
317
|
return self._torch_type
|
|
315
318
|
|
|
316
|
-
def _get_full_prompt(self, messages: List[Dict], tools):
|
|
319
|
+
def _get_full_prompt(self, messages: List[Dict], tools, generate_config: dict): # type: ignore
|
|
317
320
|
prompt, system_prompt, chat_history = parse_messages(messages)
|
|
318
321
|
system_prompt = system_prompt or ""
|
|
319
322
|
query, image, history = self.get_query_and_history(
|