xinference 1.8.1rc1__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (64) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +2 -1
  3. xinference/core/model.py +5 -0
  4. xinference/core/supervisor.py +2 -3
  5. xinference/core/worker.py +3 -4
  6. xinference/deploy/local.py +5 -0
  7. xinference/deploy/worker.py +6 -0
  8. xinference/model/core.py +3 -0
  9. xinference/model/embedding/sentence_transformers/core.py +3 -4
  10. xinference/model/embedding/vllm/core.py +4 -3
  11. xinference/model/image/model_spec.json +69 -0
  12. xinference/model/image/stable_diffusion/core.py +22 -0
  13. xinference/model/llm/cache_manager.py +17 -3
  14. xinference/model/llm/harmony.py +245 -0
  15. xinference/model/llm/llm_family.json +293 -8
  16. xinference/model/llm/llm_family.py +1 -1
  17. xinference/model/llm/sglang/core.py +108 -5
  18. xinference/model/llm/transformers/core.py +15 -7
  19. xinference/model/llm/transformers/gemma3.py +1 -1
  20. xinference/model/llm/transformers/gpt_oss.py +91 -0
  21. xinference/model/llm/transformers/multimodal/core.py +1 -1
  22. xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
  23. xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
  24. xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
  25. xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
  26. xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
  27. xinference/model/llm/transformers/utils.py +1 -33
  28. xinference/model/llm/utils.py +61 -7
  29. xinference/model/llm/vllm/core.py +38 -8
  30. xinference/model/rerank/__init__.py +66 -23
  31. xinference/model/rerank/cache_manager.py +35 -0
  32. xinference/model/rerank/core.py +84 -339
  33. xinference/model/rerank/custom.py +33 -8
  34. xinference/model/rerank/model_spec.json +251 -212
  35. xinference/model/rerank/rerank_family.py +137 -0
  36. xinference/model/rerank/sentence_transformers/__init__.py +13 -0
  37. xinference/model/rerank/sentence_transformers/core.py +337 -0
  38. xinference/model/rerank/vllm/__init__.py +13 -0
  39. xinference/model/rerank/vllm/core.py +106 -0
  40. xinference/model/utils.py +109 -0
  41. xinference/types.py +2 -0
  42. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  43. xinference/ui/web/ui/build/index.html +1 -1
  44. xinference/ui/web/ui/build/static/js/{main.b969199a.js → main.4918643a.js} +3 -3
  45. xinference/ui/web/ui/build/static/js/{main.b969199a.js.map → main.4918643a.js.map} +1 -1
  46. xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
  47. xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
  48. xinference/ui/web/ui/node_modules/.cache/babel-loader/89179f8f51887b9167721860a12412549ff04f78162e921a7b6aa6532646deb2.json +1 -0
  49. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
  50. xinference/ui/web/ui/node_modules/.cache/babel-loader/9dc5cfc67dd0617b0272aeef8651f1589b2155a4ff1fd72ad3166b217089b619.json +1 -0
  51. xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
  52. {xinference-1.8.1rc1.dist-info → xinference-1.9.0.dist-info}/METADATA +6 -1
  53. {xinference-1.8.1rc1.dist-info → xinference-1.9.0.dist-info}/RECORD +58 -50
  54. xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
  55. xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
  56. xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
  57. xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
  58. xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
  59. xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
  60. /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.4918643a.js.LICENSE.txt} +0 -0
  61. {xinference-1.8.1rc1.dist-info → xinference-1.9.0.dist-info}/WHEEL +0 -0
  62. {xinference-1.8.1rc1.dist-info → xinference-1.9.0.dist-info}/entry_points.txt +0 -0
  63. {xinference-1.8.1rc1.dist-info → xinference-1.9.0.dist-info}/licenses/LICENSE +0 -0
  64. {xinference-1.8.1rc1.dist-info → xinference-1.9.0.dist-info}/top_level.txt +0 -0
@@ -16,10 +16,10 @@ import codecs
16
16
  import json
17
17
  import os
18
18
  import warnings
19
- from typing import Dict, List
19
+ from typing import Any, Dict, List
20
20
 
21
21
  from ...constants import XINFERENCE_MODEL_DIR
