xinference 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (87) hide show
  1. xinference/_version.py +3 -3
  2. xinference/client/restful/restful_client.py +1 -1
  3. xinference/conftest.py +0 -7
  4. xinference/core/media_interface.py +9 -8
  5. xinference/core/model.py +13 -6
  6. xinference/core/scheduler.py +1 -10
  7. xinference/core/worker.py +0 -10
  8. xinference/model/audio/model_spec.json +53 -1
  9. xinference/model/audio/model_spec_modelscope.json +57 -1
  10. xinference/model/embedding/core.py +19 -11
  11. xinference/model/image/model_spec.json +10 -1
  12. xinference/model/image/model_spec_modelscope.json +20 -0
  13. xinference/model/llm/__init__.py +6 -54
  14. xinference/model/llm/core.py +19 -5
  15. xinference/model/llm/llama_cpp/core.py +59 -3
  16. xinference/model/llm/llama_cpp/memory.py +455 -0
  17. xinference/model/llm/llm_family.json +185 -397
  18. xinference/model/llm/llm_family.py +88 -16
  19. xinference/model/llm/llm_family_modelscope.json +199 -421
  20. xinference/model/llm/llm_family_openmind_hub.json +0 -34
  21. xinference/model/llm/sglang/core.py +4 -0
  22. xinference/model/llm/transformers/__init__.py +27 -6
  23. xinference/model/llm/transformers/chatglm.py +4 -2
  24. xinference/model/llm/transformers/core.py +49 -28
  25. xinference/model/llm/transformers/deepseek_v2.py +6 -49
  26. xinference/model/llm/transformers/gemma3.py +119 -164
  27. xinference/{thirdparty/omnilmm/train → model/llm/transformers/multimodal}/__init__.py +1 -1
  28. xinference/model/llm/transformers/{cogagent.py → multimodal/cogagent.py} +58 -95
  29. xinference/model/llm/transformers/multimodal/core.py +205 -0
  30. xinference/model/llm/transformers/{deepseek_vl2.py → multimodal/deepseek_vl2.py} +59 -120
  31. xinference/model/llm/transformers/multimodal/gemma3.py +117 -0
  32. xinference/model/llm/transformers/{glm4v.py → multimodal/glm4v.py} +57 -93
  33. xinference/model/llm/transformers/multimodal/intern_vl.py +412 -0
  34. xinference/model/llm/transformers/{minicpmv26.py → multimodal/minicpmv26.py} +55 -102
  35. xinference/model/llm/transformers/{ovis2.py → multimodal/ovis2.py} +114 -175
  36. xinference/model/llm/transformers/{qwen-omni.py → multimodal/qwen-omni.py} +82 -167
  37. xinference/model/llm/transformers/multimodal/qwen2_audio.py +131 -0
  38. xinference/model/llm/transformers/{qwen2_vl.py → multimodal/qwen2_vl.py} +224 -256
  39. xinference/model/llm/transformers/opt.py +4 -2
  40. xinference/model/llm/transformers/utils.py +6 -37
  41. xinference/model/llm/vllm/core.py +4 -0
  42. xinference/model/rerank/core.py +7 -1
  43. xinference/model/rerank/utils.py +17 -0
  44. xinference/web/ui/build/asset-manifest.json +3 -3
  45. xinference/web/ui/build/index.html +1 -1
  46. xinference/web/ui/build/static/js/main.ddf9eaee.js +3 -0
  47. xinference/web/ui/build/static/js/main.ddf9eaee.js.map +1 -0
  48. xinference/web/ui/node_modules/.cache/babel-loader/12e637ed5fa9ca6491b03892b6949c03afd4960fe36ac25744488e7e1982aa19.json +1 -0
  49. xinference/web/ui/node_modules/.cache/babel-loader/567e49df411efb24425d289bb484758cb57067ca54f8b5c67fe4505f698deb96.json +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/77ac2665a784e99501ae95d32ef5937837a0439a47e965d291b38e99cb619f5b.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/d4ed4e82bfe69915999ec83f5feaa4301c75ecc6bdf1c78f2d03e4671ecbefc8.json +1 -0
  52. xinference/web/ui/src/locales/en.json +3 -1
  53. xinference/web/ui/src/locales/zh.json +3 -1
  54. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/METADATA +16 -14
  55. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/RECORD +60 -76
  56. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/WHEEL +1 -1
  57. xinference/model/llm/transformers/cogvlm2.py +0 -442
  58. xinference/model/llm/transformers/cogvlm2_video.py +0 -333
  59. xinference/model/llm/transformers/deepseek_vl.py +0 -280
  60. xinference/model/llm/transformers/glm_edge_v.py +0 -213
  61. xinference/model/llm/transformers/intern_vl.py +0 -526
  62. xinference/model/llm/transformers/internlm2.py +0 -94
  63. xinference/model/llm/transformers/minicpmv25.py +0 -193
  64. xinference/model/llm/transformers/omnilmm.py +0 -132
  65. xinference/model/llm/transformers/qwen2_audio.py +0 -179
  66. xinference/model/llm/transformers/qwen_vl.py +0 -360
  67. xinference/thirdparty/omnilmm/LICENSE +0 -201
  68. xinference/thirdparty/omnilmm/__init__.py +0 -0
  69. xinference/thirdparty/omnilmm/chat.py +0 -218
  70. xinference/thirdparty/omnilmm/constants.py +0 -4
  71. xinference/thirdparty/omnilmm/conversation.py +0 -332
  72. xinference/thirdparty/omnilmm/model/__init__.py +0 -1
  73. xinference/thirdparty/omnilmm/model/omnilmm.py +0 -595
  74. xinference/thirdparty/omnilmm/model/resampler.py +0 -166
  75. xinference/thirdparty/omnilmm/model/utils.py +0 -578
  76. xinference/thirdparty/omnilmm/train/train_utils.py +0 -150
  77. xinference/thirdparty/omnilmm/utils.py +0 -134
  78. xinference/web/ui/build/static/js/main.ae579a97.js +0 -3
  79. xinference/web/ui/build/static/js/main.ae579a97.js.map +0 -1
  80. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +0 -1
  81. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +0 -1
  82. xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +0 -1
  83. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +0 -1
  84. /xinference/web/ui/build/static/js/{main.ae579a97.js.LICENSE.txt → main.ddf9eaee.js.LICENSE.txt} +0 -0
  85. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/entry_points.txt +0 -0
  86. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/licenses/LICENSE +0 -0
  87. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2022-2023 XProbe Inc.
