xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +2 -1
- xinference/core/model.py +8 -4
- xinference/core/supervisor.py +2 -3
- xinference/core/worker.py +7 -5
- xinference/deploy/cmdline.py +2 -0
- xinference/deploy/local.py +5 -0
- xinference/deploy/test/test_cmdline.py +1 -1
- xinference/deploy/worker.py +6 -0
- xinference/model/audio/cosyvoice.py +0 -1
- xinference/model/audio/model_spec.json +44 -20
- xinference/model/core.py +3 -0
- xinference/model/embedding/flag/core.py +5 -0
- xinference/model/embedding/llama_cpp/core.py +22 -19
- xinference/model/embedding/sentence_transformers/core.py +18 -4
- xinference/model/embedding/vllm/core.py +36 -9
- xinference/model/image/cache_manager.py +56 -0
- xinference/model/image/core.py +9 -0
- xinference/model/image/model_spec.json +178 -1
- xinference/model/image/stable_diffusion/core.py +155 -23
- xinference/model/llm/cache_manager.py +17 -3
- xinference/model/llm/harmony.py +245 -0
- xinference/model/llm/llama_cpp/core.py +41 -40
- xinference/model/llm/llm_family.json +688 -11
- xinference/model/llm/llm_family.py +1 -1
- xinference/model/llm/sglang/core.py +108 -5
- xinference/model/llm/transformers/core.py +20 -18
- xinference/model/llm/transformers/gemma3.py +1 -1
- xinference/model/llm/transformers/gpt_oss.py +91 -0
- xinference/model/llm/transformers/multimodal/core.py +1 -1
- xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
- xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
- xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
- xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
- xinference/model/llm/transformers/utils.py +1 -33
- xinference/model/llm/utils.py +61 -7
- xinference/model/llm/vllm/core.py +44 -8
- xinference/model/rerank/__init__.py +66 -23
- xinference/model/rerank/cache_manager.py +35 -0
- xinference/model/rerank/core.py +87 -339
- xinference/model/rerank/custom.py +33 -8
- xinference/model/rerank/model_spec.json +251 -212
- xinference/model/rerank/rerank_family.py +137 -0
- xinference/model/rerank/sentence_transformers/__init__.py +13 -0
- xinference/model/rerank/sentence_transformers/core.py +337 -0
- xinference/model/rerank/vllm/__init__.py +13 -0
- xinference/model/rerank/vllm/core.py +156 -0
- xinference/model/utils.py +108 -0
- xinference/model/video/model_spec.json +95 -1
- xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
- xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
- xinference/thirdparty/cosyvoice/bin/train.py +23 -3
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
- xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
- xinference/thirdparty/cosyvoice/cli/model.py +53 -75
- xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
- xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
- xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
- xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
- xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
- xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
- xinference/thirdparty/cosyvoice/utils/common.py +20 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
- xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
- xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
- xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
- xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
- xinference/types.py +2 -0
- xinference/ui/gradio/chat_interface.py +2 -0
- xinference/ui/gradio/media_interface.py +353 -7
- xinference/ui/web/ui/build/asset-manifest.json +3 -3
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
- xinference/ui/web/ui/src/locales/en.json +2 -0
- xinference/ui/web/ui/src/locales/ja.json +2 -0
- xinference/ui/web/ui/src/locales/ko.json +2 -0
- xinference/ui/web/ui/src/locales/zh.json +2 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
- xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
- xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
- /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
xinference/model/rerank/core.py
CHANGED
|
@@ -10,28 +10,18 @@
|
|
|
10
10
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
13
|
|
|
15
|
-
import gc
|
|
16
|
-
import importlib
|
|
17
|
-
import importlib.util
|
|
18
14
|
import logging
|
|
19
15
|
import os
|
|
20
|
-
import
|
|
21
|
-
import uuid
|
|
16
|
+
from abc import abstractmethod
|
|
22
17
|
from collections import defaultdict
|
|
23
|
-
from collections.abc import Sequence
|
|
24
18
|
from typing import Dict, List, Literal, Optional
|
|
25
19
|
|
|
26
|
-
|
|
27
|
-
import
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
from ...device_utils import empty_cache, is_device_available
|
|
31
|
-
from ...types import Document, DocumentObj, Rerank, RerankTokens
|
|
32
|
-
from ..core import CacheableModelSpec, VirtualEnvSettings
|
|
20
|
+
from ..._compat import BaseModel
|
|
21
|
+
from ...types import Rerank
|
|
22
|
+
from ..core import VirtualEnvSettings
|
|
33
23
|
from ..utils import ModelInstanceInfoMixin
|
|
34
|
-
from .
|
|
24
|
+
from .rerank_family import check_engine_by_model_name_and_engine, match_rerank
|
|
35
25
|
|
|
36
26
|
logger = logging.getLogger(__name__)
|
|
37
27
|
|
|
@@ -48,21 +38,29 @@ def get_rerank_model_descriptions():
|
|
|
48
38
|
return copy.deepcopy(RERANK_MODEL_DESCRIPTIONS)
|
|
49
39
|
|
|
50
40
|
|
|
51
|
-
class
|
|
41
|
+
class RerankSpecV1(BaseModel):
|
|
42
|
+
model_format: Literal["pytorch"]
|
|
43
|
+
model_hub: str = "huggingface"
|
|
44
|
+
model_id: Optional[str] = None
|
|
45
|
+
model_revision: Optional[str] = None
|
|
46
|
+
model_uri: Optional[str] = None
|
|
47
|
+
quantization: str = "none"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class RerankModelFamilyV2(BaseModel, ModelInstanceInfoMixin):
|
|
52
51
|
version: Literal[2]
|
|
53
52
|
model_name: str
|
|
53
|
+
model_specs: List[RerankSpecV1]
|
|
54
54
|
language: List[str]
|
|
55
55
|
type: Optional[str] = "unknown"
|
|
56
56
|
max_tokens: Optional[int]
|
|
57
|
-
model_id: str
|
|
58
|
-
model_revision: Optional[str]
|
|
59
|
-
model_hub: str = "huggingface"
|
|
60
57
|
virtualenv: Optional[VirtualEnvSettings]
|
|
61
58
|
|
|
62
59
|
class Config:
|
|
63
60
|
extra = "allow"
|
|
64
61
|
|
|
65
62
|
def to_description(self):
|
|
63
|
+
spec = self.model_specs[0]
|
|
66
64
|
return {
|
|
67
65
|
"model_type": "rerank",
|
|
68
66
|
"address": getattr(self, "address", None),
|
|
@@ -70,13 +68,13 @@ class RerankModelFamilyV2(CacheableModelSpec, ModelInstanceInfoMixin):
|
|
|
70
68
|
"type": self.type,
|
|
71
69
|
"model_name": self.model_name,
|
|
72
70
|
"language": self.language,
|
|
73
|
-
"model_revision":
|
|
71
|
+
"model_revision": spec.model_revision,
|
|
74
72
|
}
|
|
75
73
|
|
|
76
74
|
def to_version_info(self):
|
|
77
|
-
from
|
|
75
|
+
from .cache_manager import RerankCacheManager
|
|
78
76
|
|
|
79
|
-
cache_manager =
|
|
77
|
+
cache_manager = RerankCacheManager(self)
|
|
80
78
|
return {
|
|
81
79
|
"model_version": self.model_name,
|
|
82
80
|
"model_file_location": cache_manager.get_cache_dir(),
|
|
@@ -93,56 +91,59 @@ def generate_rerank_description(
|
|
|
93
91
|
return res
|
|
94
92
|
|
|
95
93
|
|
|
96
|
-
class _ModelWrapper(nn.Module):
|
|
97
|
-
def __init__(self, module: nn.Module):
|
|
98
|
-
super().__init__()
|
|
99
|
-
self.model = module
|
|
100
|
-
self._local_data = threading.local()
|
|
101
|
-
|
|
102
|
-
@property
|
|
103
|
-
def n_tokens(self):
|
|
104
|
-
return getattr(self._local_data, "n_tokens", 0)
|
|
105
|
-
|
|
106
|
-
@n_tokens.setter
|
|
107
|
-
def n_tokens(self, value):
|
|
108
|
-
self._local_data.n_tokens = value
|
|
109
|
-
|
|
110
|
-
def forward(self, **kwargs):
|
|
111
|
-
attention_mask = kwargs.get("attention_mask")
|
|
112
|
-
# when batching, the attention mask 1 means there is a token
|
|
113
|
-
# thus we just sum up it to get the total number of tokens
|
|
114
|
-
if attention_mask is not None:
|
|
115
|
-
self.n_tokens += attention_mask.sum().item()
|
|
116
|
-
return self.model(**kwargs)
|
|
117
|
-
|
|
118
|
-
def __getattr__(self, attr):
|
|
119
|
-
try:
|
|
120
|
-
return super().__getattr__(attr)
|
|
121
|
-
except AttributeError:
|
|
122
|
-
return getattr(self.model, attr)
|
|
123
|
-
|
|
124
|
-
|
|
125
94
|
class RerankModel:
|
|
126
95
|
def __init__(
|
|
127
96
|
self,
|
|
128
|
-
model_spec: RerankModelFamilyV2,
|
|
129
97
|
model_uid: str,
|
|
130
|
-
model_path:
|
|
98
|
+
model_path: str,
|
|
99
|
+
model_family: RerankModelFamilyV2,
|
|
100
|
+
quantization: Optional[str],
|
|
101
|
+
*,
|
|
131
102
|
device: Optional[str] = None,
|
|
132
103
|
use_fp16: bool = False,
|
|
133
|
-
|
|
104
|
+
**kwargs,
|
|
134
105
|
):
|
|
135
|
-
self.model_family =
|
|
136
|
-
self._model_spec =
|
|
106
|
+
self.model_family = model_family
|
|
107
|
+
self._model_spec = model_family.model_specs[0]
|
|
137
108
|
self._model_uid = model_uid
|
|
138
109
|
self._model_path = model_path
|
|
110
|
+
self._quantization = quantization
|
|
139
111
|
self._device = device
|
|
140
|
-
self._model_config = model_config or dict()
|
|
141
112
|
self._use_fp16 = use_fp16
|
|
142
113
|
self._model = None
|
|
143
114
|
self._counter = 0
|
|
144
|
-
|
|
145
|
-
|
|
115
|
+
self._kwargs = kwargs
|
|
116
|
+
if model_family.type == "unknown":
|
|
117
|
+
model_family.type = self._auto_detect_type(model_path)
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
@abstractmethod
|
|
121
|
+
def check_lib(cls) -> bool:
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
@classmethod
|
|
125
|
+
@abstractmethod
|
|
126
|
+
def match_json(
|
|
127
|
+
cls,
|
|
128
|
+
model_family: RerankModelFamilyV2,
|
|
129
|
+
model_spec: RerankSpecV1,
|
|
130
|
+
quantization: str,
|
|
131
|
+
) -> bool:
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def match(
|
|
136
|
+
cls,
|
|
137
|
+
model_family: RerankModelFamilyV2,
|
|
138
|
+
model_spec: RerankSpecV1,
|
|
139
|
+
quantization: str,
|
|
140
|
+
):
|
|
141
|
+
"""
|
|
142
|
+
Return if the model_spec can be matched.
|
|
143
|
+
"""
|
|
144
|
+
if not cls.check_lib():
|
|
145
|
+
return False
|
|
146
|
+
return cls.match_json(model_family, model_spec, quantization)
|
|
146
147
|
|
|
147
148
|
@staticmethod
|
|
148
149
|
def _get_tokenizer(model_path):
|
|
@@ -171,145 +172,10 @@ class RerankModel:
|
|
|
171
172
|
return "normal"
|
|
172
173
|
return rerank_type
|
|
173
174
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
|
|
177
|
-
if (
|
|
178
|
-
self._auto_detect_type(self._model_path) != "normal"
|
|
179
|
-
and flash_attn_installed
|
|
180
|
-
):
|
|
181
|
-
logger.warning(
|
|
182
|
-
"flash_attn can only support fp16 and bf16, "
|
|
183
|
-
"will force set `use_fp16` to True"
|
|
184
|
-
)
|
|
185
|
-
self._use_fp16 = True
|
|
186
|
-
|
|
187
|
-
if (
|
|
188
|
-
self._model_spec.type == "normal"
|
|
189
|
-
and "qwen3" not in self._model_spec.model_name.lower()
|
|
190
|
-
):
|
|
191
|
-
try:
|
|
192
|
-
import sentence_transformers
|
|
193
|
-
from sentence_transformers.cross_encoder import CrossEncoder
|
|
194
|
-
|
|
195
|
-
if sentence_transformers.__version__ < "3.1.0":
|
|
196
|
-
raise ValueError(
|
|
197
|
-
"The sentence_transformers version must be greater than 3.1.0. "
|
|
198
|
-
"Please upgrade your version via `pip install -U sentence_transformers` or refer to "
|
|
199
|
-
"https://github.com/UKPLab/sentence-transformers"
|
|
200
|
-
)
|
|
201
|
-
except ImportError:
|
|
202
|
-
error_message = "Failed to import module 'sentence-transformers'"
|
|
203
|
-
installation_guide = [
|
|
204
|
-
"Please make sure 'sentence-transformers' is installed. ",
|
|
205
|
-
"You can install it by `pip install sentence-transformers`\n",
|
|
206
|
-
]
|
|
207
|
-
|
|
208
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
209
|
-
self._model = CrossEncoder(
|
|
210
|
-
self._model_path,
|
|
211
|
-
device=self._device,
|
|
212
|
-
trust_remote_code=True,
|
|
213
|
-
max_length=getattr(self._model_spec, "max_tokens"),
|
|
214
|
-
**self._model_config,
|
|
215
|
-
)
|
|
216
|
-
if self._use_fp16:
|
|
217
|
-
self._model.model.half()
|
|
218
|
-
elif "qwen3" in self._model_spec.model_name.lower():
|
|
219
|
-
# qwen3-reranker
|
|
220
|
-
# now we use transformers
|
|
221
|
-
# TODO: support engines for rerank models
|
|
222
|
-
try:
|
|
223
|
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
224
|
-
except ImportError:
|
|
225
|
-
error_message = "Failed to import module 'transformers'"
|
|
226
|
-
installation_guide = [
|
|
227
|
-
"Please make sure 'transformers' is installed. ",
|
|
228
|
-
"You can install it by `pip install transformers`\n",
|
|
229
|
-
]
|
|
230
|
-
|
|
231
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
232
|
-
|
|
233
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
234
|
-
self._model_path, padding_side="left"
|
|
235
|
-
)
|
|
236
|
-
enable_flash_attn = self._model_config.pop(
|
|
237
|
-
"enable_flash_attn", is_device_available("cuda")
|
|
238
|
-
)
|
|
239
|
-
model_kwargs = {"device_map": "auto"}
|
|
240
|
-
if flash_attn_installed and enable_flash_attn:
|
|
241
|
-
model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
242
|
-
model_kwargs["torch_dtype"] = torch.float16
|
|
243
|
-
model_kwargs.update(self._model_config)
|
|
244
|
-
logger.debug("Loading qwen3 rerank with kwargs %s", model_kwargs)
|
|
245
|
-
model = self._model = AutoModelForCausalLM.from_pretrained(
|
|
246
|
-
self._model_path, **model_kwargs
|
|
247
|
-
).eval()
|
|
248
|
-
max_length = getattr(self._model_spec, "max_tokens")
|
|
249
|
-
|
|
250
|
-
prefix = (
|
|
251
|
-
"<|im_start|>system\nJudge whether the Document meets the requirements based on the Query "
|
|
252
|
-
'and the Instruct provided. Note that the answer can only be "yes" or "no".'
|
|
253
|
-
"<|im_end|>\n<|im_start|>user\n"
|
|
254
|
-
)
|
|
255
|
-
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
|
256
|
-
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
|
|
257
|
-
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
|
|
258
|
-
|
|
259
|
-
def process_inputs(pairs):
|
|
260
|
-
inputs = tokenizer(
|
|
261
|
-
pairs,
|
|
262
|
-
padding=False,
|
|
263
|
-
truncation="longest_first",
|
|
264
|
-
return_attention_mask=False,
|
|
265
|
-
max_length=max_length - len(prefix_tokens) - len(suffix_tokens),
|
|
266
|
-
)
|
|
267
|
-
for i, ele in enumerate(inputs["input_ids"]):
|
|
268
|
-
inputs["input_ids"][i] = prefix_tokens + ele + suffix_tokens
|
|
269
|
-
inputs = tokenizer.pad(
|
|
270
|
-
inputs, padding=True, return_tensors="pt", max_length=max_length
|
|
271
|
-
)
|
|
272
|
-
for key in inputs:
|
|
273
|
-
inputs[key] = inputs[key].to(model.device)
|
|
274
|
-
return inputs
|
|
275
|
-
|
|
276
|
-
token_false_id = tokenizer.convert_tokens_to_ids("no")
|
|
277
|
-
token_true_id = tokenizer.convert_tokens_to_ids("yes")
|
|
278
|
-
|
|
279
|
-
@torch.no_grad()
|
|
280
|
-
def compute_logits(inputs, **kwargs):
|
|
281
|
-
batch_scores = model(**inputs).logits[:, -1, :]
|
|
282
|
-
true_vector = batch_scores[:, token_true_id]
|
|
283
|
-
false_vector = batch_scores[:, token_false_id]
|
|
284
|
-
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
|
285
|
-
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
|
286
|
-
scores = batch_scores[:, 1].exp().tolist()
|
|
287
|
-
return scores
|
|
288
|
-
|
|
289
|
-
self.process_inputs = process_inputs
|
|
290
|
-
self.compute_logits = compute_logits
|
|
291
|
-
else:
|
|
292
|
-
try:
|
|
293
|
-
if self._model_spec.type == "LLM-based":
|
|
294
|
-
from FlagEmbedding import FlagLLMReranker as FlagReranker
|
|
295
|
-
elif self._model_spec.type == "LLM-based layerwise":
|
|
296
|
-
from FlagEmbedding import LayerWiseFlagLLMReranker as FlagReranker
|
|
297
|
-
else:
|
|
298
|
-
raise RuntimeError(
|
|
299
|
-
f"Unsupported Rank model type: {self._model_spec.type}"
|
|
300
|
-
)
|
|
301
|
-
except ImportError:
|
|
302
|
-
error_message = "Failed to import module 'FlagEmbedding'"
|
|
303
|
-
installation_guide = [
|
|
304
|
-
"Please make sure 'FlagEmbedding' is installed. ",
|
|
305
|
-
"You can install it by `pip install FlagEmbedding`\n",
|
|
306
|
-
]
|
|
307
|
-
|
|
308
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
309
|
-
self._model = FlagReranker(self._model_path, use_fp16=self._use_fp16)
|
|
310
|
-
# Wrap transformers model to record number of tokens
|
|
311
|
-
self._model.model = _ModelWrapper(self._model.model)
|
|
175
|
+
@abstractmethod
|
|
176
|
+
def load(self): ...
|
|
312
177
|
|
|
178
|
+
@abstractmethod
|
|
313
179
|
def rerank(
|
|
314
180
|
self,
|
|
315
181
|
documents: List[str],
|
|
@@ -319,159 +185,41 @@ class RerankModel:
|
|
|
319
185
|
return_documents: Optional[bool],
|
|
320
186
|
return_len: Optional[bool],
|
|
321
187
|
**kwargs,
|
|
322
|
-
) -> Rerank:
|
|
323
|
-
assert self._model is not None
|
|
324
|
-
if max_chunks_per_doc is not None:
|
|
325
|
-
raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
|
|
326
|
-
logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model)
|
|
327
|
-
|
|
328
|
-
pre_query = preprocess_sentence(
|
|
329
|
-
query, kwargs.get("instruction", None), self._model_spec.model_name
|
|
330
|
-
)
|
|
331
|
-
sentence_combinations = [[pre_query, doc] for doc in documents]
|
|
332
|
-
# reset n tokens
|
|
333
|
-
self._model.model.n_tokens = 0
|
|
334
|
-
if (
|
|
335
|
-
self._model_spec.type == "normal"
|
|
336
|
-
and "qwen3" not in self._model_spec.model_name.lower()
|
|
337
|
-
):
|
|
338
|
-
logger.debug("Passing processed sentences: %s", sentence_combinations)
|
|
339
|
-
similarity_scores = self._model.predict(
|
|
340
|
-
sentence_combinations,
|
|
341
|
-
convert_to_numpy=False,
|
|
342
|
-
convert_to_tensor=True,
|
|
343
|
-
**kwargs,
|
|
344
|
-
).cpu()
|
|
345
|
-
if similarity_scores.dtype == torch.bfloat16:
|
|
346
|
-
similarity_scores = similarity_scores.float()
|
|
347
|
-
elif "qwen3" in self._model_spec.model_name.lower():
|
|
348
|
-
|
|
349
|
-
def format_instruction(instruction, query, doc):
|
|
350
|
-
if instruction is None:
|
|
351
|
-
instruction = "Given a web search query, retrieve relevant passages that answer the query"
|
|
352
|
-
output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
|
|
353
|
-
instruction=instruction, query=query, doc=doc
|
|
354
|
-
)
|
|
355
|
-
return output
|
|
356
|
-
|
|
357
|
-
# reduce memory usage.
|
|
358
|
-
micro_bs = 4
|
|
359
|
-
similarity_scores = []
|
|
360
|
-
for i in range(0, len(documents), micro_bs):
|
|
361
|
-
sub_docs = documents[i : i + micro_bs]
|
|
362
|
-
pairs = [
|
|
363
|
-
format_instruction(kwargs.get("instruction", None), query, doc)
|
|
364
|
-
for doc in sub_docs
|
|
365
|
-
]
|
|
366
|
-
# Tokenize the input texts
|
|
367
|
-
inputs = self.process_inputs(pairs)
|
|
368
|
-
similarity_scores.extend(self.compute_logits(inputs))
|
|
369
|
-
else:
|
|
370
|
-
# Related issue: https://github.com/xorbitsai/inference/issues/1775
|
|
371
|
-
similarity_scores = self._model.compute_score(
|
|
372
|
-
sentence_combinations, **kwargs
|
|
373
|
-
)
|
|
374
|
-
|
|
375
|
-
if not isinstance(similarity_scores, Sequence):
|
|
376
|
-
similarity_scores = [similarity_scores]
|
|
377
|
-
elif (
|
|
378
|
-
isinstance(similarity_scores, list)
|
|
379
|
-
and len(similarity_scores) > 0
|
|
380
|
-
and isinstance(similarity_scores[0], Sequence)
|
|
381
|
-
):
|
|
382
|
-
similarity_scores = similarity_scores[0]
|
|
383
|
-
|
|
384
|
-
sim_scores_argsort = list(reversed(np.argsort(similarity_scores)))
|
|
385
|
-
if top_n is not None:
|
|
386
|
-
sim_scores_argsort = sim_scores_argsort[:top_n]
|
|
387
|
-
if return_documents:
|
|
388
|
-
docs = [
|
|
389
|
-
DocumentObj(
|
|
390
|
-
index=int(arg),
|
|
391
|
-
relevance_score=float(similarity_scores[arg]),
|
|
392
|
-
document=Document(text=documents[arg]),
|
|
393
|
-
)
|
|
394
|
-
for arg in sim_scores_argsort
|
|
395
|
-
]
|
|
396
|
-
else:
|
|
397
|
-
docs = [
|
|
398
|
-
DocumentObj(
|
|
399
|
-
index=int(arg),
|
|
400
|
-
relevance_score=float(similarity_scores[arg]),
|
|
401
|
-
document=None,
|
|
402
|
-
)
|
|
403
|
-
for arg in sim_scores_argsort
|
|
404
|
-
]
|
|
405
|
-
if return_len:
|
|
406
|
-
input_len = self._model.model.n_tokens
|
|
407
|
-
# Rerank Model output is just score or documents
|
|
408
|
-
# while return_documents = True
|
|
409
|
-
output_len = input_len
|
|
410
|
-
|
|
411
|
-
# api_version, billed_units, warnings
|
|
412
|
-
# is for Cohere API compatibility, set to None
|
|
413
|
-
metadata = {
|
|
414
|
-
"api_version": None,
|
|
415
|
-
"billed_units": None,
|
|
416
|
-
"tokens": (
|
|
417
|
-
RerankTokens(input_tokens=input_len, output_tokens=output_len)
|
|
418
|
-
if return_len
|
|
419
|
-
else None
|
|
420
|
-
),
|
|
421
|
-
"warnings": None,
|
|
422
|
-
}
|
|
423
|
-
|
|
424
|
-
del similarity_scores
|
|
425
|
-
# clear cache if possible
|
|
426
|
-
self._counter += 1
|
|
427
|
-
if self._counter % RERANK_EMPTY_CACHE_COUNT == 0:
|
|
428
|
-
logger.debug("Empty rerank cache.")
|
|
429
|
-
gc.collect()
|
|
430
|
-
empty_cache()
|
|
431
|
-
|
|
432
|
-
return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)
|
|
188
|
+
) -> Rerank: ...
|
|
433
189
|
|
|
434
190
|
|
|
435
191
|
def create_rerank_model_instance(
|
|
436
192
|
model_uid: str,
|
|
437
193
|
model_name: str,
|
|
194
|
+
model_engine: Optional[str],
|
|
195
|
+
model_format: Optional[str] = None,
|
|
196
|
+
quantization: Optional[str] = None,
|
|
438
197
|
download_hub: Optional[
|
|
439
198
|
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
440
199
|
] = None,
|
|
441
200
|
model_path: Optional[str] = None,
|
|
442
201
|
**kwargs,
|
|
443
202
|
) -> RerankModel:
|
|
444
|
-
from
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
model_spec = None
|
|
450
|
-
for ud_spec in get_user_defined_reranks():
|
|
451
|
-
if ud_spec.model_name == model_name:
|
|
452
|
-
model_spec = ud_spec
|
|
453
|
-
break
|
|
454
|
-
|
|
455
|
-
if model_spec is None:
|
|
456
|
-
if model_name in BUILTIN_RERANK_MODELS:
|
|
457
|
-
model_specs = BUILTIN_RERANK_MODELS[model_name]
|
|
458
|
-
if download_hub == "modelscope" or download_from_modelscope():
|
|
459
|
-
model_spec = (
|
|
460
|
-
[x for x in model_specs if x.model_hub == "modelscope"]
|
|
461
|
-
+ [x for x in model_specs if x.model_hub == "huggingface"]
|
|
462
|
-
)[0]
|
|
463
|
-
else:
|
|
464
|
-
model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0]
|
|
465
|
-
else:
|
|
466
|
-
raise ValueError(
|
|
467
|
-
f"Rerank model {model_name} not found, available "
|
|
468
|
-
f"model: {BUILTIN_RERANK_MODELS.keys()}"
|
|
469
|
-
)
|
|
470
|
-
if not model_path:
|
|
471
|
-
cache_manager = CacheManager(model_spec)
|
|
203
|
+
from .cache_manager import RerankCacheManager
|
|
204
|
+
|
|
205
|
+
model_family = match_rerank(model_name, model_format, quantization, download_hub)
|
|
206
|
+
if model_path is None:
|
|
207
|
+
cache_manager = RerankCacheManager(model_family)
|
|
472
208
|
model_path = cache_manager.cache()
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
209
|
+
|
|
210
|
+
if model_engine is None:
|
|
211
|
+
# unlike LLM and for compatibility,
|
|
212
|
+
# we use sentence_transformers as the default engine for all models
|
|
213
|
+
model_engine = "sentence_transformers"
|
|
214
|
+
|
|
215
|
+
rerank_cls = check_engine_by_model_name_and_engine(
|
|
216
|
+
model_engine, model_name, model_format, quantization
|
|
217
|
+
)
|
|
218
|
+
model = rerank_cls(
|
|
219
|
+
model_uid,
|
|
220
|
+
model_path,
|
|
221
|
+
model_family,
|
|
222
|
+
quantization,
|
|
223
|
+
**kwargs,
|
|
476
224
|
)
|
|
477
225
|
return model
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import logging
|
|
15
|
-
from typing import List, Literal
|
|
15
|
+
from typing import List, Literal
|
|
16
16
|
|
|
17
17
|
from ..custom import ModelRegistry
|
|
18
18
|
from .core import RerankModelFamilyV2
|
|
@@ -22,10 +22,6 @@ logger = logging.getLogger(__name__)
|
|
|
22
22
|
|
|
23
23
|
class CustomRerankModelFamilyV2(RerankModelFamilyV2):
|
|
24
24
|
version: Literal[2] = 2
|
|
25
|
-
model_id: Optional[str] # type: ignore
|
|
26
|
-
model_revision: Optional[str] # type: ignore
|
|
27
|
-
model_uri: Optional[str]
|
|
28
|
-
model_type: Literal["rerank"] = "rerank" # for frontend
|
|
29
25
|
|
|
30
26
|
|
|
31
27
|
UD_RERANKS: List[CustomRerankModelFamilyV2] = []
|
|
@@ -35,12 +31,41 @@ class RerankModelRegistry(ModelRegistry):
|
|
|
35
31
|
model_type = "rerank"
|
|
36
32
|
|
|
37
33
|
def __init__(self):
|
|
38
|
-
from . import BUILTIN_RERANK_MODELS
|
|
34
|
+
from .rerank_family import BUILTIN_RERANK_MODELS
|
|
39
35
|
|
|
40
36
|
super().__init__()
|
|
41
37
|
self.models = UD_RERANKS
|
|
42
38
|
self.builtin_models = list(BUILTIN_RERANK_MODELS.keys())
|
|
43
39
|
|
|
40
|
+
def add_ud_model(self, model_spec):
|
|
41
|
+
from . import generate_engine_config_by_model_name
|
|
42
|
+
|
|
43
|
+
UD_RERANKS.append(model_spec)
|
|
44
|
+
generate_engine_config_by_model_name(model_spec)
|
|
45
|
+
|
|
46
|
+
def check_model_uri(self, model_family: "RerankModelFamilyV2"):
|
|
47
|
+
from ..utils import is_valid_model_uri
|
|
48
|
+
|
|
49
|
+
for spec in model_family.model_specs:
|
|
50
|
+
model_uri = spec.model_uri
|
|
51
|
+
if model_uri and not is_valid_model_uri(model_uri):
|
|
52
|
+
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
53
|
+
|
|
54
|
+
def remove_ud_model(self, model_family: "CustomRerankModelFamilyV2"):
|
|
55
|
+
from .rerank_family import RERANK_ENGINES
|
|
56
|
+
|
|
57
|
+
UD_RERANKS.remove(model_family)
|
|
58
|
+
del RERANK_ENGINES[model_family.model_name]
|
|
59
|
+
|
|
60
|
+
def remove_ud_model_files(self, model_family: "CustomRerankModelFamilyV2"):
|
|
61
|
+
from .cache_manager import RerankCacheManager
|
|
62
|
+
|
|
63
|
+
_model_family = model_family.copy()
|
|
64
|
+
for spec in model_family.model_specs:
|
|
65
|
+
_model_family.model_specs = [spec]
|
|
66
|
+
cache_manager = RerankCacheManager(_model_family)
|
|
67
|
+
cache_manager.unregister_custom_model(self.model_type)
|
|
68
|
+
|
|
44
69
|
|
|
45
70
|
def get_user_defined_reranks() -> List[CustomRerankModelFamilyV2]:
|
|
46
71
|
from ..custom import RegistryManager
|
|
@@ -49,11 +74,11 @@ def get_user_defined_reranks() -> List[CustomRerankModelFamilyV2]:
|
|
|
49
74
|
return registry.get_custom_models()
|
|
50
75
|
|
|
51
76
|
|
|
52
|
-
def register_rerank(
|
|
77
|
+
def register_rerank(model_family: CustomRerankModelFamilyV2, persist: bool):
|
|
53
78
|
from ..custom import RegistryManager
|
|
54
79
|
|
|
55
80
|
registry = RegistryManager.get_registry("rerank")
|
|
56
|
-
registry.register(
|
|
81
|
+
registry.register(model_family, persist)
|
|
57
82
|
|
|
58
83
|
|
|
59
84
|
def unregister_rerank(model_name: str, raise_error: bool = True):
|