22
- from ..utils import flatten_model_src
22
+ from ..utils import flatten_quantizations
23
23
  from .core import (
24
24
  RERANK_MODEL_DESCRIPTIONS,
25
25
  RerankModelFamilyV2,
@@ -32,8 +32,13 @@ from .custom import (
32
32
  register_rerank,
33
33
  unregister_rerank,
34
34
  )
35
-
36
- BUILTIN_RERANK_MODELS: Dict[str, List["RerankModelFamilyV2"]] = {}
35
+ from .rerank_family import (
36
+ BUILTIN_RERANK_MODELS,
37
+ RERANK_ENGINES,
38
+ SENTENCE_TRANSFORMER_CLASSES,
39
+ SUPPORTED_ENGINES,
40
+ VLLM_CLASSES,
41
+ )
37
42
 
38
43
 
39
44
  def register_custom_model():
@@ -58,31 +63,69 @@ def register_custom_model():
58
63
  warnings.warn(f"{user_defined_rerank_dir}/{f} has error, {e}")
59
64
 
60
65
 
61
- def _install():
62
- load_model_family_from_json("model_spec.json", BUILTIN_RERANK_MODELS)
66
+ def generate_engine_config_by_model_name(model_family: "RerankModelFamilyV2"):
67
+ model_name = model_family.model_name
68
+ engines: Dict[str, List[Dict[str, Any]]] = RERANK_ENGINES.get(
69
+ model_name, {}
70
+ ) # structure for engine query
71
+ for spec in [x for x in model_family.model_specs if x.model_hub == "huggingface"]:
72
+ model_format = spec.model_format
73
+ quantization = spec.quantization
74
+ for engine in SUPPORTED_ENGINES:
75
+ CLASSES = SUPPORTED_ENGINES[engine]
76
+ for cls in CLASSES:
77
+ # Every engine needs to implement match method
78
+ if cls.match(model_family, spec, quantization):
79
+ # we only match the first class for an engine
80
+ if engine not in engines:
81
+ engines[engine] = [
82
+ {
83
+ "model_name": model_name,
84
+ "model_format": model_format,
85
+ "quantization": quantization,
86
+ "rerank_class": cls,
87
+ }
88
+ ]
89
+ else:
90
+ engines[engine].append(
91
+ {
92
+ "model_name": model_name,
93
+ "model_format": model_format,
94
+ "quantization": quantization,
95
+ "rerank_class": cls,
96
+ }
97
+ )
98
+ break
99
+ RERANK_ENGINES[model_name] = engines
100
+
63
101
 
64
- for model_name, model_specs in BUILTIN_RERANK_MODELS.items():
65
- model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0]
102
+ def _install():
103
+ _model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
104
+ for json_obj in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8")):
105
+ flattened = []
106
+ for spec in json_obj["model_specs"]:
107
+ flattened.extend(flatten_quantizations(spec))
108
+ json_obj["model_specs"] = flattened
109
+ BUILTIN_RERANK_MODELS[json_obj["model_name"]] = RerankModelFamilyV2(**json_obj)
110
+
111
+ for model_name, model_spec in BUILTIN_RERANK_MODELS.items():
66
112
  if model_spec.model_name not in RERANK_MODEL_DESCRIPTIONS:
67
113
  RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(model_spec))
68
114
 
69
- register_custom_model()
115
+ from .sentence_transformers.core import SentenceTransformerRerankModel
116
+ from .vllm.core import VLLMRerankModel
70
117
 
71
- # register model description
72
- for ud_rerank in get_user_defined_reranks():
73
- RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(ud_rerank))
118
+ SENTENCE_TRANSFORMER_CLASSES.extend([SentenceTransformerRerankModel])
119
+ VLLM_CLASSES.extend([VLLMRerankModel])
74
120
 
121
+ SUPPORTED_ENGINES["sentence_transformers"] = SENTENCE_TRANSFORMER_CLASSES
122
+ SUPPORTED_ENGINES["vllm"] = VLLM_CLASSES
75
123
 
76
- def load_model_family_from_json(json_filename, target_families):
77
- _model_spec_json = os.path.join(os.path.dirname(__file__), json_filename)
78
- flattened_model_specs = []
79
- for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8")):
80
- flattened_model_specs.extend(flatten_model_src(spec))
124
+ for model_spec in BUILTIN_RERANK_MODELS.values():
125
+ generate_engine_config_by_model_name(model_spec)
81
126
 
