xinference 1.6.0.post1__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (124) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +79 -2
  3. xinference/client/restful/restful_client.py +65 -3
  4. xinference/conftest.py +0 -7
  5. xinference/core/media_interface.py +132 -8
  6. xinference/core/model.py +44 -6
  7. xinference/core/scheduler.py +1 -10
  8. xinference/core/supervisor.py +8 -17
  9. xinference/core/worker.py +5 -27
  10. xinference/deploy/cmdline.py +6 -2
  11. xinference/model/audio/chattts.py +24 -39
  12. xinference/model/audio/cosyvoice.py +18 -30
  13. xinference/model/audio/funasr.py +42 -0
  14. xinference/model/audio/model_spec.json +71 -1
  15. xinference/model/audio/model_spec_modelscope.json +76 -2
  16. xinference/model/audio/utils.py +75 -0
  17. xinference/model/core.py +1 -0
  18. xinference/model/embedding/__init__.py +74 -18
  19. xinference/model/embedding/core.py +98 -589
  20. xinference/model/embedding/embed_family.py +133 -0
  21. xinference/{thirdparty/omnilmm/train → model/embedding/flag}/__init__.py +1 -1
  22. xinference/model/embedding/flag/core.py +282 -0
  23. xinference/model/embedding/model_spec.json +24 -0
  24. xinference/model/embedding/model_spec_modelscope.json +24 -0
  25. xinference/model/embedding/sentence_transformers/__init__.py +13 -0
  26. xinference/model/embedding/sentence_transformers/core.py +399 -0
  27. xinference/model/embedding/vllm/core.py +95 -0
  28. xinference/model/image/model_spec.json +30 -3
  29. xinference/model/image/model_spec_modelscope.json +41 -2
  30. xinference/model/image/stable_diffusion/core.py +144 -53
  31. xinference/model/llm/__init__.py +6 -54
  32. xinference/model/llm/core.py +19 -5
  33. xinference/model/llm/llama_cpp/core.py +59 -3
  34. xinference/model/llm/llama_cpp/memory.py +457 -0
  35. xinference/model/llm/llm_family.json +247 -402
  36. xinference/model/llm/llm_family.py +88 -16
  37. xinference/model/llm/llm_family_modelscope.json +260 -421
  38. xinference/model/llm/llm_family_openmind_hub.json +0 -34
  39. xinference/model/llm/sglang/core.py +8 -0
  40. xinference/model/llm/transformers/__init__.py +27 -6
  41. xinference/model/llm/transformers/chatglm.py +4 -2
  42. xinference/model/llm/transformers/core.py +49 -28
  43. xinference/model/llm/transformers/deepseek_v2.py +6 -49
  44. xinference/model/llm/transformers/gemma3.py +119 -164
  45. xinference/model/llm/transformers/multimodal/__init__.py +13 -0
  46. xinference/model/llm/transformers/{cogagent.py → multimodal/cogagent.py} +58 -95
  47. xinference/model/llm/transformers/multimodal/core.py +205 -0
  48. xinference/model/llm/transformers/{deepseek_vl2.py → multimodal/deepseek_vl2.py} +59 -120
  49. xinference/model/llm/transformers/multimodal/gemma3.py +117 -0
  50. xinference/model/llm/transformers/{glm4v.py → multimodal/glm4v.py} +57 -93
  51. xinference/model/llm/transformers/multimodal/intern_vl.py +412 -0
  52. xinference/model/llm/transformers/{minicpmv26.py → multimodal/minicpmv26.py} +55 -102
  53. xinference/model/llm/transformers/{ovis2.py → multimodal/ovis2.py} +114 -175
  54. xinference/model/llm/transformers/{qwen-omni.py → multimodal/qwen-omni.py} +82 -167
  55. xinference/model/llm/transformers/multimodal/qwen2_audio.py +131 -0
  56. xinference/model/llm/transformers/{qwen2_vl.py → multimodal/qwen2_vl.py} +224 -256
  57. xinference/model/llm/transformers/opt.py +4 -2
  58. xinference/model/llm/transformers/utils.py +6 -37
  59. xinference/model/llm/utils.py +11 -0
  60. xinference/model/llm/vllm/core.py +7 -0
  61. xinference/model/rerank/core.py +91 -3
  62. xinference/model/rerank/model_spec.json +24 -0
  63. xinference/model/rerank/model_spec_modelscope.json +24 -0
  64. xinference/model/rerank/utils.py +20 -2
  65. xinference/model/utils.py +38 -1
  66. xinference/model/video/diffusers.py +65 -3
  67. xinference/model/video/model_spec.json +31 -4
  68. xinference/model/video/model_spec_modelscope.json +32 -4
  69. xinference/web/ui/build/asset-manifest.json +6 -6
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/css/main.013f296b.css +2 -0
  72. xinference/web/ui/build/static/css/main.013f296b.css.map +1 -0
  73. xinference/web/ui/build/static/js/main.8a9e3ba0.js +3 -0
  74. xinference/web/ui/build/static/js/main.8a9e3ba0.js.map +1 -0
  75. xinference/web/ui/node_modules/.cache/babel-loader/34cfbfb7836e136ba3261cfd411cc554bf99ba24b35dcceebeaa4f008cb3c9dc.json +1 -0
  76. xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +1 -0
  77. xinference/web/ui/node_modules/.cache/babel-loader/567e49df411efb24425d289bb484758cb57067ca54f8b5c67fe4505f698deb96.json +1 -0
  78. xinference/web/ui/node_modules/.cache/babel-loader/6595880facebca7ceace6f17cf21c3a5a9219a2f52fb0ba9f3cf1131eddbcf6b.json +1 -0
  79. xinference/web/ui/node_modules/.cache/babel-loader/aa998bc2d9c11853add6b8a2e08f50327f56d8824ccaaec92d6dde1b305f0d85.json +1 -0
  80. xinference/web/ui/node_modules/.cache/babel-loader/c748246b1d7bcebc16153be69f37e955bb2145526c47dd425aeeff70d3004dbc.json +1 -0
  81. xinference/web/ui/node_modules/.cache/babel-loader/e31234e95d60a5a7883fbcd70de2475dc1c88c90705df1a530abb68f86f80a51.json +1 -0
  82. xinference/web/ui/src/locales/en.json +21 -8
  83. xinference/web/ui/src/locales/ja.json +224 -0
  84. xinference/web/ui/src/locales/ko.json +224 -0
  85. xinference/web/ui/src/locales/zh.json +21 -8
  86. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/METADATA +14 -11
  87. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/RECORD +93 -100
  88. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/WHEEL +1 -1
  89. xinference/model/llm/transformers/cogvlm2.py +0 -442
  90. xinference/model/llm/transformers/cogvlm2_video.py +0 -333
  91. xinference/model/llm/transformers/deepseek_vl.py +0 -280
  92. xinference/model/llm/transformers/glm_edge_v.py +0 -213
  93. xinference/model/llm/transformers/intern_vl.py +0 -526
  94. xinference/model/llm/transformers/internlm2.py +0 -94
  95. xinference/model/llm/transformers/minicpmv25.py +0 -193
  96. xinference/model/llm/transformers/omnilmm.py +0 -132
  97. xinference/model/llm/transformers/qwen2_audio.py +0 -179
  98. xinference/model/llm/transformers/qwen_vl.py +0 -360
  99. xinference/thirdparty/omnilmm/LICENSE +0 -201
  100. xinference/thirdparty/omnilmm/chat.py +0 -218
  101. xinference/thirdparty/omnilmm/constants.py +0 -4
  102. xinference/thirdparty/omnilmm/conversation.py +0 -332
  103. xinference/thirdparty/omnilmm/model/__init__.py +0 -1
  104. xinference/thirdparty/omnilmm/model/omnilmm.py +0 -595
  105. xinference/thirdparty/omnilmm/model/resampler.py +0 -166
  106. xinference/thirdparty/omnilmm/model/utils.py +0 -578
  107. xinference/thirdparty/omnilmm/train/train_utils.py +0 -150
  108. xinference/thirdparty/omnilmm/utils.py +0 -134
  109. xinference/web/ui/build/static/css/main.337afe76.css +0 -2
  110. xinference/web/ui/build/static/css/main.337afe76.css.map +0 -1
  111. xinference/web/ui/build/static/js/main.ae579a97.js +0 -3
  112. xinference/web/ui/build/static/js/main.ae579a97.js.map +0 -1
  113. xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +0 -1
  114. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +0 -1
  115. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +0 -1
  116. xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +0 -1
  117. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +0 -1
  118. xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +0 -1
  119. xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +0 -1
  120. /xinference/{thirdparty/omnilmm → model/embedding/vllm}/__init__.py +0 -0
  121. /xinference/web/ui/build/static/js/{main.ae579a97.js.LICENSE.txt → main.8a9e3ba0.js.LICENSE.txt} +0 -0
  122. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/entry_points.txt +0 -0
  123. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/licenses/LICENSE +0 -0
  124. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,193 +0,0 @@
