xinference 1.6.0.post1__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (87) hide show
  1. xinference/_version.py +3 -3
  2. xinference/client/restful/restful_client.py +1 -1
  3. xinference/conftest.py +0 -7
  4. xinference/core/media_interface.py +9 -8
  5. xinference/core/model.py +13 -6
  6. xinference/core/scheduler.py +1 -10
  7. xinference/core/worker.py +0 -10
  8. xinference/model/audio/model_spec.json +53 -1
  9. xinference/model/audio/model_spec_modelscope.json +57 -1
  10. xinference/model/embedding/core.py +19 -11
  11. xinference/model/image/model_spec.json +10 -1
  12. xinference/model/image/model_spec_modelscope.json +20 -0
  13. xinference/model/llm/__init__.py +6 -54
  14. xinference/model/llm/core.py +19 -5
  15. xinference/model/llm/llama_cpp/core.py +59 -3
  16. xinference/model/llm/llama_cpp/memory.py +455 -0
  17. xinference/model/llm/llm_family.json +185 -397
  18. xinference/model/llm/llm_family.py +88 -16
  19. xinference/model/llm/llm_family_modelscope.json +199 -421
  20. xinference/model/llm/llm_family_openmind_hub.json +0 -34
  21. xinference/model/llm/sglang/core.py +4 -0
  22. xinference/model/llm/transformers/__init__.py +27 -6
  23. xinference/model/llm/transformers/chatglm.py +4 -2
  24. xinference/model/llm/transformers/core.py +49 -28
  25. xinference/model/llm/transformers/deepseek_v2.py +6 -49
  26. xinference/model/llm/transformers/gemma3.py +119 -164
  27. xinference/{thirdparty/omnilmm/train → model/llm/transformers/multimodal}/__init__.py +1 -1
  28. xinference/model/llm/transformers/{cogagent.py → multimodal/cogagent.py} +58 -95
  29. xinference/model/llm/transformers/multimodal/core.py +205 -0
  30. xinference/model/llm/transformers/{deepseek_vl2.py → multimodal/deepseek_vl2.py} +59 -120
  31. xinference/model/llm/transformers/multimodal/gemma3.py +117 -0
  32. xinference/model/llm/transformers/{glm4v.py → multimodal/glm4v.py} +57 -93
  33. xinference/model/llm/transformers/multimodal/intern_vl.py +412 -0
  34. xinference/model/llm/transformers/{minicpmv26.py → multimodal/minicpmv26.py} +55 -102
  35. xinference/model/llm/transformers/{ovis2.py → multimodal/ovis2.py} +114 -175
  36. xinference/model/llm/transformers/{qwen-omni.py → multimodal/qwen-omni.py} +82 -167
  37. xinference/model/llm/transformers/multimodal/qwen2_audio.py +131 -0
  38. xinference/model/llm/transformers/{qwen2_vl.py → multimodal/qwen2_vl.py} +224 -256
  39. xinference/model/llm/transformers/opt.py +4 -2
  40. xinference/model/llm/transformers/utils.py +6 -37
  41. xinference/model/llm/vllm/core.py +4 -0
  42. xinference/model/rerank/core.py +7 -1
  43. xinference/model/rerank/utils.py +17 -0
  44. xinference/web/ui/build/asset-manifest.json +3 -3
  45. xinference/web/ui/build/index.html +1 -1
  46. xinference/web/ui/build/static/js/main.ddf9eaee.js +3 -0
  47. xinference/web/ui/build/static/js/main.ddf9eaee.js.map +1 -0
  48. xinference/web/ui/node_modules/.cache/babel-loader/12e637ed5fa9ca6491b03892b6949c03afd4960fe36ac25744488e7e1982aa19.json +1 -0
  49. xinference/web/ui/node_modules/.cache/babel-loader/567e49df411efb24425d289bb484758cb57067ca54f8b5c67fe4505f698deb96.json +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/77ac2665a784e99501ae95d32ef5937837a0439a47e965d291b38e99cb619f5b.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/d4ed4e82bfe69915999ec83f5feaa4301c75ecc6bdf1c78f2d03e4671ecbefc8.json +1 -0
  52. xinference/web/ui/src/locales/en.json +3 -1
  53. xinference/web/ui/src/locales/zh.json +3 -1
  54. {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/METADATA +6 -4
  55. {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/RECORD +60 -76
  56. {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/WHEEL +1 -1
  57. xinference/model/llm/transformers/cogvlm2.py +0 -442
  58. xinference/model/llm/transformers/cogvlm2_video.py +0 -333
  59. xinference/model/llm/transformers/deepseek_vl.py +0 -280
  60. xinference/model/llm/transformers/glm_edge_v.py +0 -213
  61. xinference/model/llm/transformers/intern_vl.py +0 -526
  62. xinference/model/llm/transformers/internlm2.py +0 -94
  63. xinference/model/llm/transformers/minicpmv25.py +0 -193
  64. xinference/model/llm/transformers/omnilmm.py +0 -132
  65. xinference/model/llm/transformers/qwen2_audio.py +0 -179
  66. xinference/model/llm/transformers/qwen_vl.py +0 -360
  67. xinference/thirdparty/omnilmm/LICENSE +0 -201
  68. xinference/thirdparty/omnilmm/__init__.py +0 -0
  69. xinference/thirdparty/omnilmm/chat.py +0 -218
  70. xinference/thirdparty/omnilmm/constants.py +0 -4
  71. xinference/thirdparty/omnilmm/conversation.py +0 -332
  72. xinference/thirdparty/omnilmm/model/__init__.py +0 -1
  73. xinference/thirdparty/omnilmm/model/omnilmm.py +0 -595
  74. xinference/thirdparty/omnilmm/model/resampler.py +0 -166
  75. xinference/thirdparty/omnilmm/model/utils.py +0 -578
  76. xinference/thirdparty/omnilmm/train/train_utils.py +0 -150
  77. xinference/thirdparty/omnilmm/utils.py +0 -134
  78. xinference/web/ui/build/static/js/main.ae579a97.js +0 -3
  79. xinference/web/ui/build/static/js/main.ae579a97.js.map +0 -1
  80. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +0 -1
  81. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +0 -1
  82. xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +0 -1
  83. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +0 -1
  84. /xinference/web/ui/build/static/js/{main.ae579a97.js.LICENSE.txt → main.ddf9eaee.js.LICENSE.txt} +0 -0
  85. {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/entry_points.txt +0 -0
  86. {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/licenses/LICENSE +0 -0
  87. {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2022-2023 XProbe Inc.
1
+ # Copyright 2022-2025 XProbe Inc.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,33 +12,26 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import logging
15
- import uuid
16
- from typing import Dict, Iterator, List, Optional, Union
15
+ from threading import Thread
16
+ from typing import Any, Dict, Iterator, List, Tuple
17
17
 
18
18
  import torch
19
19
  from PIL import Image
20
20
 
21
- from ....types import (
22
- ChatCompletion,
23
- ChatCompletionChunk,
24
- ChatCompletionMessage,
25
- CompletionChunk,
26
- )
27
- from ..llm_family import LLMFamilyV1, LLMSpecV1
28
- from ..utils import generate_chat_completion, generate_completion_chunk
29
- from .core import PytorchChatModel, PytorchGenerateConfig
30
- from .utils import cache_clean
21
+ from ...llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
22
+ from ..core import register_non_default_model
23
+ from .core import PytorchMultiModalModel
31
24
 
32
25
  logger = logging.getLogger(__name__)
33
26
 
34
27
 
35
- class Ovis2ChatModel(PytorchChatModel):
36
- def __init__(self, *args, **kwargs):
37
- super().__init__(*args, **kwargs)
38
- self._tokenizer = None
39
- self._model = None
40
- self._device = None
41
- self._processor = None
28
+ @register_transformer
29
+ @register_non_default_model("Ovis2")
30
+ class Ovis2ChatModel(PytorchMultiModalModel):
31
+ def __init__(self, *args, **kws):
32
+ super().__init__(*args, **kws)
33
+ self._text_tokenizer = None
34
+ self._visual_tokenizer = None
42
35
 
43
36
  @classmethod
44
37
  def match_json(
@@ -51,127 +44,28 @@ class Ovis2ChatModel(PytorchChatModel):
51
44
  return True
52
45
  return False
53
46
 
54
- def load(self):
47
+ def decide_device(self):
48
+ pass
49
+
50
+ def load_processor(self):
51
+ pass
52
+
53
+ def load_multimodal_model(self):
55
54
  from transformers import AutoModelForCausalLM
56
55
 
57
- # load model
56
+ kwargs = self.apply_bnb_quantization()
58
57
  self._model = AutoModelForCausalLM.from_pretrained(
59
58
  self.model_path,
60
59
  torch_dtype=torch.bfloat16,
61
60
  multimodal_max_length=32768,
62
61
  trust_remote_code=True,
62
+ **kwargs,
63
63
  ).cuda()
64
64
  self._text_tokenizer = self._model.get_text_tokenizer()
65
65
  self._visual_tokenizer = self._model.get_visual_tokenizer()
66
66
 
67
- @cache_clean
68
- def chat(
69
- self,
70
- messages: List[ChatCompletionMessage], # type: ignore
71
- generate_config: Optional[PytorchGenerateConfig] = None,
72
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
73
- messages = self._transform_messages(messages)
74
-
75
- generate_config = generate_config if generate_config else {}
76
-
77
- stream = generate_config.get("stream", False) if generate_config else False
78
-
79
- if stream:
80
- # raise NotImplementedError("Stream is not supported for Ovis2 model.")
81
- it = self._generate_stream(messages, generate_config)
82
- return self._to_chat_completion_chunks(it)
83
- else:
84
- c = self._generate(messages, generate_config)
85
- return c
86
-
87
- def _generate(
88
- self, messages: List, config: PytorchGenerateConfig = {}
89
- ) -> ChatCompletion:
90
- input_ids, attention_mask, pixel_values, gen_kwargs = self._generate_chat_data(
91
- messages, config
92
- )
93
-
94
- # generate output
95
- with torch.inference_mode():
96
- gen_kwargs.update(
97
- dict(
98
- pixel_values=pixel_values,
99
- attention_mask=attention_mask,
100
- )
101
- )
102
-
103
- output_ids = self._model.generate(
104
- input_ids,
105
- **gen_kwargs,
106
- )[0]
107
- output = self._text_tokenizer.decode(output_ids, skip_special_tokens=True)
108
- return generate_chat_completion(self.model_uid, output)
109
-
110
- def _generate_stream(
111
- self, messages: List, config: PytorchGenerateConfig = {}
112
- ) -> Iterator[CompletionChunk]:
113
- from threading import Thread
114
-
115
- from transformers import TextIteratorStreamer
116
-
117
- input_ids, attention_mask, pixel_values, gen_kwargs = self._generate_chat_data(
118
- messages, config
119
- )
120
-
121
- _, inputs_embeds, _, attention_mask = self._model.merge_multimodal(
122
- text_input_ids=input_ids,
123
- text_attention_masks=attention_mask,
124
- text_labels=None,
125
- pixel_values=pixel_values,
126
- left_padding=True,
127
- )
128
-
129
- streamer = TextIteratorStreamer(
130
- self._text_tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
131
- )
132
-
133
- gen_kwargs.update(
134
- dict(
135
- inputs_embeds=inputs_embeds,
136
- attention_mask=attention_mask,
137
- streamer=streamer,
138
- )
139
- )
140
-
141
- inputs_embeds = inputs_embeds.detach()
142
- torch.cuda.empty_cache()
143
-
144
- thread = Thread(target=self._model.llm.generate, kwargs=gen_kwargs)
145
- thread.start()
146
-
147
- completion_id = str(uuid.uuid1())
148
-
149
- for new_text in streamer:
150
- yield generate_completion_chunk(
151
- chunk_text=new_text,
152
- finish_reason=None,
153
- chunk_id=completion_id,
154
- model_uid=self.model_uid,
155
- prompt_tokens=-1,
156
- completion_tokens=-1,
157
- total_tokens=-1,
158
- has_choice=True,
159
- has_content=True,
160
- )
161
-
162
- yield generate_completion_chunk(
163
- chunk_text=None,
164
- finish_reason="stop",
165
- chunk_id=completion_id,
166
- model_uid=self.model_uid,
167
- prompt_tokens=-1,
168
- completion_tokens=-1,
169
- total_tokens=-1,
170
- has_choice=True,
171
- has_content=False,
172
- )
173
-
174
- def parse_messages_ovis(self, messages: List[Dict]) -> List[Dict]:
67
+ @staticmethod
68
+ def _parse_messages_ovis(messages: List[Dict]) -> List[Dict]:
175
69
  ovis_msgs = []
176
70
  for mess in messages:
177
71
  contents = mess["content"]
@@ -189,12 +83,52 @@ class Ovis2ChatModel(PytorchChatModel):
189
83
 
190
84
  return ovis_msgs
191
85
 
192
- def _generate_chat_data(
193
- self, messages: List[Dict], config: PytorchGenerateConfig = {}
194
- ):
86
+ @staticmethod
87
+ def _convert_video_tensors_to_pil(video_inputs: List) -> List[Image.Image]:
88
+ """Convert video tensors to a list of PIL images"""
89
+ from torchvision import transforms
90
+
91
+ to_pil = transforms.ToPILImage()
92
+ pil_images = []
93
+
94
+ for video_tensor_4d in video_inputs:
95
+ if isinstance(video_tensor_4d, torch.Tensor):
96
+ # Verify it's a 4D tensor
97
+ if video_tensor_4d.ndim == 4:
98
+ # Iterate through the first dimension (frames) of 4D tensor
99
+ for i in range(video_tensor_4d.size(0)):
100
+ frame_tensor_3d = video_tensor_4d[
101
+ i
102
+ ] # Get 3D frame tensor [C, H, W]
103
+ # Ensure tensor is on CPU before conversion
104
+ if frame_tensor_3d.is_cuda:
105
+ frame_tensor_3d = frame_tensor_3d.cpu()
106
+ try:
107
+ pil_image = to_pil(frame_tensor_3d)
108
+ pil_images.append(pil_image)
109
+ except Exception as e:
110
+ logger.error(
111
+ f"Error converting frame {i} to PIL Image: {e}"
112
+ )
113
+ # Can choose to skip this frame or handle error differently
114
+ else:
115
+ logger.warning(
116
+ f"Expected 4D tensor in video_inputs, but got {video_tensor_4d.ndim}D. Skipping this tensor."
117
+ )
118
+ elif isinstance(video_tensor_4d, Image.Image):
119
+ # If fetch_video returns Image list, add directly
120
+ pil_images.append(video_tensor_4d)
121
+ else:
122
+ logger.warning(
123
+ f"Unexpected type in video_inputs: {type(video_tensor_4d)}. Skipping."
124
+ )
125
+
126
+ return pil_images
127
+
128
+ def _generate_chat_data(self, messages: List[Dict]):
195
129
  from qwen_vl_utils import process_vision_info
196
130
 
197
- messages_ovis = self.parse_messages_ovis(messages)
131
+ messages_ovis = self._parse_messages_ovis(messages)
198
132
  max_partition = None
199
133
  prompt = messages_ovis[-1]["value"]
200
134
 
@@ -246,57 +180,62 @@ class Ovis2ChatModel(PytorchChatModel):
246
180
  )
247
181
  pixel_values = [pixel_values]
248
182
 
249
- gen_kwargs = dict(
250
- max_new_tokens=config.get("max_tokens", 1024),
183
+ return input_ids, attention_mask, pixel_values
184
+
185
+ def build_generate_kwargs(
186
+ self,
187
+ generate_config: Dict,
188
+ ) -> Dict[str, Any]:
189
+ return dict(
190
+ max_new_tokens=generate_config.get("max_tokens", 1024),
251
191
  do_sample=False,
252
192
  top_p=None,
253
193
  top_k=None,
254
- temperature=config.get("temperature", None),
194
+ temperature=generate_config.get("temperature", None),
255
195
  repetition_penalty=None,
256
196
  eos_token_id=self._model.generation_config.eos_token_id,
257
197
  pad_token_id=self._text_tokenizer.pad_token_id,
258
198
  use_cache=True,
259
199
  )
260
200
 
261
- return input_ids, attention_mask, pixel_values, gen_kwargs
201
+ def build_inputs_from_messages(
202
+ self,
203
+ messages: List[Dict],
204
+ generate_config: Dict,
205
+ ):
206
+ msgs = self._transform_messages(messages)
207
+ input_ids, attention_mask, pixel_values = self._generate_chat_data(msgs)
208
+ _, inputs_embeds, _, attention_mask = self._model.merge_multimodal(
209
+ text_input_ids=input_ids,
210
+ text_attention_masks=attention_mask,
211
+ text_labels=None,
212
+ pixel_values=pixel_values,
213
+ left_padding=True,
214
+ )
215
+ inputs_embeds = inputs_embeds.detach()
216
+ torch.cuda.empty_cache()
217
+ return dict(
218
+ input_ids=input_ids,
219
+ inputs_embeds=inputs_embeds,
220
+ attention_mask=attention_mask,
221
+ )
262
222
 
263
- def _convert_video_tensors_to_pil(self, video_inputs: List) -> List[Image.Image]:
264
- """Convert video tensors to a list of PIL images"""
265
- from torchvision import transforms
223
+ def build_streaming_iter(
224
+ self,
225
+ messages: List[Dict],
226
+ generate_config: Dict,
227
+ ) -> Tuple[Iterator, int]:
228
+ from transformers import TextIteratorStreamer
266
229
 
267
- to_pil = transforms.ToPILImage()
268
- pil_images = []
230
+ streamer = TextIteratorStreamer(
231
+ self._text_tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
232
+ )
233
+ config = self.build_generate_kwargs(generate_config)
234
+ inputs = self.build_inputs_from_messages(messages, generate_config)
235
+ input_ids = inputs.pop("input_ids")
269
236
 
270
- for video_tensor_4d in video_inputs:
271
- if isinstance(video_tensor_4d, torch.Tensor):
272
- # Verify it's a 4D tensor
273
- if video_tensor_4d.ndim == 4:
274
- # Iterate through the first dimension (frames) of 4D tensor
275
- for i in range(video_tensor_4d.size(0)):
276
- frame_tensor_3d = video_tensor_4d[
277
- i
278
- ] # Get 3D frame tensor [C, H, W]
279
- # Ensure tensor is on CPU before conversion
280
- if frame_tensor_3d.is_cuda:
281
- frame_tensor_3d = frame_tensor_3d.cpu()
282
- try:
283
- pil_image = to_pil(frame_tensor_3d)
284
- pil_images.append(pil_image)
285
- except Exception as e:
286
- logger.error(
287
- f"Error converting frame {i} to PIL Image: {e}"
288
- )
289
- # Can choose to skip this frame or handle error differently
290
- else:
291
- logger.warning(
292
- f"Expected 4D tensor in video_inputs, but got {video_tensor_4d.ndim}D. Skipping this tensor."
293
- )
294
- elif isinstance(video_tensor_4d, Image.Image):
295
- # If fetch_video returns Image list, add directly
296
- pil_images.append(video_tensor_4d)
297
- else:
298
- logger.warning(
299
- f"Unexpected type in video_inputs: {type(video_tensor_4d)}. Skipping."
300
- )
237
+ gen_kwargs = dict(**inputs, **config, streamer=streamer)
301
238
 
302
- return pil_images
239
+ thread = Thread(target=self._model.llm.generate, kwargs=gen_kwargs)
240
+ thread.start()
241
+ return streamer, len(input_ids[0])
@@ -11,49 +11,36 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
14
  import base64
16
15
  import importlib.util
17
16
  import io
18
17
  import logging
19
- import sys
20
18
  import time
21
19
  import uuid
22
- from typing import Dict, Iterator, List, Optional, Union
20
+ from threading import Thread
21
+ from typing import Any, Dict, Iterator, List, Optional, Tuple
23
22
 
24
- from ....model.utils import select_device
25
- from ....types import (
23
+ from .....model.utils import select_device
24
+ from .....types import (
26
25
  ChatCompletion,
27
26
  ChatCompletionAudio,
28
27
  ChatCompletionChoice,
29
- ChatCompletionChunk,
30
- ChatCompletionMessage,
31
- CompletionChunk,
32
28
  CompletionUsage,
33
29
  )
34
- from ..llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
35
- from ..utils import generate_completion_chunk
36
- from .core import PytorchChatModel, PytorchGenerateConfig, register_non_default_model
37
- from .utils import cache_clean
30
+ from ...llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
31
+ from ..core import PytorchGenerateConfig, register_non_default_model
32
+ from .core import PytorchMultiModalModel
38
33
 
39
34
  logger = logging.getLogger(__name__)
40
35
 
41
- DEFAULT_SYSTEM_PROMPT = (
42
- "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
43
- "capable of perceiving auditory and visual inputs, as well as generating text and speech."
44
- )
45
-
46
36
 
47
37
  @register_transformer
48
38
  @register_non_default_model("qwen2.5-omni")
49
- class Qwen2_5OmniChatModel(PytorchChatModel):
50
- def __init__(self, *args, **kwargs):
51
- super().__init__(*args, **kwargs)
52
-
53
- self._tokenizer = None
54
- self._model = None
55
- self._device = None
56
- self._processor = None
39
+ class Qwen2_5OmniChatModel(PytorchMultiModalModel):
40
+ DEFAULT_SYSTEM_PROMPT = (
41
+ "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
42
+ "capable of perceiving auditory and visual inputs, as well as generating text and speech."
43
+ )
57
44
 
58
45
  @classmethod
59
46
  def match_json(
@@ -66,23 +53,24 @@ class Qwen2_5OmniChatModel(PytorchChatModel):
66
53
  return True
67
54
  return False
68
55
 
69
- def load(self):
70
- logger.debug(
71
- "Try to load model, current python: %s, sys path: %s",
72
- sys.executable,
73
- sys.path,
74
- )
75
-
76
- from transformers import (
77
- Qwen2_5OmniForConditionalGeneration,
78
- Qwen2_5OmniProcessor,
79
- )
80
-
56
+ def decide_device(self):
81
57
  device = self._pytorch_model_config.get("device", "auto")
82
58
  device = select_device(device)
83
59
  self._device = device
60
+
61
+ def load_processor(self):
62
+ from transformers import Qwen2_5OmniProcessor
63
+
64
+ self._processor = Qwen2_5OmniProcessor.from_pretrained(
65
+ self.model_path, trust_remote_code=True
66
+ )
67
+ self._tokenizer = self._processor.tokenizer
68
+
69
+ def load_multimodal_model(self):
70
+ from transformers import Qwen2_5OmniForConditionalGeneration
71
+
84
72
  # for multiple GPU, set back to auto to make multiple devices work
85
- device = "auto" if device == "cuda" else device
73
+ device = "auto" if self._device == "cuda" else self._device
86
74
  flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
87
75
  kwargs = (
88
76
  {}
@@ -92,10 +80,6 @@ class Qwen2_5OmniChatModel(PytorchChatModel):
92
80
  kwargs = self.apply_bnb_quantization(kwargs)
93
81
  logger.debug("Loading model with extra kwargs: %s", kwargs)
94
82
 
95
- self._processor = Qwen2_5OmniProcessor.from_pretrained(
96
- self.model_path, trust_remote_code=True
97
- )
98
- self._tokenizer = self._processor.tokenizer
99
83
  self._model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
100
84
  self.model_path,
101
85
  torch_dtype="auto",
@@ -104,28 +88,9 @@ class Qwen2_5OmniChatModel(PytorchChatModel):
104
88
  **kwargs,
105
89
  )
106
90
 
107
- @cache_clean
108
- def chat(
109
- self,
110
- messages: List[Dict],
111
- generate_config: Optional[PytorchGenerateConfig] = None,
112
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
113
- messages = self._transform_messages(messages)
114
-
115
- generate_config = generate_config if generate_config else {}
116
-
117
- stream = generate_config.get("stream", False) if generate_config else False
118
-
119
- if stream:
120
- it = self._generate_stream(messages, generate_config)
121
- return self._to_chat_completion_chunks(it)
122
- else:
123
- c = self._generate(messages, generate_config)
124
- return c
125
-
126
91
  def _transform_messages(
127
92
  self,
128
- messages: Union[List[ChatCompletionMessage], List[dict]],
93
+ messages: List[dict], # type: ignore
129
94
  ):
130
95
  messages = super()._transform_messages(messages)
131
96
  if messages[0]["role"] != "system":
@@ -133,23 +98,24 @@ class Qwen2_5OmniChatModel(PytorchChatModel):
133
98
  0,
134
99
  {
135
100
  "role": "system",
136
- "content": [{"type": "text", "text": DEFAULT_SYSTEM_PROMPT}], # type: ignore
101
+ "content": [{"type": "text", "text": self.DEFAULT_SYSTEM_PROMPT}], # type: ignore
137
102
  },
138
103
  )
139
104
  else:
140
105
  logger.debug("Force to set system prompt")
141
- messages[0]["content"] = [{"type": "text", "text": DEFAULT_SYSTEM_PROMPT}] # type: ignore
106
+ messages[0]["content"] = [{"type": "text", "text": self.DEFAULT_SYSTEM_PROMPT}] # type: ignore
142
107
  return messages
143
108
 
144
- def _generate(
145
- self, messages: List, config: PytorchGenerateConfig = {}
146
- ) -> ChatCompletion:
147
- import soundfile as sf
109
+ def build_inputs_from_messages(
110
+ self,
111
+ messages: List[Dict],
112
+ generate_config: Dict,
113
+ ):
148
114
  from qwen_omni_utils import process_mm_info
149
115
 
150
- use_audio_in_video = config.get("use_audio_in_video", True)
151
- voice = config.get("voice", "Chelsie")
116
+ use_audio_in_video = generate_config.get("use_audio_in_video", True)
152
117
 
118
+ messages = self._transform_messages(messages)
153
119
  text = self._processor.apply_chat_template(
154
120
  messages, tokenize=False, add_generation_prompt=True
155
121
  )
@@ -169,15 +135,54 @@ class Qwen2_5OmniChatModel(PytorchChatModel):
169
135
  use_audio_in_video=use_audio_in_video,
170
136
  )
171
137
  inputs = inputs.to(self._device)
138
+ return inputs
172
139
 
173
- # Inference: Generation of the output
174
- generated_ids, audio = self._model.generate(
175
- **inputs,
176
- speaker=voice,
177
- max_new_tokens=config.get("max_tokens", 512),
178
- temperature=config.get("temperature", 1),
179
- use_audio_in_video=use_audio_in_video,
140
+ def build_generate_kwargs(
141
+ self,
142
+ generate_config: Dict,
143
+ ) -> Dict[str, Any]:
144
+ voice = generate_config.get("voice", "Chelsie")
145
+ return {
146
+ "max_new_tokens": generate_config.get("max_tokens", 512),
147
+ "temperature": generate_config.get("temperature", 1),
148
+ "speaker": voice,
149
+ }
150
+
151
+ def build_streaming_iter(
152
+ self,
153
+ messages: List[Dict],
154
+ generate_config: Dict,
155
+ ) -> Tuple[Iterator, int]:
156
+ from transformers import TextIteratorStreamer
157
+
158
+ tokenizer = self._tokenizer
159
+ streamer = TextIteratorStreamer(
160
+ tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
180
161
  )
162
+
163
+ config = self.build_generate_kwargs(generate_config)
164
+ inputs = self.build_inputs_from_messages(messages, generate_config)
165
+ gen_kwargs = dict(**inputs, **config, streamer=streamer)
166
+ thread = Thread(target=self._model.generate, kwargs=gen_kwargs)
167
+ thread.start()
168
+ return streamer, len(inputs.input_ids[0])
169
+
170
+ def generate_non_streaming(
171
+ self,
172
+ messages: List[Dict],
173
+ generate_config: Optional[PytorchGenerateConfig] = None,
174
+ ) -> ChatCompletion:
175
+ """
176
+ Special case for qwen2.5-omni, since it has audio output
177
+ """
178
+ import soundfile as sf
179
+
180
+ generate_config = generate_config if generate_config else {} # type: ignore
181
+ config = self.build_generate_kwargs(generate_config) # type: ignore
182
+ inputs = self.build_inputs_from_messages(messages, generate_config) # type: ignore
183
+ use_audio_in_video = generate_config.get("use_audio_in_video", True)
184
+ gen_kwargs = dict(**inputs, **config, use_audio_in_video=use_audio_in_video)
185
+ generated_ids, audio = self._model.generate(**gen_kwargs)
181
186
  generated_ids_trimmed = [
182
187
  out_ids[len(in_ids) :]
183
188
  for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
@@ -223,93 +228,3 @@ class Qwen2_5OmniChatModel(PytorchChatModel):
223
228
  prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
224
229
  ),
225
230
  )
226
-
227
- def _generate_stream(
228
- self, messages: List, config: PytorchGenerateConfig = {}
229
- ) -> Iterator[CompletionChunk]:
230
- from threading import Thread
231
-
232
- from qwen_omni_utils import process_mm_info
233
- from transformers import TextIteratorStreamer
234
-
235
- use_audio_in_video = config.get("use_audio_in_video", True)
236
- voice = config.get("voice", "Chelsie")
237
-
238
- text = self._processor.apply_chat_template(
239
- messages, tokenize=False, add_generation_prompt=True
240
- )
241
- audios, images, videos = process_mm_info(
242
- messages, use_audio_in_video=use_audio_in_video
243
- )
244
- logger.debug(
245
- "Text, audio, image, video: %s, %s, %s, %s", text, audios, images, videos
246
- )
247
- inputs = self._processor(
248
- text=text,
249
- images=images,
250
- audio=audios,
251
- videos=videos,
252
- padding=True,
253
- return_tensors="pt",
254
- use_audio_in_video=use_audio_in_video,
255
- )
256
- inputs = inputs.to(self._device)
257
-
258
- tokenizer = self._tokenizer
259
- streamer = TextIteratorStreamer(
260
- tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
261
- )
262
-
263
- # TODO(xuye): Cannot find a way to streaming output,
264
- # will implement it when it's supported
265
-
266
- gen_kwargs = {
267
- "max_new_tokens": config.get("max_tokens", 512),
268
- "temperature": config.get("temperature", 1),
269
- "streamer": streamer,
270
- "speaker": voice,
271
- **inputs,
272
- }
273
- error = None
274
-
275
- def model_generate():
276
- try:
277
- return self._model.generate(**gen_kwargs)
278
- except Exception:
279
- nonlocal error
280
- error = sys.exc_info()
281
- streamer.end()
282
- raise
283
-
284
- thread = Thread(target=model_generate)
285
- thread.start()
286
-
287
- completion_id = str(uuid.uuid1())
288
- for new_text in streamer:
289
- yield generate_completion_chunk(
290
- chunk_text=new_text,
291
- finish_reason=None,
292
- chunk_id=completion_id,
293
- model_uid=self.model_uid,
294
- prompt_tokens=-1,
295
- completion_tokens=-1,
296
- total_tokens=-1,
297
- has_choice=True,
298
- has_content=True,
299
- )
300
-
301
- if error:
302
- _, err, tb = error # type: ignore
303
- raise err.with_traceback(tb)
304
-
305
- yield generate_completion_chunk(
306
- chunk_text=None,
307
- finish_reason="stop",
308
- chunk_id=completion_id,
309
- model_uid=self.model_uid,
310
- prompt_tokens=-1,
311
- completion_tokens=-1,
312
- total_tokens=-1,
313
- has_choice=True,
314
- has_content=False,
315
- )