82
- for spec in flattened_model_specs:
83
- if spec["model_name"] not in target_families:
84
- target_families[spec["model_name"]] = [RerankModelFamilyV2(**spec)]
85
- else:
86
- target_families[spec["model_name"]].append(RerankModelFamilyV2(**spec))
127
+ register_custom_model()
87
128
 
88
- del _model_spec_json
129
+ # register model description
130
+ for ud_rerank in get_user_defined_reranks():
131
+ RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(ud_rerank))
@@ -0,0 +1,35 @@
1
+ import os
2
+ from typing import TYPE_CHECKING
3
+
4
+ from ..cache_manager import CacheManager
5
+
6
+ if TYPE_CHECKING:
7
+ from .core import RerankModelFamilyV2
8
+
9
+
10
+ class RerankCacheManager(CacheManager):
11
+ def __init__(self, model_family: "RerankModelFamilyV2"):
12
+ from ..llm.cache_manager import LLMCacheManager
13
+
14
+ super().__init__(model_family)
15
+ # Composition design mode for avoiding duplicate code
16
+ self.cache_helper = LLMCacheManager(model_family)
17
+
18
+ spec = self._model_family.model_specs[0]
19
+ model_dir_name = (
20
+ f"{self._model_family.model_name}-{spec.model_format}-{spec.quantization}"
21
+ )
22
+ self._cache_dir = os.path.join(self._v2_cache_dir_prefix, model_dir_name)
23
+ self.cache_helper._cache_dir = self._cache_dir
24
+
25
+ def cache(self) -> str:
26
+ spec = self._model_family.model_specs[0]
27
+ if spec.model_uri is not None:
28
+ return self.cache_helper.cache_uri()
29
+ else:
30
+ if spec.model_hub == "huggingface":
31
+ return self.cache_helper.cache_from_huggingface()
32
+ elif spec.model_hub == "modelscope":
33
+ return self.cache_helper.cache_from_modelscope()
34
+ else:
35
+ raise ValueError(f"Unknown model hub: {spec.model_hub}")
@@ -10,28 +10,18 @@
10
10
  # distributed under the License is distributed on an "AS IS" BASIS,
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
13
 
15
- import gc
16
- import importlib
17
- import importlib.util
18
14
  import logging
19
15
  import os
20
- import threading
21
- import uuid
16
+ from abc import abstractmethod
22
17
  from collections import defaultdict
23
- from collections.abc import Sequence
24
18
  from typing import Dict, List, Literal, Optional
25
19
 
26
- import numpy as np
27
- import torch
28
- import torch.nn as nn
29
-
30
- from ...device_utils import empty_cache, is_device_available
31
- from ...types import Document, DocumentObj, Rerank, RerankTokens
32
- from ..core import CacheableModelSpec, VirtualEnvSettings
20
+ from ..._compat import BaseModel
21
+ from ...types import Rerank
22
+ from ..core import VirtualEnvSettings
33
23
  from ..utils import ModelInstanceInfoMixin
34
- from .utils import preprocess_sentence
24
+ from .rerank_family import check_engine_by_model_name_and_engine, match_rerank
35
25
 
36
26
  logger = logging.getLogger(__name__)
37
27
 
@@ -48,21 +38,29 @@ def get_rerank_model_descriptions():
48
38
  return copy.deepcopy(RERANK_MODEL_DESCRIPTIONS)
49
39
 
50
40
 
51
- class RerankModelFamilyV2(CacheableModelSpec, ModelInstanceInfoMixin):
41
+ class RerankSpecV1(BaseModel):
42
+ model_format: Literal["pytorch"]
43
+ model_hub: str = "huggingface"
44
+ model_id: Optional[str] = None
45
+ model_revision: Optional[str] = None
46
+ model_uri: Optional[str] = None
47
+ quantization: str = "none"
48
+
49
+
50
+ class RerankModelFamilyV2(BaseModel, ModelInstanceInfoMixin):
52
51
  version: Literal[2]
53
52
  model_name: str
53
+ model_specs: List[RerankSpecV1]
54
54
  language: List[str]
55
55
  type: Optional[str] = "unknown"
