xinference 0.9.3__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/oauth2/auth_service.py +47 -18
- xinference/api/oauth2/types.py +1 -0
- xinference/api/restful_api.py +16 -11
- xinference/client/restful/restful_client.py +12 -2
- xinference/conftest.py +13 -2
- xinference/constants.py +2 -0
- xinference/core/supervisor.py +32 -1
- xinference/core/worker.py +139 -20
- xinference/deploy/cmdline.py +119 -20
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +711 -10
- xinference/model/llm/llm_family_modelscope.json +557 -7
- xinference/model/llm/pytorch/chatglm.py +2 -1
- xinference/model/llm/pytorch/core.py +2 -0
- xinference/model/llm/pytorch/deepseek_vl.py +232 -0
- xinference/model/llm/pytorch/internlm2.py +2 -1
- xinference/model/llm/pytorch/omnilmm.py +153 -0
- xinference/model/llm/sglang/__init__.py +13 -0
- xinference/model/llm/sglang/core.py +365 -0
- xinference/model/llm/utils.py +46 -13
- xinference/model/llm/vllm/core.py +10 -0
- xinference/thirdparty/deepseek_vl/__init__.py +31 -0
- xinference/thirdparty/deepseek_vl/models/__init__.py +28 -0
- xinference/thirdparty/deepseek_vl/models/clip_encoder.py +242 -0
- xinference/thirdparty/deepseek_vl/models/image_processing_vlm.py +208 -0
- xinference/thirdparty/deepseek_vl/models/modeling_vlm.py +170 -0
- xinference/thirdparty/deepseek_vl/models/processing_vlm.py +390 -0
- xinference/thirdparty/deepseek_vl/models/projector.py +100 -0
- xinference/thirdparty/deepseek_vl/models/sam.py +593 -0
- xinference/thirdparty/deepseek_vl/models/siglip_vit.py +681 -0
- xinference/thirdparty/deepseek_vl/utils/__init__.py +18 -0
- xinference/thirdparty/deepseek_vl/utils/conversation.py +348 -0
- xinference/thirdparty/deepseek_vl/utils/io.py +78 -0
- xinference/thirdparty/omnilmm/__init__.py +0 -0
- xinference/thirdparty/omnilmm/chat.py +216 -0
- xinference/thirdparty/omnilmm/constants.py +4 -0
- xinference/thirdparty/omnilmm/conversation.py +332 -0
- xinference/thirdparty/omnilmm/model/__init__.py +1 -0
- xinference/thirdparty/omnilmm/model/omnilmm.py +594 -0
- xinference/thirdparty/omnilmm/model/resampler.py +166 -0
- xinference/thirdparty/omnilmm/model/utils.py +563 -0
- xinference/thirdparty/omnilmm/train/__init__.py +13 -0
- xinference/thirdparty/omnilmm/train/train_utils.py +150 -0
- xinference/thirdparty/omnilmm/utils.py +134 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.98516614.js +3 -0
- xinference/web/ui/build/static/js/main.98516614.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/139969fd25258eb7decc9505f30b779089bba50c402bb5c663008477c7bff73b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3f357ab57b8e7fade54c667f0e0ebf2787566f72bfdca0fea14e395b5c203753.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9d7c49815d97539207e5aab2fb967591b5fed7791218a0762539efc9491f36af.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d0d0b591d9adaf42b83ad6633f8b7c118541a4b80ea957c303d3bf9b86fbad0a.json +1 -0
- {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/METADATA +21 -5
- {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/RECORD +60 -31
- xinference/web/ui/build/static/js/main.66b1c4fb.js +0 -3
- xinference/web/ui/build/static/js/main.66b1c4fb.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c2124cfe036b26befcbd386d1d17743b1a58d0b7a041a17bb67f9924400d63c3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/fd4a8ae5d192331af1bedd1d2d70efcc569708ee6cc4cb479b225d059482aa81.json +0 -1
- /xinference/web/ui/build/static/js/{main.66b1c4fb.js.LICENSE.txt → main.98516614.js.LICENSE.txt} +0 -0
- {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/LICENSE +0 -0
- {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/WHEEL +0 -0
- {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
# Copyright 2022-2024 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import time
|
|
17
|
+
import uuid
|
|
18
|
+
from typing import AsyncGenerator, Dict, List, Optional, TypedDict, Union
|
|
19
|
+
|
|
20
|
+
from ....constants import XINFERENCE_ENABLE_SGLANG
|
|
21
|
+
from ....types import (
|
|
22
|
+
ChatCompletion,
|
|
23
|
+
ChatCompletionChunk,
|
|
24
|
+
ChatCompletionMessage,
|
|
25
|
+
Completion,
|
|
26
|
+
CompletionChoice,
|
|
27
|
+
CompletionChunk,
|
|
28
|
+
CompletionUsage,
|
|
29
|
+
)
|
|
30
|
+
from .. import LLM, LLMFamilyV1, LLMSpecV1
|
|
31
|
+
from ..llm_family import CustomLLMFamilyV1
|
|
32
|
+
from ..utils import ChatModelMixin
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class SGLANGModelConfig(TypedDict, total=False):
|
|
38
|
+
tokenizer_mode: str
|
|
39
|
+
trust_remote_code: bool
|
|
40
|
+
tp_size: int
|
|
41
|
+
mem_fraction_static: float
|
|
42
|
+
log_level: str
|
|
43
|
+
attention_reduce_in_fp32: bool # For gemma
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class SGLANGGenerateConfig(TypedDict, total=False):
|
|
47
|
+
presence_penalty: float
|
|
48
|
+
frequency_penalty: float
|
|
49
|
+
temperature: float
|
|
50
|
+
top_p: float
|
|
51
|
+
top_k: int
|
|
52
|
+
max_new_tokens: int
|
|
53
|
+
stop: Optional[Union[str, List[str]]]
|
|
54
|
+
ignore_eos: bool
|
|
55
|
+
stream: bool
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
import sglang # noqa: F401
|
|
60
|
+
|
|
61
|
+
SGLANG_INSTALLED = True
|
|
62
|
+
except ImportError:
|
|
63
|
+
SGLANG_INSTALLED = False
|
|
64
|
+
|
|
65
|
+
SGLANG_SUPPORTED_MODELS = ["llama-2", "mistral-v0.1", "mixtral-v0.1"]
|
|
66
|
+
SGLANG_SUPPORTED_CHAT_MODELS = [
|
|
67
|
+
"llama-2-chat",
|
|
68
|
+
"qwen-chat",
|
|
69
|
+
"qwen1.5-chat",
|
|
70
|
+
"mistral-instruct-v0.1",
|
|
71
|
+
"mistral-instruct-v0.2",
|
|
72
|
+
"mixtral-instruct-v0.1",
|
|
73
|
+
"gemma-it",
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class SGLANGModel(LLM):
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
model_uid: str,
|
|
81
|
+
model_family: "LLMFamilyV1",
|
|
82
|
+
model_spec: "LLMSpecV1",
|
|
83
|
+
quantization: str,
|
|
84
|
+
model_path: str,
|
|
85
|
+
model_config: Optional[SGLANGModelConfig],
|
|
86
|
+
):
|
|
87
|
+
super().__init__(model_uid, model_family, model_spec, quantization, model_path)
|
|
88
|
+
self._model_config = model_config
|
|
89
|
+
self._engine = None
|
|
90
|
+
|
|
91
|
+
def load(self):
|
|
92
|
+
try:
|
|
93
|
+
import sglang as sgl
|
|
94
|
+
except ImportError:
|
|
95
|
+
error_message = "Failed to import module 'sglang'"
|
|
96
|
+
installation_guide = [
|
|
97
|
+
"Please make sure 'sglang' is installed. ",
|
|
98
|
+
"You can install it by `pip install 'sglang[all]'`\n",
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
102
|
+
|
|
103
|
+
self._model_config = self._sanitize_model_config(self._model_config)
|
|
104
|
+
logger.info(
|
|
105
|
+
f"Loading {self.model_uid} with following model config: {self._model_config}"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self._engine = sgl.Runtime(
|
|
109
|
+
model_path=self.model_path,
|
|
110
|
+
tokenizer_path=self.model_path,
|
|
111
|
+
**self._model_config,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def _sanitize_model_config(
|
|
115
|
+
self, model_config: Optional[SGLANGModelConfig]
|
|
116
|
+
) -> SGLANGModelConfig:
|
|
117
|
+
if model_config is None:
|
|
118
|
+
model_config = SGLANGModelConfig()
|
|
119
|
+
|
|
120
|
+
cuda_count = self._get_cuda_count()
|
|
121
|
+
model_config.setdefault("tokenizer_mode", "auto")
|
|
122
|
+
model_config.setdefault("trust_remote_code", True)
|
|
123
|
+
model_config.setdefault("tp_size", cuda_count)
|
|
124
|
+
# See https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py#L37
|
|
125
|
+
mem_fraction_static = model_config.pop("mem_fraction_static", None)
|
|
126
|
+
if mem_fraction_static is None:
|
|
127
|
+
tp_size = model_config.get("tp_size", cuda_count)
|
|
128
|
+
if tp_size >= 8:
|
|
129
|
+
model_config["mem_fraction_static"] = 0.80
|
|
130
|
+
elif tp_size >= 4:
|
|
131
|
+
model_config["mem_fraction_static"] = 0.82
|
|
132
|
+
elif tp_size >= 2:
|
|
133
|
+
model_config["mem_fraction_static"] = 0.85
|
|
134
|
+
else:
|
|
135
|
+
model_config["mem_fraction_static"] = 0.90
|
|
136
|
+
model_config.setdefault("log_level", "info")
|
|
137
|
+
model_config.setdefault("attention_reduce_in_fp32", False)
|
|
138
|
+
|
|
139
|
+
return model_config
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
def _sanitize_generate_config(
|
|
143
|
+
generate_config: Optional[SGLANGGenerateConfig] = None,
|
|
144
|
+
) -> SGLANGGenerateConfig:
|
|
145
|
+
if generate_config is None:
|
|
146
|
+
generate_config = SGLANGGenerateConfig()
|
|
147
|
+
|
|
148
|
+
generate_config.setdefault("presence_penalty", 0.0)
|
|
149
|
+
generate_config.setdefault("frequency_penalty", 0.0)
|
|
150
|
+
generate_config.setdefault("temperature", 1.0)
|
|
151
|
+
generate_config.setdefault("top_p", 1.0)
|
|
152
|
+
generate_config.setdefault("top_k", -1)
|
|
153
|
+
# See https://github.com/sgl-project/sglang/blob/main/python/sglang/lang/ir.py#L120
|
|
154
|
+
# 16 is too less, so here set 256 by default
|
|
155
|
+
generate_config.setdefault(
|
|
156
|
+
"max_new_tokens", generate_config.pop("max_tokens", 256) # type: ignore
|
|
157
|
+
)
|
|
158
|
+
generate_config.setdefault("stop", [])
|
|
159
|
+
generate_config.setdefault("stream", False)
|
|
160
|
+
generate_config.setdefault("ignore_eos", False)
|
|
161
|
+
|
|
162
|
+
return generate_config
|
|
163
|
+
|
|
164
|
+
@classmethod
|
|
165
|
+
def match(
|
|
166
|
+
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
167
|
+
) -> bool:
|
|
168
|
+
if not XINFERENCE_ENABLE_SGLANG:
|
|
169
|
+
return False
|
|
170
|
+
if not cls._has_cuda_device():
|
|
171
|
+
return False
|
|
172
|
+
if not cls._is_linux():
|
|
173
|
+
return False
|
|
174
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
175
|
+
return False
|
|
176
|
+
if llm_spec.model_format == "pytorch":
|
|
177
|
+
if quantization != "none" and not (quantization is None):
|
|
178
|
+
return False
|
|
179
|
+
if llm_spec.model_format in ["gptq", "awq"]:
|
|
180
|
+
# Currently, only 4-bit weight quantization is supported for GPTQ, but got 8 bits.
|
|
181
|
+
if "4" not in quantization:
|
|
182
|
+
return False
|
|
183
|
+
if isinstance(llm_family, CustomLLMFamilyV1):
|
|
184
|
+
if llm_family.model_family not in SGLANG_SUPPORTED_MODELS:
|
|
185
|
+
return False
|
|
186
|
+
else:
|
|
187
|
+
if llm_family.model_name not in SGLANG_SUPPORTED_MODELS:
|
|
188
|
+
return False
|
|
189
|
+
if "generate" not in llm_family.model_ability:
|
|
190
|
+
return False
|
|
191
|
+
return SGLANG_INSTALLED
|
|
192
|
+
|
|
193
|
+
@staticmethod
|
|
194
|
+
def _convert_state_to_completion_chunk(
|
|
195
|
+
request_id: str, model: str, output_text: str, meta_info: Dict
|
|
196
|
+
) -> CompletionChunk:
|
|
197
|
+
choices: List[CompletionChoice] = [
|
|
198
|
+
CompletionChoice(
|
|
199
|
+
text=output_text,
|
|
200
|
+
index=0,
|
|
201
|
+
logprobs=None,
|
|
202
|
+
finish_reason=None,
|
|
203
|
+
)
|
|
204
|
+
]
|
|
205
|
+
chunk = CompletionChunk(
|
|
206
|
+
id=request_id,
|
|
207
|
+
object="text_completion",
|
|
208
|
+
created=int(time.time()),
|
|
209
|
+
model=model,
|
|
210
|
+
choices=choices,
|
|
211
|
+
)
|
|
212
|
+
prompt_tokens = meta_info["prompt_tokens"]
|
|
213
|
+
completion_tokens = meta_info["completion_tokens"]
|
|
214
|
+
chunk["usage"] = CompletionUsage(
|
|
215
|
+
prompt_tokens=prompt_tokens,
|
|
216
|
+
completion_tokens=completion_tokens,
|
|
217
|
+
total_tokens=prompt_tokens + completion_tokens,
|
|
218
|
+
)
|
|
219
|
+
return chunk
|
|
220
|
+
|
|
221
|
+
@staticmethod
|
|
222
|
+
def _convert_state_to_completion(
|
|
223
|
+
request_id: str, model: str, output_text: str, meta_info: Dict
|
|
224
|
+
) -> Completion:
|
|
225
|
+
choices = [
|
|
226
|
+
CompletionChoice(
|
|
227
|
+
text=output_text,
|
|
228
|
+
index=0,
|
|
229
|
+
logprobs=None,
|
|
230
|
+
finish_reason=None,
|
|
231
|
+
)
|
|
232
|
+
]
|
|
233
|
+
|
|
234
|
+
usage = CompletionUsage(
|
|
235
|
+
prompt_tokens=meta_info["prompt_tokens"],
|
|
236
|
+
completion_tokens=meta_info["completion_tokens"],
|
|
237
|
+
total_tokens=meta_info["prompt_tokens"] + meta_info["completion_tokens"],
|
|
238
|
+
)
|
|
239
|
+
return Completion(
|
|
240
|
+
id=request_id,
|
|
241
|
+
object="text_completion",
|
|
242
|
+
created=int(time.time()),
|
|
243
|
+
model=model,
|
|
244
|
+
choices=choices,
|
|
245
|
+
usage=usage,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
async def async_generate(
|
|
249
|
+
self,
|
|
250
|
+
prompt: str,
|
|
251
|
+
generate_config: Optional[SGLANGGenerateConfig] = None,
|
|
252
|
+
) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]:
|
|
253
|
+
try:
|
|
254
|
+
import sglang as sgl
|
|
255
|
+
from sglang import assistant, gen, user
|
|
256
|
+
except ImportError:
|
|
257
|
+
error_message = "Failed to import module 'sglang'"
|
|
258
|
+
installation_guide = [
|
|
259
|
+
"Please make sure 'sglang' is installed. ",
|
|
260
|
+
"You can install it by `pip install sglang[all]`\n",
|
|
261
|
+
]
|
|
262
|
+
|
|
263
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
264
|
+
|
|
265
|
+
@sgl.function
|
|
266
|
+
def pipeline(s, question):
|
|
267
|
+
s += user(question)
|
|
268
|
+
s += assistant(gen("answer"))
|
|
269
|
+
|
|
270
|
+
sanitized_generate_config = self._sanitize_generate_config(generate_config)
|
|
271
|
+
logger.debug(
|
|
272
|
+
"Enter generate, prompt: %s, generate config: %s", prompt, generate_config
|
|
273
|
+
)
|
|
274
|
+
stream = sanitized_generate_config.pop("stream")
|
|
275
|
+
request_id = str(uuid.uuid1())
|
|
276
|
+
state = pipeline.run(
|
|
277
|
+
question=prompt,
|
|
278
|
+
backend=self._engine,
|
|
279
|
+
stream=stream,
|
|
280
|
+
**sanitized_generate_config,
|
|
281
|
+
)
|
|
282
|
+
if not stream:
|
|
283
|
+
return self._convert_state_to_completion(
|
|
284
|
+
request_id,
|
|
285
|
+
model=self.model_uid,
|
|
286
|
+
output_text=state["answer"],
|
|
287
|
+
meta_info=state.get_meta_info(name="answer"),
|
|
288
|
+
)
|
|
289
|
+
else:
|
|
290
|
+
|
|
291
|
+
async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
|
|
292
|
+
async for out, meta_info in state.text_async_iter(
|
|
293
|
+
var_name="answer", return_meta_data=True
|
|
294
|
+
):
|
|
295
|
+
chunk = self._convert_state_to_completion_chunk(
|
|
296
|
+
request_id, self.model_uid, output_text=out, meta_info=meta_info
|
|
297
|
+
)
|
|
298
|
+
yield chunk
|
|
299
|
+
|
|
300
|
+
return stream_results()
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class SGLANGChatModel(SGLANGModel, ChatModelMixin):
|
|
304
|
+
@classmethod
|
|
305
|
+
def match(
|
|
306
|
+
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
307
|
+
) -> bool:
|
|
308
|
+
if not XINFERENCE_ENABLE_SGLANG:
|
|
309
|
+
return False
|
|
310
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
311
|
+
return False
|
|
312
|
+
if llm_spec.model_format == "pytorch":
|
|
313
|
+
if quantization != "none" and not (quantization is None):
|
|
314
|
+
return False
|
|
315
|
+
if llm_spec.model_format in ["gptq", "awq"]:
|
|
316
|
+
# Currently, only 4-bit weight quantization is supported for GPTQ, but got 8 bits.
|
|
317
|
+
if "4" not in quantization:
|
|
318
|
+
return False
|
|
319
|
+
if isinstance(llm_family, CustomLLMFamilyV1):
|
|
320
|
+
if llm_family.model_family not in SGLANG_SUPPORTED_CHAT_MODELS:
|
|
321
|
+
return False
|
|
322
|
+
else:
|
|
323
|
+
if llm_family.model_name not in SGLANG_SUPPORTED_CHAT_MODELS:
|
|
324
|
+
return False
|
|
325
|
+
if "chat" not in llm_family.model_ability:
|
|
326
|
+
return False
|
|
327
|
+
return SGLANG_INSTALLED
|
|
328
|
+
|
|
329
|
+
def _sanitize_chat_config(
|
|
330
|
+
self,
|
|
331
|
+
generate_config: Optional[Dict] = None,
|
|
332
|
+
) -> Dict:
|
|
333
|
+
if not generate_config:
|
|
334
|
+
generate_config = {}
|
|
335
|
+
if self.model_family.prompt_style:
|
|
336
|
+
if (
|
|
337
|
+
not generate_config.get("stop")
|
|
338
|
+
) and self.model_family.prompt_style.stop:
|
|
339
|
+
generate_config["stop"] = self.model_family.prompt_style.stop.copy()
|
|
340
|
+
return generate_config
|
|
341
|
+
|
|
342
|
+
async def async_chat(
|
|
343
|
+
self,
|
|
344
|
+
prompt: str,
|
|
345
|
+
system_prompt: Optional[str] = None,
|
|
346
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
347
|
+
generate_config: Optional[Dict] = None,
|
|
348
|
+
) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
|
|
349
|
+
assert self.model_family.prompt_style is not None
|
|
350
|
+
prompt_style = self.model_family.prompt_style.copy()
|
|
351
|
+
if system_prompt:
|
|
352
|
+
prompt_style.system_prompt = system_prompt
|
|
353
|
+
chat_history = chat_history or []
|
|
354
|
+
full_prompt = self.get_prompt(prompt, chat_history, prompt_style)
|
|
355
|
+
|
|
356
|
+
generate_config = self._sanitize_chat_config(generate_config)
|
|
357
|
+
stream = generate_config.get("stream", None)
|
|
358
|
+
if stream:
|
|
359
|
+
agen = await self.async_generate(full_prompt, generate_config) # type: ignore
|
|
360
|
+
assert isinstance(agen, AsyncGenerator)
|
|
361
|
+
return self._async_to_chat_completion_chunks(agen)
|
|
362
|
+
else:
|
|
363
|
+
c = await self.async_generate(full_prompt, generate_config) # type: ignore
|
|
364
|
+
assert not isinstance(c, AsyncGenerator)
|
|
365
|
+
return self._to_chat_completion(c)
|
xinference/model/llm/utils.py
CHANGED
|
@@ -411,6 +411,26 @@ Begin!"""
|
|
|
411
411
|
if content:
|
|
412
412
|
ret += content + "<end_of_turn>\n"
|
|
413
413
|
return ret
|
|
414
|
+
elif prompt_style.style_name == "CodeShell":
|
|
415
|
+
ret = ""
|
|
416
|
+
for message in chat_history:
|
|
417
|
+
content = message["content"]
|
|
418
|
+
role = get_role(message["role"])
|
|
419
|
+
if content:
|
|
420
|
+
ret += f"{role}{content}|<end>|"
|
|
421
|
+
else:
|
|
422
|
+
ret += f"{role}".rstrip()
|
|
423
|
+
return ret
|
|
424
|
+
elif prompt_style.style_name == "MINICPM-2B":
|
|
425
|
+
ret = ""
|
|
426
|
+
for message in chat_history:
|
|
427
|
+
content = message["content"] or ""
|
|
428
|
+
role = get_role(message["role"])
|
|
429
|
+
if role == "user":
|
|
430
|
+
ret += "<用户>" + content.strip()
|
|
431
|
+
else:
|
|
432
|
+
ret += "<AI>" + content.strip()
|
|
433
|
+
return ret
|
|
414
434
|
else:
|
|
415
435
|
raise ValueError(f"Invalid prompt style: {prompt_style.style_name}")
|
|
416
436
|
|
|
@@ -451,6 +471,7 @@ Begin!"""
|
|
|
451
471
|
"index": i,
|
|
452
472
|
"delta": {
|
|
453
473
|
"role": "assistant",
|
|
474
|
+
"content": "",
|
|
454
475
|
},
|
|
455
476
|
"finish_reason": None,
|
|
456
477
|
}
|
|
@@ -535,33 +556,46 @@ Begin!"""
|
|
|
535
556
|
# Refer to:
|
|
536
557
|
# https://github.com/QwenLM/Qwen/blob/main/examples/react_prompt.md
|
|
537
558
|
# https://github.com/QwenLM/Qwen/blob/main/openai_api.py#L297
|
|
538
|
-
func_name, func_args = "", ""
|
|
559
|
+
func_name, func_args, content = "", "", ""
|
|
539
560
|
i = text.rfind("\nAction:")
|
|
540
561
|
j = text.rfind("\nAction Input:")
|
|
541
562
|
k = text.rfind("\nObservation:")
|
|
563
|
+
t = max(
|
|
564
|
+
text.rfind("\nThought:", 0, i), text.rfind("Thought:", 0, i)
|
|
565
|
+
) # find the last thought just before Action, considering the Thought at the very beginning
|
|
542
566
|
if 0 <= i < j: # If the text has `Action` and `Action input`,
|
|
543
567
|
if k < j: # but does not contain `Observation`,
|
|
544
568
|
# then it is likely that `Observation` is omitted by the LLM,
|
|
545
569
|
# because the output text may have discarded the stop word.
|
|
546
570
|
text = text.rstrip() + "\nObservation:" # Add it back.
|
|
547
571
|
k = text.rfind("\nObservation:")
|
|
548
|
-
if 0 <= i < j < k:
|
|
572
|
+
if 0 <= t < i < j < k:
|
|
549
573
|
func_name = text[i + len("\nAction:") : j].strip()
|
|
550
574
|
func_args = text[j + len("\nAction Input:") : k].strip()
|
|
575
|
+
content = text[
|
|
576
|
+
t + len("\nThought:") : i
|
|
577
|
+
].strip() # len("\nThought:") and len("Thought:") both are OK since there is a space after :
|
|
551
578
|
if func_name:
|
|
552
|
-
return
|
|
553
|
-
z = text.rfind("\nFinal Answer: ")
|
|
554
|
-
if z >= 0:
|
|
555
|
-
text = text[z + len("\nFinal Answer: ") :]
|
|
579
|
+
return content, func_name, json.loads(func_args)
|
|
556
580
|
except Exception as e:
|
|
557
581
|
logger.error("Eval tool calls completion failed: %s", e)
|
|
582
|
+
t = max(text.rfind("\nThought:"), text.rfind("Thought:"))
|
|
583
|
+
z = max(text.rfind("\nFinal Answer:"), text.rfind("Final Answer:"))
|
|
584
|
+
if z >= 0:
|
|
585
|
+
text = text[
|
|
586
|
+
z + len("\nFinal Answer:") :
|
|
587
|
+
] # len("\nFinal Answer::") and len("Final Answer::") both are OK since there is a space after :
|
|
588
|
+
else:
|
|
589
|
+
text = text[
|
|
590
|
+
t + len("\nThought:") :
|
|
591
|
+
] # There is only Thought: no Final Answer:
|
|
558
592
|
return text, None, None
|
|
559
593
|
|
|
560
594
|
@classmethod
|
|
561
595
|
def _tool_calls_completion(cls, model_family, model_uid, c, tools):
|
|
562
596
|
_id = str(uuid.uuid4())
|
|
563
597
|
family = model_family.model_family or model_family.model_name
|
|
564
|
-
if "gorilla-openfunctions-v1"
|
|
598
|
+
if family in ["gorilla-openfunctions-v1", "gorilla-openfunctions-v2"]:
|
|
565
599
|
content, func, args = cls._eval_gorilla_openfunctions_arguments(c, tools)
|
|
566
600
|
elif "chatglm3" == family:
|
|
567
601
|
content, func, args = cls._eval_chatglm3_arguments(c, tools)
|
|
@@ -573,13 +607,10 @@ Begin!"""
|
|
|
573
607
|
)
|
|
574
608
|
logger.debug("Tool call content: %s, func: %s, args: %s", content, func, args)
|
|
575
609
|
|
|
576
|
-
if
|
|
577
|
-
m = {"role": "assistant", "content": content, "tool_calls": []}
|
|
578
|
-
finish_reason = "stop"
|
|
579
|
-
else:
|
|
610
|
+
if func:
|
|
580
611
|
m = {
|
|
581
612
|
"role": "assistant",
|
|
582
|
-
"content":
|
|
613
|
+
"content": content,
|
|
583
614
|
"tool_calls": [
|
|
584
615
|
{
|
|
585
616
|
"id": f"call_{_id}",
|
|
@@ -592,7 +623,9 @@ Begin!"""
|
|
|
592
623
|
],
|
|
593
624
|
}
|
|
594
625
|
finish_reason = "tool_calls"
|
|
595
|
-
|
|
626
|
+
else:
|
|
627
|
+
m = {"role": "assistant", "content": content, "tool_calls": []}
|
|
628
|
+
finish_reason = "stop"
|
|
596
629
|
return {
|
|
597
630
|
"id": "chat" + f"cmpl-{_id}",
|
|
598
631
|
"model": model_uid,
|
|
@@ -86,6 +86,7 @@ VLLM_SUPPORTED_CHAT_MODELS = [
|
|
|
86
86
|
"vicuna-v1.3",
|
|
87
87
|
"vicuna-v1.5",
|
|
88
88
|
"baichuan-chat",
|
|
89
|
+
"baichuan-2-chat",
|
|
89
90
|
"internlm-chat-7b",
|
|
90
91
|
"internlm-chat-8k",
|
|
91
92
|
"internlm-chat-20b",
|
|
@@ -99,10 +100,19 @@ VLLM_SUPPORTED_CHAT_MODELS = [
|
|
|
99
100
|
"mistral-instruct-v0.2",
|
|
100
101
|
"mixtral-instruct-v0.1",
|
|
101
102
|
"chatglm3",
|
|
103
|
+
"deepseek-chat",
|
|
104
|
+
"deepseek-coder-instruct",
|
|
102
105
|
]
|
|
103
106
|
if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
|
|
104
107
|
VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-chat")
|
|
105
108
|
|
|
109
|
+
if VLLM_INSTALLED and vllm.__version__ >= "0.3.2":
|
|
110
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("gemma-it")
|
|
111
|
+
|
|
112
|
+
if VLLM_INSTALLED and vllm.__version__ >= "0.3.3":
|
|
113
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("orion-chat")
|
|
114
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("orion-chat-rag")
|
|
115
|
+
|
|
106
116
|
|
|
107
117
|
class VLLMModel(LLM):
|
|
108
118
|
def __init__(
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright (c) 2023-2024 DeepSeek.
|
|
2
|
+
#
|
|
3
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
4
|
+
# this software and associated documentation files (the "Software"), to deal in
|
|
5
|
+
# the Software without restriction, including without limitation the rights to
|
|
6
|
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|
7
|
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
|
8
|
+
# subject to the following conditions:
|
|
9
|
+
#
|
|
10
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
11
|
+
# copies or substantial portions of the Software.
|
|
12
|
+
#
|
|
13
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
14
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
|
15
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|
16
|
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
|
17
|
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
18
|
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# check if python version is above 3.10
|
|
22
|
+
import sys
|
|
23
|
+
|
|
24
|
+
if sys.version_info >= (3, 10):
|
|
25
|
+
print("Python version is above 3.10, patching the collections module.")
|
|
26
|
+
# Monkey patch collections
|
|
27
|
+
import collections
|
|
28
|
+
import collections.abc
|
|
29
|
+
|
|
30
|
+
for type_name in collections.abc.__all__:
|
|
31
|
+
setattr(collections, type_name, getattr(collections.abc, type_name))
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Copyright (c) 2023-2024 DeepSeek.
|
|
2
|
+
#
|
|
3
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
4
|
+
# this software and associated documentation files (the "Software"), to deal in
|
|
5
|
+
# the Software without restriction, including without limitation the rights to
|
|
6
|
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|
7
|
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
|
8
|
+
# subject to the following conditions:
|
|
9
|
+
#
|
|
10
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
11
|
+
# copies or substantial portions of the Software.
|
|
12
|
+
#
|
|
13
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
14
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
|
15
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|
16
|
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
|
17
|
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
18
|
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
19
|
+
|
|
20
|
+
from .image_processing_vlm import VLMImageProcessor
|
|
21
|
+
from .modeling_vlm import MultiModalityCausalLM
|
|
22
|
+
from .processing_vlm import VLChatProcessor
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"VLMImageProcessor",
|
|
26
|
+
"VLChatProcessor",
|
|
27
|
+
"MultiModalityCausalLM",
|
|
28
|
+
]
|