xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (108) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +2 -1
  3. xinference/core/model.py +8 -4
  4. xinference/core/supervisor.py +2 -3
  5. xinference/core/worker.py +7 -5
  6. xinference/deploy/cmdline.py +2 -0
  7. xinference/deploy/local.py +5 -0
  8. xinference/deploy/test/test_cmdline.py +1 -1
  9. xinference/deploy/worker.py +6 -0
  10. xinference/model/audio/cosyvoice.py +0 -1
  11. xinference/model/audio/model_spec.json +44 -20
  12. xinference/model/core.py +3 -0
  13. xinference/model/embedding/flag/core.py +5 -0
  14. xinference/model/embedding/llama_cpp/core.py +22 -19
  15. xinference/model/embedding/sentence_transformers/core.py +18 -4
  16. xinference/model/embedding/vllm/core.py +36 -9
  17. xinference/model/image/cache_manager.py +56 -0
  18. xinference/model/image/core.py +9 -0
  19. xinference/model/image/model_spec.json +178 -1
  20. xinference/model/image/stable_diffusion/core.py +155 -23
  21. xinference/model/llm/cache_manager.py +17 -3
  22. xinference/model/llm/harmony.py +245 -0
  23. xinference/model/llm/llama_cpp/core.py +41 -40
  24. xinference/model/llm/llm_family.json +688 -11
  25. xinference/model/llm/llm_family.py +1 -1
  26. xinference/model/llm/sglang/core.py +108 -5
  27. xinference/model/llm/transformers/core.py +20 -18
  28. xinference/model/llm/transformers/gemma3.py +1 -1
  29. xinference/model/llm/transformers/gpt_oss.py +91 -0
  30. xinference/model/llm/transformers/multimodal/core.py +1 -1
  31. xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
  32. xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
  33. xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
  34. xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
  35. xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
  36. xinference/model/llm/transformers/utils.py +1 -33
  37. xinference/model/llm/utils.py +61 -7
  38. xinference/model/llm/vllm/core.py +44 -8
  39. xinference/model/rerank/__init__.py +66 -23
  40. xinference/model/rerank/cache_manager.py +35 -0
  41. xinference/model/rerank/core.py +87 -339
  42. xinference/model/rerank/custom.py +33 -8
  43. xinference/model/rerank/model_spec.json +251 -212
  44. xinference/model/rerank/rerank_family.py +137 -0
  45. xinference/model/rerank/sentence_transformers/__init__.py +13 -0
  46. xinference/model/rerank/sentence_transformers/core.py +337 -0
  47. xinference/model/rerank/vllm/__init__.py +13 -0
  48. xinference/model/rerank/vllm/core.py +156 -0
  49. xinference/model/utils.py +108 -0
  50. xinference/model/video/model_spec.json +95 -1
  51. xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
  52. xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
  53. xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
  54. xinference/thirdparty/cosyvoice/bin/train.py +23 -3
  55. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
  56. xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
  57. xinference/thirdparty/cosyvoice/cli/model.py +53 -75
  58. xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
  59. xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
  60. xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
  61. xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
  62. xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
  63. xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
  64. xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
  65. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
  66. xinference/thirdparty/cosyvoice/utils/common.py +20 -0
  67. xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
  68. xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
  69. xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
  70. xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
  71. xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
  72. xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
  73. xinference/types.py +2 -0
  74. xinference/ui/gradio/chat_interface.py +2 -0
  75. xinference/ui/gradio/media_interface.py +353 -7
  76. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  77. xinference/ui/web/ui/build/index.html +1 -1
  78. xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
  79. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
  80. xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
  81. xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
  82. xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
  83. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
  84. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
  85. xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
  86. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
  87. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
  88. xinference/ui/web/ui/src/locales/en.json +2 -0
  89. xinference/ui/web/ui/src/locales/ja.json +2 -0
  90. xinference/ui/web/ui/src/locales/ko.json +2 -0
  91. xinference/ui/web/ui/src/locales/zh.json +2 -0
  92. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
  93. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
  94. xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
  95. xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
  96. xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
  97. xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
  98. xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
  99. xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
  100. xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
  101. xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
  102. xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
  103. xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
  104. /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
  105. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
  106. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
  107. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
  108. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
@@ -10,28 +10,18 @@
10
10
  # distributed under the License is distributed on an "AS IS" BASIS,
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
13
 
15
- import gc
16
- import importlib
17
- import importlib.util
18
14
  import logging
19
15
  import os