56
56
  max_tokens: Optional[int]
57
- model_id: str
58
- model_revision: Optional[str]
59
- model_hub: str = "huggingface"
60
57
  virtualenv: Optional[VirtualEnvSettings]
61
58
 
62
59
  class Config:
63
60
  extra = "allow"
64
61
 
65
62
  def to_description(self):
63
+ spec = self.model_specs[0]
66
64
  return {
67
65
  "model_type": "rerank",
68
66
  "address": getattr(self, "address", None),
@@ -70,13 +68,13 @@ class RerankModelFamilyV2(CacheableModelSpec, ModelInstanceInfoMixin):
70
68
  "type": self.type,
71
69
  "model_name": self.model_name,
72
70
  "language": self.language,
73
- "model_revision": self.model_revision,
71
+ "model_revision": spec.model_revision,
74
72
  }
75
73
 
76
74
  def to_version_info(self):
77
- from ..cache_manager import CacheManager
75
+ from .cache_manager import RerankCacheManager
78
76
 
79
- cache_manager = CacheManager(self)
77
+ cache_manager = RerankCacheManager(self)
80
78
  return {
81
79
  "model_version": self.model_name,
82
80
  "model_file_location": cache_manager.get_cache_dir(),
@@ -93,56 +91,56 @@ def generate_rerank_description(
93
91
  return res
94
92
 
95
93
 
96
- class _ModelWrapper(nn.Module):
97
- def __init__(self, module: nn.Module):
98
- super().__init__()
99
- self.model = module
100
- self._local_data = threading.local()
101
-
102
- @property
103
- def n_tokens(self):
104
- return getattr(self._local_data, "n_tokens", 0)
105
-
106
- @n_tokens.setter
107
- def n_tokens(self, value):
108
- self._local_data.n_tokens = value
109
-
110
- def forward(self, **kwargs):
111
- attention_mask = kwargs.get("attention_mask")
112
- # when batching, the attention mask 1 means there is a token
113
- # thus we just sum up it to get the total number of tokens
114
- if attention_mask is not None:
115
- self.n_tokens += attention_mask.sum().item()
116
- return self.model(**kwargs)
117
-
118
- def __getattr__(self, attr):
119
- try:
120
- return super().__getattr__(attr)
121
- except AttributeError:
122
- return getattr(self.model, attr)
123
-
124
-
125
94
  class RerankModel:
126
95
  def __init__(
127
96
  self,
128
- model_spec: RerankModelFamilyV2,
129
97
  model_uid: str,
130
- model_path: Optional[str] = None,
98
+ model_path: str,
99
+ model_family: RerankModelFamilyV2,
131
100
  device: Optional[str] = None,
132
101
  use_fp16: bool = False,
133
- model_config: Optional[Dict] = None,
102
+ **kwargs,
134
103
  ):
135
- self.model_family = model_spec
136
- self._model_spec = model_spec
104
+ self.model_family = model_family
105
+ self._model_spec = model_family.model_specs[0]
137
106
  self._model_uid = model_uid
138
107
  self._model_path = model_path
139
108
  self._device = device
140
- self._model_config = model_config or dict()
141
109
  self._use_fp16 = use_fp16
142
110
  self._model = None
143
111
  self._counter = 0
144
- if model_spec.type == "unknown":
145
- model_spec.type = self._auto_detect_type(model_path)
112
+ self._kwargs = kwargs
113
+ if model_family.type == "unknown":
114
+ model_family.type = self._auto_detect_type(model_path)
115
+
116
+ @classmethod
117
+ @abstractmethod
118
+ def check_lib(cls) -> bool:
119
+ pass
120
+
121
+ @classmethod
122
+ @abstractmethod
123
+ def match_json(
124
+ cls,
125
+ model_family: RerankModelFamilyV2,
126
+ model_spec: RerankSpecV1,
127
+ quantization: str,
128
+ ) -> bool:
129
+ pass
130
+
131
+ @classmethod
132
+ def match(
133
+ cls,
134
+ model_family: RerankModelFamilyV2,
135
+ model_spec: RerankSpecV1,
136
+ quantization: str,
137
+ ):
138
+ """
139
+ Return if the model_spec can be matched.
140
+ """
141
+ if not cls.check_lib():
142
+ return False
143
+ return cls.match_json(model_family, model_spec, quantization)
146
144
 
147
145
  @staticmethod
148
146
  def _get_tokenizer(model_path):
@@ -171,145 +169,10 @@ class RerankModel:
171
169
  return "normal"
172
170
  return rerank_type
173
171
 
174
- def load(self):
175
- logger.info("Loading rerank model: %s", self._model_path)
176
- flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
177
- if (
178
- self._auto_detect_type(self._model_path) != "normal"
179
- and flash_attn_installed
180
- ):
181
- logger.warning(
182
- "flash_attn can only support fp16 and bf16, "
183
- "will force set `use_fp16` to True"
184
- )
185
- self._use_fp16 = True
186
-
187
- if (
188
- self._model_spec.type == "normal"
189
- and "qwen3" not in self._model_spec.model_name.lower()
190
- ):
191
- try:
192
- import sentence_transformers
193
- from sentence_transformers.cross_encoder import CrossEncoder
194
-
195
- if sentence_transformers.__version__ < "3.1.0":
196
- raise ValueError(
197
- "The sentence_transformers version must be greater than 3.1.0. "
198
- "Please upgrade your version via `pip install -U sentence_transformers` or refer to "
199
- "https://github.com/UKPLab/sentence-transformers"
200
- )
201
- except ImportError:
202
- error_message = "Failed to import module 'sentence-transformers'"
203
- installation_guide = [
204
- "Please make sure 'sentence-transformers' is installed. ",
205
- "You can install it by `pip install sentence-transformers`\n",
206
- ]
207
-
208
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
209
- self._model = CrossEncoder(
210
- self._model_path,
211
- device=self._device,
212
- trust_remote_code=True,
213
- max_length=getattr(self._model_spec, "max_tokens"),
214
- **self._model_config,
215
- )
216
- if self._use_fp16:
217
- self._model.model.half()
218
- elif "qwen3" in self._model_spec.model_name.lower():
219
- # qwen3-reranker
220
- # now we use transformers
221
- # TODO: support engines for rerank models
222
- try:
223
- from transformers import AutoModelForCausalLM, AutoTokenizer
224
- except ImportError:
225
- error_message = "Failed to import module 'transformers'"
226
- installation_guide = [
227
- "Please make sure 'transformers' is installed. ",
228
- "You can install it by `pip install transformers`\n",
229
- ]
230
-
231
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
232
-
233
- tokenizer = AutoTokenizer.from_pretrained(
234
- self._model_path, padding_side="left"
235
- )
236
- enable_flash_attn = self._model_config.pop(
237
- "enable_flash_attn", is_device_available("cuda")
238
- )
239
- model_kwargs = {"device_map": "auto"}
240
- if flash_attn_installed and enable_flash_attn:
241
- model_kwargs["attn_implementation"] = "flash_attention_2"
242
- model_kwargs["torch_dtype"] = torch.float16
243
- model_kwargs.update(self._model_config)
244
- logger.debug("Loading qwen3 rerank with kwargs %s", model_kwargs)
245
- model = self._model = AutoModelForCausalLM.from_pretrained(
246
- self._model_path, **model_kwargs
247
- ).eval()
248
- max_length = getattr(self._model_spec, "max_tokens")
249
-
250
- prefix = (
251
- "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query "
252
- 'and the Instruct provided. Note that the answer can only be "yes" or "no".'
253
- "<|im_end|>\n<|im_start|>user\n"
254
- )
255
- suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
256
- prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
257
- suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
258
-
259
- def process_inputs(pairs):
260
- inputs = tokenizer(
261
- pairs,
262
- padding=False,
263
- truncation="longest_first",
264
- return_attention_mask=False,
265
- max_length=max_length - len(prefix_tokens) - len(suffix_tokens),
266
- )
267
- for i, ele in enumerate(inputs["input_ids"]):
268
- inputs["input_ids"][i] = prefix_tokens + ele + suffix_tokens
269
- inputs = tokenizer.pad(
270
- inputs, padding=True, return_tensors="pt", max_length=max_length
271
- )
272
- for key in inputs:
273
- inputs[key] = inputs[key].to(model.device)
274
- return inputs
275
-
276
- token_false_id = tokenizer.convert_tokens_to_ids("no")
277
- token_true_id = tokenizer.convert_tokens_to_ids("yes")
278
-
279
- @torch.no_grad()
280
- def compute_logits(inputs, **kwargs):
281
- batch_scores = model(**inputs).logits[:, -1, :]
282
- true_vector = batch_scores[:, token_true_id]
283
- false_vector = batch_scores[:, token_false_id]
284
- batch_scores = torch.stack([false_vector, true_vector], dim=1)
285
- batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
286
- scores = batch_scores[:, 1].exp().tolist()
287
- return scores
288
-
289
- self.process_inputs = process_inputs
290
- self.compute_logits = compute_logits
291
- else:
292
- try:
293
- if self._model_spec.type == "LLM-based":
294
- from FlagEmbedding import FlagLLMReranker as FlagReranker
295
- elif self._model_spec.type == "LLM-based layerwise":
296
- from FlagEmbedding import LayerWiseFlagLLMReranker as FlagReranker
297
- else:
298
- raise RuntimeError(
299
- f"Unsupported Rank model type: {self._model_spec.type}"
300
- )
301
- except ImportError:
302
- error_message = "Failed to import module 'FlagEmbedding'"
303
- installation_guide = [
304
- "Please make sure 'FlagEmbedding' is installed. ",
305
- "You can install it by `pip install FlagEmbedding`\n",
306
- ]
307
-
308
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
309
- self._model = FlagReranker(self._model_path, use_fp16=self._use_fp16)
310
- # Wrap transformers model to record number of tokens
311
- self._model.model = _ModelWrapper(self._model.model)
172
+ @abstractmethod
173
+ def load(self): ...
312
174
 
175
+ @abstractmethod
313
176
  def rerank(
314
177
  self,
315
178
  documents: List[str],
@@ -319,159 +182,41 @@ class RerankModel:
319
182
  return_documents: Optional[bool],
320
183
  return_len: Optional[bool],
321
184
  **kwargs,
322
- ) -> Rerank:
323
- assert self._model is not None
324
- if max_chunks_per_doc is not None:
325
- raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
326
- logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model)
327
-
328
- pre_query = preprocess_sentence(
329
- query, kwargs.get("instruction", None), self._model_spec.model_name
330
- )
331
- sentence_combinations = [[pre_query, doc] for doc in documents]
332
- # reset n tokens
333
- self._model.model.n_tokens = 0
334
- if (
335
- self._model_spec.type == "normal"
336
- and "qwen3" not in self._model_spec.model_name.lower()
337
- ):
338
- logger.debug("Passing processed sentences: %s", sentence_combinations)
339
- similarity_scores = self._model.predict(
340
- sentence_combinations,
341
- convert_to_numpy=False,
342
- convert_to_tensor=True,
343
- **kwargs,
344
- ).cpu()
345
- if similarity_scores.dtype == torch.bfloat16:
346
- similarity_scores = similarity_scores.float()
347
- elif "qwen3" in self._model_spec.model_name.lower():
348
-
349
- def format_instruction(instruction, query, doc):
350
- if instruction is None:
351
- instruction = "Given a web search query, retrieve relevant passages that answer the query"
352
- output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
353
- instruction=instruction, query=query, doc=doc
354
- )
355
- return output
356
-
357
- # reduce memory usage.
358
- micro_bs = 4
359
- similarity_scores = []
360
- for i in range(0, len(documents), micro_bs):
361
- sub_docs = documents[i : i + micro_bs]
362
- pairs = [
363
- format_instruction(kwargs.get("instruction", None), query, doc)
364
- for doc in sub_docs
365
- ]
366
- # Tokenize the input texts
367
- inputs = self.process_inputs(pairs)
368
- similarity_scores.extend(self.compute_logits(inputs))
369
- else:
370
- # Related issue: https://github.com/xorbitsai/inference/issues/1775
371
- similarity_scores = self._model.compute_score(
372
- sentence_combinations, **kwargs
373
- )
374
-
375
- if not isinstance(similarity_scores, Sequence):
376
- similarity_scores = [similarity_scores]
377
- elif (
378
- isinstance(similarity_scores, list)
379
- and len(similarity_scores) > 0
380
- and isinstance(similarity_scores[0], Sequence)
381
- ):
382
- similarity_scores = similarity_scores[0]
383
-
384
- sim_scores_argsort = list(reversed(np.argsort(similarity_scores)))
385
- if top_n is not None:
386
- sim_scores_argsort = sim_scores_argsort[:top_n]
387
- if return_documents:
388
- docs = [
389
- DocumentObj(
390
- index=int(arg),
391
- relevance_score=float(similarity_scores[arg]),
392
- document=Document(text=documents[arg]),
393
- )
394
- for arg in sim_scores_argsort
395
- ]
396
- else:
397
- docs = [
398
- DocumentObj(
399
- index=int(arg),
400
- relevance_score=float(similarity_scores[arg]),
401
- document=None,
402
- )
403
- for arg in sim_scores_argsort
404
- ]
405
- if return_len:
406
- input_len = self._model.model.n_tokens
407
- # Rerank Model output is just score or documents
408
- # while return_documents = True
409
- output_len = input_len
410
-
411
- # api_version, billed_units, warnings
412
- # is for Cohere API compatibility, set to None
413
- metadata = {
414
- "api_version": None,
415
- "billed_units": None,
416
- "tokens": (
417
- RerankTokens(input_tokens=input_len, output_tokens=output_len)
418
- if return_len
419
- else None
420
- ),
421
- "warnings": None,
422
- }
423
-
424
- del similarity_scores
425
- # clear cache if possible
426
- self._counter += 1
427
- if self._counter % RERANK_EMPTY_CACHE_COUNT == 0:
428
- logger.debug("Empty rerank cache.")
429
- gc.collect()
430
- empty_cache()
431
-
432
- return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)
185
+ ) -> Rerank: ...
433
186
 