1
+ # Copyright 2022-2025 XProbe Inc.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,47 +13,36 @@
13
13
  # limitations under the License.
14
14
  import logging
15
15
  import re
16
- import uuid
17
16
  from concurrent.futures import ThreadPoolExecutor
18
- from typing import Dict, Iterator, List, Literal, Optional, Union
17
+ from threading import Thread
18
+ from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
19
19
 
20
20
  import torch
21
21
 
22
- from ....model.utils import select_device
23
- from ....types import (
24
- ChatCompletion,
25
- ChatCompletionChunk,
26
- CogagentGenerateConfig,
27
- CompletionChunk,
28
- )
29
- from ..llm_family import LLMFamilyV1, LLMSpecV1
30
- from ..utils import (
31
- _decode_image,
32
- generate_chat_completion,
33
- generate_completion_chunk,
34
- parse_messages,
35
- )
36
- from .core import PytorchChatModel
37
- from .utils import cache_clean
22
+ from .....model.utils import select_device
23
+ from ...llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
24
+ from ...utils import _decode_image, parse_messages
25
+ from ..core import register_non_default_model
26
+ from .core import PytorchMultiModalModel
38
27
 
39
28
  logger = logging.getLogger(__name__)
40
29
 
41
30
 
42
- class CogAgentChatModel(PytorchChatModel):
43
- def __init__(self, *args, **kwargs):
44
- super().__init__(*args, **kwargs)
45
- self._torch_type = None
46
- self._device = None
47
- self._tokenizer = None
48
- self._model = None
49
- self._platform: Literal["Mac", "WIN", "Mobile"] | None = "Mac" # type: ignore
50
- self._format: Literal[ # type: ignore
51
- "(Answer in Action-Operation-Sensitive format.)",
52
- "(Answer in Status-Plan-Action-Operation format.)",
53
- "(Answer in Status-Action-Operation-Sensitive format.)",
54
- "(Answer in Status-Action-Operation format.)",
55
- "(Answer in Action-Operation format.)",
56
- ] | None = "(Answer in Action-Operation-Sensitive format.)"
31
+ @register_transformer
32
+ @register_non_default_model("cogagent")
33
+ class CogAgentChatModel(PytorchMultiModalModel):
34
+ def __init__(self, *args, **kws):
35
+ super().__init__(*args, **kws)
36
+ self._platform: Optional[Literal["Mac", "WIN", "Mobile"]] = "Mac"
37
+ self._format: Optional[
38
+ Literal[
39
+ "(Answer in Action-Operation-Sensitive format.)",
40
+ "(Answer in Status-Plan-Action-Operation format.)",
41
+ "(Answer in Status-Action-Operation-Sensitive format.)",
42
+ "(Answer in Status-Action-Operation format.)",
43
+ "(Answer in Action-Operation format.)",
44
+ ]
45
+ ] = "(Answer in Action-Operation-Sensitive format.)"
57
46
 
