xinference 0.12.3__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (71) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +6 -6
  3. xinference/client/restful/restful_client.py +0 -2
  4. xinference/core/model.py +21 -4
  5. xinference/core/scheduler.py +2 -0
  6. xinference/core/worker.py +74 -45
  7. xinference/deploy/utils.py +33 -2
  8. xinference/model/llm/__init__.py +5 -0
  9. xinference/model/llm/llm_family.json +240 -1
  10. xinference/model/llm/llm_family.py +32 -8
  11. xinference/model/llm/llm_family_modelscope.json +192 -0
  12. xinference/model/llm/mlx/__init__.py +13 -0
  13. xinference/model/llm/mlx/core.py +408 -0
  14. xinference/model/llm/pytorch/chatglm.py +2 -9
  15. xinference/model/llm/pytorch/cogvlm2.py +206 -21
  16. xinference/model/llm/pytorch/core.py +213 -40
  17. xinference/model/llm/pytorch/glm4v.py +171 -15
  18. xinference/model/llm/pytorch/qwen_vl.py +168 -7
  19. xinference/model/llm/pytorch/utils.py +53 -62
  20. xinference/model/llm/utils.py +24 -5
  21. xinference/model/rerank/core.py +5 -0
  22. xinference/thirdparty/deepseek_vl/serve/__init__.py +13 -0
  23. xinference/thirdparty/deepseek_vl/serve/app_deepseek.py +510 -0
  24. xinference/thirdparty/deepseek_vl/serve/app_modules/__init__.py +13 -0
  25. xinference/thirdparty/deepseek_vl/serve/app_modules/gradio_utils.py +94 -0
  26. xinference/thirdparty/deepseek_vl/serve/app_modules/overwrites.py +81 -0
  27. xinference/thirdparty/deepseek_vl/serve/app_modules/presets.py +96 -0
  28. xinference/thirdparty/deepseek_vl/serve/app_modules/utils.py +229 -0
  29. xinference/thirdparty/deepseek_vl/serve/inference.py +170 -0
  30. xinference/web/ui/build/asset-manifest.json +3 -3
  31. xinference/web/ui/build/index.html +1 -1
  32. xinference/web/ui/build/static/js/main.0fb6f3ab.js +3 -0
  33. xinference/web/ui/build/static/js/main.0fb6f3ab.js.map +1 -0
  34. xinference/web/ui/node_modules/.cache/babel-loader/0f6b391abec76271137faad13a3793fe7acc1024e8cd2269c147b653ecd3a73b.json +1 -0
  35. xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +1 -0
  36. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +1 -0
  37. xinference/web/ui/node_modules/.cache/babel-loader/2c63090c842376cdd368c3ded88a333ef40d94785747651343040a6f7872a223.json +1 -0
  38. xinference/web/ui/node_modules/.cache/babel-loader/30a0c79d8025d6441eb75b2df5bc2750a14f30119c869ef02570d294dff65c2f.json +1 -0
  39. xinference/web/ui/node_modules/.cache/babel-loader/40486e655c3c5801f087e2cf206c0b5511aaa0dfdba78046b7181bf9c17e54c5.json +1 -0
  40. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +1 -0
  41. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +1 -0
  42. xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +1 -0
  43. xinference/web/ui/node_modules/.cache/babel-loader/b5507cd57f16a3a230aa0128e39fe103e928de139ea29e2679e4c64dcbba3b3a.json +1 -0
  44. xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +1 -0
  45. xinference/web/ui/node_modules/.cache/babel-loader/d779b915f83f9c7b5a72515b6932fdd114f1822cef90ae01cc0d12bca59abc2d.json +1 -0
  46. xinference/web/ui/node_modules/.cache/babel-loader/d87824cb266194447a9c0c69ebab2d507bfc3e3148976173760d18c035e9dd26.json +1 -0
  47. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +1 -0
  48. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +1 -0
  49. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/METADATA +4 -1
  50. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/RECORD +55 -44
  51. xinference/web/ui/build/static/js/main.77dd47c3.js +0 -3
  52. xinference/web/ui/build/static/js/main.77dd47c3.js.map +0 -1
  53. xinference/web/ui/node_modules/.cache/babel-loader/0cd591866aa345566e0b63fb51ff2043e163a770af6fdc2f3bad395d046353e2.json +0 -1
  54. xinference/web/ui/node_modules/.cache/babel-loader/37c1476717199863bbba1530e3513a9368f8f73001b75b4a85c2075956308027.json +0 -1
  55. xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +0 -1
  56. xinference/web/ui/node_modules/.cache/babel-loader/3fa1f69162f9c6dc0f6a6e21b64d49d6b8e6fa8dfa59a82cf829931c5f97d99f.json +0 -1
  57. xinference/web/ui/node_modules/.cache/babel-loader/46edc1fe657dfedb2e673148332bb442c6eb98f09f2592c389209e376510afa5.json +0 -1
  58. xinference/web/ui/node_modules/.cache/babel-loader/62e257ed9016471035fa1a7da57c9e2a4250974ed566b4d1295873d747c68eb2.json +0 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/82db357f3fd5b32215d747ee593f69ff06c95ad6cde37f71a96c8290aaab64c0.json +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/bc6da27195ec4607bb472bf61f97c928ad4966fa64e4c2247661bedb7400abba.json +0 -1
  63. xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +0 -1
  64. xinference/web/ui/node_modules/.cache/babel-loader/e606671420d2937102c3c34b4b04056c11736408c1d3347b8cf42dfe61fb394b.json +0 -1
  65. xinference/web/ui/node_modules/.cache/babel-loader/f118f99c22b713c678c1209c4e1dd43fe86e3f6e801a4c0c35d3bbf41fd05fe6.json +0 -1
  66. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +0 -1
  67. /xinference/web/ui/build/static/js/{main.77dd47c3.js.LICENSE.txt → main.0fb6f3ab.js.LICENSE.txt} +0 -0
  68. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/LICENSE +0 -0
  69. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/WHEEL +0 -0
  70. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/entry_points.txt +0 -0
  71. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,408 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import platform
