xinference 0.15.0__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (83) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +204 -1
  3. xinference/client/restful/restful_client.py +4 -2
  4. xinference/core/image_interface.py +28 -0
  5. xinference/core/model.py +28 -0
  6. xinference/core/supervisor.py +6 -0
  7. xinference/model/audio/fish_speech.py +9 -9
  8. xinference/model/audio/model_spec.json +9 -9
  9. xinference/model/audio/whisper.py +4 -1
  10. xinference/model/image/core.py +2 -1
  11. xinference/model/image/model_spec.json +16 -4
  12. xinference/model/image/model_spec_modelscope.json +16 -4
  13. xinference/model/image/sdapi.py +136 -0
  14. xinference/model/image/stable_diffusion/core.py +148 -20
  15. xinference/model/llm/__init__.py +8 -0
  16. xinference/model/llm/llm_family.json +393 -0
  17. xinference/model/llm/llm_family.py +3 -1
  18. xinference/model/llm/llm_family_modelscope.json +408 -3
  19. xinference/model/llm/sglang/core.py +3 -0
  20. xinference/model/llm/transformers/chatglm.py +1 -1
  21. xinference/model/llm/transformers/core.py +6 -0
  22. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  23. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  24. xinference/model/llm/transformers/qwen2_vl.py +31 -5
  25. xinference/model/llm/utils.py +104 -84
  26. xinference/model/llm/vllm/core.py +8 -0
  27. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +2 -3
  28. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +1 -1
  29. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  30. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  31. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  32. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  33. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  34. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  35. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  36. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  37. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  38. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  39. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  40. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  41. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  42. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  43. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  44. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  45. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  46. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  47. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  48. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  49. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  50. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  51. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  52. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  53. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  54. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  55. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  56. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  57. xinference/types.py +7 -4
  58. xinference/web/ui/build/asset-manifest.json +6 -6
  59. xinference/web/ui/build/index.html +1 -1
  60. xinference/web/ui/build/static/css/{main.632e9148.css → main.5061c4c3.css} +2 -2
  61. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  62. xinference/web/ui/build/static/js/{main.9cfafbd6.js → main.754740c0.js} +3 -3
  63. xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  66. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/METADATA +9 -3
  67. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/RECORD +72 -74
  68. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  69. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  72. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  73. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  74. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  75. xinference/web/ui/build/static/css/main.632e9148.css.map +0 -1
  76. xinference/web/ui/build/static/js/main.9cfafbd6.js.map +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/01d6d198156bacbd436c51435edbd4b2cacd47a79db929105eba30f74b67d48d.json +0 -1
  78. xinference/web/ui/node_modules/.cache/babel-loader/59eb25f514afcc4fefd1b309d192b2455f1e0aec68a9de598ca4b2333fe2c774.json +0 -1
  79. /xinference/web/ui/build/static/js/{main.9cfafbd6.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +0 -0
  80. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
  81. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
  82. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
  83. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,340 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import uuid
16
+ from typing import Dict, Iterator, List, Optional, Union
17
+
18
+ import torch
19
+
20
+ from ....types import (
21
+ ChatCompletion,
22
+ ChatCompletionChunk,
23
+ Completion,
24
+ CompletionChunk,
25
+ PytorchGenerateConfig,
26
+ )
27
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
28
+ from ..utils import (
29
+ generate_chat_completion,
30
+ generate_completion,
31
+ generate_completion_chunk,
32
+ )
33
+ from .core import PytorchChatModel, PytorchModel
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class DeepSeekV2PytorchModel(PytorchModel):
39
+ def _load_model(self, **kwargs):
40
+ try:
41
+ from transformers import (
42
+ AutoModelForCausalLM,
43
+ AutoTokenizer,
44
+ GenerationConfig,
45
+ )
46
+ except ImportError:
47
+ error_message = "Failed to import module 'transformers'"
48
+ installation_guide = [
49
+ "Please make sure 'transformers' is installed. ",
50
+ "You can install it by `pip install transformers`\n",
51
+ ]
52
+
53
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
54
+
55
+ tokenizer = AutoTokenizer.from_pretrained(
56
+ self.model_path,
57
+ trust_remote_code=kwargs["trust_remote_code"],
58
+ )
59
+ model = AutoModelForCausalLM.from_pretrained(
60
+ self.model_path,
61
+ attn_implementation="eager",
62
+ torch_dtype=torch.bfloat16,
63
+ trust_remote_code=True,
64
+ device_map="auto",
65
+ )
66
+ model.generation_config = GenerationConfig.from_pretrained(self.model_path)
67
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id
68
+ return model, tokenizer
69
+
70
+ @classmethod
71
+ def match(
72
+ cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
73
+ ) -> bool:
74
+ if llm_spec.model_format != "pytorch":
75
+ return False
76
+ model_family = llm_family.model_family or llm_family.model_name
77
+ if "deepseek-v2" not in model_family:
78
+ return False
79
+ if "generate" not in llm_family.model_ability:
80
+ return False
81
+ return True
82
+
83
+ def generate(
84
+ self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None
85
+ ) -> Union[Completion, Iterator[CompletionChunk]]:
86
+ input_tensor = self._tokenizer(prompt, return_tensors="pt")
87
+ generate_config = self._sanitize_generate_config(generate_config)
88
+ default_generate_config = self._model.generation_config
89
+ generate_kwargs = {
90
+ "input_ids": input_tensor["input_ids"].cuda(),
91
+ "attention_mask": input_tensor["attention_mask"].cuda(),
92
+ "temperature": float(
93
+ generate_config.get("temperature", default_generate_config.temperature)
94
+ ),
95
+ "repetition_penalty": float(generate_config.get("repetition_penalty", 1.0)),
96
+ "top_p": float(generate_config.get("top_p", default_generate_config.top_p)),
97
+ "top_k": int(generate_config.get("top_k", -1)),
98
+ "max_new_tokens": generate_config.get("max_tokens", 512),
99
+ "bos_token_id": default_generate_config.bos_token_id,
100
+ "do_sample": default_generate_config.do_sample,
101
+ "eos_token_id": default_generate_config.eos_token_id,
102
+ }
103
+
104
+ stream = generate_config.get("stream", False)
105
+ if stream:
106
+ return self._generate_stream(generate_kwargs, input_tensor)
107
+ else:
108
+ return self._generate(generate_kwargs, input_tensor)
109
+
110
+ def _generate(self, generate_kwargs, input_ids) -> Completion:
111
+ prompt_tokens = len(input_ids[0])
112
+ logger.info(f"generate_kwargs:{generate_kwargs}")
113
+ generation_output = self._model.generate(**generate_kwargs)
114
+ completion_tokens = len(generation_output[0])
115
+ response = self._tokenizer.decode(
116
+ generation_output[0], skip_special_tokens=True
117
+ )
118
+ return generate_completion(
119
+ self.model_uid,
120
+ response,
121
+ prompt_tokens=prompt_tokens,
122
+ completion_tokens=completion_tokens,
123
+ total_tokens=prompt_tokens + completion_tokens,
124
+ )
125
+
126
+ def _generate_stream(self, generate_kwargs, input_ids):
127
+ from threading import Thread
128
+
129
+ from transformers import TextIteratorStreamer
130
+
131
+ # Initialize the streamer
132
+ streamer = TextIteratorStreamer(
133
+ self._tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
134
+ )
135
+ # Define the generation configuration
136
+ generate_kwargs["streamer"] = streamer
137
+ # Start the model chat in a separate thread
138
+ thread = Thread(
139
+ target=self._model.generate,
140
+ kwargs=generate_kwargs,
141
+ )
142
+ thread.start()
143
+
144
+ completion_id = str(uuid.uuid1())
145
+ prompt_tokens = len(input_ids[0])
146
+ total_tokens, completion_tokens = 0, 0
147
+ # Loop through the streamer to get the new text as it is generated
148
+ for i, new_text in enumerate(streamer):
149
+ completion_tokens = i
150
+ total_tokens = prompt_tokens + completion_tokens
151
+ yield generate_completion_chunk(
152
+ chunk_text=new_text,
153
+ finish_reason=None,
154
+ chunk_id=completion_id,
155
+ model_uid=self.model_uid,
156
+ prompt_tokens=prompt_tokens,
157
+ completion_tokens=completion_tokens,
158
+ total_tokens=total_tokens,
159
+ )
160
+ yield generate_completion_chunk(
161
+ chunk_text=None,
162
+ finish_reason="stop",
163
+ chunk_id=completion_id,
164
+ model_uid=self.model_uid,
165
+ prompt_tokens=prompt_tokens,
166
+ completion_tokens=completion_tokens,
167
+ total_tokens=total_tokens,
168
+ has_choice=True,
169
+ has_content=False,
170
+ )
171
+
172
+
173
+ class DeepSeekV2PytorchChatModel(PytorchChatModel):
174
+ def _load_model(self, **kwargs):
175
+ try:
176
+ from transformers import (
177
+ AutoModelForCausalLM,
178
+ AutoTokenizer,
179
+ GenerationConfig,
180
+ )
181
+ except ImportError:
182
+ error_message = "Failed to import module 'transformers'"
183
+ installation_guide = [
184
+ "Please make sure 'transformers' is installed. ",
185
+ "You can install it by `pip install transformers`\n",
186
+ ]
187
+
188
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
189
+
190
+ tokenizer = AutoTokenizer.from_pretrained(
191
+ self.model_path,
192
+ trust_remote_code=kwargs["trust_remote_code"],
193
+ )
194
+ logger.info(f"kwargs:{kwargs}")
195
+ model = AutoModelForCausalLM.from_pretrained(
196
+ self.model_path,
197
+ attn_implementation="eager",
198
+ torch_dtype=torch.bfloat16,
199
+ trust_remote_code=True,
200
+ device_map="auto",
201
+ )
202
+ model.generation_config = GenerationConfig.from_pretrained(self.model_path)
203
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id
204
+ return model, tokenizer
205
+
206
+ @classmethod
207
+ def match(
208
+ cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
209
+ ) -> bool:
210
+ if llm_spec.model_format != "pytorch":
211
+ return False
212
+ model_family = llm_family.model_family or llm_family.model_name
213
+ if "deepseek-v2" not in model_family:
214
+ return False
215
+ if "chat" not in llm_family.model_ability:
216
+ return False
217
+ return True
218
+
219
+ def chat(
220
+ self,
221
+ messages: List[Dict],
222
+ generate_config: Optional[PytorchGenerateConfig] = None,
223
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
224
+ assert self.model_family.chat_template is not None
225
+ full_prompt = self.get_full_context(
226
+ messages,
227
+ self.model_family.chat_template,
228
+ tokenizer=self._tokenizer,
229
+ )
230
+ input_tensor = self._tokenizer.encode(
231
+ full_prompt,
232
+ padding=False,
233
+ truncation=False,
234
+ max_length=None,
235
+ add_special_tokens=False,
236
+ return_tensors="pt",
237
+ )
238
+
239
+ generate_config = self._sanitize_generate_config(generate_config)
240
+ default_generate_config = self._model.generation_config
241
+ generate_kwargs = {
242
+ "input_ids": input_tensor.cuda(),
243
+ "temperature": float(
244
+ generate_config.get("temperature", default_generate_config.temperature)
245
+ ),
246
+ "repetition_penalty": float(generate_config.get("repetition_penalty", 1.0)),
247
+ "top_p": float(generate_config.get("top_p", default_generate_config.top_p)),
248
+ "top_k": int(generate_config.get("top_k", -1)),
249
+ "max_new_tokens": generate_config.get("max_tokens", 512),
250
+ "bos_token_id": default_generate_config.bos_token_id,
251
+ "do_sample": default_generate_config.do_sample,
252
+ "eos_token_id": default_generate_config.eos_token_id,
253
+ }
254
+
255
+ stream = generate_config.get("stream", False)
256
+ stream_options = generate_config.get("stream_options", None)
257
+ include_usage = (
258
+ stream_options["include_usage"]
259
+ if isinstance(stream_options, dict)
260
+ else False
261
+ )
262
+ if stream:
263
+ chunk = self._generate_stream(generate_kwargs, input_tensor, include_usage)
264
+ return self._to_chat_completion_chunks(chunk)
265
+ else:
266
+ return self._generate(generate_kwargs, input_tensor)
267
+
268
+ def _generate(self, generate_kwargs, input_ids) -> ChatCompletion:
269
+ prompt_tokens = len(input_ids[0])
270
+ generation_output = self._model.generate(**generate_kwargs)
271
+ completion_tokens = len(generation_output[0])
272
+ response = self._tokenizer.decode(
273
+ generation_output[0][input_ids.shape[1] :], skip_special_tokens=True
274
+ )
275
+ return generate_chat_completion(
276
+ self.model_uid,
277
+ response,
278
+ prompt_tokens=prompt_tokens,
279
+ completion_tokens=completion_tokens,
280
+ total_tokens=prompt_tokens + completion_tokens,
281
+ )
282
+
283
+ def _generate_stream(self, generate_kwargs, input_ids, include_usage):
284
+ from threading import Thread
285
+
286
+ from transformers import TextIteratorStreamer
287
+
288
+ # Initialize the streamer
289
+ streamer = TextIteratorStreamer(
290
+ self._tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
291
+ )
292
+ # Define the generation configuration
293
+ generate_kwargs["streamer"] = streamer
294
+ # Start the model chat in a separate thread
295
+ thread = Thread(
296
+ target=self._model.generate,
297
+ kwargs=generate_kwargs,
298
+ )
299
+ thread.start()
300
+
301
+ completion_id = str(uuid.uuid1())
302
+ prompt_tokens = len(input_ids[0])
303
+ total_tokens, completion_tokens = 0, 0
304
+ # Loop through the streamer to get the new text as it is generated
305
+ for i, new_text in enumerate(streamer):
306
+ completion_tokens = max(completion_tokens, len(streamer.token_cache))
307
+ total_tokens = prompt_tokens + completion_tokens
308
+ yield generate_completion_chunk(
309
+ chunk_text=new_text,
310
+ finish_reason=None,
311
+ chunk_id=completion_id,
312
+ model_uid=self.model_uid,
313
+ prompt_tokens=prompt_tokens,
314
+ completion_tokens=completion_tokens,
315
+ total_tokens=total_tokens,
316
+ )
317
+ yield generate_completion_chunk(
318
+ chunk_text=None,
319
+ finish_reason="stop",
320
+ chunk_id=completion_id,
321
+ model_uid=self.model_uid,
322
+ prompt_tokens=prompt_tokens,
323
+ completion_tokens=completion_tokens,
324
+ total_tokens=total_tokens,
325
+ has_choice=True,
326
+ has_content=False,
327
+ )
328
+
329
+ if include_usage:
330
+ yield generate_completion_chunk(
331
+ chunk_text=None,
332
+ finish_reason=None,
333
+ chunk_id=completion_id,
334
+ model_uid=self.model_uid,
335
+ prompt_tokens=prompt_tokens,
336
+ completion_tokens=completion_tokens,
337
+ total_tokens=total_tokens,
338
+ has_choice=False,
339
+ has_content=False,
340
+ )
@@ -0,0 +1,168 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import uuid
16
+ from io import BytesIO
17
+ from typing import Dict, Iterator, List, Optional, Union
18
+ from urllib.request import urlopen
19
+
20
+ import numpy as np
21
+
22
+ from ....model.utils import select_device
23
+ from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
24
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
25
+ from ..utils import generate_chat_completion, generate_completion_chunk
26
+ from .core import PytorchChatModel, PytorchGenerateConfig
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class Qwen2AudioChatModel(PytorchChatModel):
32
+ def __init__(self, *args, **kwargs):
33
+ super().__init__(*args, **kwargs)
34
+ self._processor = None
35
+ self._model = None
36
+ self._device = None
37
+
38
+ @classmethod
39
+ def match(
40
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
41
+ ) -> bool:
42
+ llm_family = model_family.model_family or model_family.model_name
43
+ if "qwen2-audio".lower() in llm_family.lower():
44
+ return True
45
+ return False
46
+
47
+ def load(self):
48
+ from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
49
+
50
+ device = self._pytorch_model_config.get("device", "auto")
51
+ device = select_device(device)
52
+ self._device = device
53
+ # for multiple GPU, set back to auto to make multiple devices work
54
+ device = "auto" if device == "cuda" else device
55
+
56
+ self._processor = AutoProcessor.from_pretrained(
57
+ self.model_path,
58
+ device_map=device,
59
+ # trust_remote_code=True,
60
+ code_revision=self.model_spec.model_revision,
61
+ )
62
+ self._model = Qwen2AudioForConditionalGeneration.from_pretrained(
63
+ self.model_path,
64
+ device_map=device,
65
+ # trust_remote_code=True,
66
+ revision=self.model_spec.model_revision,
67
+ )
68
+
69
+ def _transform_messages(
70
+ self,
71
+ messages: List[Dict],
72
+ ):
73
+ import librosa
74
+
75
+ text = self._processor.apply_chat_template(
76
+ messages, add_generation_prompt=True, tokenize=False
77
+ )
78
+ audios: List[np.ndarray] = []
79
+ for msg in messages:
80
+ content = msg["content"]
81
+ if isinstance(content, List):
82
+ for item in content: # type: ignore
83
+ if item.get("type") == "audio" and "audio_url" in item:
84
+ audio = librosa.load(
85
+ BytesIO(urlopen(item["audio_url"]).read()),
86
+ sr=self._processor.feature_extractor.sampling_rate,
87
+ )[0]
88
+ audios.append(audio)
89
+
90
+ return text, audios
91
+
92
+ def chat(
93
+ self,
94
+ messages: List[Dict],
95
+ generate_config: Optional[PytorchGenerateConfig] = None,
96
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
97
+ text, audios = self._transform_messages(messages)
98
+ inputs = self._processor(
99
+ text=text, audios=audios, return_tensors="pt", padding=True
100
+ )
101
+ inputs.input_ids = inputs.input_ids.to(self._device)
102
+ generate_config = generate_config if generate_config else {}
103
+ stream = generate_config.get("stream", False) if generate_config else False
104
+
105
+ if stream:
106
+ it = self._generate_stream(inputs, generate_config)
107
+ return self._to_chat_completion_chunks(it)
108
+ else:
109
+ c = self._generate(inputs, generate_config)
110
+ return c
111
+
112
+ def _generate(self, inputs, config: PytorchGenerateConfig = {}) -> ChatCompletion:
113
+ generate_ids = self._model.generate(
114
+ **inputs,
115
+ max_length=config.get("max_tokens", 512),
116
+ )
117
+ generate_ids = generate_ids[:, inputs.input_ids.size(1) :]
118
+ response = self._processor.batch_decode(
119
+ generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
120
+ )[0]
121
+ return generate_chat_completion(self.model_uid, response)
122
+
123
+ def _generate_stream(
124
+ self, inputs, config: PytorchGenerateConfig = {}
125
+ ) -> Iterator[CompletionChunk]:
126
+ from threading import Thread
127
+
128
+ from transformers import TextIteratorStreamer
129
+
130
+ tokenizer = self._processor.tokenizer
131
+ streamer = TextIteratorStreamer(
132
+ tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
133
+ )
134
+
135
+ gen_kwargs = {
136
+ "max_new_tokens": config.get("max_tokens", 512),
137
+ "streamer": streamer,
138
+ **inputs,
139
+ }
140
+
141
+ thread = Thread(target=self._model.generate, kwargs=gen_kwargs)
142
+ thread.start()
143
+
144
+ completion_id = str(uuid.uuid1())
145
+ for new_text in streamer:
146
+ yield generate_completion_chunk(
147
+ chunk_text=new_text,
148
+ finish_reason=None,
149
+ chunk_id=completion_id,
150
+ model_uid=self.model_uid,
151
+ prompt_tokens=-1,
152
+ completion_tokens=-1,
153
+ total_tokens=-1,
154
+ has_choice=True,
155
+ has_content=True,
156
+ )
157
+
158
+ yield generate_completion_chunk(
159
+ chunk_text=None,
160
+ finish_reason="stop",
161
+ chunk_id=completion_id,
162
+ model_uid=self.model_uid,
163
+ prompt_tokens=-1,
164
+ completion_tokens=-1,
165
+ total_tokens=-1,
166
+ has_choice=True,
167
+ has_content=False,
168
+ )
@@ -11,7 +11,9 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import importlib.util
14
15
  import logging
16
+ import sys
15
17
  import uuid
16
18
  from typing import Iterator, List, Optional, Union
17
19
 
@@ -59,9 +61,19 @@ class Qwen2VLChatModel(PytorchChatModel):
59
61
  self.model_path, trust_remote_code=True
60
62
  )
61
63
  self._tokenizer = self._processor.tokenizer
62
- self._model = Qwen2VLForConditionalGeneration.from_pretrained(
63
- self.model_path, device_map=device, trust_remote_code=True
64
- ).eval()
64
+ flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
65
+ if flash_attn_installed:
66
+ self._model = Qwen2VLForConditionalGeneration.from_pretrained(
67
+ self.model_path,
68
+ torch_dtype="bfloat16",
69
+ device_map=device,
70
+ attn_implementation="flash_attention_2",
71
+ trust_remote_code=True,
72
+ ).eval()
73
+ else:
74
+ self._model = Qwen2VLForConditionalGeneration.from_pretrained(
75
+ self.model_path, device_map=device, trust_remote_code=True
76
+ ).eval()
65
77
 
66
78
  def _transform_messages(
67
79
  self,
@@ -177,8 +189,18 @@ class Qwen2VLChatModel(PytorchChatModel):
177
189
  "streamer": streamer,
178
190
  **inputs,
179
191
  }
180
-
181
- thread = Thread(target=self._model.generate, kwargs=gen_kwargs)
192
+ error = None
193
+
194
+ def model_generate():
195
+ try:
196
+ return self._model.generate(**gen_kwargs)
197
+ except Exception:
198
+ nonlocal error
199
+ error = sys.exc_info()
200
+ streamer.end()
201
+ raise
202
+
203
+ thread = Thread(target=model_generate)
182
204
  thread.start()
183
205
 
184
206
  completion_id = str(uuid.uuid1())
@@ -195,6 +217,10 @@ class Qwen2VLChatModel(PytorchChatModel):
195
217
  has_content=True,
196
218
  )
197
219
 
220
+ if error:
221
+ _, err, tb = error # type: ignore
222
+ raise err.with_traceback(tb)
223
+
198
224
  yield generate_completion_chunk(
199
225
  chunk_text=None,
200
226
  finish_reason="stop",