xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (108) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +2 -1
  3. xinference/core/model.py +8 -4
  4. xinference/core/supervisor.py +2 -3
  5. xinference/core/worker.py +7 -5
  6. xinference/deploy/cmdline.py +2 -0
  7. xinference/deploy/local.py +5 -0
  8. xinference/deploy/test/test_cmdline.py +1 -1
  9. xinference/deploy/worker.py +6 -0
  10. xinference/model/audio/cosyvoice.py +0 -1
  11. xinference/model/audio/model_spec.json +44 -20
  12. xinference/model/core.py +3 -0
  13. xinference/model/embedding/flag/core.py +5 -0
  14. xinference/model/embedding/llama_cpp/core.py +22 -19
  15. xinference/model/embedding/sentence_transformers/core.py +18 -4
  16. xinference/model/embedding/vllm/core.py +36 -9
  17. xinference/model/image/cache_manager.py +56 -0
  18. xinference/model/image/core.py +9 -0
  19. xinference/model/image/model_spec.json +178 -1
  20. xinference/model/image/stable_diffusion/core.py +155 -23
  21. xinference/model/llm/cache_manager.py +17 -3
  22. xinference/model/llm/harmony.py +245 -0
  23. xinference/model/llm/llama_cpp/core.py +41 -40
  24. xinference/model/llm/llm_family.json +688 -11
  25. xinference/model/llm/llm_family.py +1 -1
  26. xinference/model/llm/sglang/core.py +108 -5
  27. xinference/model/llm/transformers/core.py +20 -18
  28. xinference/model/llm/transformers/gemma3.py +1 -1
  29. xinference/model/llm/transformers/gpt_oss.py +91 -0
  30. xinference/model/llm/transformers/multimodal/core.py +1 -1
  31. xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
  32. xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
  33. xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
  34. xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
  35. xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
  36. xinference/model/llm/transformers/utils.py +1 -33
  37. xinference/model/llm/utils.py +61 -7
  38. xinference/model/llm/vllm/core.py +44 -8
  39. xinference/model/rerank/__init__.py +66 -23
  40. xinference/model/rerank/cache_manager.py +35 -0
  41. xinference/model/rerank/core.py +87 -339
  42. xinference/model/rerank/custom.py +33 -8
  43. xinference/model/rerank/model_spec.json +251 -212
  44. xinference/model/rerank/rerank_family.py +137 -0
  45. xinference/model/rerank/sentence_transformers/__init__.py +13 -0
  46. xinference/model/rerank/sentence_transformers/core.py +337 -0
  47. xinference/model/rerank/vllm/__init__.py +13 -0
  48. xinference/model/rerank/vllm/core.py +156 -0
  49. xinference/model/utils.py +108 -0
  50. xinference/model/video/model_spec.json +95 -1
  51. xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
  52. xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
  53. xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
  54. xinference/thirdparty/cosyvoice/bin/train.py +23 -3
  55. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
  56. xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
  57. xinference/thirdparty/cosyvoice/cli/model.py +53 -75
  58. xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
  59. xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
  60. xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
  61. xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
  62. xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
  63. xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
  64. xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
  65. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
  66. xinference/thirdparty/cosyvoice/utils/common.py +20 -0
  67. xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
  68. xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
  69. xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
  70. xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
  71. xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
  72. xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
  73. xinference/types.py +2 -0
  74. xinference/ui/gradio/chat_interface.py +2 -0
  75. xinference/ui/gradio/media_interface.py +353 -7
  76. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  77. xinference/ui/web/ui/build/index.html +1 -1
  78. xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
  79. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
  80. xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
  81. xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
  82. xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
  83. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
  84. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
  85. xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
  86. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
  87. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
  88. xinference/ui/web/ui/src/locales/en.json +2 -0
  89. xinference/ui/web/ui/src/locales/ja.json +2 -0
  90. xinference/ui/web/ui/src/locales/ko.json +2 -0
  91. xinference/ui/web/ui/src/locales/zh.json +2 -0
  92. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
  93. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
  94. xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
  95. xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
  96. xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
  97. xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
  98. xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
  99. xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
  100. xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
  101. xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
  102. xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
  103. xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
  104. /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
  105. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
  106. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
  107. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
  108. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,337 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import gc