17
+ import sys
18
+ import time
19
+ import uuid
20
+ from typing import Dict, Iterable, Iterator, List, Optional, TypedDict, Union
21
+
22
+ from ....fields import max_tokens_field
23
+ from ....types import (
24
+ ChatCompletion,
25
+ ChatCompletionChunk,
26
+ ChatCompletionMessage,
27
+ Completion,
28
+ CompletionChoice,
29
+ CompletionChunk,
30
+ CompletionUsage,
31
+ LoRA,
32
+ )
33
+ from ..core import LLM
34
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
35
+ from ..utils import ChatModelMixin
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class MLXModelConfig(TypedDict, total=False):
41
+ revision: Optional[str]
42
+ max_gpu_memory: str
43
+ trust_remote_code: bool
44
+
45
+
46
+ class MLXGenerateConfig(TypedDict, total=False):
47
+ max_tokens: int
48
+ temperature: float
49
+ repetition_penalty: Optional[float]
50
+ repetition_context_size: Optional[float]
51
+ top_p: float
52
+ logit_bias: Optional[Dict[int, float]]
53
+ stop: Optional[Union[str, List[str]]]
54
+ stop_token_ids: Optional[Union[int, List[int]]]
55
+ stream: bool
56
+ stream_options: Optional[Union[dict, None]]
57
+
58
+
59
+ class MLXModel(LLM):
60
+ def __init__(
61
+ self,
62
+ model_uid: str,
63
+ model_family: "LLMFamilyV1",
64
+ model_spec: "LLMSpecV1",
65
+ quantization: str,
66
+ model_path: str,
67
+ model_config: Optional[MLXModelConfig] = None,
68
+ peft_model: Optional[List[LoRA]] = None,
69
+ ):
70
+ super().__init__(model_uid, model_family, model_spec, quantization, model_path)
71
+ self._use_fast_tokenizer = True
72
+ self._model_config: MLXModelConfig = self._sanitize_model_config(model_config)
73
+ if peft_model is not None:
74
+ raise ValueError("MLX engine has not supported lora yet")
75
+
76
+ def _sanitize_model_config(
77
+ self, model_config: Optional[MLXModelConfig]
78
+ ) -> MLXModelConfig:
79
+ if model_config is None:
80
+ model_config = MLXModelConfig()
81
+ model_config.setdefault("revision", self.model_spec.model_revision)
82
+ model_config.setdefault("trust_remote_code", True)
83
+ return model_config
84
+
85
+ def _sanitize_generate_config(
86
+ self,
87
+ generate_config: Optional[MLXGenerateConfig],
88
+ ) -> MLXGenerateConfig:
89
+ if generate_config is None:
90
+ generate_config = MLXGenerateConfig()
91
+
92
+ generate_config.setdefault("max_tokens", max_tokens_field.default)
93
+ # default config is adapted from
94
+ # https://github.com/ml-explore/mlx-examples/blob/f212b770d8b5143e23102eda20400ae43340f844/llms/mlx_lm/utils.py#L129
95
+ generate_config.setdefault("temperature", 0.0)
96
+ generate_config.setdefault("repetition_penalty", None)
97
+ generate_config.setdefault("repetition_context_size", 20)
98
+ generate_config.setdefault("top_p", 1.0)
99
+ generate_config.setdefault("logit_bias", None)
100
+ return generate_config
101
+
102
+ def _load_model(self, **kwargs):
103
+ try:
104
+ from mlx_lm import load
105
+ except ImportError:
106
+ error_message = "Failed to import module 'mlx_lm'"
107
+ installation_guide = [
108
+ "Please make sure 'mlx_lm' is installed. ",
109
+ "You can install it by `pip install mlx_lm`\n",
110
+ ]
111
+
112
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
113
+
114
+ tokenizer_config = dict(
115
+ use_fast=self._use_fast_tokenizer,
116
+ trust_remote_code=kwargs["trust_remote_code"],
117
+ revision=kwargs["revision"],
118
+ )
119
+ logger.debug(
120
+ "loading model with tokenizer config: %s, model config: %s",
121
+ tokenizer_config,
122
+ self._model_config,
123
+ )
124
+
125
+ return load(
126
+ self.model_path,
127
+ tokenizer_config=tokenizer_config,
128
+ model_config=self._model_config,
129
+ )
130
+
131
+ def load(self):
132
+ kwargs = {}
133
+ kwargs["revision"] = self._model_config.get(
134
+ "revision", self.model_spec.model_revision
135
+ )
136
+ kwargs["trust_remote_code"] = self._model_config.get("trust_remote_code")
137
+
138
+ self._model, self._tokenizer = self._load_model(**kwargs)
139
+
140
+ @classmethod
141
+ def match(
142
+ cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
143
+ ) -> bool:
144
+ if llm_spec.model_format not in ["mlx"]:
145
+ return False
146
+ if sys.platform != "darwin" or platform.processor() != "arm":
147
+ # only work for Mac M chips
148
+ return False
149
+ if "generate" not in llm_family.model_ability:
150
+ return False
151
+ return True
152
+
153
+ def _generate_stream(self, prompt: str, kwargs: MLXGenerateConfig):
154
+ import mlx.core as mx
155
+ from mlx_lm.utils import generate_step
156
+
157
+ model = self._model
158
+ model_uid = self.model_uid
159
+ tokenizer = self._tokenizer
160
+ max_tokens = kwargs["max_tokens"]
161
+ chunk_id = str(uuid.uuid4())
162
+ stop_token_ids = kwargs.get("stop_token_ids", [])
163
+ stream = kwargs.get("stream", False)
164
+ stream_options = kwargs.pop("stream_options", None)
165
+ include_usage = (
166
+ stream_options["include_usage"]
167
+ if isinstance(stream_options, dict)
168
+ else False
169
+ )
170
+
171
+ prompt_tokens = mx.array(tokenizer.encode(prompt))
172
+ input_echo_len = len(prompt_tokens)
173
+
174
+ i = 0
175
+ start = time.time()
176
+ output = ""
177
+ for (token, _), i in zip(
178
+ generate_step(
179
+ prompt_tokens,
180
+ model,
181
+ temp=kwargs["temperature"],
182
+ repetition_penalty=kwargs["repetition_penalty"],
183
+ repetition_context_size=kwargs["repetition_context_size"],
184
+ top_p=kwargs["top_p"],
185
+ logit_bias=kwargs["logit_bias"],
186
+ ),
187
+ range(max_tokens),
188
+ ):
189
+ if token == tokenizer.eos_token_id or token in stop_token_ids: # type: ignore
190
+ break
191
+
192
+ # Yield the last segment if streaming
193
+ out = tokenizer.decode(
194
+ token,
195
+ skip_special_tokens=True,
196
+ spaces_between_special_tokens=False,
197
+ clean_up_tokenization_spaces=True,
198
+ )
199
+
200
+ if stream:
201
+ # this special character is mainly for qwen
202
+ out = out.strip("�")
203
+ output = out
204
+ else:
205
+ output += out
206
+
207
+ completion_choice = CompletionChoice(
208
+ text=output, index=0, logprobs=None, finish_reason=None
209
+ )
210
+ completion_chunk = CompletionChunk(
211
+ id=chunk_id,
212
+ object="text_completion",
213
+ created=int(time.time()),
214
+ model=model_uid,
215
+ choices=[completion_choice],
216
+ )
217
+ completion_usage = CompletionUsage(
218
+ prompt_tokens=input_echo_len,
219
+ completion_tokens=i,
220
+ total_tokens=(input_echo_len + i),
221
+ )
222
+
223
+ yield completion_chunk, completion_usage
224
+
225
+ logger.info(
226
+ f"Average generation speed: {i / (time.time() - start):.2f} tokens/s."
227
+ )
228
+
229
+ if i == max_tokens - 1:
230
+ finish_reason = "length"
231
+ else:
232
+ finish_reason = "stop"
233
+
234
+ if stream:
235
+ completion_choice = CompletionChoice(
236
+ text="", index=0, logprobs=None, finish_reason=finish_reason
237
+ )
238
+ else:
239
+ completion_choice = CompletionChoice(
240
+ text=output, index=0, logprobs=None, finish_reason=finish_reason
241
+ )
242
+
243
+ completion_chunk = CompletionChunk(
244
+ id=chunk_id,
245
+ object="text_completion",
246
+ created=int(time.time()),
247
+ model=model_uid,
248
+ choices=[completion_choice],
249
+ )
250
+ completion_usage = CompletionUsage(
251
+ prompt_tokens=input_echo_len,
252
+ completion_tokens=i,
253
+ total_tokens=(input_echo_len + i),
254
+ )
255
+
256
+ yield completion_chunk, completion_usage
257
+
258
+ if include_usage:
259
+ completion_chunk = CompletionChunk(
260
+ id=chunk_id,
261
+ object="text_completion",
262
+ created=int(time.time()),
263
+ model=model_uid,
264
+ choices=[],
265
+ )
266
+ completion_usage = CompletionUsage(
267
+ prompt_tokens=input_echo_len,
268
+ completion_tokens=i,
269
+ total_tokens=(input_echo_len + i),
270
+ )
271
+ yield completion_chunk, completion_usage
272
+
273
+ def generate(
274
+ self, prompt: str, generate_config: Optional[MLXGenerateConfig] = None
275
+ ) -> Union[Completion, Iterator[CompletionChunk]]:
276
+ def generator_wrapper(
277
+ prompt: str, generate_config: MLXGenerateConfig
278
+ ) -> Iterator[CompletionChunk]:
279
+ for completion_chunk, completion_usage in self._generate_stream(
280
+ prompt,
281
+ generate_config,
282
+ ):
283
+ completion_chunk["usage"] = completion_usage
284
+ yield completion_chunk
285
+
286
+ logger.debug(
287
+ "Enter generate, prompt: %s, generate config: %s", prompt, generate_config
288
+ )
289
+
290
+ generate_config = self._sanitize_generate_config(generate_config)
291
+
292
+ assert self._model is not None
293
+ assert self._tokenizer is not None
294
+
295
+ stream = generate_config.get("stream", False)
296
+ if not stream:
297
+ for completion_chunk, completion_usage in self._generate_stream(
298
+ prompt,
299
+ generate_config,
300
+ ):
301
+ pass
302
+ completion = Completion(
303
+ id=completion_chunk["id"],
304
+ object=completion_chunk["object"],
305
+ created=completion_chunk["created"],
306
+ model=completion_chunk["model"],
307
+ choices=completion_chunk["choices"],
308
+ usage=completion_usage,
309
+ )
310
+ return completion
311
+ else:
312
+ return generator_wrapper(prompt, generate_config)
313
+
314
+
315
+ class MLXChatModel(MLXModel, ChatModelMixin):
316
+ def __init__(
317
+ self,
318
+ model_uid: str,
319
+ model_family: "LLMFamilyV1",
320
+ model_spec: "LLMSpecV1",
321
+ quantization: str,
322
+ model_path: str,
323
+ model_config: Optional[MLXModelConfig] = None,
324
+ peft_model: Optional[List[LoRA]] = None,
325
+ ):
326
+ super().__init__(
327
+ model_uid,
328
+ model_family,
329
+ model_spec,
330
+ quantization,
331
+ model_path,
332
+ model_config,
333
+ peft_model,
334
+ )
335
+
336
+ def _sanitize_generate_config(
337
+ self,
338
+ generate_config: Optional[MLXGenerateConfig],
339
+ ) -> MLXGenerateConfig:
340
+ generate_config = super()._sanitize_generate_config(generate_config)
341
+ if (
342
+ (not generate_config.get("stop"))
343
+ and self.model_family.prompt_style
344
+ and self.model_family.prompt_style.stop
345
+ ):
346
+ generate_config["stop"] = self.model_family.prompt_style.stop.copy()
347
+ if (
348
+ generate_config.get("stop_token_ids", None) is None
349
+ and self.model_family.prompt_style
350
+ and self.model_family.prompt_style.stop_token_ids
351
+ ):
352
+ generate_config[
353
+ "stop_token_ids"
354
+ ] = self.model_family.prompt_style.stop_token_ids.copy()
355
+
356
+ return generate_config
357
+
358
+ @classmethod
359
+ def match(
360
+ cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
361
+ ) -> bool:
362
+ if llm_spec.model_format not in ["mlx"]:
363
+ return False
364
+ if sys.platform != "darwin" or platform.processor() != "arm":
365
+ # only work for Mac M chips
366
+ return False
367
+ if "chat" not in llm_family.model_ability:
368
+ return False
369
+ return True
370
+
371
+ def chat(
372
+ self,
373
+ prompt: str,
374
+ system_prompt: Optional[str] = None,
375
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
376
+ generate_config: Optional[MLXGenerateConfig] = None,
377
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
378
+ tools = generate_config.pop("tools", []) if generate_config else None # type: ignore
379
+ full_prompt = self.get_full_prompt(
380
+ self.model_family, prompt, system_prompt, chat_history, tools
381
+ )
382
+
383
+ generate_config = self._sanitize_generate_config(generate_config)
384
+ # TODO(codingl2k1): qwen hacky to set stop for function call.
385
+ model_family = self.model_family.model_family or self.model_family.model_name
386
+ if tools and model_family in ["qwen-chat", "qwen1.5-chat"]:
387
+ stop = generate_config.get("stop")
388
+ if isinstance(stop, str):
389
+ generate_config["stop"] = [stop, "Observation:"]
390
+ elif isinstance(stop, Iterable):
391
+ assert not isinstance(stop, str)
392
+ generate_config["stop"] = list(stop) + ["Observation:"]
393
+ else:
394
+ generate_config["stop"] = "Observation:"
395
+
396
+ stream = generate_config.get("stream", False)
397
+ if stream:
398
+ it = self.generate(full_prompt, generate_config)
399
+ assert isinstance(it, Iterator)
400
+ return self._to_chat_completion_chunks(it)
401
+ else:
402
+ c = self.generate(full_prompt, generate_config)
403
+ assert not isinstance(c, Iterator)
404
+ if tools:
405
+ return self._tool_calls_completion(
406
+ self.model_family, self.model_uid, c, tools
407
+ )
408
+ return self._to_chat_completion(c)
@@ -29,6 +29,7 @@ from ....types import (
29
29
  PytorchGenerateConfig,
30
30
  )
31
31
  from ..llm_family import LLMFamilyV1, LLMSpecV1
32
+ from ..utils import GLM4_TOOL_CALL_FAMILY
32
33
  from .core import PytorchChatModel, PytorchModelConfig
33
34
 
34
35
 
@@ -103,7 +104,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
103
104
  if tools is None:
104
105
  return False
105
106
  tool_choice = generate_config.pop("tool_choice", "none")
106
- if self.model_family.model_name == "glm4-chat":
107
+ if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
107
108
  chat_history[:] = self.process_messages(
108
109
  chat_history, tools=tools, tool_choice=tool_choice
109
110
  )
@@ -335,14 +336,6 @@ class ChatglmPytorchChatModel(PytorchChatModel):
335
336
  ),