20
- import threading
21
- import uuid
16
+ from abc import abstractmethod
22
17
  from collections import defaultdict
23
- from collections.abc import Sequence
24
18
  from typing import Dict, List, Literal, Optional
25
19
 
26
- import numpy as np
27
- import torch
28
- import torch.nn as nn
29
-
30
- from ...device_utils import empty_cache, is_device_available
31
- from ...types import Document, DocumentObj, Rerank, RerankTokens
32
- from ..core import CacheableModelSpec, VirtualEnvSettings
20
+ from ..._compat import BaseModel
21
+ from ...types import Rerank
22
+ from ..core import VirtualEnvSettings
33
23
  from ..utils import ModelInstanceInfoMixin
34
- from .utils import preprocess_sentence
24
+ from .rerank_family import check_engine_by_model_name_and_engine, match_rerank
35
25
 
36
26
  logger = logging.getLogger(__name__)
37
27
 
@@ -48,21 +38,29 @@ def get_rerank_model_descriptions():
48
38
  return copy.deepcopy(RERANK_MODEL_DESCRIPTIONS)
49
39
 
50
40
 
51
- class RerankModelFamilyV2(CacheableModelSpec, ModelInstanceInfoMixin):
41
+ class RerankSpecV1(BaseModel):
42
+ model_format: Literal["pytorch"]
43
+ model_hub: str = "huggingface"
44
+ model_id: Optional[str] = None
45
+ model_revision: Optional[str] = None
46
+ model_uri: Optional[str] = None
47
+ quantization: str = "none"
48
+
49
+
50
+ class RerankModelFamilyV2(BaseModel, ModelInstanceInfoMixin):
52
51
  version: Literal[2]
53
52
  model_name: str
53
+ model_specs: List[RerankSpecV1]
54
54
  language: List[str]
55
55
  type: Optional[str] = "unknown"
56
56
  max_tokens: Optional[int]
57
- model_id: str
58
- model_revision: Optional[str]
59
- model_hub: str = "huggingface"
60
57
  virtualenv: Optional[VirtualEnvSettings]
61
58
 
62
59
  class Config:
63
60
  extra = "allow"
64
61
 
65
62
  def to_description(self):
63
+ spec = self.model_specs[0]
66
64
  return {
67
65
  "model_type": "rerank",
68
66
  "address": getattr(self, "address", None),
@@ -70,13 +68,13 @@ class RerankModelFamilyV2(CacheableModelSpec, ModelInstanceInfoMixin):
70
68
  "type": self.type,
71
69
  "model_name": self.model_name,
72
70
  "language": self.language,
73
- "model_revision": self.model_revision,
71
+ "model_revision": spec.model_revision,
74
72
  }
75
73
 
76
74
  def to_version_info(self):
77
- from ..cache_manager import CacheManager
75
+ from .cache_manager import RerankCacheManager
78
76
 