1
- # Copyright 2022-2023 XProbe Inc.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import json
15
- import logging
16
- import uuid
17
- from concurrent.futures import ThreadPoolExecutor
18
- from typing import Dict, Iterator, List, Optional, Union
19
-
20
- import torch
21
-
22
- from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
23
- from ...utils import select_device
24
- from ..llm_family import LLMFamilyV1, LLMSpecV1
25
- from ..utils import (
26
- _decode_image,
27
- generate_chat_completion,
28
- generate_completion_chunk,
29
- parse_messages,
30
- )
31
- from .core import PytorchChatModel, PytorchGenerateConfig
32
- from .utils import cache_clean
33
-
34
- logger = logging.getLogger(__name__)
35
-
36
-
37
- class MiniCPMV25Model(PytorchChatModel):
38
- def __init__(self, *args, **kwargs):
39
- super().__init__(*args, **kwargs)
40
- self._device = None
41
- self._tokenizer = None
42
- self._model = None
43
-
44
- @classmethod
45
- def match_json(
46
- cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
- ) -> bool:
48
- family = model_family.model_family or model_family.model_name
49
- if "MiniCPM-Llama3-V-2_5".lower() in family.lower():
50
- return True
51
- return False
52
-
53
- def _get_model_class(self):
54
- from transformers import AutoModel
55
-
56
- return AutoModel
57
-
58
- def load(self):
59
- from transformers import AutoModel, AutoTokenizer
60
- from transformers.generation import GenerationConfig
61
-
62
- device = self._pytorch_model_config.get("device", "auto")
63
- self._device = select_device(device)
64
- self._device = "auto" if self._device == "cuda" else self._device
65
-
66
- if "int4" in self.model_path and device == "mps":
67
- logger.error(
68
- "Error: running int4 model with bitsandbytes on Mac is not supported right now."
69
- )
70
- exit()
71
-
72
- if self._check_tensorizer_integrity():
73
- self._model, self._tokenizer = self._load_tensorizer()
74
- return
75
-
76
- if "int4" in self.model_path:
77
- model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
78
- else:
79
- kwargs = self.apply_bnb_quantization()
80
- model = AutoModel.from_pretrained(
81
- self.model_path,
82
- trust_remote_code=True,
83
- torch_dtype=torch.float16,
84
- device_map=self._device,
85
- **kwargs
86
- )
87
- tokenizer = AutoTokenizer.from_pretrained(
88
- self.model_path, trust_remote_code=True
89
- )
90
- self._model = model.eval()
91
- self._tokenizer = tokenizer
92
-
93
- # Specify hyperparameters for generation
94
- self._model.generation_config = GenerationConfig.from_pretrained(
95
- self.model_path,
96
- trust_remote_code=True,
97
- )
98
- self._save_tensorizer()
99
-
100
- def _message_content_to_chat(self, content):
101
- if not isinstance(content, str):
102
- texts = []
103
- image_urls = []
104
- for c in content:
105
- c_type = c.get("type")
106
- if c_type == "text":
107
- texts.append(c["text"])
108
- elif c_type == "image_url":
109
- image_urls.append(c["image_url"]["url"])
110
- image_futures = []
111
- with ThreadPoolExecutor() as executor:
112
- for image_url in image_urls:
113
- fut = executor.submit(_decode_image, image_url)
114
- image_futures.append(fut)
115
- images = [fut.result() for fut in image_futures]
116
- text = " ".join(texts)
117
- if len(images) == 0:
118
- return text, []
119
- elif len(images) == 1:
120
- return text, images
121
- else:
122
- raise RuntimeError("Only one image per message is supported")
123
- return content, []
124
-
125
- @cache_clean
126
- def chat(
127
- self,
128
- messages: List[Dict],
129
- generate_config: Optional[PytorchGenerateConfig] = None,
130
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
131
- stream = generate_config.get("stream", False) if generate_config else False
132
- prompt, _, chat_history = parse_messages(messages)
133
- content, images_chat = self._message_content_to_chat(prompt)
134
-
135
- msgs = []
136
- query_to_response: List[Dict] = []
137
- images_history = []
138
- for h in chat_history or []:
139
- role = h["role"]
140
- content_h, images_tmp = self._message_content_to_chat(h["content"])
141
- if images_tmp != []:
142
- images_history = images_tmp
143
- if len(query_to_response) == 0 and role == "user":
144
- query_to_response.append({"role": "user", "content": content_h})
145
- if len(query_to_response) == 1 and role == "assistant":
146
- query_to_response.append({"role": "assistant", "content": content_h})
147
- if len(query_to_response) == 2:
148
- msgs.extend(query_to_response)
149
- query_to_response = []
150
- image = None
151
- if len(images_chat) > 0:
152
- image = images_chat[0]
153
- elif len(images_history) > 0:
154
- image = images_history[0]
155
- msgs.append({"role": "user", "content": content})
156
-
157
- chat = self._model.chat(
158
- image=image,
159
- msgs=json.dumps(msgs, ensure_ascii=True),
160
- tokenizer=self._tokenizer,
161
- sampling=True,
162
- **generate_config
163
- )
164
- if stream:
165
- it = self.chat_stream(chat)
166
- return self._to_chat_completion_chunks(it)
167
- else:
168
- return generate_chat_completion(self.model_uid, chat)
169
-
170
- def chat_stream(self, chat) -> Iterator[CompletionChunk]:
171
- completion_id = str(uuid.uuid1())
172
- for new_text in chat:
173
- yield generate_completion_chunk(
174
- chunk_text=new_text,
175
- finish_reason=None,
176
- chunk_id=completion_id,
177
- model_uid=self.model_uid,
178
- prompt_tokens=-1,
179
- completion_tokens=-1,
180
- total_tokens=-1,
181
- )
182
-
183
- yield generate_completion_chunk(
184
- chunk_text=None,
185
- finish_reason="stop",
186
- chunk_id=completion_id,
187
- model_uid=self.model_uid,
188
- prompt_tokens=-1,
189
- completion_tokens=-1,
190
- total_tokens=-1,
191
- has_choice=True,
192
- has_content=False,
193
- )
@@ -1,132 +0,0 @@
1
- # Copyright 2022-2023 XProbe Inc.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import base64
15
- import json
16
- import logging
17
- import operator
18
- import tempfile
19
- from typing import Dict, Iterator, List, Optional, Tuple, Union
20
-
21
- from ....thirdparty.omnilmm.chat import OmniLMMChat, img2base64
22
- from ....types import ChatCompletion, ChatCompletionChunk
23
- from ...utils import select_device
24
- from ..llm_family import LLMFamilyV1, LLMSpecV1
25
- from ..utils import generate_chat_completion, parse_messages
26
- from .core import PytorchChatModel, PytorchGenerateConfig
27
- from .utils import cache_clean
28
-
29
- logger = logging.getLogger(__name__)
30
-
31
-
32
- class OmniLMMModel(PytorchChatModel):
33
- def __init__(self, *args, **kwargs):
34
- super().__init__(*args, **kwargs)
35
- self._model = None
36
-
37
- @classmethod
38
- def match_json(
39
- cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
40
- ) -> bool:
41
- llm_family = model_family.model_family or model_family.model_name
42
- if "OmniLMM" in llm_family:
43
- return True
44
- return False
45
-
46
- def load(self):
47
- device = self._pytorch_model_config.get("device", "auto")
48
- device = select_device(device)
49
- self._model = OmniLMMChat(self.model_path, device_map=device)
50
-
51
- def _message_content_to_OmniLMM(
52
- self, content
53
- ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
54
- def _ensure_url(_url):
55
- if _url.startswith("data:"):
56
- logging.info("Parse url by base64 decoder.")
57
- # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
58
- # e.g. f"data:image/jpeg;base64,{base64_image}"
59
- _type, data = _url.split(";")
60
- _, ext = _type.split("/")
61
- data = data[len("base64,") :]
62
- data = base64.b64decode(data.encode("utf-8"))
63
-
64
- with tempfile.NamedTemporaryFile(suffix=f".{ext}", delete=False) as f:
65
- f.write(data)
66
- logging.info("Dump base64 data to %s", f.name)
67
- return f.name
68
- else:
69
- if len(_url) > 2048:
70
- raise Exception(f"Image url is too long, {len(_url)} > 2048.")
71
- return _url
72
-
73
- if not isinstance(content, str):
74
- images = []
75
- other_content = []
76
-
77
- for c in content:
78
- if c.get("type") == "image_url":
79
- images.append(
80
- {"image": _ensure_url(c["image_url"]["url"]), "type": "image"}
81
- )
82
- else:
83
- other_content.append(c)
84
-
85
- images = sorted(images, key=operator.itemgetter("type"))
86
- other_content = sorted(other_content, key=operator.itemgetter("type"))
87
-
88
- return images, other_content
89
- return [], [{"type": "text", "text": content}]
90
-
91
- @cache_clean
92
- def chat(
93
- self,
94
- messages: List[Dict],
95
- generate_config: Optional[PytorchGenerateConfig] = None,
96
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
97
- if generate_config and generate_config.get("stream"):
98
- raise Exception(
99
- f"Chat with model {self.model_family.model_name} does not support stream."
100
- )
101
- prompt, _, chat_history = parse_messages(messages)
102
- image_first, prompt = self._message_content_to_OmniLMM(prompt)
103
-
104
- msgs = []
105
- query_to_response: List[Dict] = []
106
- image_another = []
107
- for h in chat_history or []:
108
- role = h["role"]
109
- image_tmp, content = self._message_content_to_OmniLMM(h["content"])
110
- if image_tmp != []:
111
- image_another = image_tmp
112
- if len(query_to_response) == 0 and role == "user":
113
- query_to_response.append(
114
- {"role": "user", "content": content[0]["text"]}
115
- )
116
- if len(query_to_response) == 1 and role == "assistant":
117
- query_to_response.append(
118
- {"role": "assistant", "content": content[0]["text"]}
119
- )
120
- if len(query_to_response) == 2:
121
- msgs.extend(query_to_response)
122
- query_to_response = []
123
- if image_first != []:
124
- image = image_first
125
- if image_another != []:
126
- image = image_another
127
- im_64 = img2base64(image[0]["image"])
128
- msgs.append({"role": "user", "content": prompt[0]["text"]})
129
- input = {"image": im_64, "question": json.dumps(msgs, ensure_ascii=True)}
130
- answer = self._model.chat(input=input)
131
-
132
- return generate_chat_completion(self.model_uid, answer)
@@ -1,179 +0,0 @@
1
- # Copyright 2022-2023 XProbe Inc.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import logging
15
- import uuid
16
- from io import BytesIO
17
- from typing import Iterator, List, Optional, Union
18
- from urllib.request import urlopen
19
-
20
- import numpy as np
21
-
22
- from ....model.utils import select_device
23
- from ....types import (
24
- ChatCompletion,
25
- ChatCompletionChunk,
26
- ChatCompletionMessage,
27
- CompletionChunk,
28
- )
29
- from ..llm_family import LLMFamilyV1, LLMSpecV1
30
- from ..utils import generate_chat_completion, generate_completion_chunk
31
- from .core import PytorchChatModel, PytorchGenerateConfig
32
- from .utils import cache_clean
33
-
34
- logger = logging.getLogger(__name__)
35
-
36
-
37
- class Qwen2AudioChatModel(PytorchChatModel):
38
- def __init__(self, *args, **kwargs):
39
- super().__init__(*args, **kwargs)
40
- self._processor = None
41
- self._model = None
42
- self._device = None
43
-
44
- @classmethod
45
- def match_json(
46
- cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
- ) -> bool:
48
- llm_family = model_family.model_family or model_family.model_name
49
- if "qwen2-audio".lower() in llm_family.lower():
50
- return True
51
- return False
52
-
53
- def load(self):
54
- from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
55
-
56
- device = self._pytorch_model_config.get("device", "auto")
57
- device = select_device(device)
58
- # for multiple GPU, set back to auto to make multiple devices work
59
- device = "auto" if device == "cuda" else device
60
- self._device = device
61
- kwargs = self.apply_bnb_quantization()
62
-
63
- self._processor = AutoProcessor.from_pretrained(
64
- self.model_path,
65
- device_map=device,
66
- # trust_remote_code=True,
67
- code_revision=self.model_spec.model_revision,
68
- )
69
- self._model = Qwen2AudioForConditionalGeneration.from_pretrained(
70
- self.model_path,
71
- device_map=device,
72
- # trust_remote_code=True,
73
- revision=self.model_spec.model_revision,
74
- **kwargs,
75
- )
76
-
77
- def _transform_messages(
78
- self,
79
- messages: Union[List[ChatCompletionMessage], List[dict]],
80
- ):
81
- import librosa
82
-
83
- text = self._processor.apply_chat_template(
84
- messages, add_generation_prompt=True, tokenize=False
85
- )
86
- audios: List[np.ndarray] = []
87
- for msg in messages:
88
- content = msg["content"]
89
- if isinstance(content, List):
90
- for item in content: # type: ignore
91
- if item.get("type") == "audio" and "audio_url" in item:
92
- audio = librosa.load(
93
- BytesIO(urlopen(item["audio_url"]).read()),
94
- sr=self._processor.feature_extractor.sampling_rate,
95
- )[0]
96
- audios.append(audio)
97
-
98
- return text, audios
99
-
100
- @cache_clean
101
- def chat(
102
- self,
103
- messages: List[ChatCompletionMessage],
104
- generate_config: Optional[PytorchGenerateConfig] = None,
105
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
106
- text, audios = self._transform_messages(messages)
107
- inputs = self._processor(
108
- text=text, audios=audios, return_tensors="pt", padding=True
109
- )
110
- # Make sure that the inputs and the model are on the same device.
111
- inputs.data = {k: v.to(self._device) for k, v in inputs.data.items()}
112
- inputs.input_ids = inputs.input_ids.to(self._device)
113
- generate_config = generate_config if generate_config else {}
114
- stream = generate_config.get("stream", False) if generate_config else False
115
-
116
- if stream:
117
- it = self._generate_stream(inputs, generate_config)
118
- return self._to_chat_completion_chunks(it)
119
- else:
120
- c = self._generate(inputs, generate_config)
121
- return c
122
-
123
- def _generate(self, inputs, config: PytorchGenerateConfig = {}) -> ChatCompletion:
124
- generate_ids = self._model.generate(
125
- **inputs,
126
- max_length=config.get("max_tokens", 512),
127
- )
128
- generate_ids = generate_ids[:, inputs.input_ids.size(1) :]
129
- response = self._processor.batch_decode(
130
- generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
131
- )[0]
132
- return generate_chat_completion(self.model_uid, response)
133
-
134
- def _generate_stream(
135
- self, inputs, config: PytorchGenerateConfig = {}
136
- ) -> Iterator[CompletionChunk]:
137
- from threading import Thread
138
-
139
- from transformers import TextIteratorStreamer
140
-
141
- tokenizer = self._processor.tokenizer
142
- streamer = TextIteratorStreamer(
143
- tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
144
- )
145
-
146
- gen_kwargs = {
147
- "max_new_tokens": config.get("max_tokens", 512),
148
- "streamer": streamer,
149
- **inputs,
150
- }
151
-
152
- thread = Thread(target=self._model.generate, kwargs=gen_kwargs)
153
- thread.start()
154
-
155
- completion_id = str(uuid.uuid1())
156
- for new_text in streamer:
157
- yield generate_completion_chunk(
158
- chunk_text=new_text,
159
- finish_reason=None,
160
- chunk_id=completion_id,
161
- model_uid=self.model_uid,
162
- prompt_tokens=-1,
163
- completion_tokens=-1,
164
- total_tokens=-1,
165
- has_choice=True,
166
- has_content=True,
167
- )
168
-
169
- yield generate_completion_chunk(
170
- chunk_text=None,
171
- finish_reason="stop",
172
- chunk_id=completion_id,
173
- model_uid=self.model_uid,
174
- prompt_tokens=-1,
175
- completion_tokens=-1,
176
- total_tokens=-1,
177
- has_choice=True,
178
- has_content=False,
179
- )