15
+ import importlib.util
16
+ import logging
17
+ import threading
18
+ import uuid
19
+ from typing import List, Optional, Sequence
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from ....device_utils import empty_cache
26
+ from ....types import Document, DocumentObj, Meta, Rerank, RerankTokens
27
+ from ...utils import is_flash_attn_available
28
+ from ..core import (
29
+ RERANK_EMPTY_CACHE_COUNT,
30
+ RerankModel,
31
+ RerankModelFamilyV2,
32
+ RerankSpecV1,
33
+ )
34
+ from ..utils import preprocess_sentence
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class _ModelWrapper(nn.Module):
40
+ def __init__(self, module: nn.Module):
41
+ super().__init__()
42
+ self.model = module
43
+ self._local_data = threading.local()
44
+
45
+ @property
46
+ def n_tokens(self):
47
+ return getattr(self._local_data, "n_tokens", 0)
48
+
49
+ @n_tokens.setter
50
+ def n_tokens(self, value):
51
+ self._local_data.n_tokens = value
52
+
53
+ def forward(self, **kwargs):
54
+ attention_mask = kwargs.get("attention_mask")
55
+ # when batching, the attention mask 1 means there is a token
56
+ # thus we just sum up it to get the total number of tokens
57
+ if attention_mask is not None:
58
+ self.n_tokens += attention_mask.sum().item()
59
+ return self.model(**kwargs)
60
+
61
+ def __getattr__(self, attr):
62
+ try:
63
+ return super().__getattr__(attr)
64
+ except AttributeError:
65
+ return getattr(self.model, attr)
66
+
67
+
68
+ class SentenceTransformerRerankModel(RerankModel):
69
+ def load(self):
70
+ # TODO: Split FlagReranker and sentence_transformers into different model_engines like FlagRerankModel
71
+ logger.info("Loading rerank model: %s", self._model_path)
72
+ enable_flash_attn = self._kwargs.pop(
73
+ "enable_flash_attn", is_flash_attn_available()
74
+ )
75
+ if enable_flash_attn:
76
+ logger.warning(
77
+ "flash_attn can only support fp16 and bf16, will force set `use_fp16` to True"
78
+ )
79
+ self._use_fp16 = True
80
+
81
+ if (
82
+ self.model_family.type == "normal"
83
+ and "qwen3" not in self.model_family.model_name.lower()
84
+ ):
85
+ try:
86
+ import sentence_transformers
87
+ from sentence_transformers.cross_encoder import CrossEncoder
88
+
89
+ if sentence_transformers.__version__ < "3.1.0":
90
+ raise ValueError(
91
+ "The sentence_transformers version must be greater than 3.1.0. "
92
+ "Please upgrade your version via `pip install -U sentence_transformers` or refer to "
93
+ "https://github.com/UKPLab/sentence-transformers"
94
+ )
95
+ except ImportError:
96
+ error_message = "Failed to import module 'sentence-transformers'"
97
+ installation_guide = [
98
+ "Please make sure 'sentence-transformers' is installed. ",
99
+ "You can install it by `pip install sentence-transformers`\n",
100
+ ]
101
+
102
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
103
+ self._model = CrossEncoder(
104
+ self._model_path,
105
+ device=self._device,
106
+ trust_remote_code=True,
107
+ max_length=getattr(self.model_family, "max_tokens"),
108
+ **self._kwargs,
109
+ )
110
+ if self._use_fp16:
111
+ self._model.model.half()
112
+ elif "qwen3" in self.model_family.model_name.lower():
113
+ # qwen3-reranker
114
+ # now we use transformers
115
+ # TODO: support engines for rerank models
116
+ try:
117
+ from transformers import AutoModelForCausalLM, AutoTokenizer
118
+ except ImportError:
119
+ error_message = "Failed to import module 'transformers'"
120
+ installation_guide = [
121
+ "Please make sure 'transformers' is installed. ",
122
+ "You can install it by `pip install transformers`\n",
123
+ ]
124
+
125
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
126
+
127
+ tokenizer = AutoTokenizer.from_pretrained(
128
+ self._model_path, padding_side="left"
129
+ )
130
+ model_kwargs = {"device_map": "auto"}
131
+ if enable_flash_attn:
132
+ model_kwargs["attn_implementation"] = "flash_attention_2"
133
+ model_kwargs["torch_dtype"] = torch.float16
134
+ model_kwargs.update(self._kwargs)
135
+ logger.debug("Loading qwen3 rerank with kwargs %s", model_kwargs)
136
+ model = self._model = AutoModelForCausalLM.from_pretrained(
137
+ self._model_path, **model_kwargs
138
+ ).eval()
139
+ max_length = getattr(self.model_family, "max_tokens")
140
+
141
+ prefix = (
142
+ "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query "
143
+ 'and the Instruct provided. Note that the answer can only be "yes" or "no".'
144
+ "<|im_end|>\n<|im_start|>user\n"
145
+ )
146
+ suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
147
+ prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
148
+ suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
149
+
150
+ def process_inputs(pairs):
151
+ inputs = tokenizer(
152
+ pairs,
153
+ padding=False,
154
+ truncation="longest_first",
155
+ return_attention_mask=False,
156
+ max_length=max_length - len(prefix_tokens) - len(suffix_tokens),
157
+ )
158
+ for i, ele in enumerate(inputs["input_ids"]):
159
+ inputs["input_ids"][i] = prefix_tokens + ele + suffix_tokens
160
+ inputs = tokenizer.pad(
161
+ inputs, padding=True, return_tensors="pt", max_length=max_length
162
+ )
163
+ for key in inputs:
164
+ inputs[key] = inputs[key].to(model.device)
165
+ return inputs
166
+
167
+ token_false_id = tokenizer.convert_tokens_to_ids("no")
168
+ token_true_id = tokenizer.convert_tokens_to_ids("yes")
169
+
170
+ @torch.inference_mode()
171
+ def compute_logits(inputs, **kwargs):
172
+ batch_scores = model(**inputs).logits[:, -1, :]
173
+ true_vector = batch_scores[:, token_true_id]
174
+ false_vector = batch_scores[:, token_false_id]
175
+ batch_scores = torch.stack([false_vector, true_vector], dim=1)
176
+ batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
177
+ scores = batch_scores[:, 1].exp().tolist()
178
+ return scores
179
+
180
+ self.process_inputs = process_inputs
181
+ self.compute_logits = compute_logits
182
+ else:
183
+ try:
184
+ if self.model_family.type == "LLM-based":
185
+ from FlagEmbedding import FlagLLMReranker as FlagReranker
186
+ elif self.model_family.type == "LLM-based layerwise":
187
+ from FlagEmbedding import LayerWiseFlagLLMReranker as FlagReranker
188
+ else:
189
+ raise RuntimeError(
190
+ f"Unsupported Rank model type: {self.model_family.type}"
191
+ )
192
+ except ImportError:
193
+ error_message = "Failed to import module 'FlagEmbedding'"
194
+ installation_guide = [
195
+ "Please make sure 'FlagEmbedding' is installed. ",
196
+ "You can install it by `pip install FlagEmbedding`\n",
197
+ ]
198
+
199
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
200
+ self._model = FlagReranker(self._model_path, use_fp16=self._use_fp16)
201
+ # Wrap transformers model to record number of tokens
202
+ self._model.model = _ModelWrapper(self._model.model)
203
+
204
+ def rerank(
205
+ self,
206
+ documents: List[str],
207
+ query: str,
208
+ top_n: Optional[int],
209
+ max_chunks_per_doc: Optional[int],
210
+ return_documents: Optional[bool],
211
+ return_len: Optional[bool],
212
+ **kwargs,
213
+ ) -> Rerank:
214
+ assert self._model is not None
215
+ if max_chunks_per_doc is not None:
216
+ raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
217
+ logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model)
218
+
219
+ pre_query = preprocess_sentence(
220
+ query, kwargs.get("instruction", None), self.model_family.model_name
221
+ )
222
+ sentence_combinations = [[pre_query, doc] for doc in documents]
223
+ # reset n tokens
224
+ self._model.model.n_tokens = 0
225
+ if (
226
+ self.model_family.type == "normal"
227
+ and "qwen3" not in self.model_family.model_name.lower()
228
+ ):
229
+ logger.debug("Passing processed sentences: %s", sentence_combinations)
230
+ similarity_scores = self._model.predict(
231
+ sentence_combinations,
232
+ convert_to_numpy=False,
233
+ convert_to_tensor=True,
234
+ **kwargs,
235
+ ).cpu()
236
+ if similarity_scores.dtype == torch.bfloat16:
237
+ similarity_scores = similarity_scores.float()
238
+ elif "qwen3" in self.model_family.model_name.lower():
239
+
240
+ def format_instruction(instruction, query, doc):
241
+ if instruction is None:
242
+ instruction = "Given a web search query, retrieve relevant passages that answer the query"
243
+ output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
244
+ instruction=instruction, query=query, doc=doc
245
+ )
246
+ return output
247
+
248
+ # reduce memory usage.
249
+ micro_bs = 4
250
+ similarity_scores = []
251
+ for i in range(0, len(documents), micro_bs):
252
+ sub_docs = documents[i : i + micro_bs]
253
+ pairs = [
254
+ format_instruction(kwargs.get("instruction", None), query, doc)
255
+ for doc in sub_docs
256
+ ]
257
+ # Tokenize the input texts
258
+ inputs = self.process_inputs(pairs)
259
+ similarity_scores.extend(self.compute_logits(inputs))
260
+ else:
261
+ # Related issue: https://github.com/xorbitsai/inference/issues/1775
262
+ similarity_scores = self._model.compute_score(
263
+ sentence_combinations, **kwargs
264
+ )
265
+
266
+ if not isinstance(similarity_scores, Sequence):
267
+ similarity_scores = [similarity_scores]
268
+ elif (
269
+ isinstance(similarity_scores, list)
270
+ and len(similarity_scores) > 0
271
+ and isinstance(similarity_scores[0], Sequence)
272
+ ):
273
+ similarity_scores = similarity_scores[0]
274
+
275
+ sim_scores_argsort = list(reversed(np.argsort(similarity_scores)))
276
+ if top_n is not None:
277
+ sim_scores_argsort = sim_scores_argsort[:top_n]
278
+ if return_documents:
279
+ docs = [
280
+ DocumentObj(
281
+ index=int(arg),
282
+ relevance_score=float(similarity_scores[arg]),
283
+ document=Document(text=documents[arg]),
284
+ )
285
+ for arg in sim_scores_argsort
286
+ ]
287
+ else:
288
+ docs = [
289
+ DocumentObj(
290
+ index=int(arg),
291
+ relevance_score=float(similarity_scores[arg]),
292
+ document=None,
293
+ )
294
+ for arg in sim_scores_argsort
295
+ ]
296
+ if return_len:
297
+ input_len = self._model.model.n_tokens
298
+ # Rerank Model output is just score or documents
299
+ # while return_documents = True
300
+ output_len = input_len
301
+
302
+ # api_version, billed_units, warnings
303
+ # is for Cohere API compatibility, set to None
304
+ metadata = Meta(
305
+ api_version=None,
306
+ billed_units=None,
307
+ tokens=(
308
+ RerankTokens(input_tokens=input_len, output_tokens=output_len)
309
+ if return_len
310
+ else None
311
+ ),
312
+ warnings=None,
313
+ )
314
+
315
+ del similarity_scores
316
+ # clear cache if possible
317
+ self._counter += 1
318
+ if self._counter % RERANK_EMPTY_CACHE_COUNT == 0:
319
+ logger.debug("Empty rerank cache.")
320
+ gc.collect()
321
+ empty_cache()
322
+
323
+ return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)
324
+
325
+ @classmethod
326
+ def check_lib(cls) -> bool:
327
+ return importlib.util.find_spec("sentence_transformers") is not None
328
+
329
+ @classmethod
330
+ def match_json(
331
+ cls,
332
+ model_family: RerankModelFamilyV2,
333
+ model_spec: RerankSpecV1,
334
+ quantization: str,
335
+ ) -> bool:
336
+ # As default embedding engine, sentence-transformer support all models
337
+ return model_spec.model_format in ["pytorch"]
@@ -0,0 +1,13 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,156 @@
1
+ import importlib.util
2
+ import uuid
3
+ from typing import List, Optional
4
+
5
+ from ....types import Document, DocumentObj, Meta, Rerank, RerankTokens
6
+ from ...utils import cache_clean
7
+ from ..core import RerankModel, RerankModelFamilyV2, RerankSpecV1
8
+
9
+ SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "Qwen3"]
10
+
11
+
12
+ class VLLMRerankModel(RerankModel):
13
+ def load(self):
14
+ try:
15
+ from vllm import LLM
16
+
17
+ except ImportError:
18
+ error_message = "Failed to import module 'vllm'"
19
+ installation_guide = [
20
+ "Please make sure 'vllm' is installed. ",
21
+ "You can install it by `pip install vllm`\n",
22
+ ]
23
+
24
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
25
+
26
+ if self.model_family.model_name in {
27
+ "Qwen3-Reranker-0.6B",
28
+ "Qwen3-Reranker-4B",
29
+ "Qwen3-Reranker-8B",
30
+ }:
31
+ if "hf_overrides" not in self._kwargs:
32
+ self._kwargs["hf_overrides"] = {
33
+ "architectures": ["Qwen3ForSequenceClassification"],
34
+ "classifier_from_token": ["no", "yes"],
35
+ "is_original_qwen3_reranker": True,
36
+ }
37
+ elif isinstance(self._kwargs["hf_overrides"], dict):
38
+ self._kwargs["hf_overrides"].update(
39
+ architectures=["Qwen3ForSequenceClassification"],
40
+ classifier_from_token=["no", "yes"],
41
+ is_original_qwen3_reranker=True,
42
+ )
43
+ self._model = LLM(model=self._model_path, task="score", **self._kwargs)
44
+ self._tokenizer = self._model.get_tokenizer()
45
+
46
+ @cache_clean
47
+ def rerank(
48
+ self,
49
+ documents: List[str],
50
+ query: str,
51
+ top_n: Optional[int],
52
+ max_chunks_per_doc: Optional[int],
53
+ return_documents: Optional[bool],
54
+ return_len: Optional[bool],
55
+ **kwargs,
56
+ ) -> Rerank:
57
+ """
58
+ Rerank the documents based on the query using the VLLM model.
59
+
60
+ Args:
61
+ documents (List[str]): List of documents to be reranked.
62
+ query (str): The query string to rank the documents against.
63
+ top_n (Optional[int]): The number of top documents to return.
64
+ max_chunks_per_doc (Optional[int]): Maximum chunks per document.
65
+ return_documents (Optional[bool]): Whether to return the documents.
66
+ return_len (Optional[bool]): Whether to return the length of the documents.
67
+
68
+ Returns:
69
+ Rerank: The reranked results.
70
+ """
71
+ if kwargs:
72
+ raise RuntimeError("Unexpected keyword arguments: {}".format(kwargs))
73
+ assert self._model is not None
74
+ documents_size = len(documents)
75
+ query_list = [query] * documents_size
76
+
77
+ if self.model_family.model_name in {
78
+ "Qwen3-Reranker-0.6B",
79
+ "Qwen3-Reranker-4B",
80
+ "Qwen3-Reranker-8B",
81
+ }:
82
+ instruction = "Given a web search query, retrieve relevant passages that answer the query"
83
+ prefix = (
84
+ "<|im_start|>system\nJudge whether the Document meets the requirements based on"
85
+ " the Query and the Instruct provided. "
86
+ 'Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
87
+ )
88
+ suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
89
+ query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
90
+ document_template = "<Document>: {doc}{suffix}"
91
+ processed_queries = [
92
+ query_template.format(
93
+ prefix=prefix, instruction=instruction, query=query
94
+ )
95
+ for query in query_list
96
+ ]
97
+ processed_documents = [
98
+ document_template.format(doc=doc, suffix=suffix) for doc in documents
99
+ ]
100
+ outputs = self._model.score(
101
+ processed_documents,
102
+ processed_queries,
103
+ use_tqdm=False,
104
+ )
105
+
106
+ else:
107
+ outputs = self._model.score(
108
+ documents,
109
+ query_list,
110
+ use_tqdm=False,
111
+ )
112
+ scores = map(lambda scoreoutput: scoreoutput.outputs.score, outputs)
113
+ documents = list(map(lambda doc: Document(text=doc), documents))
114
+ document_parts = list(zip(range(documents_size), scores, documents))
115
+ document_parts.sort(key=lambda x: x[1], reverse=True)
116
+ if top_n is not None:
117
+ document_parts = document_parts[:top_n]
118
+ reranked_docs = list(
119
+ map(
120
+ lambda doc: DocumentObj(
121
+ index=doc[0],
122
+ relevance_score=doc[1],
123
+ document=doc[2] if return_documents else None,
124
+ ),
125
+ document_parts,
126
+ )
127
+ )
128
+ tokens = sum(map(lambda x: len(x.prompt_token_ids), outputs))
129
+ metadata = Meta(
130
+ api_version=None,
131
+ billed_units=None,
132
+ tokens=(
133
+ RerankTokens(input_tokens=tokens, output_tokens=tokens)
134
+ if return_len
135
+ else None
136
+ ),
137
+ warnings=None,
138
+ )
139
+ return Rerank(id=str(uuid.uuid4()), results=reranked_docs, meta=metadata)
140
+
141
+ @classmethod
142
+ def check_lib(cls) -> bool:
143
+ return importlib.util.find_spec("vllm") is not None
144
+
145
+ @classmethod
146
+ def match_json(
147
+ cls,
148
+ model_family: RerankModelFamilyV2,
149
+ model_spec: RerankSpecV1,
150
+ quantization: str,
151
+ ) -> bool:
152
+ if model_spec.model_format in ["pytorch"]:
153
+ prefix = model_family.model_name.split("-", 1)[0]
154
+ if prefix in SUPPORTED_MODELS_PREFIXES:
155
+ return True
156
+ return False
xinference/model/utils.py CHANGED
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import asyncio
16
+ import functools
16
17
  import json