79
- cache_manager = CacheManager(self)
77
+ cache_manager = RerankCacheManager(self)
80
78
  return {
81
79
  "model_version": self.model_name,
82
80
  "model_file_location": cache_manager.get_cache_dir(),
@@ -93,56 +91,59 @@ def generate_rerank_description(
93
91
  return res
94
92
 
95
93
 
96
- class _ModelWrapper(nn.Module):
97
- def __init__(self, module: nn.Module):
98
- super().__init__()
99
- self.model = module
100
- self._local_data = threading.local()
101
-
102
- @property
103
- def n_tokens(self):
104
- return getattr(self._local_data, "n_tokens", 0)
105
-
106
- @n_tokens.setter
107
- def n_tokens(self, value):
108
- self._local_data.n_tokens = value
109
-
110
- def forward(self, **kwargs):
111
- attention_mask = kwargs.get("attention_mask")
112
- # when batching, the attention mask 1 means there is a token
113
- # thus we just sum up it to get the total number of tokens
114
- if attention_mask is not None:
115
- self.n_tokens += attention_mask.sum().item()
116
- return self.model(**kwargs)
117
-
118
- def __getattr__(self, attr):
119
- try:
120
- return super().__getattr__(attr)
121
- except AttributeError:
122
- return getattr(self.model, attr)
123
-
124
-
125
94
  class RerankModel:
126
95
  def __init__(
127
96
  self,
128
- model_spec: RerankModelFamilyV2,
129
97
  model_uid: str,
130
- model_path: Optional[str] = None,
98
+ model_path: str,
99
+ model_family: RerankModelFamilyV2,
100
+ quantization: Optional[str],
101
+ *,
131
102
  device: Optional[str] = None,
132
103
  use_fp16: bool = False,
133
- model_config: Optional[Dict] = None,
104
+ **kwargs,
134
105
  ):
135
- self.model_family = model_spec
136
- self._model_spec = model_spec
106
+ self.model_family = model_family
107
+ self._model_spec = model_family.model_specs[0]
137
108
  self._model_uid = model_uid
138
109
  self._model_path = model_path
110
+ self._quantization = quantization
139
111
  self._device = device
140
- self._model_config = model_config or dict()
141
112
  self._use_fp16 = use_fp16
142
113
  self._model = None
143
114
  self._counter = 0
144
- if model_spec.type == "unknown":
145
- model_spec.type = self._auto_detect_type(model_path)
115
+ self._kwargs = kwargs
116
+ if model_family.type == "unknown":
117
+ model_family.type = self._auto_detect_type(model_path)
118
+
119
+ @classmethod
120
+ @abstractmethod
121
+ def check_lib(cls) -> bool:
122
+ pass
123
+
124
+ @classmethod
125
+ @abstractmethod
126
+ def match_json(
127
+ cls,
128
+ model_family: RerankModelFamilyV2,
129
+ model_spec: RerankSpecV1,
130
+ quantization: str,
131
+ ) -> bool:
132
+ pass
133
+
134
+ @classmethod
135
+ def match(
136
+ cls,
137
+ model_family: RerankModelFamilyV2,
138
+ model_spec: RerankSpecV1,
139
+ quantization: str,
140
+ ):
141
+ """
142
+ Return if the model_spec can be matched.
143
+ """
144
+ if not cls.check_lib():
145
+ return False
146
+ return cls.match_json(model_family, model_spec, quantization)
146
147
 
147
148
  @staticmethod
148
149
  def _get_tokenizer(model_path):
@@ -171,145 +172,10 @@ class RerankModel:
171
172
  return "normal"
172
173
  return rerank_type
173
174
 
174
- def load(self):
175
- logger.info("Loading rerank model: %s", self._model_path)
176
- flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
177
- if (
178
- self._auto_detect_type(self._model_path) != "normal"
179
- and flash_attn_installed
180
- ):
181
- logger.warning(
182
- "flash_attn can only support fp16 and bf16, "
183
- "will force set `use_fp16` to True"
184
- )
185
- self._use_fp16 = True
186
-
187
- if (
188
- self._model_spec.type == "normal"
189
- and "qwen3" not in self._model_spec.model_name.lower()
190
- ):
191
- try:
192
- import sentence_transformers
193
- from sentence_transformers.cross_encoder import CrossEncoder
194
-
195
- if sentence_transformers.__version__ < "3.1.0":
196
- raise ValueError(
197
- "The sentence_transformers version must be greater than 3.1.0. "
198
- "Please upgrade your version via `pip install -U sentence_transformers` or refer to "
199
- "https://github.com/UKPLab/sentence-transformers"
200
- )
201
- except ImportError:
202
- error_message = "Failed to import module 'sentence-transformers'"
203
- installation_guide = [
204
- "Please make sure 'sentence-transformers' is installed. ",
205
- "You can install it by `pip install sentence-transformers`\n",
206
- ]
207
-
208
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
209
- self._model = CrossEncoder(
210
- self._model_path,
211
- device=self._device,
212
- trust_remote_code=True,
213
- max_length=getattr(self._model_spec, "max_tokens"),
214
- **self._model_config,
215
- )
216
- if self._use_fp16:
217
- self._model.model.half()
218
- elif "qwen3" in self._model_spec.model_name.lower():
219
- # qwen3-reranker
220
- # now we use transformers
221
- # TODO: support engines for rerank models
222
- try:
223
- from transformers import AutoModelForCausalLM, AutoTokenizer
224
- except ImportError:
225
- error_message = "Failed to import module 'transformers'"
226
- installation_guide = [
227
- "Please make sure 'transformers' is installed. ",
228
- "You can install it by `pip install transformers`\n",
229
- ]
230
-
231
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
232
-
233
- tokenizer = AutoTokenizer.from_pretrained(
234
- self._model_path, padding_side="left"
235
- )
236
- enable_flash_attn = self._model_config.pop(
237
- "enable_flash_attn", is_device_available("cuda")
238
- )
239
- model_kwargs = {"device_map": "auto"}
240
- if flash_attn_installed and enable_flash_attn:
241
- model_kwargs["attn_implementation"] = "flash_attention_2"
242
- model_kwargs["torch_dtype"] = torch.float16
243
- model_kwargs.update(self._model_config)
244
- logger.debug("Loading qwen3 rerank with kwargs %s", model_kwargs)
245
- model = self._model = AutoModelForCausalLM.from_pretrained(
246
- self._model_path, **model_kwargs
247
- ).eval()
248
- max_length = getattr(self._model_spec, "max_tokens")
249
-
250
- prefix = (
251
- "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query "
252
- 'and the Instruct provided. Note that the answer can only be "yes" or "no".'
253
- "<|im_end|>\n<|im_start|>user\n"
254
- )
255
- suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
256
- prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
257
- suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
258
-
259
- def process_inputs(pairs):
260
- inputs = tokenizer(
261
- pairs,
262
- padding=False,
263
- truncation="longest_first",
264
- return_attention_mask=False,
265
- max_length=max_length - len(prefix_tokens) - len(suffix_tokens),
266
- )
267
- for i, ele in enumerate(inputs["input_ids"]):
268
- inputs["input_ids"][i] = prefix_tokens + ele + suffix_tokens
269
- inputs = tokenizer.pad(
270
- inputs, padding=True, return_tensors="pt", max_length=max_length
271
- )
272
- for key in inputs:
273
- inputs[key] = inputs[key].to(model.device)
274
- return inputs
275
-
276
- token_false_id = tokenizer.convert_tokens_to_ids("no")
277
- token_true_id = tokenizer.convert_tokens_to_ids("yes")
278
-
279
- @torch.no_grad()
280
- def compute_logits(inputs, **kwargs):
281
- batch_scores = model(**inputs).logits[:, -1, :]
282
- true_vector = batch_scores[:, token_true_id]
283
- false_vector = batch_scores[:, token_false_id]
284
- batch_scores = torch.stack([false_vector, true_vector], dim=1)
285
- batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
286
- scores = batch_scores[:, 1].exp().tolist()
287
- return scores
288
-
289
- self.process_inputs = process_inputs
290
- self.compute_logits = compute_logits
291
- else:
292
- try:
293
- if self._model_spec.type == "LLM-based":
294
- from FlagEmbedding import FlagLLMReranker as FlagReranker
295
- elif self._model_spec.type == "LLM-based layerwise":
296
- from FlagEmbedding import LayerWiseFlagLLMReranker as FlagReranker
297
- else:
298
- raise RuntimeError(
299
- f"Unsupported Rank model type: {self._model_spec.type}"
300
- )
301
- except ImportError:
302
- error_message = "Failed to import module 'FlagEmbedding'"
303
- installation_guide = [
304
- "Please make sure 'FlagEmbedding' is installed. ",
305
- "You can install it by `pip install FlagEmbedding`\n",
306
- ]
307
-
308
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
309
- self._model = FlagReranker(self._model_path, use_fp16=self._use_fp16)
310
- # Wrap transformers model to record number of tokens
311
- self._model.model = _ModelWrapper(self._model.model)
175
+ @abstractmethod
176
+ def load(self): ...
312
177
 
178
+ @abstractmethod
313
179
  def rerank(
314
180
  self,
315
181
  documents: List[str],
@@ -319,159 +185,41 @@ class RerankModel:
319
185
  return_documents: Optional[bool],
320
186
  return_len: Optional[bool],
321
187
  **kwargs,
322
- ) -> Rerank:
323
- assert self._model is not None
324
- if max_chunks_per_doc is not None:
325
- raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
326
- logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model)
327
-
328
- pre_query = preprocess_sentence(
329
- query, kwargs.get("instruction", None), self._model_spec.model_name
330
- )
331
- sentence_combinations = [[pre_query, doc] for doc in documents]
332
- # reset n tokens
333
- self._model.model.n_tokens = 0
334
- if (
335
- self._model_spec.type == "normal"
336
- and "qwen3" not in self._model_spec.model_name.lower()
337
- ):
338
- logger.debug("Passing processed sentences: %s", sentence_combinations)
339
- similarity_scores = self._model.predict(
340
- sentence_combinations,
341
- convert_to_numpy=False,
342
- convert_to_tensor=True,
343
- **kwargs,
344
- ).cpu()
345
- if similarity_scores.dtype == torch.bfloat16:
346
- similarity_scores = similarity_scores.float()
347
- elif "qwen3" in self._model_spec.model_name.lower():
348
-
349
- def format_instruction(instruction, query, doc):
350
- if instruction is None:
351
- instruction = "Given a web search query, retrieve relevant passages that answer the query"
352
- output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
353
- instruction=instruction, query=query, doc=doc
354
- )
355
- return output
356
-
357
- # reduce memory usage.
358
- micro_bs = 4
359
- similarity_scores = []
360
- for i in range(0, len(documents), micro_bs):
361
- sub_docs = documents[i : i + micro_bs]
362
- pairs = [
363
- format_instruction(kwargs.get("instruction", None), query, doc)
364
- for doc in sub_docs
365
- ]
366
- # Tokenize the input texts
367
- inputs = self.process_inputs(pairs)
368
- similarity_scores.extend(self.compute_logits(inputs))
369
- else:
370
- # Related issue: https://github.com/xorbitsai/inference/issues/1775
371
- similarity_scores = self._model.compute_score(
372
- sentence_combinations, **kwargs
373
- )
374
-
375
- if not isinstance(similarity_scores, Sequence):
376
- similarity_scores = [similarity_scores]
377
- elif (
378
- isinstance(similarity_scores, list)
379
- and len(similarity_scores) > 0
380
- and isinstance(similarity_scores[0], Sequence)
381
- ):
382
- similarity_scores = similarity_scores[0]
383
-
384
- sim_scores_argsort = list(reversed(np.argsort(similarity_scores)))
385
- if top_n is not None:
386
- sim_scores_argsort = sim_scores_argsort[:top_n]
387
- if return_documents:
388
- docs = [
389
- DocumentObj(
390
- index=int(arg),
391
- relevance_score=float(similarity_scores[arg]),
392
- document=Document(text=documents[arg]),
393
- )
394
- for arg in sim_scores_argsort
395
- ]
396
- else:
397
- docs = [
398
- DocumentObj(
399
- index=int(arg),
400
- relevance_score=float(similarity_scores[arg]),
401
- document=None,
402
- )
403
- for arg in sim_scores_argsort
404
- ]
405
- if return_len:
406
- input_len = self._model.model.n_tokens
407
- # Rerank Model output is just score or documents
408
- # while return_documents = True
409
- output_len = input_len
410
-
411
- # api_version, billed_units, warnings
412
- # is for Cohere API compatibility, set to None
413
- metadata = {
414
- "api_version": None,
415
- "billed_units": None,
416
- "tokens": (
417
- RerankTokens(input_tokens=input_len, output_tokens=output_len)
418
- if return_len
419
- else None
420
- ),
421
- "warnings": None,
422
- }
423
-
424
- del similarity_scores
425
- # clear cache if possible
426
- self._counter += 1
427
- if self._counter % RERANK_EMPTY_CACHE_COUNT == 0:
428
- logger.debug("Empty rerank cache.")
429
- gc.collect()
430
- empty_cache()
431
-
432
- return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)
188
+ ) -> Rerank: ...
433
189
 
