xinference 0.12.3__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +6 -6
- xinference/client/restful/restful_client.py +0 -2
- xinference/core/model.py +21 -4
- xinference/core/scheduler.py +2 -0
- xinference/core/worker.py +74 -45
- xinference/deploy/utils.py +33 -2
- xinference/model/llm/__init__.py +5 -0
- xinference/model/llm/llm_family.json +240 -1
- xinference/model/llm/llm_family.py +32 -8
- xinference/model/llm/llm_family_modelscope.json +192 -0
- xinference/model/llm/mlx/__init__.py +13 -0
- xinference/model/llm/mlx/core.py +408 -0
- xinference/model/llm/pytorch/chatglm.py +2 -9
- xinference/model/llm/pytorch/cogvlm2.py +206 -21
- xinference/model/llm/pytorch/core.py +213 -40
- xinference/model/llm/pytorch/glm4v.py +171 -15
- xinference/model/llm/pytorch/qwen_vl.py +168 -7
- xinference/model/llm/pytorch/utils.py +53 -62
- xinference/model/llm/utils.py +24 -5
- xinference/model/rerank/core.py +5 -0
- xinference/thirdparty/deepseek_vl/serve/__init__.py +13 -0
- xinference/thirdparty/deepseek_vl/serve/app_deepseek.py +510 -0
- xinference/thirdparty/deepseek_vl/serve/app_modules/__init__.py +13 -0
- xinference/thirdparty/deepseek_vl/serve/app_modules/gradio_utils.py +94 -0
- xinference/thirdparty/deepseek_vl/serve/app_modules/overwrites.py +81 -0
- xinference/thirdparty/deepseek_vl/serve/app_modules/presets.py +96 -0
- xinference/thirdparty/deepseek_vl/serve/app_modules/utils.py +229 -0
- xinference/thirdparty/deepseek_vl/serve/inference.py +170 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.0fb6f3ab.js +3 -0
- xinference/web/ui/build/static/js/main.0fb6f3ab.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0f6b391abec76271137faad13a3793fe7acc1024e8cd2269c147b653ecd3a73b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2c63090c842376cdd368c3ded88a333ef40d94785747651343040a6f7872a223.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/30a0c79d8025d6441eb75b2df5bc2750a14f30119c869ef02570d294dff65c2f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/40486e655c3c5801f087e2cf206c0b5511aaa0dfdba78046b7181bf9c17e54c5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b5507cd57f16a3a230aa0128e39fe103e928de139ea29e2679e4c64dcbba3b3a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d779b915f83f9c7b5a72515b6932fdd114f1822cef90ae01cc0d12bca59abc2d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d87824cb266194447a9c0c69ebab2d507bfc3e3148976173760d18c035e9dd26.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +1 -0
- {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/METADATA +4 -1
- {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/RECORD +55 -44
- xinference/web/ui/build/static/js/main.77dd47c3.js +0 -3
- xinference/web/ui/build/static/js/main.77dd47c3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0cd591866aa345566e0b63fb51ff2043e163a770af6fdc2f3bad395d046353e2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/37c1476717199863bbba1530e3513a9368f8f73001b75b4a85c2075956308027.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3fa1f69162f9c6dc0f6a6e21b64d49d6b8e6fa8dfa59a82cf829931c5f97d99f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/46edc1fe657dfedb2e673148332bb442c6eb98f09f2592c389209e376510afa5.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/62e257ed9016471035fa1a7da57c9e2a4250974ed566b4d1295873d747c68eb2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/82db357f3fd5b32215d747ee593f69ff06c95ad6cde37f71a96c8290aaab64c0.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bc6da27195ec4607bb472bf61f97c928ad4966fa64e4c2247661bedb7400abba.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e606671420d2937102c3c34b4b04056c11736408c1d3347b8cf42dfe61fb394b.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f118f99c22b713c678c1209c4e1dd43fe86e3f6e801a4c0c35d3bbf41fd05fe6.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +0 -1
- /xinference/web/ui/build/static/js/{main.77dd47c3.js.LICENSE.txt → main.0fb6f3ab.js.LICENSE.txt} +0 -0
- {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/LICENSE +0 -0
- {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/WHEEL +0 -0
- {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,408 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import platform
|
|
17
|
+
import sys
|
|
18
|
+
import time
|
|
19
|
+
import uuid
|
|
20
|
+
from typing import Dict, Iterable, Iterator, List, Optional, TypedDict, Union
|
|
21
|
+
|
|
22
|
+
from ....fields import max_tokens_field
|
|
23
|
+
from ....types import (
|
|
24
|
+
ChatCompletion,
|
|
25
|
+
ChatCompletionChunk,
|
|
26
|
+
ChatCompletionMessage,
|
|
27
|
+
Completion,
|
|
28
|
+
CompletionChoice,
|
|
29
|
+
CompletionChunk,
|
|
30
|
+
CompletionUsage,
|
|
31
|
+
LoRA,
|
|
32
|
+
)
|
|
33
|
+
from ..core import LLM
|
|
34
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
35
|
+
from ..utils import ChatModelMixin
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class MLXModelConfig(TypedDict, total=False):
|
|
41
|
+
revision: Optional[str]
|
|
42
|
+
max_gpu_memory: str
|
|
43
|
+
trust_remote_code: bool
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class MLXGenerateConfig(TypedDict, total=False):
|
|
47
|
+
max_tokens: int
|
|
48
|
+
temperature: float
|
|
49
|
+
repetition_penalty: Optional[float]
|
|
50
|
+
repetition_context_size: Optional[float]
|
|
51
|
+
top_p: float
|
|
52
|
+
logit_bias: Optional[Dict[int, float]]
|
|
53
|
+
stop: Optional[Union[str, List[str]]]
|
|
54
|
+
stop_token_ids: Optional[Union[int, List[int]]]
|
|
55
|
+
stream: bool
|
|
56
|
+
stream_options: Optional[Union[dict, None]]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class MLXModel(LLM):
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
model_uid: str,
|
|
63
|
+
model_family: "LLMFamilyV1",
|
|
64
|
+
model_spec: "LLMSpecV1",
|
|
65
|
+
quantization: str,
|
|
66
|
+
model_path: str,
|
|
67
|
+
model_config: Optional[MLXModelConfig] = None,
|
|
68
|
+
peft_model: Optional[List[LoRA]] = None,
|
|
69
|
+
):
|
|
70
|
+
super().__init__(model_uid, model_family, model_spec, quantization, model_path)
|
|
71
|
+
self._use_fast_tokenizer = True
|
|
72
|
+
self._model_config: MLXModelConfig = self._sanitize_model_config(model_config)
|
|
73
|
+
if peft_model is not None:
|
|
74
|
+
raise ValueError("MLX engine has not supported lora yet")
|
|
75
|
+
|
|
76
|
+
def _sanitize_model_config(
|
|
77
|
+
self, model_config: Optional[MLXModelConfig]
|
|
78
|
+
) -> MLXModelConfig:
|
|
79
|
+
if model_config is None:
|
|
80
|
+
model_config = MLXModelConfig()
|
|
81
|
+
model_config.setdefault("revision", self.model_spec.model_revision)
|
|
82
|
+
model_config.setdefault("trust_remote_code", True)
|
|
83
|
+
return model_config
|
|
84
|
+
|
|
85
|
+
def _sanitize_generate_config(
|
|
86
|
+
self,
|
|
87
|
+
generate_config: Optional[MLXGenerateConfig],
|
|
88
|
+
) -> MLXGenerateConfig:
|
|
89
|
+
if generate_config is None:
|
|
90
|
+
generate_config = MLXGenerateConfig()
|
|
91
|
+
|
|
92
|
+
generate_config.setdefault("max_tokens", max_tokens_field.default)
|
|
93
|
+
# default config is adapted from
|
|
94
|
+
# https://github.com/ml-explore/mlx-examples/blob/f212b770d8b5143e23102eda20400ae43340f844/llms/mlx_lm/utils.py#L129
|
|
95
|
+
generate_config.setdefault("temperature", 0.0)
|
|
96
|
+
generate_config.setdefault("repetition_penalty", None)
|
|
97
|
+
generate_config.setdefault("repetition_context_size", 20)
|
|
98
|
+
generate_config.setdefault("top_p", 1.0)
|
|
99
|
+
generate_config.setdefault("logit_bias", None)
|
|
100
|
+
return generate_config
|
|
101
|
+
|
|
102
|
+
def _load_model(self, **kwargs):
|
|
103
|
+
try:
|
|
104
|
+
from mlx_lm import load
|
|
105
|
+
except ImportError:
|
|
106
|
+
error_message = "Failed to import module 'mlx_lm'"
|
|
107
|
+
installation_guide = [
|
|
108
|
+
"Please make sure 'mlx_lm' is installed. ",
|
|
109
|
+
"You can install it by `pip install mlx_lm`\n",
|
|
110
|
+
]
|
|
111
|
+
|
|
112
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
113
|
+
|
|
114
|
+
tokenizer_config = dict(
|
|
115
|
+
use_fast=self._use_fast_tokenizer,
|
|
116
|
+
trust_remote_code=kwargs["trust_remote_code"],
|
|
117
|
+
revision=kwargs["revision"],
|
|
118
|
+
)
|
|
119
|
+
logger.debug(
|
|
120
|
+
"loading model with tokenizer config: %s, model config: %s",
|
|
121
|
+
tokenizer_config,
|
|
122
|
+
self._model_config,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
return load(
|
|
126
|
+
self.model_path,
|
|
127
|
+
tokenizer_config=tokenizer_config,
|
|
128
|
+
model_config=self._model_config,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def load(self):
|
|
132
|
+
kwargs = {}
|
|
133
|
+
kwargs["revision"] = self._model_config.get(
|
|
134
|
+
"revision", self.model_spec.model_revision
|
|
135
|
+
)
|
|
136
|
+
kwargs["trust_remote_code"] = self._model_config.get("trust_remote_code")
|
|
137
|
+
|
|
138
|
+
self._model, self._tokenizer = self._load_model(**kwargs)
|
|
139
|
+
|
|
140
|
+
@classmethod
|
|
141
|
+
def match(
|
|
142
|
+
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
143
|
+
) -> bool:
|
|
144
|
+
if llm_spec.model_format not in ["mlx"]:
|
|
145
|
+
return False
|
|
146
|
+
if sys.platform != "darwin" or platform.processor() != "arm":
|
|
147
|
+
# only work for Mac M chips
|
|
148
|
+
return False
|
|
149
|
+
if "generate" not in llm_family.model_ability:
|
|
150
|
+
return False
|
|
151
|
+
return True
|
|
152
|
+
|
|
153
|
+
def _generate_stream(self, prompt: str, kwargs: MLXGenerateConfig):
|
|
154
|
+
import mlx.core as mx
|
|
155
|
+
from mlx_lm.utils import generate_step
|
|
156
|
+
|
|
157
|
+
model = self._model
|
|
158
|
+
model_uid = self.model_uid
|
|
159
|
+
tokenizer = self._tokenizer
|
|
160
|
+
max_tokens = kwargs["max_tokens"]
|
|
161
|
+
chunk_id = str(uuid.uuid4())
|
|
162
|
+
stop_token_ids = kwargs.get("stop_token_ids", [])
|
|
163
|
+
stream = kwargs.get("stream", False)
|
|
164
|
+
stream_options = kwargs.pop("stream_options", None)
|
|
165
|
+
include_usage = (
|
|
166
|
+
stream_options["include_usage"]
|
|
167
|
+
if isinstance(stream_options, dict)
|
|
168
|
+
else False
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
|
172
|
+
input_echo_len = len(prompt_tokens)
|
|
173
|
+
|
|
174
|
+
i = 0
|
|
175
|
+
start = time.time()
|
|
176
|
+
output = ""
|
|
177
|
+
for (token, _), i in zip(
|
|
178
|
+
generate_step(
|
|
179
|
+
prompt_tokens,
|
|
180
|
+
model,
|
|
181
|
+
temp=kwargs["temperature"],
|
|
182
|
+
repetition_penalty=kwargs["repetition_penalty"],
|
|
183
|
+
repetition_context_size=kwargs["repetition_context_size"],
|
|
184
|
+
top_p=kwargs["top_p"],
|
|
185
|
+
logit_bias=kwargs["logit_bias"],
|
|
186
|
+
),
|
|
187
|
+
range(max_tokens),
|
|
188
|
+
):
|
|
189
|
+
if token == tokenizer.eos_token_id or token in stop_token_ids: # type: ignore
|
|
190
|
+
break
|
|
191
|
+
|
|
192
|
+
# Yield the last segment if streaming
|
|
193
|
+
out = tokenizer.decode(
|
|
194
|
+
token,
|
|
195
|
+
skip_special_tokens=True,
|
|
196
|
+
spaces_between_special_tokens=False,
|
|
197
|
+
clean_up_tokenization_spaces=True,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
if stream:
|
|
201
|
+
# this special character is mainly for qwen
|
|
202
|
+
out = out.strip("�")
|
|
203
|
+
output = out
|
|
204
|
+
else:
|
|
205
|
+
output += out
|
|
206
|
+
|
|
207
|
+
completion_choice = CompletionChoice(
|
|
208
|
+
text=output, index=0, logprobs=None, finish_reason=None
|
|
209
|
+
)
|
|
210
|
+
completion_chunk = CompletionChunk(
|
|
211
|
+
id=chunk_id,
|
|
212
|
+
object="text_completion",
|
|
213
|
+
created=int(time.time()),
|
|
214
|
+
model=model_uid,
|
|
215
|
+
choices=[completion_choice],
|
|
216
|
+
)
|
|
217
|
+
completion_usage = CompletionUsage(
|
|
218
|
+
prompt_tokens=input_echo_len,
|
|
219
|
+
completion_tokens=i,
|
|
220
|
+
total_tokens=(input_echo_len + i),
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
yield completion_chunk, completion_usage
|
|
224
|
+
|
|
225
|
+
logger.info(
|
|
226
|
+
f"Average generation speed: {i / (time.time() - start):.2f} tokens/s."
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if i == max_tokens - 1:
|
|
230
|
+
finish_reason = "length"
|
|
231
|
+
else:
|
|
232
|
+
finish_reason = "stop"
|
|
233
|
+
|
|
234
|
+
if stream:
|
|
235
|
+
completion_choice = CompletionChoice(
|
|
236
|
+
text="", index=0, logprobs=None, finish_reason=finish_reason
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
completion_choice = CompletionChoice(
|
|
240
|
+
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
completion_chunk = CompletionChunk(
|
|
244
|
+
id=chunk_id,
|
|
245
|
+
object="text_completion",
|
|
246
|
+
created=int(time.time()),
|
|
247
|
+
model=model_uid,
|
|
248
|
+
choices=[completion_choice],
|
|
249
|
+
)
|
|
250
|
+
completion_usage = CompletionUsage(
|
|
251
|
+
prompt_tokens=input_echo_len,
|
|
252
|
+
completion_tokens=i,
|
|
253
|
+
total_tokens=(input_echo_len + i),
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
yield completion_chunk, completion_usage
|
|
257
|
+
|
|
258
|
+
if include_usage:
|
|
259
|
+
completion_chunk = CompletionChunk(
|
|
260
|
+
id=chunk_id,
|
|
261
|
+
object="text_completion",
|
|
262
|
+
created=int(time.time()),
|
|
263
|
+
model=model_uid,
|
|
264
|
+
choices=[],
|
|
265
|
+
)
|
|
266
|
+
completion_usage = CompletionUsage(
|
|
267
|
+
prompt_tokens=input_echo_len,
|
|
268
|
+
completion_tokens=i,
|
|
269
|
+
total_tokens=(input_echo_len + i),
|
|
270
|
+
)
|
|
271
|
+
yield completion_chunk, completion_usage
|
|
272
|
+
|
|
273
|
+
def generate(
|
|
274
|
+
self, prompt: str, generate_config: Optional[MLXGenerateConfig] = None
|
|
275
|
+
) -> Union[Completion, Iterator[CompletionChunk]]:
|
|
276
|
+
def generator_wrapper(
|
|
277
|
+
prompt: str, generate_config: MLXGenerateConfig
|
|
278
|
+
) -> Iterator[CompletionChunk]:
|
|
279
|
+
for completion_chunk, completion_usage in self._generate_stream(
|
|
280
|
+
prompt,
|
|
281
|
+
generate_config,
|
|
282
|
+
):
|
|
283
|
+
completion_chunk["usage"] = completion_usage
|
|
284
|
+
yield completion_chunk
|
|
285
|
+
|
|
286
|
+
logger.debug(
|
|
287
|
+
"Enter generate, prompt: %s, generate config: %s", prompt, generate_config
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
generate_config = self._sanitize_generate_config(generate_config)
|
|
291
|
+
|
|
292
|
+
assert self._model is not None
|
|
293
|
+
assert self._tokenizer is not None
|
|
294
|
+
|
|
295
|
+
stream = generate_config.get("stream", False)
|
|
296
|
+
if not stream:
|
|
297
|
+
for completion_chunk, completion_usage in self._generate_stream(
|
|
298
|
+
prompt,
|
|
299
|
+
generate_config,
|
|
300
|
+
):
|
|
301
|
+
pass
|
|
302
|
+
completion = Completion(
|
|
303
|
+
id=completion_chunk["id"],
|
|
304
|
+
object=completion_chunk["object"],
|
|
305
|
+
created=completion_chunk["created"],
|
|
306
|
+
model=completion_chunk["model"],
|
|
307
|
+
choices=completion_chunk["choices"],
|
|
308
|
+
usage=completion_usage,
|
|
309
|
+
)
|
|
310
|
+
return completion
|
|
311
|
+
else:
|
|
312
|
+
return generator_wrapper(prompt, generate_config)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class MLXChatModel(MLXModel, ChatModelMixin):
|
|
316
|
+
def __init__(
|
|
317
|
+
self,
|
|
318
|
+
model_uid: str,
|
|
319
|
+
model_family: "LLMFamilyV1",
|
|
320
|
+
model_spec: "LLMSpecV1",
|
|
321
|
+
quantization: str,
|
|
322
|
+
model_path: str,
|
|
323
|
+
model_config: Optional[MLXModelConfig] = None,
|
|
324
|
+
peft_model: Optional[List[LoRA]] = None,
|
|
325
|
+
):
|
|
326
|
+
super().__init__(
|
|
327
|
+
model_uid,
|
|
328
|
+
model_family,
|
|
329
|
+
model_spec,
|
|
330
|
+
quantization,
|
|
331
|
+
model_path,
|
|
332
|
+
model_config,
|
|
333
|
+
peft_model,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
def _sanitize_generate_config(
|
|
337
|
+
self,
|
|
338
|
+
generate_config: Optional[MLXGenerateConfig],
|
|
339
|
+
) -> MLXGenerateConfig:
|
|
340
|
+
generate_config = super()._sanitize_generate_config(generate_config)
|
|
341
|
+
if (
|
|
342
|
+
(not generate_config.get("stop"))
|
|
343
|
+
and self.model_family.prompt_style
|
|
344
|
+
and self.model_family.prompt_style.stop
|
|
345
|
+
):
|
|
346
|
+
generate_config["stop"] = self.model_family.prompt_style.stop.copy()
|
|
347
|
+
if (
|
|
348
|
+
generate_config.get("stop_token_ids", None) is None
|
|
349
|
+
and self.model_family.prompt_style
|
|
350
|
+
and self.model_family.prompt_style.stop_token_ids
|
|
351
|
+
):
|
|
352
|
+
generate_config[
|
|
353
|
+
"stop_token_ids"
|
|
354
|
+
] = self.model_family.prompt_style.stop_token_ids.copy()
|
|
355
|
+
|
|
356
|
+
return generate_config
|
|
357
|
+
|
|
358
|
+
@classmethod
|
|
359
|
+
def match(
|
|
360
|
+
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
361
|
+
) -> bool:
|
|
362
|
+
if llm_spec.model_format not in ["mlx"]:
|
|
363
|
+
return False
|
|
364
|
+
if sys.platform != "darwin" or platform.processor() != "arm":
|
|
365
|
+
# only work for Mac M chips
|
|
366
|
+
return False
|
|
367
|
+
if "chat" not in llm_family.model_ability:
|
|
368
|
+
return False
|
|
369
|
+
return True
|
|
370
|
+
|
|
371
|
+
def chat(
|
|
372
|
+
self,
|
|
373
|
+
prompt: str,
|
|
374
|
+
system_prompt: Optional[str] = None,
|
|
375
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
376
|
+
generate_config: Optional[MLXGenerateConfig] = None,
|
|
377
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
378
|
+
tools = generate_config.pop("tools", []) if generate_config else None # type: ignore
|
|
379
|
+
full_prompt = self.get_full_prompt(
|
|
380
|
+
self.model_family, prompt, system_prompt, chat_history, tools
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
generate_config = self._sanitize_generate_config(generate_config)
|
|
384
|
+
# TODO(codingl2k1): qwen hacky to set stop for function call.
|
|
385
|
+
model_family = self.model_family.model_family or self.model_family.model_name
|
|
386
|
+
if tools and model_family in ["qwen-chat", "qwen1.5-chat"]:
|
|
387
|
+
stop = generate_config.get("stop")
|
|
388
|
+
if isinstance(stop, str):
|
|
389
|
+
generate_config["stop"] = [stop, "Observation:"]
|
|
390
|
+
elif isinstance(stop, Iterable):
|
|
391
|
+
assert not isinstance(stop, str)
|
|
392
|
+
generate_config["stop"] = list(stop) + ["Observation:"]
|
|
393
|
+
else:
|
|
394
|
+
generate_config["stop"] = "Observation:"
|
|
395
|
+
|
|
396
|
+
stream = generate_config.get("stream", False)
|
|
397
|
+
if stream:
|
|
398
|
+
it = self.generate(full_prompt, generate_config)
|
|
399
|
+
assert isinstance(it, Iterator)
|
|
400
|
+
return self._to_chat_completion_chunks(it)
|
|
401
|
+
else:
|
|
402
|
+
c = self.generate(full_prompt, generate_config)
|
|
403
|
+
assert not isinstance(c, Iterator)
|
|
404
|
+
if tools:
|
|
405
|
+
return self._tool_calls_completion(
|
|
406
|
+
self.model_family, self.model_uid, c, tools
|
|
407
|
+
)
|
|
408
|
+
return self._to_chat_completion(c)
|
|
@@ -29,6 +29,7 @@ from ....types import (
|
|
|
29
29
|
PytorchGenerateConfig,
|
|
30
30
|
)
|
|
31
31
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
32
|
+
from ..utils import GLM4_TOOL_CALL_FAMILY
|
|
32
33
|
from .core import PytorchChatModel, PytorchModelConfig
|
|
33
34
|
|
|
34
35
|
|
|
@@ -103,7 +104,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
103
104
|
if tools is None:
|
|
104
105
|
return False
|
|
105
106
|
tool_choice = generate_config.pop("tool_choice", "none")
|
|
106
|
-
if self.model_family.model_name
|
|
107
|
+
if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
|
|
107
108
|
chat_history[:] = self.process_messages(
|
|
108
109
|
chat_history, tools=tools, tool_choice=tool_choice
|
|
109
110
|
)
|
|
@@ -335,14 +336,6 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
335
336
|
),
|
|
336
337
|
)
|
|
337
338
|
|
|
338
|
-
@staticmethod
|
|
339
|
-
def require_attention_mask():
|
|
340
|
-
"""
|
|
341
|
-
GLM4 needs to use attention mask and position ids during inference.
|
|
342
|
-
Otherwise, the inference result would be not available.
|
|
343
|
-
"""
|
|
344
|
-
return True
|
|
345
|
-
|
|
346
339
|
def prepare_sanitize_generate_config(self, req: InferenceRequest):
|
|
347
340
|
"""
|
|
348
341
|
Set temperature and top_p to 0.8 by default
|
|
@@ -23,6 +23,7 @@ import requests
|
|
|
23
23
|
import torch
|
|
24
24
|
from PIL import Image
|
|
25
25
|
|
|
26
|
+
from ....core.scheduler import InferenceRequest
|
|
26
27
|
from ....model.utils import select_device
|
|
27
28
|
from ....types import (
|
|
28
29
|
ChatCompletion,
|
|
@@ -35,11 +36,30 @@ from ....types import (
|
|
|
35
36
|
)
|
|
36
37
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
37
38
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
39
|
+
from .utils import get_max_src_len
|
|
38
40
|
|
|
39
41
|
logger = logging.getLogger(__name__)
|
|
40
42
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
+
|
|
44
|
+
LANGUAGE_TOKEN_TYPE = 0
|
|
45
|
+
VISION_TOKEN_TYPE = 1
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def recur_move_to(item, tgt, criterion_func):
|
|
49
|
+
"""
|
|
50
|
+
This function is copied from https://github.com/THUDM/CogVLM2/blob/main/basic_demo/cli_demo_batch_inference.py
|
|
51
|
+
"""
|
|
52
|
+
if criterion_func(item):
|
|
53
|
+
device_copy = item.to(tgt)
|
|
54
|
+
return device_copy
|
|
55
|
+
elif isinstance(item, list):
|
|
56
|
+
return [recur_move_to(v, tgt, criterion_func) for v in item]
|
|
57
|
+
elif isinstance(item, tuple):
|
|
58
|
+
return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
|
|
59
|
+
elif isinstance(item, dict):
|
|
60
|
+
return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
|
|
61
|
+
else:
|
|
62
|
+
return item
|
|
43
63
|
|
|
44
64
|
|
|
45
65
|
class CogVLM2Model(PytorchChatModel):
|
|
@@ -176,11 +196,33 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
176
196
|
content["image_url"]["url"]
|
|
177
197
|
)
|
|
178
198
|
assistant = chat_history[i + 1]["content"]
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
query = query + f" {assistant}"
|
|
199
|
+
history.append((user, assistant))
|
|
200
|
+
query = assistant # type: ignore
|
|
182
201
|
return query, history, [pixel_values]
|
|
183
202
|
|
|
203
|
+
def get_query_and_history(
|
|
204
|
+
self,
|
|
205
|
+
prompt: Union[str, List[Dict]],
|
|
206
|
+
system_prompt: Optional[str] = None,
|
|
207
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
208
|
+
):
|
|
209
|
+
content, image = self._message_content_to_cogvlm2(prompt)
|
|
210
|
+
|
|
211
|
+
history = []
|
|
212
|
+
history_image = None
|
|
213
|
+
if chat_history:
|
|
214
|
+
query, history, history_image = self._history_content_to_cogvlm2(
|
|
215
|
+
system_prompt, chat_history # type: ignore
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
if image and history_image:
|
|
219
|
+
history = []
|
|
220
|
+
query = content
|
|
221
|
+
else:
|
|
222
|
+
image = image if image else history_image
|
|
223
|
+
query = content
|
|
224
|
+
return query, image, history
|
|
225
|
+
|
|
184
226
|
def chat(
|
|
185
227
|
self,
|
|
186
228
|
prompt: Union[str, List[Dict]],
|
|
@@ -198,22 +240,9 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
198
240
|
else 512,
|
|
199
241
|
}
|
|
200
242
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
query = ""
|
|
205
|
-
history_image = None
|
|
206
|
-
if chat_history:
|
|
207
|
-
query, history, history_image = self._history_content_to_cogvlm2(
|
|
208
|
-
system_prompt, chat_history
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
if image and history_image:
|
|
212
|
-
history = []
|
|
213
|
-
query = system_prompt + f" USER: {content} ASSISTANT:"
|
|
214
|
-
else:
|
|
215
|
-
image = image if image else history_image
|
|
216
|
-
query = query + f" USER: {content} ASSISTANT:"
|
|
243
|
+
query, image, history = self.get_query_and_history(
|
|
244
|
+
prompt, system_prompt=system_prompt, chat_history=chat_history
|
|
245
|
+
)
|
|
217
246
|
|
|
218
247
|
input_by_model = self._model.build_conversation_input_ids(
|
|
219
248
|
self._tokenizer,
|
|
@@ -319,3 +348,159 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
319
348
|
),
|
|
320
349
|
)
|
|
321
350
|
yield chunk
|
|
351
|
+
|
|
352
|
+
@staticmethod
|
|
353
|
+
def build_position_ids(x, attention_mask=None):
|
|
354
|
+
"""
|
|
355
|
+
Copied from https://huggingface.co/THUDM/cogvlm2-llama3-chinese-chat-19B-int4/blob/main/modeling_cogvlm.py
|
|
356
|
+
"""
|
|
357
|
+
# Fix: 参考官方开源代码
|
|
358
|
+
if attention_mask is not None:
|
|
359
|
+
tmp = x.clone()
|
|
360
|
+
tmp[~(attention_mask.bool())] = -1
|
|
361
|
+
else:
|
|
362
|
+
tmp = x.clone()
|
|
363
|
+
# image boi eoi token as LANGUAGE_TOKEN_TYPE
|
|
364
|
+
is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
|
|
365
|
+
is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (
|
|
366
|
+
tmp[:, :-1] == LANGUAGE_TOKEN_TYPE
|
|
367
|
+
)
|
|
368
|
+
is_boi_eoi[:, 0] |= tmp[:, 0] == VISION_TOKEN_TYPE
|
|
369
|
+
is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (
|
|
370
|
+
tmp[:, 1:] == LANGUAGE_TOKEN_TYPE
|
|
371
|
+
)
|
|
372
|
+
is_boi_eoi[:, -1] |= tmp[:, -1] == VISION_TOKEN_TYPE
|
|
373
|
+
tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
|
|
374
|
+
# final position ids
|
|
375
|
+
y = torch.zeros_like(x, dtype=torch.long)
|
|
376
|
+
y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
|
|
377
|
+
(tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
|
|
378
|
+
)
|
|
379
|
+
y = y.cumsum(dim=-1)
|
|
380
|
+
return y
|
|
381
|
+
|
|
382
|
+
def get_dtype(self):
|
|
383
|
+
return self._torch_type
|
|
384
|
+
|
|
385
|
+
def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
|
|
386
|
+
query, image, history = self.get_query_and_history(
|
|
387
|
+
prompt, system_prompt=system_prompt, chat_history=chat_history
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
input_by_model: dict = self._model.build_conversation_input_ids(
|
|
391
|
+
self._tokenizer,
|
|
392
|
+
query=query,
|
|
393
|
+
history=history,
|
|
394
|
+
images=image,
|
|
395
|
+
template_version="chat",
|
|
396
|
+
)
|
|
397
|
+
return {
|
|
398
|
+
"input_ids": input_by_model["input_ids"], # seq_len
|
|
399
|
+
"token_type_ids": input_by_model["token_type_ids"], # seq_len
|
|
400
|
+
"attention_mask": input_by_model["attention_mask"], # seq_len
|
|
401
|
+
"images": input_by_model["images"],
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
def prepare_sanitize_generate_config(self, req: InferenceRequest):
|
|
405
|
+
"""
|
|
406
|
+
See https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B/blob/main/generation_config.json
|
|
407
|
+
"""
|
|
408
|
+
raw_config = req.inference_kwargs.get("raw_params", {})
|
|
409
|
+
temperature = raw_config.get("temperature", None)
|
|
410
|
+
if temperature is None:
|
|
411
|
+
raw_config["temperature"] = 0.6
|
|
412
|
+
top_p = raw_config.get("top_p", None)
|
|
413
|
+
if top_p is None:
|
|
414
|
+
raw_config["top_p"] = 0.9
|
|
415
|
+
return raw_config
|
|
416
|
+
|
|
417
|
+
def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
|
|
418
|
+
context_len = self.get_context_len()
|
|
419
|
+
assert isinstance(prompts[0], dict)
|
|
420
|
+
images = []
|
|
421
|
+
max_length = float("-inf")
|
|
422
|
+
for i, feature in enumerate(prompts):
|
|
423
|
+
req = req_list[i]
|
|
424
|
+
if "images" in feature:
|
|
425
|
+
images.append(feature.pop("images", None))
|
|
426
|
+
max_src_len = get_max_src_len(context_len, req)
|
|
427
|
+
input_ids = feature["input_ids"][-max_src_len:]
|
|
428
|
+
req.prompt_tokens = input_ids.tolist()
|
|
429
|
+
feature["input_ids"] = input_ids
|
|
430
|
+
feature["token_type_ids"] = feature["token_type_ids"][-max_src_len:]
|
|
431
|
+
feature["attention_mask"] = feature["attention_mask"][-max_src_len:]
|
|
432
|
+
req.extra_kwargs["attention_mask_seq_len"] = feature[
|
|
433
|
+
"attention_mask"
|
|
434
|
+
].shape[0]
|
|
435
|
+
max_length = max(len(input_ids), max_length)
|
|
436
|
+
|
|
437
|
+
def pad_to_max_length_internal(feature, max_len, idx):
|
|
438
|
+
padding_length = max_len - len(feature["input_ids"])
|
|
439
|
+
req_list[idx].padding_len = padding_length
|
|
440
|
+
feature["input_ids"] = torch.cat(
|
|
441
|
+
[torch.full((padding_length,), 0), feature["input_ids"]]
|
|
442
|
+
)
|
|
443
|
+
feature["token_type_ids"] = torch.cat(
|
|
444
|
+
[
|
|
445
|
+
torch.zeros(padding_length, dtype=torch.long),
|
|
446
|
+
feature["token_type_ids"],
|
|
447
|
+
]
|
|
448
|
+
)
|
|
449
|
+
feature["attention_mask"] = torch.cat(
|
|
450
|
+
[
|
|
451
|
+
torch.zeros(padding_length, dtype=torch.long),
|
|
452
|
+
feature["attention_mask"],
|
|
453
|
+
]
|
|
454
|
+
)
|
|
455
|
+
return feature
|
|
456
|
+
|
|
457
|
+
features = [
|
|
458
|
+
pad_to_max_length_internal(feature, max_length, i)
|
|
459
|
+
for i, feature in enumerate(prompts)
|
|
460
|
+
]
|
|
461
|
+
batch = {
|
|
462
|
+
key: torch.stack([feature[key] for feature in features])
|
|
463
|
+
for key in features[0].keys()
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
position_ids = self.build_position_ids(batch["token_type_ids"])
|
|
467
|
+
batch["position_ids"] = position_ids
|
|
468
|
+
|
|
469
|
+
for i in range(len(prompts)):
|
|
470
|
+
req = req_list[i]
|
|
471
|
+
req.extra_kwargs["max_position_id"] = position_ids[i : i + 1, -1].item()
|
|
472
|
+
|
|
473
|
+
if images:
|
|
474
|
+
batch["images"] = images
|
|
475
|
+
|
|
476
|
+
batch = recur_move_to(
|
|
477
|
+
batch, self._device, lambda x: isinstance(x, torch.Tensor)
|
|
478
|
+
)
|
|
479
|
+
dtype = self.get_dtype()
|
|
480
|
+
if dtype:
|
|
481
|
+
batch = recur_move_to(
|
|
482
|
+
batch,
|
|
483
|
+
dtype,
|
|
484
|
+
lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x),
|
|
485
|
+
)
|
|
486
|
+
return batch
|
|
487
|
+
|
|
488
|
+
def build_decode_token_type_ids(
|
|
489
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
490
|
+
):
|
|
491
|
+
token_type_ids = torch.full(
|
|
492
|
+
(batch_size, 1), fill_value=1, dtype=torch.long, device=self._device
|
|
493
|
+
)
|
|
494
|
+
return token_type_ids
|
|
495
|
+
|
|
496
|
+
def build_decode_position_ids(
|
|
497
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
498
|
+
):
|
|
499
|
+
tmp = []
|
|
500
|
+
for r in reqs:
|
|
501
|
+
r.extra_kwargs["max_position_id"] += 1
|
|
502
|
+
tmp.append(r.extra_kwargs["max_position_id"])
|
|
503
|
+
position_ids = torch.as_tensor(
|
|
504
|
+
tmp, device=self._device, dtype=torch.long
|
|
505
|
+
).unsqueeze(1)
|
|
506
|
+
return position_ids
|