336
337
  )
337
338
 
338
- @staticmethod
339
- def require_attention_mask():
340
- """
341
- GLM4 needs to use attention mask and position ids during inference.
342
- Otherwise, the inference result would be not available.
343
- """
344
- return True
345
-
346
339
  def prepare_sanitize_generate_config(self, req: InferenceRequest):
347
340
  """
348
341
  Set temperature and top_p to 0.8 by default
@@ -23,6 +23,7 @@ import requests
23
23
  import torch
24
24
  from PIL import Image
25
25
 
26
+ from ....core.scheduler import InferenceRequest
26
27
  from ....model.utils import select_device
27
28
  from ....types import (
28
29
  ChatCompletion,
@@ -35,11 +36,30 @@ from ....types import (
35
36
  )
36
37
  from ..llm_family import LLMFamilyV1, LLMSpecV1
37
38
  from .core import PytorchChatModel, PytorchGenerateConfig
39
+ from .utils import get_max_src_len
38
40
 
39
41
  logger = logging.getLogger(__name__)
40
42
 
41
- IMAGENET_MEAN = (0.485, 0.456, 0.406)
42
- IMAGENET_STD = (0.229, 0.224, 0.225)
43
+
44
+ LANGUAGE_TOKEN_TYPE = 0
45
+ VISION_TOKEN_TYPE = 1
46
+
47
+
48
+ def recur_move_to(item, tgt, criterion_func):
49
+ """
50
+ This function is copied from https://github.com/THUDM/CogVLM2/blob/main/basic_demo/cli_demo_batch_inference.py
51
+ """
52
+ if criterion_func(item):
53
+ device_copy = item.to(tgt)
54
+ return device_copy
55
+ elif isinstance(item, list):
56
+ return [recur_move_to(v, tgt, criterion_func) for v in item]
57
+ elif isinstance(item, tuple):
58
+ return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
59
+ elif isinstance(item, dict):
60
+ return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
61
+ else:
62
+ return item
43
63
 
44
64
 
45
65
  class CogVLM2Model(PytorchChatModel):
@@ -176,11 +196,33 @@ class CogVLM2Model(PytorchChatModel):
176
196
  content["image_url"]["url"]
177
197
  )
178
198
  assistant = chat_history[i + 1]["content"]
179
- query = query + f" USER: {user} ASSISTANT:"
180
- history.append((query, assistant))
181
- query = query + f" {assistant}"
199
+ history.append((user, assistant))
200
+ query = assistant # type: ignore
182
201
  return query, history, [pixel_values]
183
202
 
203
+ def get_query_and_history(
204
+ self,
205
+ prompt: Union[str, List[Dict]],
206
+ system_prompt: Optional[str] = None,
207
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
208
+ ):
209
+ content, image = self._message_content_to_cogvlm2(prompt)
210
+
211
+ history = []
212
+ history_image = None
213
+ if chat_history:
214
+ query, history, history_image = self._history_content_to_cogvlm2(
215
+ system_prompt, chat_history # type: ignore
216
+ )
217
+
218
+ if image and history_image:
219
+ history = []
220
+ query = content
221
+ else:
222
+ image = image if image else history_image
223
+ query = content
224
+ return query, image, history
225
+
184
226
  def chat(
185
227
  self,
186
228
  prompt: Union[str, List[Dict]],
@@ -198,22 +240,9 @@ class CogVLM2Model(PytorchChatModel):
198
240
  else 512,
199
241
  }
200
242
 
201
- content, image = self._message_content_to_cogvlm2(prompt)
202
-
203
- history = []
204
- query = ""
205
- history_image = None
206
- if chat_history:
207
- query, history, history_image = self._history_content_to_cogvlm2(
208
- system_prompt, chat_history
209
- )
210
-
211
- if image and history_image:
212
- history = []
213
- query = system_prompt + f" USER: {content} ASSISTANT:"
214
- else:
215
- image = image if image else history_image
216
- query = query + f" USER: {content} ASSISTANT:"
243
+ query, image, history = self.get_query_and_history(
244
+ prompt, system_prompt=system_prompt, chat_history=chat_history
245
+ )
217
246
 
218
247
  input_by_model = self._model.build_conversation_input_ids(
219
248
  self._tokenizer,
@@ -319,3 +348,159 @@ class CogVLM2Model(PytorchChatModel):
319
348
  ),
320
349
  )
321
350
  yield chunk
351
+
352
+ @staticmethod
353
+ def build_position_ids(x, attention_mask=None):
354
+ """
355
+ Copied from https://huggingface.co/THUDM/cogvlm2-llama3-chinese-chat-19B-int4/blob/main/modeling_cogvlm.py
356
+ """
357
+ # Fix: 参考官方开源代码
358
+ if attention_mask is not None:
359
+ tmp = x.clone()
360
+ tmp[~(attention_mask.bool())] = -1
361
+ else:
362
+ tmp = x.clone()
363
+ # image boi eoi token as LANGUAGE_TOKEN_TYPE
364
+ is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
365
+ is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (
366
+ tmp[:, :-1] == LANGUAGE_TOKEN_TYPE
367
+ )
368
+ is_boi_eoi[:, 0] |= tmp[:, 0] == VISION_TOKEN_TYPE
369
+ is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (
370
+ tmp[:, 1:] == LANGUAGE_TOKEN_TYPE
371
+ )
372
+ is_boi_eoi[:, -1] |= tmp[:, -1] == VISION_TOKEN_TYPE
373
+ tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
374
+ # final position ids
375
+ y = torch.zeros_like(x, dtype=torch.long)
376
+ y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
377
+ (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
378
+ )
379
+ y = y.cumsum(dim=-1)
380
+ return y
381
+
382
+ def get_dtype(self):
383
+ return self._torch_type
384
+
385
+ def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
386
+ query, image, history = self.get_query_and_history(
387
+ prompt, system_prompt=system_prompt, chat_history=chat_history
388
+ )
389
+
390
+ input_by_model: dict = self._model.build_conversation_input_ids(
391
+ self._tokenizer,
392
+ query=query,
393
+ history=history,
394
+ images=image,
395
+ template_version="chat",
396
+ )
397
+ return {
398
+ "input_ids": input_by_model["input_ids"], # seq_len
399
+ "token_type_ids": input_by_model["token_type_ids"], # seq_len
400
+ "attention_mask": input_by_model["attention_mask"], # seq_len
401
+ "images": input_by_model["images"],
402
+ }
403
+
404
+ def prepare_sanitize_generate_config(self, req: InferenceRequest):
405
+ """
406
+ See https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B/blob/main/generation_config.json
407
+ """
408
+ raw_config = req.inference_kwargs.get("raw_params", {})
409
+ temperature = raw_config.get("temperature", None)
410
+ if temperature is None:
411
+ raw_config["temperature"] = 0.6
412
+ top_p = raw_config.get("top_p", None)
413
+ if top_p is None:
414
+ raw_config["top_p"] = 0.9
415
+ return raw_config
416
+
417
+ def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
418
+ context_len = self.get_context_len()
419
+ assert isinstance(prompts[0], dict)
420
+ images = []
421
+ max_length = float("-inf")
422
+ for i, feature in enumerate(prompts):
423
+ req = req_list[i]
424
+ if "images" in feature:
425
+ images.append(feature.pop("images", None))
426
+ max_src_len = get_max_src_len(context_len, req)
427
+ input_ids = feature["input_ids"][-max_src_len:]
428
+ req.prompt_tokens = input_ids.tolist()
429
+ feature["input_ids"] = input_ids
430
+ feature["token_type_ids"] = feature["token_type_ids"][-max_src_len:]
431
+ feature["attention_mask"] = feature["attention_mask"][-max_src_len:]
432
+ req.extra_kwargs["attention_mask_seq_len"] = feature[
433
+ "attention_mask"
434
+ ].shape[0]
435
+ max_length = max(len(input_ids), max_length)
436
+
437
+ def pad_to_max_length_internal(feature, max_len, idx):
438
+ padding_length = max_len - len(feature["input_ids"])
439
+ req_list[idx].padding_len = padding_length
440
+ feature["input_ids"] = torch.cat(
441
+ [torch.full((padding_length,), 0), feature["input_ids"]]
442
+ )
443
+ feature["token_type_ids"] = torch.cat(
444
+ [
445
+ torch.zeros(padding_length, dtype=torch.long),
446
+ feature["token_type_ids"],
447
+ ]
448
+ )
449
+ feature["attention_mask"] = torch.cat(
450
+ [
451
+ torch.zeros(padding_length, dtype=torch.long),
452
+ feature["attention_mask"],
453
+ ]
454
+ )
455
+ return feature
456
+
457
+ features = [
458
+ pad_to_max_length_internal(feature, max_length, i)
459
+ for i, feature in enumerate(prompts)
460
+ ]
461
+ batch = {
462
+ key: torch.stack([feature[key] for feature in features])
463
+ for key in features[0].keys()
464
+ }
465
+
466
+ position_ids = self.build_position_ids(batch["token_type_ids"])
467
+ batch["position_ids"] = position_ids
468
+
469
+ for i in range(len(prompts)):
470
+ req = req_list[i]
471
+ req.extra_kwargs["max_position_id"] = position_ids[i : i + 1, -1].item()
472
+
473
+ if images:
474
+ batch["images"] = images
475
+
476
+ batch = recur_move_to(
477
+ batch, self._device, lambda x: isinstance(x, torch.Tensor)
478
+ )
479
+ dtype = self.get_dtype()
480
+ if dtype:
481
+ batch = recur_move_to(
482
+ batch,
483
+ dtype,
484
+ lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x),
485
+ )
486
+ return batch
487
+
488
+ def build_decode_token_type_ids(
489
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
490
+ ):
491
+ token_type_ids = torch.full(
492
+ (batch_size, 1), fill_value=1, dtype=torch.long, device=self._device
493
+ )
494
+ return token_type_ids
495
+
496
+ def build_decode_position_ids(
497
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
498
+ ):
499
+ tmp = []
500
+ for r in reqs:
501
+ r.extra_kwargs["max_position_id"] += 1
502
+ tmp.append(r.extra_kwargs["max_position_id"])
503
+ position_ids = torch.as_tensor(
504
+ tmp, device=self._device, dtype=torch.long
505
+ ).unsqueeze(1)
506
+ return position_ids