xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +2 -1
- xinference/core/model.py +8 -4
- xinference/core/supervisor.py +2 -3
- xinference/core/worker.py +7 -5
- xinference/deploy/cmdline.py +2 -0
- xinference/deploy/local.py +5 -0
- xinference/deploy/test/test_cmdline.py +1 -1
- xinference/deploy/worker.py +6 -0
- xinference/model/audio/cosyvoice.py +0 -1
- xinference/model/audio/model_spec.json +44 -20
- xinference/model/core.py +3 -0
- xinference/model/embedding/flag/core.py +5 -0
- xinference/model/embedding/llama_cpp/core.py +22 -19
- xinference/model/embedding/sentence_transformers/core.py +18 -4
- xinference/model/embedding/vllm/core.py +36 -9
- xinference/model/image/cache_manager.py +56 -0
- xinference/model/image/core.py +9 -0
- xinference/model/image/model_spec.json +178 -1
- xinference/model/image/stable_diffusion/core.py +155 -23
- xinference/model/llm/cache_manager.py +17 -3
- xinference/model/llm/harmony.py +245 -0
- xinference/model/llm/llama_cpp/core.py +41 -40
- xinference/model/llm/llm_family.json +688 -11
- xinference/model/llm/llm_family.py +1 -1
- xinference/model/llm/sglang/core.py +108 -5
- xinference/model/llm/transformers/core.py +20 -18
- xinference/model/llm/transformers/gemma3.py +1 -1
- xinference/model/llm/transformers/gpt_oss.py +91 -0
- xinference/model/llm/transformers/multimodal/core.py +1 -1
- xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
- xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
- xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
- xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
- xinference/model/llm/transformers/utils.py +1 -33
- xinference/model/llm/utils.py +61 -7
- xinference/model/llm/vllm/core.py +44 -8
- xinference/model/rerank/__init__.py +66 -23
- xinference/model/rerank/cache_manager.py +35 -0
- xinference/model/rerank/core.py +87 -339
- xinference/model/rerank/custom.py +33 -8
- xinference/model/rerank/model_spec.json +251 -212
- xinference/model/rerank/rerank_family.py +137 -0
- xinference/model/rerank/sentence_transformers/__init__.py +13 -0
- xinference/model/rerank/sentence_transformers/core.py +337 -0
- xinference/model/rerank/vllm/__init__.py +13 -0
- xinference/model/rerank/vllm/core.py +156 -0
- xinference/model/utils.py +108 -0
- xinference/model/video/model_spec.json +95 -1
- xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
- xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
- xinference/thirdparty/cosyvoice/bin/train.py +23 -3
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
- xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
- xinference/thirdparty/cosyvoice/cli/model.py +53 -75
- xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
- xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
- xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
- xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
- xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
- xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
- xinference/thirdparty/cosyvoice/utils/common.py +20 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
- xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
- xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
- xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
- xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
- xinference/types.py +2 -0
- xinference/ui/gradio/chat_interface.py +2 -0
- xinference/ui/gradio/media_interface.py +353 -7
- xinference/ui/web/ui/build/asset-manifest.json +3 -3
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
- xinference/ui/web/ui/src/locales/en.json +2 -0
- xinference/ui/web/ui/src/locales/ja.json +2 -0
- xinference/ui/web/ui/src/locales/ko.json +2 -0
- xinference/ui/web/ui/src/locales/zh.json +2 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
- xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
- xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
- /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import gc
|
|
15
|
+
import importlib.util
|
|
16
|
+
import logging
|
|
17
|
+
import threading
|
|
18
|
+
import uuid
|
|
19
|
+
from typing import List, Optional, Sequence
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
import torch
|
|
23
|
+
import torch.nn as nn
|
|
24
|
+
|
|
25
|
+
from ....device_utils import empty_cache
|
|
26
|
+
from ....types import Document, DocumentObj, Meta, Rerank, RerankTokens
|
|
27
|
+
from ...utils import is_flash_attn_available
|
|
28
|
+
from ..core import (
|
|
29
|
+
RERANK_EMPTY_CACHE_COUNT,
|
|
30
|
+
RerankModel,
|
|
31
|
+
RerankModelFamilyV2,
|
|
32
|
+
RerankSpecV1,
|
|
33
|
+
)
|
|
34
|
+
from ..utils import preprocess_sentence
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class _ModelWrapper(nn.Module):
|
|
40
|
+
def __init__(self, module: nn.Module):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.model = module
|
|
43
|
+
self._local_data = threading.local()
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def n_tokens(self):
|
|
47
|
+
return getattr(self._local_data, "n_tokens", 0)
|
|
48
|
+
|
|
49
|
+
@n_tokens.setter
|
|
50
|
+
def n_tokens(self, value):
|
|
51
|
+
self._local_data.n_tokens = value
|
|
52
|
+
|
|
53
|
+
def forward(self, **kwargs):
|
|
54
|
+
attention_mask = kwargs.get("attention_mask")
|
|
55
|
+
# when batching, the attention mask 1 means there is a token
|
|
56
|
+
# thus we just sum up it to get the total number of tokens
|
|
57
|
+
if attention_mask is not None:
|
|
58
|
+
self.n_tokens += attention_mask.sum().item()
|
|
59
|
+
return self.model(**kwargs)
|
|
60
|
+
|
|
61
|
+
def __getattr__(self, attr):
|
|
62
|
+
try:
|
|
63
|
+
return super().__getattr__(attr)
|
|
64
|
+
except AttributeError:
|
|
65
|
+
return getattr(self.model, attr)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SentenceTransformerRerankModel(RerankModel):
|
|
69
|
+
def load(self):
|
|
70
|
+
# TODO: Split FlagReranker and sentence_transformers into different model_engines like FlagRerankModel
|
|
71
|
+
logger.info("Loading rerank model: %s", self._model_path)
|
|
72
|
+
enable_flash_attn = self._kwargs.pop(
|
|
73
|
+
"enable_flash_attn", is_flash_attn_available()
|
|
74
|
+
)
|
|
75
|
+
if enable_flash_attn:
|
|
76
|
+
logger.warning(
|
|
77
|
+
"flash_attn can only support fp16 and bf16, will force set `use_fp16` to True"
|
|
78
|
+
)
|
|
79
|
+
self._use_fp16 = True
|
|
80
|
+
|
|
81
|
+
if (
|
|
82
|
+
self.model_family.type == "normal"
|
|
83
|
+
and "qwen3" not in self.model_family.model_name.lower()
|
|
84
|
+
):
|
|
85
|
+
try:
|
|
86
|
+
import sentence_transformers
|
|
87
|
+
from sentence_transformers.cross_encoder import CrossEncoder
|
|
88
|
+
|
|
89
|
+
if sentence_transformers.__version__ < "3.1.0":
|
|
90
|
+
raise ValueError(
|
|
91
|
+
"The sentence_transformers version must be greater than 3.1.0. "
|
|
92
|
+
"Please upgrade your version via `pip install -U sentence_transformers` or refer to "
|
|
93
|
+
"https://github.com/UKPLab/sentence-transformers"
|
|
94
|
+
)
|
|
95
|
+
except ImportError:
|
|
96
|
+
error_message = "Failed to import module 'sentence-transformers'"
|
|
97
|
+
installation_guide = [
|
|
98
|
+
"Please make sure 'sentence-transformers' is installed. ",
|
|
99
|
+
"You can install it by `pip install sentence-transformers`\n",
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
103
|
+
self._model = CrossEncoder(
|
|
104
|
+
self._model_path,
|
|
105
|
+
device=self._device,
|
|
106
|
+
trust_remote_code=True,
|
|
107
|
+
max_length=getattr(self.model_family, "max_tokens"),
|
|
108
|
+
**self._kwargs,
|
|
109
|
+
)
|
|
110
|
+
if self._use_fp16:
|
|
111
|
+
self._model.model.half()
|
|
112
|
+
elif "qwen3" in self.model_family.model_name.lower():
|
|
113
|
+
# qwen3-reranker
|
|
114
|
+
# now we use transformers
|
|
115
|
+
# TODO: support engines for rerank models
|
|
116
|
+
try:
|
|
117
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
118
|
+
except ImportError:
|
|
119
|
+
error_message = "Failed to import module 'transformers'"
|
|
120
|
+
installation_guide = [
|
|
121
|
+
"Please make sure 'transformers' is installed. ",
|
|
122
|
+
"You can install it by `pip install transformers`\n",
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
126
|
+
|
|
127
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
128
|
+
self._model_path, padding_side="left"
|
|
129
|
+
)
|
|
130
|
+
model_kwargs = {"device_map": "auto"}
|
|
131
|
+
if enable_flash_attn:
|
|
132
|
+
model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
133
|
+
model_kwargs["torch_dtype"] = torch.float16
|
|
134
|
+
model_kwargs.update(self._kwargs)
|
|
135
|
+
logger.debug("Loading qwen3 rerank with kwargs %s", model_kwargs)
|
|
136
|
+
model = self._model = AutoModelForCausalLM.from_pretrained(
|
|
137
|
+
self._model_path, **model_kwargs
|
|
138
|
+
).eval()
|
|
139
|
+
max_length = getattr(self.model_family, "max_tokens")
|
|
140
|
+
|
|
141
|
+
prefix = (
|
|
142
|
+
"<|im_start|>system\nJudge whether the Document meets the requirements based on the Query "
|
|
143
|
+
'and the Instruct provided. Note that the answer can only be "yes" or "no".'
|
|
144
|
+
"<|im_end|>\n<|im_start|>user\n"
|
|
145
|
+
)
|
|
146
|
+
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
|
147
|
+
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
|
|
148
|
+
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
|
|
149
|
+
|
|
150
|
+
def process_inputs(pairs):
|
|
151
|
+
inputs = tokenizer(
|
|
152
|
+
pairs,
|
|
153
|
+
padding=False,
|
|
154
|
+
truncation="longest_first",
|
|
155
|
+
return_attention_mask=False,
|
|
156
|
+
max_length=max_length - len(prefix_tokens) - len(suffix_tokens),
|
|
157
|
+
)
|
|
158
|
+
for i, ele in enumerate(inputs["input_ids"]):
|
|
159
|
+
inputs["input_ids"][i] = prefix_tokens + ele + suffix_tokens
|
|
160
|
+
inputs = tokenizer.pad(
|
|
161
|
+
inputs, padding=True, return_tensors="pt", max_length=max_length
|
|
162
|
+
)
|
|
163
|
+
for key in inputs:
|
|
164
|
+
inputs[key] = inputs[key].to(model.device)
|
|
165
|
+
return inputs
|
|
166
|
+
|
|
167
|
+
token_false_id = tokenizer.convert_tokens_to_ids("no")
|
|
168
|
+
token_true_id = tokenizer.convert_tokens_to_ids("yes")
|
|
169
|
+
|
|
170
|
+
@torch.inference_mode()
|
|
171
|
+
def compute_logits(inputs, **kwargs):
|
|
172
|
+
batch_scores = model(**inputs).logits[:, -1, :]
|
|
173
|
+
true_vector = batch_scores[:, token_true_id]
|
|
174
|
+
false_vector = batch_scores[:, token_false_id]
|
|
175
|
+
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
|
176
|
+
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
|
177
|
+
scores = batch_scores[:, 1].exp().tolist()
|
|
178
|
+
return scores
|
|
179
|
+
|
|
180
|
+
self.process_inputs = process_inputs
|
|
181
|
+
self.compute_logits = compute_logits
|
|
182
|
+
else:
|
|
183
|
+
try:
|
|
184
|
+
if self.model_family.type == "LLM-based":
|
|
185
|
+
from FlagEmbedding import FlagLLMReranker as FlagReranker
|
|
186
|
+
elif self.model_family.type == "LLM-based layerwise":
|
|
187
|
+
from FlagEmbedding import LayerWiseFlagLLMReranker as FlagReranker
|
|
188
|
+
else:
|
|
189
|
+
raise RuntimeError(
|
|
190
|
+
f"Unsupported Rank model type: {self.model_family.type}"
|
|
191
|
+
)
|
|
192
|
+
except ImportError:
|
|
193
|
+
error_message = "Failed to import module 'FlagEmbedding'"
|
|
194
|
+
installation_guide = [
|
|
195
|
+
"Please make sure 'FlagEmbedding' is installed. ",
|
|
196
|
+
"You can install it by `pip install FlagEmbedding`\n",
|
|
197
|
+
]
|
|
198
|
+
|
|
199
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
200
|
+
self._model = FlagReranker(self._model_path, use_fp16=self._use_fp16)
|
|
201
|
+
# Wrap transformers model to record number of tokens
|
|
202
|
+
self._model.model = _ModelWrapper(self._model.model)
|
|
203
|
+
|
|
204
|
+
def rerank(
|
|
205
|
+
self,
|
|
206
|
+
documents: List[str],
|
|
207
|
+
query: str,
|
|
208
|
+
top_n: Optional[int],
|
|
209
|
+
max_chunks_per_doc: Optional[int],
|
|
210
|
+
return_documents: Optional[bool],
|
|
211
|
+
return_len: Optional[bool],
|
|
212
|
+
**kwargs,
|
|
213
|
+
) -> Rerank:
|
|
214
|
+
assert self._model is not None
|
|
215
|
+
if max_chunks_per_doc is not None:
|
|
216
|
+
raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
|
|
217
|
+
logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model)
|
|
218
|
+
|
|
219
|
+
pre_query = preprocess_sentence(
|
|
220
|
+
query, kwargs.get("instruction", None), self.model_family.model_name
|
|
221
|
+
)
|
|
222
|
+
sentence_combinations = [[pre_query, doc] for doc in documents]
|
|
223
|
+
# reset n tokens
|
|
224
|
+
self._model.model.n_tokens = 0
|
|
225
|
+
if (
|
|
226
|
+
self.model_family.type == "normal"
|
|
227
|
+
and "qwen3" not in self.model_family.model_name.lower()
|
|
228
|
+
):
|
|
229
|
+
logger.debug("Passing processed sentences: %s", sentence_combinations)
|
|
230
|
+
similarity_scores = self._model.predict(
|
|
231
|
+
sentence_combinations,
|
|
232
|
+
convert_to_numpy=False,
|
|
233
|
+
convert_to_tensor=True,
|
|
234
|
+
**kwargs,
|
|
235
|
+
).cpu()
|
|
236
|
+
if similarity_scores.dtype == torch.bfloat16:
|
|
237
|
+
similarity_scores = similarity_scores.float()
|
|
238
|
+
elif "qwen3" in self.model_family.model_name.lower():
|
|
239
|
+
|
|
240
|
+
def format_instruction(instruction, query, doc):
|
|
241
|
+
if instruction is None:
|
|
242
|
+
instruction = "Given a web search query, retrieve relevant passages that answer the query"
|
|
243
|
+
output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
|
|
244
|
+
instruction=instruction, query=query, doc=doc
|
|
245
|
+
)
|
|
246
|
+
return output
|
|
247
|
+
|
|
248
|
+
# reduce memory usage.
|
|
249
|
+
micro_bs = 4
|
|
250
|
+
similarity_scores = []
|
|
251
|
+
for i in range(0, len(documents), micro_bs):
|
|
252
|
+
sub_docs = documents[i : i + micro_bs]
|
|
253
|
+
pairs = [
|
|
254
|
+
format_instruction(kwargs.get("instruction", None), query, doc)
|
|
255
|
+
for doc in sub_docs
|
|
256
|
+
]
|
|
257
|
+
# Tokenize the input texts
|
|
258
|
+
inputs = self.process_inputs(pairs)
|
|
259
|
+
similarity_scores.extend(self.compute_logits(inputs))
|
|
260
|
+
else:
|
|
261
|
+
# Related issue: https://github.com/xorbitsai/inference/issues/1775
|
|
262
|
+
similarity_scores = self._model.compute_score(
|
|
263
|
+
sentence_combinations, **kwargs
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
if not isinstance(similarity_scores, Sequence):
|
|
267
|
+
similarity_scores = [similarity_scores]
|
|
268
|
+
elif (
|
|
269
|
+
isinstance(similarity_scores, list)
|
|
270
|
+
and len(similarity_scores) > 0
|
|
271
|
+
and isinstance(similarity_scores[0], Sequence)
|
|
272
|
+
):
|
|
273
|
+
similarity_scores = similarity_scores[0]
|
|
274
|
+
|
|
275
|
+
sim_scores_argsort = list(reversed(np.argsort(similarity_scores)))
|
|
276
|
+
if top_n is not None:
|
|
277
|
+
sim_scores_argsort = sim_scores_argsort[:top_n]
|
|
278
|
+
if return_documents:
|
|
279
|
+
docs = [
|
|
280
|
+
DocumentObj(
|
|
281
|
+
index=int(arg),
|
|
282
|
+
relevance_score=float(similarity_scores[arg]),
|
|
283
|
+
document=Document(text=documents[arg]),
|
|
284
|
+
)
|
|
285
|
+
for arg in sim_scores_argsort
|
|
286
|
+
]
|
|
287
|
+
else:
|
|
288
|
+
docs = [
|
|
289
|
+
DocumentObj(
|
|
290
|
+
index=int(arg),
|
|
291
|
+
relevance_score=float(similarity_scores[arg]),
|
|
292
|
+
document=None,
|
|
293
|
+
)
|
|
294
|
+
for arg in sim_scores_argsort
|
|
295
|
+
]
|
|
296
|
+
if return_len:
|
|
297
|
+
input_len = self._model.model.n_tokens
|
|
298
|
+
# Rerank Model output is just score or documents
|
|
299
|
+
# while return_documents = True
|
|
300
|
+
output_len = input_len
|
|
301
|
+
|
|
302
|
+
# api_version, billed_units, warnings
|
|
303
|
+
# is for Cohere API compatibility, set to None
|
|
304
|
+
metadata = Meta(
|
|
305
|
+
api_version=None,
|
|
306
|
+
billed_units=None,
|
|
307
|
+
tokens=(
|
|
308
|
+
RerankTokens(input_tokens=input_len, output_tokens=output_len)
|
|
309
|
+
if return_len
|
|
310
|
+
else None
|
|
311
|
+
),
|
|
312
|
+
warnings=None,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
del similarity_scores
|
|
316
|
+
# clear cache if possible
|
|
317
|
+
self._counter += 1
|
|
318
|
+
if self._counter % RERANK_EMPTY_CACHE_COUNT == 0:
|
|
319
|
+
logger.debug("Empty rerank cache.")
|
|
320
|
+
gc.collect()
|
|
321
|
+
empty_cache()
|
|
322
|
+
|
|
323
|
+
return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)
|
|
324
|
+
|
|
325
|
+
@classmethod
|
|
326
|
+
def check_lib(cls) -> bool:
|
|
327
|
+
return importlib.util.find_spec("sentence_transformers") is not None
|
|
328
|
+
|
|
329
|
+
@classmethod
|
|
330
|
+
def match_json(
|
|
331
|
+
cls,
|
|
332
|
+
model_family: RerankModelFamilyV2,
|
|
333
|
+
model_spec: RerankSpecV1,
|
|
334
|
+
quantization: str,
|
|
335
|
+
) -> bool:
|
|
336
|
+
# As default embedding engine, sentence-transformer support all models
|
|
337
|
+
return model_spec.model_format in ["pytorch"]
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
import importlib.util
|
|
2
|
+
import uuid
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
from ....types import Document, DocumentObj, Meta, Rerank, RerankTokens
|
|
6
|
+
from ...utils import cache_clean
|
|
7
|
+
from ..core import RerankModel, RerankModelFamilyV2, RerankSpecV1
|
|
8
|
+
|
|
9
|
+
SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "Qwen3"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class VLLMRerankModel(RerankModel):
|
|
13
|
+
def load(self):
|
|
14
|
+
try:
|
|
15
|
+
from vllm import LLM
|
|
16
|
+
|
|
17
|
+
except ImportError:
|
|
18
|
+
error_message = "Failed to import module 'vllm'"
|
|
19
|
+
installation_guide = [
|
|
20
|
+
"Please make sure 'vllm' is installed. ",
|
|
21
|
+
"You can install it by `pip install vllm`\n",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
25
|
+
|
|
26
|
+
if self.model_family.model_name in {
|
|
27
|
+
"Qwen3-Reranker-0.6B",
|
|
28
|
+
"Qwen3-Reranker-4B",
|
|
29
|
+
"Qwen3-Reranker-8B",
|
|
30
|
+
}:
|
|
31
|
+
if "hf_overrides" not in self._kwargs:
|
|
32
|
+
self._kwargs["hf_overrides"] = {
|
|
33
|
+
"architectures": ["Qwen3ForSequenceClassification"],
|
|
34
|
+
"classifier_from_token": ["no", "yes"],
|
|
35
|
+
"is_original_qwen3_reranker": True,
|
|
36
|
+
}
|
|
37
|
+
elif isinstance(self._kwargs["hf_overrides"], dict):
|
|
38
|
+
self._kwargs["hf_overrides"].update(
|
|
39
|
+
architectures=["Qwen3ForSequenceClassification"],
|
|
40
|
+
classifier_from_token=["no", "yes"],
|
|
41
|
+
is_original_qwen3_reranker=True,
|
|
42
|
+
)
|
|
43
|
+
self._model = LLM(model=self._model_path, task="score", **self._kwargs)
|
|
44
|
+
self._tokenizer = self._model.get_tokenizer()
|
|
45
|
+
|
|
46
|
+
@cache_clean
|
|
47
|
+
def rerank(
|
|
48
|
+
self,
|
|
49
|
+
documents: List[str],
|
|
50
|
+
query: str,
|
|
51
|
+
top_n: Optional[int],
|
|
52
|
+
max_chunks_per_doc: Optional[int],
|
|
53
|
+
return_documents: Optional[bool],
|
|
54
|
+
return_len: Optional[bool],
|
|
55
|
+
**kwargs,
|
|
56
|
+
) -> Rerank:
|
|
57
|
+
"""
|
|
58
|
+
Rerank the documents based on the query using the VLLM model.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
documents (List[str]): List of documents to be reranked.
|
|
62
|
+
query (str): The query string to rank the documents against.
|
|
63
|
+
top_n (Optional[int]): The number of top documents to return.
|
|
64
|
+
max_chunks_per_doc (Optional[int]): Maximum chunks per document.
|
|
65
|
+
return_documents (Optional[bool]): Whether to return the documents.
|
|
66
|
+
return_len (Optional[bool]): Whether to return the length of the documents.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Rerank: The reranked results.
|
|
70
|
+
"""
|
|
71
|
+
if kwargs:
|
|
72
|
+
raise RuntimeError("Unexpected keyword arguments: {}".format(kwargs))
|
|
73
|
+
assert self._model is not None
|
|
74
|
+
documents_size = len(documents)
|
|
75
|
+
query_list = [query] * documents_size
|
|
76
|
+
|
|
77
|
+
if self.model_family.model_name in {
|
|
78
|
+
"Qwen3-Reranker-0.6B",
|
|
79
|
+
"Qwen3-Reranker-4B",
|
|
80
|
+
"Qwen3-Reranker-8B",
|
|
81
|
+
}:
|
|
82
|
+
instruction = "Given a web search query, retrieve relevant passages that answer the query"
|
|
83
|
+
prefix = (
|
|
84
|
+
"<|im_start|>system\nJudge whether the Document meets the requirements based on"
|
|
85
|
+
" the Query and the Instruct provided. "
|
|
86
|
+
'Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
|
|
87
|
+
)
|
|
88
|
+
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
|
89
|
+
query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
|
|
90
|
+
document_template = "<Document>: {doc}{suffix}"
|
|
91
|
+
processed_queries = [
|
|
92
|
+
query_template.format(
|
|
93
|
+
prefix=prefix, instruction=instruction, query=query
|
|
94
|
+
)
|
|
95
|
+
for query in query_list
|
|
96
|
+
]
|
|
97
|
+
processed_documents = [
|
|
98
|
+
document_template.format(doc=doc, suffix=suffix) for doc in documents
|
|
99
|
+
]
|
|
100
|
+
outputs = self._model.score(
|
|
101
|
+
processed_documents,
|
|
102
|
+
processed_queries,
|
|
103
|
+
use_tqdm=False,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
else:
|
|
107
|
+
outputs = self._model.score(
|
|
108
|
+
documents,
|
|
109
|
+
query_list,
|
|
110
|
+
use_tqdm=False,
|
|
111
|
+
)
|
|
112
|
+
scores = map(lambda scoreoutput: scoreoutput.outputs.score, outputs)
|
|
113
|
+
documents = list(map(lambda doc: Document(text=doc), documents))
|
|
114
|
+
document_parts = list(zip(range(documents_size), scores, documents))
|
|
115
|
+
document_parts.sort(key=lambda x: x[1], reverse=True)
|
|
116
|
+
if top_n is not None:
|
|
117
|
+
document_parts = document_parts[:top_n]
|
|
118
|
+
reranked_docs = list(
|
|
119
|
+
map(
|
|
120
|
+
lambda doc: DocumentObj(
|
|
121
|
+
index=doc[0],
|
|
122
|
+
relevance_score=doc[1],
|
|
123
|
+
document=doc[2] if return_documents else None,
|
|
124
|
+
),
|
|
125
|
+
document_parts,
|
|
126
|
+
)
|
|
127
|
+
)
|
|
128
|
+
tokens = sum(map(lambda x: len(x.prompt_token_ids), outputs))
|
|
129
|
+
metadata = Meta(
|
|
130
|
+
api_version=None,
|
|
131
|
+
billed_units=None,
|
|
132
|
+
tokens=(
|
|
133
|
+
RerankTokens(input_tokens=tokens, output_tokens=tokens)
|
|
134
|
+
if return_len
|
|
135
|
+
else None
|
|
136
|
+
),
|
|
137
|
+
warnings=None,
|
|
138
|
+
)
|
|
139
|
+
return Rerank(id=str(uuid.uuid4()), results=reranked_docs, meta=metadata)
|
|
140
|
+
|
|
141
|
+
@classmethod
|
|
142
|
+
def check_lib(cls) -> bool:
|
|
143
|
+
return importlib.util.find_spec("vllm") is not None
|
|
144
|
+
|
|
145
|
+
@classmethod
|
|
146
|
+
def match_json(
|
|
147
|
+
cls,
|
|
148
|
+
model_family: RerankModelFamilyV2,
|
|
149
|
+
model_spec: RerankSpecV1,
|
|
150
|
+
quantization: str,
|
|
151
|
+
) -> bool:
|
|
152
|
+
if model_spec.model_format in ["pytorch"]:
|
|
153
|
+
prefix = model_family.model_name.split("-", 1)[0]
|
|
154
|
+
if prefix in SUPPORTED_MODELS_PREFIXES:
|
|
155
|
+
return True
|
|
156
|
+
return False
|
xinference/model/utils.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
|
+
import functools
|
|
16
17
|
import json
|
|
17
18
|
import logging
|
|
18
19
|
import os
|
|
@@ -454,6 +455,19 @@ def get_engine_params_by_name(
|
|
|
454
455
|
for param in params:
|
|
455
456
|
del param["embedding_class"]
|
|
456
457
|
|
|
458
|
+
return engine_params
|
|
459
|
+
elif model_type == "rerank":
|
|
460
|
+
from .rerank.rerank_family import RERANK_ENGINES
|
|
461
|
+
|
|
462
|
+
if model_name not in RERANK_ENGINES:
|
|
463
|
+
return None
|
|
464
|
+
|
|
465
|
+
# filter rerank_class
|
|
466
|
+
engine_params = deepcopy(RERANK_ENGINES[model_name])
|
|
467
|
+
for engine, params in engine_params.items():
|
|
468
|
+
for param in params:
|
|
469
|
+
del param["rerank_class"]
|
|
470
|
+
|
|
457
471
|
return engine_params
|
|
458
472
|
else:
|
|
459
473
|
raise ValueError(
|
|
@@ -558,3 +572,97 @@ class ModelInstanceInfoMixin(ABC):
|
|
|
558
572
|
@abstractmethod
|
|
559
573
|
def to_version_info(self):
|
|
560
574
|
""""""
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
def is_flash_attn_available() -> bool:
|
|
578
|
+
"""
|
|
579
|
+
Check if flash_attention can be enabled in the current environment.
|
|
580
|
+
|
|
581
|
+
Checks the following conditions:
|
|
582
|
+
1. Whether the flash_attn package is installed
|
|
583
|
+
2. Whether CUDA GPU is available
|
|
584
|
+
3. Whether PyTorch supports CUDA
|
|
585
|
+
4. Whether GPU compute capability meets requirements (>= 8.0)
|
|
586
|
+
|
|
587
|
+
Returns:
|
|
588
|
+
bool: True if flash_attention can be enabled, False otherwise
|
|
589
|
+
"""
|
|
590
|
+
import importlib.util
|
|
591
|
+
|
|
592
|
+
# Check if flash_attn is installed
|
|
593
|
+
if importlib.util.find_spec("flash_attn") is None:
|
|
594
|
+
logger.debug("flash_attn package not found")
|
|
595
|
+
return False
|
|
596
|
+
|
|
597
|
+
try:
|
|
598
|
+
import torch
|
|
599
|
+
|
|
600
|
+
# Check if CUDA is available
|
|
601
|
+
if not torch.cuda.is_available():
|
|
602
|
+
logger.debug("CUDA not available")
|
|
603
|
+
return False
|
|
604
|
+
|
|
605
|
+
# Check GPU count
|
|
606
|
+
if torch.cuda.device_count() == 0:
|
|
607
|
+
logger.debug("No CUDA devices found")
|
|
608
|
+
return False
|
|
609
|
+
|
|
610
|
+
# Check current GPU compute capability
|
|
611
|
+
# Flash Attention typically requires compute capability >= 8.0 (A100, H100, etc.)
|
|
612
|
+
current_device = torch.cuda.current_device()
|
|
613
|
+
capability = torch.cuda.get_device_capability(current_device)
|
|
614
|
+
major, minor = capability
|
|
615
|
+
compute_capability = major + minor * 0.1
|
|
616
|
+
|
|
617
|
+
if compute_capability < 8.0:
|
|
618
|
+
logger.debug(
|
|
619
|
+
f"GPU compute capability {compute_capability} < 8.0, "
|
|
620
|
+
"flash_attn may not work optimally"
|
|
621
|
+
)
|
|
622
|
+
return False
|
|
623
|
+
|
|
624
|
+
# Try to import flash_attn core module to verify correct installation
|
|
625
|
+
try:
|
|
626
|
+
import flash_attn
|
|
627
|
+
|
|
628
|
+
logger.debug(
|
|
629
|
+
f"flash_attn version: {getattr(flash_attn, '__version__', 'unknown')}"
|
|
630
|
+
)
|
|
631
|
+
return True
|
|
632
|
+
except ImportError as e:
|
|
633
|
+
logger.debug(f"Failed to import flash_attn: {e}")
|
|
634
|
+
return False
|
|
635
|
+
except Exception as e:
|
|
636
|
+
logger.debug(f"Error checking flash_attn availability: {e}")
|
|
637
|
+
return False
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
def cache_clean(fn):
|
|
641
|
+
@functools.wraps(fn)
|
|
642
|
+
async def _async_wrapper(self, *args, **kwargs):
|
|
643
|
+
import gc
|
|
644
|
+
|
|
645
|
+
from ..device_utils import empty_cache
|
|
646
|
+
|
|
647
|
+
result = await fn(self, *args, **kwargs)
|
|
648
|
+
|
|
649
|
+
gc.collect()
|
|
650
|
+
empty_cache()
|
|
651
|
+
return result
|
|
652
|
+
|
|
653
|
+
@functools.wraps(fn)
|
|
654
|
+
def _wrapper(self, *args, **kwargs):
|
|
655
|
+
import gc
|
|
656
|
+
|
|
657
|
+
from ..device_utils import empty_cache
|
|
658
|
+
|
|
659
|
+
result = fn(self, *args, **kwargs)
|
|
660
|
+
|
|
661
|
+
gc.collect()
|
|
662
|
+
empty_cache()
|
|
663
|
+
return result
|
|
664
|
+
|
|
665
|
+
if asyncio.iscoroutinefunction(fn):
|
|
666
|
+
return _async_wrapper
|
|
667
|
+
else:
|
|
668
|
+
return _wrapper
|