17
18
  import logging
18
19
  import os
@@ -454,6 +455,19 @@ def get_engine_params_by_name(
454
455
  for param in params:
455
456
  del param["embedding_class"]
456
457
 
458
+ return engine_params
459
+ elif model_type == "rerank":
460
+ from .rerank.rerank_family import RERANK_ENGINES
461
+
462
+ if model_name not in RERANK_ENGINES:
463
+ return None
464
+
465
+ # filter rerank_class
466
+ engine_params = deepcopy(RERANK_ENGINES[model_name])
467
+ for engine, params in engine_params.items():
468
+ for param in params:
469
+ del param["rerank_class"]
470
+
457
471
  return engine_params
458
472
  else:
459
473
  raise ValueError(
@@ -558,3 +572,97 @@ class ModelInstanceInfoMixin(ABC):
558
572
  @abstractmethod
559
573
  def to_version_info(self):
560
574
  """"""
575
+
576
+
577
+ def is_flash_attn_available() -> bool:
578
+ """
579
+ Check if flash_attention can be enabled in the current environment.
580
+
581
+ Checks the following conditions:
582
+ 1. Whether the flash_attn package is installed
583
+ 2. Whether CUDA GPU is available
584
+ 3. Whether PyTorch supports CUDA
585
+ 4. Whether GPU compute capability meets requirements (>= 8.0)
586
+
587
+ Returns:
588
+ bool: True if flash_attention can be enabled, False otherwise
589
+ """
590
+ import importlib.util
591
+
592
+ # Check if flash_attn is installed
593
+ if importlib.util.find_spec("flash_attn") is None:
594
+ logger.debug("flash_attn package not found")
595
+ return False
596
+
597
+ try:
598
+ import torch
599
+
600
+ # Check if CUDA is available
601
+ if not torch.cuda.is_available():
602
+ logger.debug("CUDA not available")
603
+ return False
604
+
605
+ # Check GPU count
606
+ if torch.cuda.device_count() == 0:
607
+ logger.debug("No CUDA devices found")
608
+ return False
609
+
610
+ # Check current GPU compute capability
611
+ # Flash Attention typically requires compute capability >= 8.0 (A100, H100, etc.)
612
+ current_device = torch.cuda.current_device()
613
+ capability = torch.cuda.get_device_capability(current_device)
614
+ major, minor = capability
615
+ compute_capability = major + minor * 0.1
616
+
617
+ if compute_capability < 8.0:
618
+ logger.debug(
619
+ f"GPU compute capability {compute_capability} < 8.0, "
620
+ "flash_attn may not work optimally"
621
+ )
622
+ return False
623
+
624
+ # Try to import flash_attn core module to verify correct installation
625
+ try:
626
+ import flash_attn
627
+
628
+ logger.debug(
629
+ f"flash_attn version: {getattr(flash_attn, '__version__', 'unknown')}"
630
+ )
631
+ return True
632
+ except ImportError as e:
633
+ logger.debug(f"Failed to import flash_attn: {e}")
634
+ return False
635
+ except Exception as e:
636
+ logger.debug(f"Error checking flash_attn availability: {e}")
637
+ return False
638
+
639
+
640
+ def cache_clean(fn):
641
+ @functools.wraps(fn)
642
+ async def _async_wrapper(self, *args, **kwargs):
643
+ import gc
644
+
645
+ from ..device_utils import empty_cache
646
+
647
+ result = await fn(self, *args, **kwargs)
648
+
649
+ gc.collect()
650
+ empty_cache()
651
+ return result
652
+
653
+ @functools.wraps(fn)
654
+ def _wrapper(self, *args, **kwargs):
655
+ import gc
656
+
657
+ from ..device_utils import empty_cache
658
+
659
+ result = fn(self, *args, **kwargs)
660
+
661
+ gc.collect()
662
+ empty_cache()
663
+ return result
664
+
665
+ if asyncio.iscoroutinefunction(fn):
666
+ return _async_wrapper
667
+ else:
668
+ return _wrapper