434
190
 
435
191
  def create_rerank_model_instance(
436
192
  model_uid: str,
437
193
  model_name: str,
194
+ model_engine: Optional[str],
195
+ model_format: Optional[str] = None,
196
+ quantization: Optional[str] = None,
438
197
  download_hub: Optional[
439
198
  Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
440
199
  ] = None,
441
200
  model_path: Optional[str] = None,
442
201
  **kwargs,
443
202
  ) -> RerankModel:
444
- from ..cache_manager import CacheManager
445
- from ..utils import download_from_modelscope
446
- from . import BUILTIN_RERANK_MODELS
447
- from .custom import get_user_defined_reranks
448
-
449
- model_spec = None
450
- for ud_spec in get_user_defined_reranks():
451
- if ud_spec.model_name == model_name:
452
- model_spec = ud_spec
453
- break
454
-
455
- if model_spec is None:
456
- if model_name in BUILTIN_RERANK_MODELS:
457
- model_specs = BUILTIN_RERANK_MODELS[model_name]
458
- if download_hub == "modelscope" or download_from_modelscope():
459
- model_spec = (
460
- [x for x in model_specs if x.model_hub == "modelscope"]
461
- + [x for x in model_specs if x.model_hub == "huggingface"]
462
- )[0]
463
- else:
464
- model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0]
465
- else:
466
- raise ValueError(
467
- f"Rerank model {model_name} not found, available "
468
- f"model: {BUILTIN_RERANK_MODELS.keys()}"
469
- )
470
- if not model_path:
471
- cache_manager = CacheManager(model_spec)
203
+ from .cache_manager import RerankCacheManager
204
+
205
+ model_family = match_rerank(model_name, model_format, quantization, download_hub)
206
+ if model_path is None:
207
+ cache_manager = RerankCacheManager(model_family)
472
208
  model_path = cache_manager.cache()
