xinference 0.16.3__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +22 -2
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +148 -12
- xinference/client/restful/restful_client.py +47 -2
- xinference/constants.py +1 -0
- xinference/core/model.py +45 -15
- xinference/core/supervisor.py +8 -2
- xinference/core/utils.py +67 -2
- xinference/model/audio/__init__.py +12 -0
- xinference/model/audio/core.py +21 -4
- xinference/model/audio/fish_speech.py +70 -35
- xinference/model/audio/model_spec.json +81 -1
- xinference/model/audio/whisper_mlx.py +208 -0
- xinference/model/embedding/core.py +259 -4
- xinference/model/embedding/model_spec.json +1 -1
- xinference/model/embedding/model_spec_modelscope.json +1 -1
- xinference/model/image/stable_diffusion/core.py +5 -2
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +485 -6
- xinference/model/llm/llm_family_modelscope.json +519 -0
- xinference/model/llm/mlx/core.py +45 -3
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/glm_edge_v.py +230 -0
- xinference/model/llm/utils.py +19 -0
- xinference/model/llm/vllm/core.py +84 -2
- xinference/model/rerank/core.py +11 -4
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/api.py +578 -75
- xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
- xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
- xinference/thirdparty/fish_speech/tools/schema.py +187 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
- xinference/thirdparty/fish_speech/tools/webui.py +138 -75
- xinference/types.py +2 -1
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/METADATA +30 -6
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/RECORD +58 -63
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/WHEEL +1 -1
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/commons.py +0 -35
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/LICENSE +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
import uuid
|
|
16
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
17
|
+
from threading import Thread
|
|
18
|
+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
|
|
23
|
+
from ...utils import select_device
|
|
24
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
25
|
+
from ..utils import (
|
|
26
|
+
_decode_image_without_rgb,
|
|
27
|
+
generate_chat_completion,
|
|
28
|
+
generate_completion_chunk,
|
|
29
|
+
)
|
|
30
|
+
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
31
|
+
from .utils import cache_clean
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class GlmEdgeVModel(PytorchChatModel):
|
|
37
|
+
def __init__(self, *args, **kwargs):
|
|
38
|
+
super().__init__(*args, **kwargs)
|
|
39
|
+
self._device = None
|
|
40
|
+
self._tokenizer = None
|
|
41
|
+
self._model = None
|
|
42
|
+
self._processor = None
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def match(
|
|
46
|
+
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
47
|
+
) -> bool:
|
|
48
|
+
family = model_family.model_family or model_family.model_name
|
|
49
|
+
if "glm-edge-v" in family.lower():
|
|
50
|
+
return True
|
|
51
|
+
return False
|
|
52
|
+
|
|
53
|
+
def load(self):
|
|
54
|
+
from transformers import AutoImageProcessor, AutoModelForCausalLM, AutoTokenizer
|
|
55
|
+
|
|
56
|
+
device = self._pytorch_model_config.get("device", "auto")
|
|
57
|
+
self._device = select_device(device)
|
|
58
|
+
|
|
59
|
+
kwargs = {"device_map": self._device}
|
|
60
|
+
quantization = self.quantization
|
|
61
|
+
|
|
62
|
+
# referenced from PytorchModel.load
|
|
63
|
+
if quantization != "none":
|
|
64
|
+
if self._device == "cuda" and self._is_linux():
|
|
65
|
+
kwargs["device_map"] = "auto"
|
|
66
|
+
if quantization == "4-bit":
|
|
67
|
+
kwargs["load_in_4bit"] = True
|
|
68
|
+
elif quantization == "8-bit":
|
|
69
|
+
kwargs["load_in_8bit"] = True
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"Quantization {quantization} is not supported in temporary"
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
if quantization != "8-bit":
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Only 8-bit quantization is supported if it is not linux system or cuda device"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
processor = AutoImageProcessor.from_pretrained(
|
|
81
|
+
self.model_path, trust_remote_code=True
|
|
82
|
+
)
|
|
83
|
+
self._processor = processor
|
|
84
|
+
|
|
85
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
86
|
+
self.model_path,
|
|
87
|
+
trust_remote_code=True,
|
|
88
|
+
torch_dtype=torch.bfloat16,
|
|
89
|
+
device_map="auto",
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
self._model = model
|
|
93
|
+
|
|
94
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
95
|
+
self.model_path, trust_remote_code=True
|
|
96
|
+
)
|
|
97
|
+
self._tokenizer = tokenizer
|
|
98
|
+
|
|
99
|
+
@staticmethod
|
|
100
|
+
def _get_processed_msgs(
|
|
101
|
+
messages: List[Dict],
|
|
102
|
+
) -> Tuple[List[Dict[str, Any]], List[Any]]:
|
|
103
|
+
res = []
|
|
104
|
+
img = []
|
|
105
|
+
for message in messages:
|
|
106
|
+
role = message["role"]
|
|
107
|
+
content = message["content"]
|
|
108
|
+
if isinstance(content, str):
|
|
109
|
+
res.append({"role": role, "content": content})
|
|
110
|
+
else:
|
|
111
|
+
texts = []
|
|
112
|
+
image_urls = []
|
|
113
|
+
for c in content:
|
|
114
|
+
c_type = c.get("type")
|
|
115
|
+
if c_type == "text":
|
|
116
|
+
texts.append(c["text"])
|
|
117
|
+
else:
|
|
118
|
+
assert (
|
|
119
|
+
c_type == "image_url"
|
|
120
|
+
), "Please follow the image input of the OpenAI API."
|
|
121
|
+
image_urls.append(c["image_url"]["url"])
|
|
122
|
+
if len(image_urls) > 1:
|
|
123
|
+
raise RuntimeError("Only one image per message is supported")
|
|
124
|
+
image_futures = []
|
|
125
|
+
with ThreadPoolExecutor() as executor:
|
|
126
|
+
for image_url in image_urls:
|
|
127
|
+
fut = executor.submit(_decode_image_without_rgb, image_url)
|
|
128
|
+
image_futures.append(fut)
|
|
129
|
+
images = [fut.result() for fut in image_futures]
|
|
130
|
+
assert len(images) <= 1
|
|
131
|
+
text = " ".join(texts)
|
|
132
|
+
img.extend(images)
|
|
133
|
+
if images:
|
|
134
|
+
res.append(
|
|
135
|
+
{
|
|
136
|
+
"role": role,
|
|
137
|
+
"content": [
|
|
138
|
+
{"type": "image"},
|
|
139
|
+
{"type": "text", "text": text},
|
|
140
|
+
],
|
|
141
|
+
}
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
res.append({"role": role, "content": text})
|
|
145
|
+
return res, img
|
|
146
|
+
|
|
147
|
+
@cache_clean
|
|
148
|
+
def chat(
|
|
149
|
+
self,
|
|
150
|
+
messages: List[Dict],
|
|
151
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
152
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
153
|
+
from transformers import TextIteratorStreamer
|
|
154
|
+
|
|
155
|
+
if not generate_config:
|
|
156
|
+
generate_config = {}
|
|
157
|
+
|
|
158
|
+
stream = generate_config.get("stream", False)
|
|
159
|
+
msgs, imgs = self._get_processed_msgs(messages)
|
|
160
|
+
|
|
161
|
+
inputs = self._tokenizer.apply_chat_template(
|
|
162
|
+
msgs,
|
|
163
|
+
add_generation_prompt=True,
|
|
164
|
+
tokenize=True,
|
|
165
|
+
return_tensors="pt",
|
|
166
|
+
return_dict=True,
|
|
167
|
+
) # chat mode
|
|
168
|
+
inputs = inputs.to(self._model.device)
|
|
169
|
+
|
|
170
|
+
generate_kwargs = {
|
|
171
|
+
**inputs,
|
|
172
|
+
}
|
|
173
|
+
if len(imgs) > 0:
|
|
174
|
+
generate_kwargs["pixel_values"] = torch.tensor(
|
|
175
|
+
self._processor(imgs[-1]).pixel_values
|
|
176
|
+
).to(self._model.device)
|
|
177
|
+
stop_str = "<|endoftext|>"
|
|
178
|
+
|
|
179
|
+
if stream:
|
|
180
|
+
streamer = TextIteratorStreamer(
|
|
181
|
+
tokenizer=self._tokenizer,
|
|
182
|
+
timeout=60,
|
|
183
|
+
skip_prompt=True,
|
|
184
|
+
skip_special_tokens=True,
|
|
185
|
+
)
|
|
186
|
+
generate_kwargs = {
|
|
187
|
+
**generate_kwargs,
|
|
188
|
+
"streamer": streamer,
|
|
189
|
+
}
|
|
190
|
+
t = Thread(target=self._model.generate, kwargs=generate_kwargs)
|
|
191
|
+
t.start()
|
|
192
|
+
|
|
193
|
+
it = self.chat_stream(streamer, stop_str)
|
|
194
|
+
return self._to_chat_completion_chunks(it)
|
|
195
|
+
else:
|
|
196
|
+
with torch.no_grad():
|
|
197
|
+
outputs = self._model.generate(**generate_kwargs)
|
|
198
|
+
outputs = outputs[0][len(inputs["input_ids"][0]) :]
|
|
199
|
+
response = self._tokenizer.decode(outputs)
|
|
200
|
+
if response.endswith(stop_str):
|
|
201
|
+
response = response[: -len(stop_str)]
|
|
202
|
+
return generate_chat_completion(self.model_uid, response)
|
|
203
|
+
|
|
204
|
+
def chat_stream(self, streamer, stop_str) -> Iterator[CompletionChunk]:
|
|
205
|
+
completion_id = str(uuid.uuid1())
|
|
206
|
+
for new_text in streamer:
|
|
207
|
+
if not new_text.endswith(stop_str):
|
|
208
|
+
yield generate_completion_chunk(
|
|
209
|
+
chunk_text=new_text,
|
|
210
|
+
finish_reason=None,
|
|
211
|
+
chunk_id=completion_id,
|
|
212
|
+
model_uid=self.model_uid,
|
|
213
|
+
prompt_tokens=-1,
|
|
214
|
+
completion_tokens=-1,
|
|
215
|
+
total_tokens=-1,
|
|
216
|
+
has_choice=True,
|
|
217
|
+
has_content=True,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
yield generate_completion_chunk(
|
|
221
|
+
chunk_text=None,
|
|
222
|
+
finish_reason="stop",
|
|
223
|
+
chunk_id=completion_id,
|
|
224
|
+
model_uid=self.model_uid,
|
|
225
|
+
prompt_tokens=-1,
|
|
226
|
+
completion_tokens=-1,
|
|
227
|
+
total_tokens=-1,
|
|
228
|
+
has_choice=True,
|
|
229
|
+
has_content=False,
|
|
230
|
+
)
|
xinference/model/llm/utils.py
CHANGED
|
@@ -569,6 +569,25 @@ def _decode_image(_url):
|
|
|
569
569
|
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
570
570
|
|
|
571
571
|
|
|
572
|
+
def _decode_image_without_rgb(_url):
|
|
573
|
+
if _url.startswith("data:"):
|
|
574
|
+
logging.info("Parse url by base64 decoder.")
|
|
575
|
+
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
576
|
+
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
577
|
+
_type, data = _url.split(";")
|
|
578
|
+
_, ext = _type.split("/")
|
|
579
|
+
data = data[len("base64,") :]
|
|
580
|
+
data = base64.b64decode(data.encode("utf-8"))
|
|
581
|
+
return Image.open(BytesIO(data))
|
|
582
|
+
else:
|
|
583
|
+
try:
|
|
584
|
+
response = requests.get(_url)
|
|
585
|
+
except requests.exceptions.MissingSchema:
|
|
586
|
+
return Image.open(_url)
|
|
587
|
+
else:
|
|
588
|
+
return Image.open(BytesIO(response.content))
|
|
589
|
+
|
|
590
|
+
|
|
572
591
|
@typing.no_type_check
|
|
573
592
|
def generate_completion_chunk(
|
|
574
593
|
chunk_text: Optional[str],
|
|
@@ -69,6 +69,7 @@ class VLLMModelConfig(TypedDict, total=False):
|
|
|
69
69
|
quantization: Optional[str]
|
|
70
70
|
max_model_len: Optional[int]
|
|
71
71
|
limit_mm_per_prompt: Optional[Dict[str, int]]
|
|
72
|
+
guided_decoding_backend: Optional[str]
|
|
72
73
|
|
|
73
74
|
|
|
74
75
|
class VLLMGenerateConfig(TypedDict, total=False):
|
|
@@ -85,6 +86,14 @@ class VLLMGenerateConfig(TypedDict, total=False):
|
|
|
85
86
|
stop: Optional[Union[str, List[str]]]
|
|
86
87
|
stream: bool # non-sampling param, should not be passed to the engine.
|
|
87
88
|
stream_options: Optional[Union[dict, None]]
|
|
89
|
+
response_format: Optional[dict]
|
|
90
|
+
guided_json: Optional[Union[str, dict]]
|
|
91
|
+
guided_regex: Optional[str]
|
|
92
|
+
guided_choice: Optional[List[str]]
|
|
93
|
+
guided_grammar: Optional[str]
|
|
94
|
+
guided_json_object: Optional[bool]
|
|
95
|
+
guided_decoding_backend: Optional[str]
|
|
96
|
+
guided_whitespace_pattern: Optional[str]
|
|
88
97
|
|
|
89
98
|
|
|
90
99
|
try:
|
|
@@ -144,6 +153,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
|
|
|
144
153
|
VLLM_SUPPORTED_CHAT_MODELS.append("qwen2.5-instruct")
|
|
145
154
|
VLLM_SUPPORTED_MODELS.append("qwen2.5-coder")
|
|
146
155
|
VLLM_SUPPORTED_CHAT_MODELS.append("qwen2.5-coder-instruct")
|
|
156
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("QwQ-32B-Preview")
|
|
147
157
|
|
|
148
158
|
|
|
149
159
|
if VLLM_INSTALLED and vllm.__version__ >= "0.3.2":
|
|
@@ -314,6 +324,7 @@ class VLLMModel(LLM):
|
|
|
314
324
|
model_config.setdefault("max_num_seqs", 256)
|
|
315
325
|
model_config.setdefault("quantization", None)
|
|
316
326
|
model_config.setdefault("max_model_len", None)
|
|
327
|
+
model_config.setdefault("guided_decoding_backend", "outlines")
|
|
317
328
|
|
|
318
329
|
return model_config
|
|
319
330
|
|
|
@@ -325,6 +336,22 @@ class VLLMModel(LLM):
|
|
|
325
336
|
generate_config = {}
|
|
326
337
|
|
|
327
338
|
sanitized = VLLMGenerateConfig()
|
|
339
|
+
|
|
340
|
+
response_format = generate_config.pop("response_format", None)
|
|
341
|
+
guided_decoding_backend = generate_config.get("guided_decoding_backend", None)
|
|
342
|
+
guided_json_object = None
|
|
343
|
+
guided_json = None
|
|
344
|
+
|
|
345
|
+
if response_format is not None:
|
|
346
|
+
if response_format.get("type") == "json_object":
|
|
347
|
+
guided_json_object = True
|
|
348
|
+
elif response_format.get("type") == "json_schema":
|
|
349
|
+
json_schema = response_format.get("json_schema")
|
|
350
|
+
assert json_schema is not None
|
|
351
|
+
guided_json = json_schema.get("json_schema")
|
|
352
|
+
if guided_decoding_backend is None:
|
|
353
|
+
guided_decoding_backend = "outlines"
|
|
354
|
+
|
|
328
355
|
sanitized.setdefault("lora_name", generate_config.get("lora_name", None))
|
|
329
356
|
sanitized.setdefault("n", generate_config.get("n", 1))
|
|
330
357
|
sanitized.setdefault("best_of", generate_config.get("best_of", None))
|
|
@@ -346,6 +373,28 @@ class VLLMModel(LLM):
|
|
|
346
373
|
sanitized.setdefault(
|
|
347
374
|
"stream_options", generate_config.get("stream_options", None)
|
|
348
375
|
)
|
|
376
|
+
sanitized.setdefault(
|
|
377
|
+
"guided_json", generate_config.get("guided_json", guided_json)
|
|
378
|
+
)
|
|
379
|
+
sanitized.setdefault("guided_regex", generate_config.get("guided_regex", None))
|
|
380
|
+
sanitized.setdefault(
|
|
381
|
+
"guided_choice", generate_config.get("guided_choice", None)
|
|
382
|
+
)
|
|
383
|
+
sanitized.setdefault(
|
|
384
|
+
"guided_grammar", generate_config.get("guided_grammar", None)
|
|
385
|
+
)
|
|
386
|
+
sanitized.setdefault(
|
|
387
|
+
"guided_whitespace_pattern",
|
|
388
|
+
generate_config.get("guided_whitespace_pattern", None),
|
|
389
|
+
)
|
|
390
|
+
sanitized.setdefault(
|
|
391
|
+
"guided_json_object",
|
|
392
|
+
generate_config.get("guided_json_object", guided_json_object),
|
|
393
|
+
)
|
|
394
|
+
sanitized.setdefault(
|
|
395
|
+
"guided_decoding_backend",
|
|
396
|
+
generate_config.get("guided_decoding_backend", guided_decoding_backend),
|
|
397
|
+
)
|
|
349
398
|
|
|
350
399
|
return sanitized
|
|
351
400
|
|
|
@@ -483,13 +532,46 @@ class VLLMModel(LLM):
|
|
|
483
532
|
if isinstance(stream_options, dict)
|
|
484
533
|
else False
|
|
485
534
|
)
|
|
486
|
-
|
|
535
|
+
|
|
536
|
+
if VLLM_INSTALLED and vllm.__version__ >= "0.6.3":
|
|
537
|
+
# guided decoding only available for vllm >= 0.6.3
|
|
538
|
+
from vllm.sampling_params import GuidedDecodingParams
|
|
539
|
+
|
|
540
|
+
guided_options = GuidedDecodingParams.from_optional(
|
|
541
|
+
json=sanitized_generate_config.pop("guided_json", None),
|
|
542
|
+
regex=sanitized_generate_config.pop("guided_regex", None),
|
|
543
|
+
choice=sanitized_generate_config.pop("guided_choice", None),
|
|
544
|
+
grammar=sanitized_generate_config.pop("guided_grammar", None),
|
|
545
|
+
json_object=sanitized_generate_config.pop("guided_json_object", None),
|
|
546
|
+
backend=sanitized_generate_config.pop("guided_decoding_backend", None),
|
|
547
|
+
whitespace_pattern=sanitized_generate_config.pop(
|
|
548
|
+
"guided_whitespace_pattern", None
|
|
549
|
+
),
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
sampling_params = SamplingParams(
|
|
553
|
+
guided_decoding=guided_options, **sanitized_generate_config
|
|
554
|
+
)
|
|
555
|
+
else:
|
|
556
|
+
# ignore generate configs
|
|
557
|
+
sanitized_generate_config.pop("guided_json", None)
|
|
558
|
+
sanitized_generate_config.pop("guided_regex", None)
|
|
559
|
+
sanitized_generate_config.pop("guided_choice", None)
|
|
560
|
+
sanitized_generate_config.pop("guided_grammar", None)
|
|
561
|
+
sanitized_generate_config.pop("guided_json_object", None)
|
|
562
|
+
sanitized_generate_config.pop("guided_decoding_backend", None)
|
|
563
|
+
sanitized_generate_config.pop("guided_whitespace_pattern", None)
|
|
564
|
+
sampling_params = SamplingParams(**sanitized_generate_config)
|
|
565
|
+
|
|
487
566
|
if not request_id:
|
|
488
567
|
request_id = str(uuid.uuid1())
|
|
489
568
|
|
|
490
569
|
assert self._engine is not None
|
|
491
570
|
results_generator = self._engine.generate(
|
|
492
|
-
prompt,
|
|
571
|
+
prompt,
|
|
572
|
+
sampling_params,
|
|
573
|
+
request_id,
|
|
574
|
+
lora_request,
|
|
493
575
|
)
|
|
494
576
|
|
|
495
577
|
async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
|
xinference/model/rerank/core.py
CHANGED
|
@@ -179,6 +179,7 @@ class RerankModel:
|
|
|
179
179
|
return rerank_type
|
|
180
180
|
|
|
181
181
|
def load(self):
|
|
182
|
+
logger.info("Loading rerank model: %s", self._model_path)
|
|
182
183
|
flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
|
|
183
184
|
if (
|
|
184
185
|
self._auto_detect_type(self._model_path) != "normal"
|
|
@@ -189,6 +190,7 @@ class RerankModel:
|
|
|
189
190
|
"will force set `use_fp16` to True"
|
|
190
191
|
)
|
|
191
192
|
self._use_fp16 = True
|
|
193
|
+
|
|
192
194
|
if self._model_spec.type == "normal":
|
|
193
195
|
try:
|
|
194
196
|
import sentence_transformers
|
|
@@ -250,22 +252,27 @@ class RerankModel:
|
|
|
250
252
|
**kwargs,
|
|
251
253
|
) -> Rerank:
|
|
252
254
|
assert self._model is not None
|
|
253
|
-
if kwargs:
|
|
254
|
-
raise ValueError("rerank hasn't support extra parameter.")
|
|
255
255
|
if max_chunks_per_doc is not None:
|
|
256
256
|
raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
|
|
257
|
+
logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model)
|
|
257
258
|
sentence_combinations = [[query, doc] for doc in documents]
|
|
258
259
|
# reset n tokens
|
|
259
260
|
self._model.model.n_tokens = 0
|
|
260
261
|
if self._model_spec.type == "normal":
|
|
261
262
|
similarity_scores = self._model.predict(
|
|
262
|
-
sentence_combinations,
|
|
263
|
+
sentence_combinations,
|
|
264
|
+
convert_to_numpy=False,
|
|
265
|
+
convert_to_tensor=True,
|
|
266
|
+
**kwargs,
|
|
263
267
|
).cpu()
|
|
264
268
|
if similarity_scores.dtype == torch.bfloat16:
|
|
265
269
|
similarity_scores = similarity_scores.float()
|
|
266
270
|
else:
|
|
267
271
|
# Related issue: https://github.com/xorbitsai/inference/issues/1775
|
|
268
|
-
similarity_scores = self._model.compute_score(
|
|
272
|
+
similarity_scores = self._model.compute_score(
|
|
273
|
+
sentence_combinations, **kwargs
|
|
274
|
+
)
|
|
275
|
+
|
|
269
276
|
if not isinstance(similarity_scores, Sequence):
|
|
270
277
|
similarity_scores = [similarity_scores]
|
|
271
278
|
elif (
|