58
47
  @classmethod
59
48
  def match_json(
@@ -64,17 +53,21 @@ class CogAgentChatModel(PytorchChatModel):
64
53
  return True
65
54
  return False
66
55
 
67
- def load(self):
68
- from transformers import AutoModelForCausalLM, AutoTokenizer
69
-
56
+ def decide_device(self):
70
57
  device = self._pytorch_model_config.get("device", "auto")
71
58
  self._device = select_device(device)
72
59
 
60
+ def load_processor(self):
61
+ from transformers import AutoTokenizer
62
+
73
63
  self._tokenizer = AutoTokenizer.from_pretrained(
74
64
  self.model_path, trust_remote_code=True
75
65
  )
76
- kwargs = self.apply_bnb_quantization()
77
66
 
67
+ def load_multimodal_model(self):
68
+ from transformers import AutoModelForCausalLM
69
+
70
+ kwargs = self.apply_bnb_quantization()
78
71
  self._model = AutoModelForCausalLM.from_pretrained(
79
72
  self.model_path,
80
73
  torch_dtype=torch.bfloat16,
@@ -153,7 +146,7 @@ class CogAgentChatModel(PytorchChatModel):
153
146
 
154
147
  return history_step, history_action
155
148
 
156
- def get_query_and_history(
149
+ def _get_query_and_history(
157
150
  self,
158
151
  prompt: Union[str, List[Dict]],
159
152
  chat_history: Optional[List[Dict]] = None,
@@ -181,26 +174,14 @@ class CogAgentChatModel(PytorchChatModel):
181
174
  logger.info(f"query:{query}")
182
175
  return query, image
183
176
 
184
- @cache_clean
185
- def chat(
177
+ def build_inputs_from_messages(
186
178
  self,
187
179
  messages: List[Dict],
188
- generate_config: Optional[CogagentGenerateConfig] = None,
189
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
190
- if generate_config is not None:
191
- self._platform = generate_config.pop("platform", self._platform)
192
- self._format = generate_config.pop("format", self._format)
193
-
194
- sanitize_generate_config = self._sanitize_generate_config(generate_config)
195
- stream = sanitize_generate_config.get("stream")
196
- sanitized_config = {
197
- "max_length": sanitize_generate_config.get("max_tokens", 512),
198
- "top_k": sanitize_generate_config.get("top_k", 1),
199
- "do_sample": True,
200
- }
180
+ generate_config: Dict,
181
+ ):
201
182
  prompt, _, chat_history = parse_messages(messages)
202
183
 
203
- query, image = self.get_query_and_history(prompt, chat_history)
184
+ query, image = self._get_query_and_history(prompt, chat_history)
204
185
 
205
186
  full_context_kwargs = {
206
187
  "return_tensors": "pt",
@@ -218,53 +199,35 @@ class CogAgentChatModel(PytorchChatModel):
218
199
  **full_context_kwargs,
219
200
  )
220
201
  inputs.to(self._model.device)
202
+ return inputs
221
203
 
222
- if stream:
223
- it = self._streaming_chat_response(inputs, sanitized_config)
224
- return self._to_chat_completion_chunks(it)
225
- else:
226
- # Generate response
227
- with torch.no_grad():
228
- outputs = self._model.generate(**inputs, **sanitized_config)
229
- outputs = outputs[:, inputs["input_ids"].shape[1] :]
230
- response = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
231
-
232
- return generate_chat_completion(self.model_uid, response)
233
-
234
- def _streaming_chat_response(
235
- self, inputs: Dict, config: Dict
236
- ) -> Iterator[CompletionChunk]:
237
- from threading import Thread
204
+ def build_generate_kwargs(
205
+ self,
206
+ generate_config: Dict,
207
+ ) -> Dict[str, Any]:
208
+ generate_config = {} if generate_config is None else generate_config
209
+ self._platform = generate_config.pop("platform", self._platform)
210
+ self._format = generate_config.pop("format", self._format)
211
+ return {
212
+ "max_length": generate_config.get("max_tokens", 512),
213
+ "top_k": generate_config.get("top_k", 1),
214
+ "do_sample": True,
215
+ }
238
216
 
217
+ def build_streaming_iter(
218
+ self,
219
+ messages: List[Dict],
220
+ generate_config: Dict,
221
+ ) -> Tuple[Iterator, int]:
239
222
  from transformers import TextIteratorStreamer
240
223
 
224
+ config = self.build_generate_kwargs(generate_config)
225
+ inputs = self.build_inputs_from_messages(messages, generate_config)
241
226
  streamer = TextIteratorStreamer(
242
227
  self._tokenizer, skip_prompt=True, skip_special_tokens=True
243
228
  )
244
- generation_kwargs = {**inputs, **config}
229
+ generation_kwargs = {**inputs, **config, "streamer": streamer}
245
230
 
246
231
  thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
247
232
  thread.start()
248
-
249
- completion_id = str(uuid.uuid1())
250
- for new_text in streamer:
251
- yield generate_completion_chunk(
252
- chunk_text=new_text,
253
- finish_reason=None,
254
- chunk_id=completion_id,
255
- model_uid=self.model_uid,
256
- prompt_tokens=-1,
257
- completion_tokens=-1,
258
- total_tokens=-1,
259
- )
260
- yield generate_completion_chunk(
261
- chunk_text=None,
262
- finish_reason="stop",
263
- chunk_id=completion_id,
264
- model_uid=self.model_uid,
265
- prompt_tokens=-1,
266
- completion_tokens=-1,
267
- total_tokens=-1,
268
- has_choice=True,
269
- has_content=False,
270
- )
233
+ return streamer, len(inputs.input_ids[0])
@@ -0,0 +1,205 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import uuid
15
+ from abc import abstractmethod
16
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
17
+
18
+ from .....types import (
19
+ ChatCompletion,
20
+ ChatCompletionChunk,
21
+ CompletionChunk,
22
+ PytorchGenerateConfig,
23
+ )
24
+ from ...utils import generate_chat_completion, generate_completion_chunk
25
+ from ..core import PytorchChatModel
26
+ from ..utils import cache_clean
27
+
28
+
29
+ class PytorchMultiModalModel(PytorchChatModel):
30
+ def __init__(self, *args, **kwargs):
31
+ super().__init__(*args, **kwargs)
32
+ self._tokenizer = None
33
+ self._device = None
34
+ self._processor = None
35
+ self._model = None
36
+
37
+ @abstractmethod
38
+ def decide_device(self):
39
+ """
40
+ Update self._device
41
+ """
42
+ pass
43
+
44
+ @abstractmethod
45
+ def load_processor(self):
46
+ """
47
+ Load self._processor and self._tokenizer
48
+ """
49
+ pass
50
+
51
+ @abstractmethod
52
+ def load_multimodal_model(self):
53
+ """
54
+ Load self._model
55
+ """
56
+ pass
57
+
58
+ def load(self):
59
+ self.decide_device()
60
+ self.load_processor()
61
+ self.load_multimodal_model()
62
+
63
+ @abstractmethod
64
+ def build_inputs_from_messages(
65
+ self,
66
+ messages: List[Dict],
67
+ generate_config: Dict,
68
+ ):
69
+ """
70
+ Convert from input OpenAI-formatted messages to
71
+ actual parameters needed for inference,
72
+ e.g. input_ids, attention_masks, etc.
73
+ """
74
+ pass
75
+
76
+ @abstractmethod
77
+ def build_generate_kwargs(
78
+ self,
79
+ generate_config: Dict,
80
+ ) -> Dict[str, Any]:
81
+ """
82
+ Hyperparameters needed for generation,
83
+ e.g. temperature, max_new_tokens, etc.
84
+ """
85
+ pass
86
+
87
+ @abstractmethod
88
+ def build_streaming_iter(
89
+ self,
90
+ messages: List[Dict],
91
+ generate_config: Dict,
92
+ ) -> Tuple[Iterator, int]:
93
+ """
94
+ Return the iterator needed for streaming inference and the length of prompt token for statisticians.
95
+ The length of prompt token usually comes from the input_ids.
96
+ In this interface you need to call the `build_inputs_from_messages` and `build_generate_kwargs`.
97
+ """
98
+ pass
99
+
100
+ def get_stop_strs(self) -> List[str]:
101
+ return []
102
+
103
+ def check_conditions(self, new_text: str) -> Tuple[str, bool]:
104
+ stop_strs = self.get_stop_strs()
105
+ for ss in stop_strs:
106
+ if new_text.endswith(ss):
107
+ new_text = new_text[: -len(ss)]
108
+ break
109
+ return new_text, False
110
+
111
+ def generate_non_streaming(
112
+ self,
113
+ messages: List[Dict],
114
+ generate_config: Optional[PytorchGenerateConfig] = None,
115
+ ) -> ChatCompletion:
116
+ generate_config = generate_config if generate_config else {} # type: ignore
117
+ streamer, prompt_tokens = self.build_streaming_iter(messages, generate_config) # type: ignore
118
+ completion_tokens, total_tokens = 0, 0
119
+ res = ""
120
+ for i, new_text in enumerate(streamer):
121
+ new_text, should_stop = self.check_conditions(new_text)
122
+ if should_stop:
123
+ break
124
+ completion_tokens = i
125
+ total_tokens = prompt_tokens + completion_tokens
126
+ res += new_text
127
+ return generate_chat_completion(
128
+ self.model_uid,
129
+ res,
130
+ prompt_tokens=prompt_tokens,
131
+ completion_tokens=completion_tokens if prompt_tokens != -1 else -1,
132
+ total_tokens=total_tokens if prompt_tokens != -1 else -1,
133
+ )
134
+
135
+ def generate_streaming(
136
+ self,
137
+ messages: List[Dict],
138
+ generate_config: Optional[PytorchGenerateConfig] = None,
139
+ ) -> Iterator[CompletionChunk]:
140
+ generate_config = generate_config if generate_config else {} # type: ignore
141
+ streamer, prompt_tokens = self.build_streaming_iter(messages, generate_config) # type: ignore
142
+ stream_options = generate_config.pop("stream_options", None)
143
+ include_usage = (
144
+ stream_options["include_usage"]
145
+ if isinstance(stream_options, dict)
146
+ else False
147
+ )
148
+
149
+ completion_id = str(uuid.uuid1())
150
+ completion_tokens, total_tokens = 0, 0
151
+ for i, new_text in enumerate(streamer):
152
+ new_text, should_stop = self.check_conditions(new_text)
153
+ if should_stop:
154
+ break
155
+ completion_tokens = i
156
+ total_tokens = prompt_tokens + completion_tokens
157
+ yield generate_completion_chunk(
158
+ chunk_text=new_text,
159
+ finish_reason=None,
160
+ chunk_id=completion_id,
161
+ model_uid=self.model_uid,
162
+ prompt_tokens=prompt_tokens,
163
+ completion_tokens=completion_tokens if prompt_tokens != -1 else -1,
164
+ total_tokens=total_tokens if prompt_tokens != -1 else -1,
165
+ has_choice=True,
166
+ has_content=True,
167
+ )
168
+ yield generate_completion_chunk(
169
+ chunk_text=None,
170
+ finish_reason="stop",
171
+ chunk_id=completion_id,
172
+ model_uid=self.model_uid,
173
+ prompt_tokens=prompt_tokens,
174
+ completion_tokens=completion_tokens if prompt_tokens != -1 else -1,
175
+ total_tokens=total_tokens if prompt_tokens != -1 else -1,
176
+ has_choice=True,
177
+ has_content=False,
178
+ )
179
+ if include_usage:
180
+ yield generate_completion_chunk(
181
+ chunk_text=None,
182
+ finish_reason=None,
183
+ chunk_id=completion_id,
184
+ model_uid=self.model_uid,
185
+ prompt_tokens=prompt_tokens,
186
+ completion_tokens=completion_tokens if prompt_tokens != -1 else -1,
187
+ total_tokens=total_tokens if prompt_tokens != -1 else -1,
188
+ has_choice=False,
189
+ has_content=False,
190
+ )
191
+
192
+ @cache_clean
193
+ def chat(
194
+ self,
195
+ messages: List[Dict],
196
+ generate_config: Optional[PytorchGenerateConfig] = None,
197
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
198
+ stream = generate_config.get("stream", False) if generate_config else False
199
+ return (
200
+ self._to_chat_completion_chunks(
201
+ self.generate_streaming(messages, generate_config)
202
+ )
203
+ if stream
204
+ else self.generate_non_streaming(messages, generate_config)
205
+ )
@@ -1,4 +1,4 @@
1
- # Copyright 2022-2023 XProbe Inc.
1
+ # Copyright 2022-2025 XProbe Inc.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,32 +13,28 @@
13
13
  # limitations under the License.
14
14
  import base64
15
15
  import logging
16
- import os.path
16
+ import os
17
17
  import tempfile
18
- import uuid
19
18
  from concurrent.futures import ThreadPoolExecutor
20
19
  from io import BytesIO
21
- from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
20
+ from typing import Any, Dict, Iterator, List, Tuple
22
21
 
23
22
  import requests
24
23
  import torch
25
24
 
26
- from ....model.utils import select_device
27
- from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
28
- from ..llm_family import LLMFamilyV1, LLMSpecV1
29
- from ..utils import generate_chat_completion, generate_completion_chunk
30
- from .core import PytorchChatModel, PytorchGenerateConfig
31
- from .utils import cache_clean
25
+ from .....model.utils import select_device
26
+ from ...llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
27
+ from ..core import register_non_default_model
28
+ from .core import PytorchMultiModalModel
32
29
 
33
30
  logger = logging.getLogger(__name__)
34
31
 
35
32
 
36
- class DeepSeekVL2ChatModel(PytorchChatModel):
33
+ @register_transformer
34
+ @register_non_default_model("deepseek-vl2")
35
+ class DeepSeekVL2ChatModel(PytorchMultiModalModel):
37
36
  def __init__(self, *args, **kwargs):
38
37
  super().__init__(*args, **kwargs)
39
- self._tokenizer = None
40
- self._model = None
41
- self._vl_chat_processor = None
42
38
  self._type = None
43
39
 
44
40
  @classmethod
@@ -50,25 +46,26 @@ class DeepSeekVL2ChatModel(PytorchChatModel):
50
46
  return True
51
47
  return False
52
48
 
53
- def load(self):
54
- from transformers import AutoModelForCausalLM
55
-
56
- from ....thirdparty.deepseek_vl2.models import (
57
- DeepseekVLV2ForCausalLM,
58
- DeepseekVLV2Processor,
59
- )
60
-
49
+ def decide_device(self):
61
50
  self._device = self._pytorch_model_config.get("device", "auto")
62
51
  self._device = select_device(self._device)
63
52
  self._type = torch.bfloat16
64
- kwargs = self.apply_bnb_quantization()
53
+
54
+ def load_processor(self):
55
+ from .....thirdparty.deepseek_vl2.models import DeepseekVLV2Processor
65
56
 
66
57
  # specify the path to the model
67
- self._vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained( # type: ignore
58
+ self._processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained( # type: ignore
68
59
  self.model_path
69
60
  )
70
- self._tokenizer = self._vl_chat_processor.tokenizer
61
+ self._tokenizer = self._processor.tokenizer
62
+
63
+ def load_multimodal_model(self):
64
+ from transformers import AutoModelForCausalLM
65
+
66
+ from .....thirdparty.deepseek_vl2.models import DeepseekVLV2ForCausalLM
71
67
 
68
+ kwargs = self.apply_bnb_quantization()
72
69
  vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained( # type: ignore
73
70
  self.model_path,
74
71
  trust_remote_code=True,
@@ -138,29 +135,24 @@ class DeepSeekVL2ChatModel(PytorchChatModel):
138
135
  elif c_type == "text":
139
136
  new_content.append(c["text"])
140
137
  if images:
141
- new_content.insert(0, "<image_placeholder>")
142
138
  images = _download(images)
143
139
  return "".join(new_content), images
144
140
  return content, []
145
141
 
146
- @cache_clean
147
- def chat(
148
- self,
149
- messages: List[Dict],
150
- generate_config: Optional[PytorchGenerateConfig] = None,
151
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
152
- if not generate_config:
153
- generate_config = {}
142
+ def get_stop_strs(self) -> List[str]:
143
+ conversation = self._processor.new_chat_template()
144
+ stop_str = conversation.sep2
145
+ return [stop_str]
154
146
 
155
- stream = generate_config.get("stream", False)
156
- stream_options = generate_config.pop("stream_options", None)
157
- include_usage = (
158
- stream_options["include_usage"]
159
- if isinstance(stream_options, dict)
160
- else False
161
- )
147
+ def build_generate_kwargs(self, generate_config: Dict):
148
+ max_new_tokens = generate_config.get("max_tokens", 512)
149
+ return {"max_new_tokens": max_new_tokens}
162
150
 
163
- prompt = ""
151
+ def build_inputs_from_messages(
152
+ self,
153
+ messages: List[Dict],
154
+ generate_config: Dict,
155
+ ):
164
156
  deepseek_messages = []
165
157
  for i, message in enumerate(messages):
166
158
  role = message["role"]
@@ -183,8 +175,6 @@ class DeepSeekVL2ChatModel(PytorchChatModel):
183
175
  msg["images"] = images
184
176
  deepseek_messages.append(msg)
185
177
  deepseek_messages.append({"role": "<|Assistant|>", "content": ""})
186
- if i == len(messages) - 1:
187
- prompt = "<image>\n<|ref|>" + content + "<|/ref|>"
188
178
  elif role == "assistant":
189
179
  deepseek_messages.append({"role": "<|Assistant|>", "content": content})
190
180
  else:
@@ -192,11 +182,11 @@ class DeepSeekVL2ChatModel(PytorchChatModel):
192
182
  f"Unexpected message in messages: role: {role}, message: {message}"
193
183
  )
194
184
 
195
- from ....thirdparty.deepseek_vl2.utils.io import load_pil_images
185
+ from .....thirdparty.deepseek_vl2.utils.io import load_pil_images
196
186
 
197
187
  # load images and prepare for inputs
198
188
  pil_images = load_pil_images(deepseek_messages)
199
- prepare_inputs = self._vl_chat_processor(
189
+ prepare_inputs = self._processor(
200
190
  conversations=deepseek_messages,
201
191
  images=pil_images,
202
192
  force_batchify=True,
@@ -205,88 +195,37 @@ class DeepSeekVL2ChatModel(PytorchChatModel):
205
195
 
206
196
  # run image encoder to get the image embeddings
207
197
  inputs_embeds = self._model.prepare_inputs_embeds(**prepare_inputs)
208
-
209
- max_new_tokens = generate_config.get("max_tokens", 512)
210
- conversation = self._vl_chat_processor.new_chat_template()
211
- stop_str = conversation.sep2
212
-
213
- streamer = self._model.language.generate(
198
+ return dict(
199
+ input_ids=prepare_inputs.input_ids,
214
200
  inputs_embeds=inputs_embeds,
215
201
  attention_mask=prepare_inputs.attention_mask,
216
202
  pad_token_id=self._tokenizer.eos_token_id,
217
203
  bos_token_id=self._tokenizer.bos_token_id,
218
204
  eos_token_id=self._tokenizer.eos_token_id,
219
- max_new_tokens=max_new_tokens,
205
+ )
206
+
207
+ def build_streaming_iter(
208
+ self,
209
+ messages: List[Dict],
210
+ generate_config: Dict,
211
+ ) -> Tuple[Iterator, int]:
212
+ _inputs = self.build_inputs_from_messages(messages, generate_config)
213
+ configs = self.build_generate_kwargs(generate_config)
214
+ streamer = self._model.language.generate(
215
+ **_inputs,
216
+ **configs,
220
217
  do_sample=False,
221
218
  use_cache=True,
222
219
  )
220
+ return streamer, len(_inputs["input_ids"][0])
223
221
 
224
- if stream:
225
- it = self._generate_stream(streamer, stop_str, include_usage, prompt)
226
- return self._to_chat_completion_chunks(it)
227
- else:
228
- return self._generate(streamer, stop_str)
229
-
230
- def _generate(self, streamer, stop_str) -> ChatCompletion:
231
- generated_text = ""
232
-
233
- for new_text in streamer:
234
- if isinstance(new_text, torch.Tensor):
235
- new_text = self._tokenizer.decode(
236
- new_text.cpu().tolist(), skip_special_tokens=True
237
- )
238
-
239
- if new_text.endswith(stop_str):
240
- new_text = new_text[: -len(stop_str)]
241
-
242
- generated_text += new_text
243
-
244
- return generate_chat_completion(self.model_uid, generated_text)
245
-
246
- def _generate_stream(
247
- self, streamer, stop_str, include_usage, prompt
248
- ) -> Iterator[CompletionChunk]:
249
- completion_id = str(uuid.uuid1())
250
- prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
251
- input_ids = self._tokenizer(prompt).input_ids
252
- prompt_tokens = len(input_ids)
253
- for i, new_text in enumerate(streamer):
254
- if new_text.endswith(stop_str):
255
- new_text = new_text[: -len(stop_str)]
256
- completion_tokens = i
257
- total_tokens = prompt_tokens + completion_tokens
258
- yield generate_completion_chunk(
259
- chunk_text=new_text,
260
- finish_reason=None,
261
- chunk_id=completion_id,
262
- model_uid=self.model_uid,
263
- prompt_tokens=prompt_tokens,
264
- completion_tokens=completion_tokens,
265
- total_tokens=total_tokens,
266
- has_choice=True,
267
- has_content=True,
222
+ def check_conditions(self, new_text: str) -> Tuple[str, bool]:
223
+ stop_str = self.get_stop_strs()[0]
224
+ if isinstance(new_text, torch.Tensor):
225
+ new_text = self._tokenizer.decode(
226
+ new_text.cpu().tolist(), skip_special_tokens=True
268
227
  )
269
- yield generate_completion_chunk(
270
- chunk_text=None,
271
- finish_reason="stop",
272
- chunk_id=completion_id,
273
- model_uid=self.model_uid,
274
- prompt_tokens=prompt_tokens,
275
- completion_tokens=completion_tokens,
276
- total_tokens=total_tokens,
277
- has_choice=True,
278
- has_content=False,
279
- )
280
228
 
281
- if include_usage:
282
- yield generate_completion_chunk(
283
- chunk_text=None,
284
- finish_reason=None,
285
- chunk_id=completion_id,
286
- model_uid=self.model_uid,
287
- prompt_tokens=prompt_tokens,
288
- completion_tokens=completion_tokens,
289
- total_tokens=total_tokens,
290
- has_choice=False,
291
- has_content=False,
292
- )
229
+ if new_text.endswith(stop_str):
230
+ new_text = new_text[: -len(stop_str)]
231
+ return new_text, False