473
- use_fp16 = kwargs.pop("use_fp16", False)
474
- model = RerankModel(
475
- model_spec, model_uid, model_path, use_fp16=use_fp16, model_config=kwargs
209
+
210
+ if model_engine is None:
211
+ # unlike LLM and for compatibility,
212
+ # we use sentence_transformers as the default engine for all models
213
+ model_engine = "sentence_transformers"
214
+
215
+ rerank_cls = check_engine_by_model_name_and_engine(
216
+ model_engine, model_name, model_format, quantization
217
+ )
218
+ model = rerank_cls(
219
+ model_uid,
220
+ model_path,
221
+ model_family,
222
+ quantization,
223
+ **kwargs,
476
224
  )
477
225
  return model
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import logging
15
- from typing import List, Literal, Optional
15
+ from typing import List, Literal
16
16
 
17
17
  from ..custom import ModelRegistry
18
18
  from .core import RerankModelFamilyV2
@@ -22,10 +22,6 @@ logger = logging.getLogger(__name__)
22
22
 
23
23
  class CustomRerankModelFamilyV2(RerankModelFamilyV2):
24
24
  version: Literal[2] = 2
25
- model_id: Optional[str] # type: ignore
26
- model_revision: Optional[str] # type: ignore
27
- model_uri: Optional[str]
28
- model_type: Literal["rerank"] = "rerank" # for frontend
29
25
 
30
26
 
31
27
  UD_RERANKS: List[CustomRerankModelFamilyV2] = []
@@ -35,12 +31,41 @@ class RerankModelRegistry(ModelRegistry):
35
31
  model_type = "rerank"
36
32
 
37
33
  def __init__(self):
38
- from . import BUILTIN_RERANK_MODELS
34
+ from .rerank_family import BUILTIN_RERANK_MODELS
39
35
 
40
36
  super().__init__()
41
37
  self.models = UD_RERANKS
42
38
  self.builtin_models = list(BUILTIN_RERANK_MODELS.keys())
43
39
 
40
+ def add_ud_model(self, model_spec):
41
+ from . import generate_engine_config_by_model_name
42
+
43
+ UD_RERANKS.append(model_spec)
44
+ generate_engine_config_by_model_name(model_spec)
45
+
46
+ def check_model_uri(self, model_family: "RerankModelFamilyV2"):
47
+ from ..utils import is_valid_model_uri
48
+
49
+ for spec in model_family.model_specs:
50
+ model_uri = spec.model_uri
51
+ if model_uri and not is_valid_model_uri(model_uri):
52
+ raise ValueError(f"Invalid model URI {model_uri}.")
53
+
54
+ def remove_ud_model(self, model_family: "CustomRerankModelFamilyV2"):
55
+ from .rerank_family import RERANK_ENGINES
56
+
57
+ UD_RERANKS.remove(model_family)
58
+ del RERANK_ENGINES[model_family.model_name]
59
+
60
+ def remove_ud_model_files(self, model_family: "CustomRerankModelFamilyV2"):
61
+ from .cache_manager import RerankCacheManager
62
+
63
+ _model_family = model_family.copy()
64
+ for spec in model_family.model_specs:
65
+ _model_family.model_specs = [spec]
66
+ cache_manager = RerankCacheManager(_model_family)
67
+ cache_manager.unregister_custom_model(self.model_type)
68
+
44
69
 
45
70
  def get_user_defined_reranks() -> List[CustomRerankModelFamilyV2]:
46
71
  from ..custom import RegistryManager
@@ -49,11 +74,11 @@ def get_user_defined_reranks() -> List[CustomRerankModelFamilyV2]:
49
74
  return registry.get_custom_models()
50
75
 
51
76
 
52
- def register_rerank(model_spec: CustomRerankModelFamilyV2, persist: bool):
77
+ def register_rerank(model_family: CustomRerankModelFamilyV2, persist: bool):
53
78
  from ..custom import RegistryManager
54
79
 
55
80
  registry = RegistryManager.get_registry("rerank")
56
- registry.register(model_spec, persist)
81
+ registry.register(model_family, persist)
57
82
 
58
83
 
59
84
  def unregister_rerank(model_name: str, raise_error: bool = True):