434
187
 
435
188
  def create_rerank_model_instance(
436
189
  model_uid: str,
437
190
  model_name: str,
191
+ model_engine: Optional[str],
192
+ model_format: Optional[str] = None,
193
+ quantization: Optional[str] = None,
438
194
  download_hub: Optional[
439
195
  Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
440
196
  ] = None,
441
197
  model_path: Optional[str] = None,
442
198
  **kwargs,
443
199
  ) -> RerankModel:
444
- from ..cache_manager import CacheManager
445
- from ..utils import download_from_modelscope
446
- from . import BUILTIN_RERANK_MODELS
447
- from .custom import get_user_defined_reranks
448
-
449
- model_spec = None
450
- for ud_spec in get_user_defined_reranks():
451
- if ud_spec.model_name == model_name:
452
- model_spec = ud_spec
453
- break
454
-
455
- if model_spec is None:
456
- if model_name in BUILTIN_RERANK_MODELS:
457
- model_specs = BUILTIN_RERANK_MODELS[model_name]
458
- if download_hub == "modelscope" or download_from_modelscope():
459
- model_spec = (
460
- [x for x in model_specs if x.model_hub == "modelscope"]
461
- + [x for x in model_specs if x.model_hub == "huggingface"]
462
- )[0]
463
- else:
464
- model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0]
465
- else:
466
- raise ValueError(
467
- f"Rerank model {model_name} not found, available "
468
- f"model: {BUILTIN_RERANK_MODELS.keys()}"
469
- )
470
- if not model_path:
471
- cache_manager = CacheManager(model_spec)
200
+ from .cache_manager import RerankCacheManager
201
+
202
+ model_family = match_rerank(model_name, model_format, quantization, download_hub)
203
+ if model_path is None:
204
+ cache_manager = RerankCacheManager(model_family)
472
205
  model_path = cache_manager.cache()
473
- use_fp16 = kwargs.pop("use_fp16", False)
474
- model = RerankModel(
475
- model_spec, model_uid, model_path, use_fp16=use_fp16, model_config=kwargs
206
+
207
+ if model_engine is None:
208
+ # unlike LLM and for compatibility,
209
+ # we use sentence_transformers as the default engine for all models
210
+ model_engine = "sentence_transformers"
211
+
212
+ rerank_cls = check_engine_by_model_name_and_engine(
213
+ model_engine, model_name, model_format, quantization
214
+ )
215
+ model = rerank_cls(
216
+ model_uid,
217
+ model_path,
218
+ model_family,
219
+ quantization,
220
+ **kwargs,
476
221
  )
477
222
  return model