xinference 1.8.1rc1__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +2 -1
- xinference/core/model.py +5 -0
- xinference/core/supervisor.py +2 -3
- xinference/core/worker.py +3 -4
- xinference/deploy/local.py +5 -0
- xinference/deploy/worker.py +6 -0
- xinference/model/core.py +3 -0
- xinference/model/embedding/sentence_transformers/core.py +3 -4
- xinference/model/embedding/vllm/core.py +4 -3
- xinference/model/image/model_spec.json +69 -0
- xinference/model/image/stable_diffusion/core.py +22 -0
- xinference/model/llm/cache_manager.py +17 -3
- xinference/model/llm/harmony.py +245 -0
- xinference/model/llm/llm_family.json +293 -8
- xinference/model/llm/llm_family.py +1 -1
- xinference/model/llm/sglang/core.py +108 -5
- xinference/model/llm/transformers/core.py +15 -7
- xinference/model/llm/transformers/gemma3.py +1 -1
- xinference/model/llm/transformers/gpt_oss.py +91 -0
- xinference/model/llm/transformers/multimodal/core.py +1 -1
- xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
- xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
- xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
- xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
- xinference/model/llm/transformers/utils.py +1 -33
- xinference/model/llm/utils.py +61 -7
- xinference/model/llm/vllm/core.py +38 -8
- xinference/model/rerank/__init__.py +66 -23
- xinference/model/rerank/cache_manager.py +35 -0
- xinference/model/rerank/core.py +84 -339
- xinference/model/rerank/custom.py +33 -8
- xinference/model/rerank/model_spec.json +251 -212
- xinference/model/rerank/rerank_family.py +137 -0
- xinference/model/rerank/sentence_transformers/__init__.py +13 -0
- xinference/model/rerank/sentence_transformers/core.py +337 -0
- xinference/model/rerank/vllm/__init__.py +13 -0
- xinference/model/rerank/vllm/core.py +106 -0
- xinference/model/utils.py +109 -0
- xinference/types.py +2 -0
- xinference/ui/web/ui/build/asset-manifest.json +3 -3
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/js/{main.b969199a.js → main.4918643a.js} +3 -3
- xinference/ui/web/ui/build/static/js/{main.b969199a.js.map → main.4918643a.js.map} +1 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/89179f8f51887b9167721860a12412549ff04f78162e921a7b6aa6532646deb2.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/9dc5cfc67dd0617b0272aeef8651f1589b2155a4ff1fd72ad3166b217089b619.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.0.dist-info}/METADATA +6 -1
- {xinference-1.8.1rc1.dist-info → xinference-1.9.0.dist-info}/RECORD +58 -50
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
- /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.4918643a.js.LICENSE.txt} +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.0.dist-info}/WHEEL +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.0.dist-info}/top_level.txt +0 -0
|
@@ -16,10 +16,10 @@ import codecs
|
|
|
16
16
|
import json
|
|
17
17
|
import os
|
|
18
18
|
import warnings
|
|
19
|
-
from typing import Dict, List
|
|
19
|
+
from typing import Any, Dict, List
|
|
20
20
|
|
|
21
21
|
from ...constants import XINFERENCE_MODEL_DIR
|
|
22
|
-
from ..utils import
|
|
22
|
+
from ..utils import flatten_quantizations
|
|
23
23
|
from .core import (
|
|
24
24
|
RERANK_MODEL_DESCRIPTIONS,
|
|
25
25
|
RerankModelFamilyV2,
|
|
@@ -32,8 +32,13 @@ from .custom import (
|
|
|
32
32
|
register_rerank,
|
|
33
33
|
unregister_rerank,
|
|
34
34
|
)
|
|
35
|
-
|
|
36
|
-
BUILTIN_RERANK_MODELS
|
|
35
|
+
from .rerank_family import (
|
|
36
|
+
BUILTIN_RERANK_MODELS,
|
|
37
|
+
RERANK_ENGINES,
|
|
38
|
+
SENTENCE_TRANSFORMER_CLASSES,
|
|
39
|
+
SUPPORTED_ENGINES,
|
|
40
|
+
VLLM_CLASSES,
|
|
41
|
+
)
|
|
37
42
|
|
|
38
43
|
|
|
39
44
|
def register_custom_model():
|
|
@@ -58,31 +63,69 @@ def register_custom_model():
|
|
|
58
63
|
warnings.warn(f"{user_defined_rerank_dir}/{f} has error, {e}")
|
|
59
64
|
|
|
60
65
|
|
|
61
|
-
def
|
|
62
|
-
|
|
66
|
+
def generate_engine_config_by_model_name(model_family: "RerankModelFamilyV2"):
|
|
67
|
+
model_name = model_family.model_name
|
|
68
|
+
engines: Dict[str, List[Dict[str, Any]]] = RERANK_ENGINES.get(
|
|
69
|
+
model_name, {}
|
|
70
|
+
) # structure for engine query
|
|
71
|
+
for spec in [x for x in model_family.model_specs if x.model_hub == "huggingface"]:
|
|
72
|
+
model_format = spec.model_format
|
|
73
|
+
quantization = spec.quantization
|
|
74
|
+
for engine in SUPPORTED_ENGINES:
|
|
75
|
+
CLASSES = SUPPORTED_ENGINES[engine]
|
|
76
|
+
for cls in CLASSES:
|
|
77
|
+
# Every engine needs to implement match method
|
|
78
|
+
if cls.match(model_family, spec, quantization):
|
|
79
|
+
# we only match the first class for an engine
|
|
80
|
+
if engine not in engines:
|
|
81
|
+
engines[engine] = [
|
|
82
|
+
{
|
|
83
|
+
"model_name": model_name,
|
|
84
|
+
"model_format": model_format,
|
|
85
|
+
"quantization": quantization,
|
|
86
|
+
"rerank_class": cls,
|
|
87
|
+
}
|
|
88
|
+
]
|
|
89
|
+
else:
|
|
90
|
+
engines[engine].append(
|
|
91
|
+
{
|
|
92
|
+
"model_name": model_name,
|
|
93
|
+
"model_format": model_format,
|
|
94
|
+
"quantization": quantization,
|
|
95
|
+
"rerank_class": cls,
|
|
96
|
+
}
|
|
97
|
+
)
|
|
98
|
+
break
|
|
99
|
+
RERANK_ENGINES[model_name] = engines
|
|
100
|
+
|
|
63
101
|
|
|
64
|
-
|
|
65
|
-
|
|
102
|
+
def _install():
|
|
103
|
+
_model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
|
|
104
|
+
for json_obj in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8")):
|
|
105
|
+
flattened = []
|
|
106
|
+
for spec in json_obj["model_specs"]:
|
|
107
|
+
flattened.extend(flatten_quantizations(spec))
|
|
108
|
+
json_obj["model_specs"] = flattened
|
|
109
|
+
BUILTIN_RERANK_MODELS[json_obj["model_name"]] = RerankModelFamilyV2(**json_obj)
|
|
110
|
+
|
|
111
|
+
for model_name, model_spec in BUILTIN_RERANK_MODELS.items():
|
|
66
112
|
if model_spec.model_name not in RERANK_MODEL_DESCRIPTIONS:
|
|
67
113
|
RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(model_spec))
|
|
68
114
|
|
|
69
|
-
|
|
115
|
+
from .sentence_transformers.core import SentenceTransformerRerankModel
|
|
116
|
+
from .vllm.core import VLLMRerankModel
|
|
70
117
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(ud_rerank))
|
|
118
|
+
SENTENCE_TRANSFORMER_CLASSES.extend([SentenceTransformerRerankModel])
|
|
119
|
+
VLLM_CLASSES.extend([VLLMRerankModel])
|
|
74
120
|
|
|
121
|
+
SUPPORTED_ENGINES["sentence_transformers"] = SENTENCE_TRANSFORMER_CLASSES
|
|
122
|
+
SUPPORTED_ENGINES["vllm"] = VLLM_CLASSES
|
|
75
123
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
flattened_model_specs = []
|
|
79
|
-
for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8")):
|
|
80
|
-
flattened_model_specs.extend(flatten_model_src(spec))
|
|
124
|
+
for model_spec in BUILTIN_RERANK_MODELS.values():
|
|
125
|
+
generate_engine_config_by_model_name(model_spec)
|
|
81
126
|
|
|
82
|
-
|
|
83
|
-
if spec["model_name"] not in target_families:
|
|
84
|
-
target_families[spec["model_name"]] = [RerankModelFamilyV2(**spec)]
|
|
85
|
-
else:
|
|
86
|
-
target_families[spec["model_name"]].append(RerankModelFamilyV2(**spec))
|
|
127
|
+
register_custom_model()
|
|
87
128
|
|
|
88
|
-
|
|
129
|
+
# register model description
|
|
130
|
+
for ud_rerank in get_user_defined_reranks():
|
|
131
|
+
RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(ud_rerank))
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from ..cache_manager import CacheManager
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from .core import RerankModelFamilyV2
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RerankCacheManager(CacheManager):
|
|
11
|
+
def __init__(self, model_family: "RerankModelFamilyV2"):
|
|
12
|
+
from ..llm.cache_manager import LLMCacheManager
|
|
13
|
+
|
|
14
|
+
super().__init__(model_family)
|
|
15
|
+
# Composition design mode for avoiding duplicate code
|
|
16
|
+
self.cache_helper = LLMCacheManager(model_family)
|
|
17
|
+
|
|
18
|
+
spec = self._model_family.model_specs[0]
|
|
19
|
+
model_dir_name = (
|
|
20
|
+
f"{self._model_family.model_name}-{spec.model_format}-{spec.quantization}"
|
|
21
|
+
)
|
|
22
|
+
self._cache_dir = os.path.join(self._v2_cache_dir_prefix, model_dir_name)
|
|
23
|
+
self.cache_helper._cache_dir = self._cache_dir
|
|
24
|
+
|
|
25
|
+
def cache(self) -> str:
|
|
26
|
+
spec = self._model_family.model_specs[0]
|
|
27
|
+
if spec.model_uri is not None:
|
|
28
|
+
return self.cache_helper.cache_uri()
|
|
29
|
+
else:
|
|
30
|
+
if spec.model_hub == "huggingface":
|
|
31
|
+
return self.cache_helper.cache_from_huggingface()
|
|
32
|
+
elif spec.model_hub == "modelscope":
|
|
33
|
+
return self.cache_helper.cache_from_modelscope()
|
|
34
|
+
else:
|
|
35
|
+
raise ValueError(f"Unknown model hub: {spec.model_hub}")
|
xinference/model/rerank/core.py
CHANGED
|
@@ -10,28 +10,18 @@
|
|
|
10
10
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
13
|
|
|
15
|
-
import gc
|
|
16
|
-
import importlib
|
|
17
|
-
import importlib.util
|
|
18
14
|
import logging
|
|
19
15
|
import os
|
|
20
|
-
import
|
|
21
|
-
import uuid
|
|
16
|
+
from abc import abstractmethod
|
|
22
17
|
from collections import defaultdict
|
|
23
|
-
from collections.abc import Sequence
|
|
24
18
|
from typing import Dict, List, Literal, Optional
|
|
25
19
|
|
|
26
|
-
|
|
27
|
-
import
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
from ...device_utils import empty_cache, is_device_available
|
|
31
|
-
from ...types import Document, DocumentObj, Rerank, RerankTokens
|
|
32
|
-
from ..core import CacheableModelSpec, VirtualEnvSettings
|
|
20
|
+
from ..._compat import BaseModel
|
|
21
|
+
from ...types import Rerank
|
|
22
|
+
from ..core import VirtualEnvSettings
|
|
33
23
|
from ..utils import ModelInstanceInfoMixin
|
|
34
|
-
from .
|
|
24
|
+
from .rerank_family import check_engine_by_model_name_and_engine, match_rerank
|
|
35
25
|
|
|
36
26
|
logger = logging.getLogger(__name__)
|
|
37
27
|
|
|
@@ -48,21 +38,29 @@ def get_rerank_model_descriptions():
|
|
|
48
38
|
return copy.deepcopy(RERANK_MODEL_DESCRIPTIONS)
|
|
49
39
|
|
|
50
40
|
|
|
51
|
-
class
|
|
41
|
+
class RerankSpecV1(BaseModel):
|
|
42
|
+
model_format: Literal["pytorch"]
|
|
43
|
+
model_hub: str = "huggingface"
|
|
44
|
+
model_id: Optional[str] = None
|
|
45
|
+
model_revision: Optional[str] = None
|
|
46
|
+
model_uri: Optional[str] = None
|
|
47
|
+
quantization: str = "none"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class RerankModelFamilyV2(BaseModel, ModelInstanceInfoMixin):
|
|
52
51
|
version: Literal[2]
|
|
53
52
|
model_name: str
|
|
53
|
+
model_specs: List[RerankSpecV1]
|
|
54
54
|
language: List[str]
|
|
55
55
|
type: Optional[str] = "unknown"
|
|
56
56
|
max_tokens: Optional[int]
|
|
57
|
-
model_id: str
|
|
58
|
-
model_revision: Optional[str]
|
|
59
|
-
model_hub: str = "huggingface"
|
|
60
57
|
virtualenv: Optional[VirtualEnvSettings]
|
|
61
58
|
|
|
62
59
|
class Config:
|
|
63
60
|
extra = "allow"
|
|
64
61
|
|
|
65
62
|
def to_description(self):
|
|
63
|
+
spec = self.model_specs[0]
|
|
66
64
|
return {
|
|
67
65
|
"model_type": "rerank",
|
|
68
66
|
"address": getattr(self, "address", None),
|
|
@@ -70,13 +68,13 @@ class RerankModelFamilyV2(CacheableModelSpec, ModelInstanceInfoMixin):
|
|
|
70
68
|
"type": self.type,
|
|
71
69
|
"model_name": self.model_name,
|
|
72
70
|
"language": self.language,
|
|
73
|
-
"model_revision":
|
|
71
|
+
"model_revision": spec.model_revision,
|
|
74
72
|
}
|
|
75
73
|
|
|
76
74
|
def to_version_info(self):
|
|
77
|
-
from
|
|
75
|
+
from .cache_manager import RerankCacheManager
|
|
78
76
|
|
|
79
|
-
cache_manager =
|
|
77
|
+
cache_manager = RerankCacheManager(self)
|
|
80
78
|
return {
|
|
81
79
|
"model_version": self.model_name,
|
|
82
80
|
"model_file_location": cache_manager.get_cache_dir(),
|
|
@@ -93,56 +91,56 @@ def generate_rerank_description(
|
|
|
93
91
|
return res
|
|
94
92
|
|
|
95
93
|
|
|
96
|
-
class _ModelWrapper(nn.Module):
|
|
97
|
-
def __init__(self, module: nn.Module):
|
|
98
|
-
super().__init__()
|
|
99
|
-
self.model = module
|
|
100
|
-
self._local_data = threading.local()
|
|
101
|
-
|
|
102
|
-
@property
|
|
103
|
-
def n_tokens(self):
|
|
104
|
-
return getattr(self._local_data, "n_tokens", 0)
|
|
105
|
-
|
|
106
|
-
@n_tokens.setter
|
|
107
|
-
def n_tokens(self, value):
|
|
108
|
-
self._local_data.n_tokens = value
|
|
109
|
-
|
|
110
|
-
def forward(self, **kwargs):
|
|
111
|
-
attention_mask = kwargs.get("attention_mask")
|
|
112
|
-
# when batching, the attention mask 1 means there is a token
|
|
113
|
-
# thus we just sum up it to get the total number of tokens
|
|
114
|
-
if attention_mask is not None:
|
|
115
|
-
self.n_tokens += attention_mask.sum().item()
|
|
116
|
-
return self.model(**kwargs)
|
|
117
|
-
|
|
118
|
-
def __getattr__(self, attr):
|
|
119
|
-
try:
|
|
120
|
-
return super().__getattr__(attr)
|
|
121
|
-
except AttributeError:
|
|
122
|
-
return getattr(self.model, attr)
|
|
123
|
-
|
|
124
|
-
|
|
125
94
|
class RerankModel:
|
|
126
95
|
def __init__(
|
|
127
96
|
self,
|
|
128
|
-
model_spec: RerankModelFamilyV2,
|
|
129
97
|
model_uid: str,
|
|
130
|
-
model_path:
|
|
98
|
+
model_path: str,
|
|
99
|
+
model_family: RerankModelFamilyV2,
|
|
131
100
|
device: Optional[str] = None,
|
|
132
101
|
use_fp16: bool = False,
|
|
133
|
-
|
|
102
|
+
**kwargs,
|
|
134
103
|
):
|
|
135
|
-
self.model_family =
|
|
136
|
-
self._model_spec =
|
|
104
|
+
self.model_family = model_family
|
|
105
|
+
self._model_spec = model_family.model_specs[0]
|
|
137
106
|
self._model_uid = model_uid
|
|
138
107
|
self._model_path = model_path
|
|
139
108
|
self._device = device
|
|
140
|
-
self._model_config = model_config or dict()
|
|
141
109
|
self._use_fp16 = use_fp16
|
|
142
110
|
self._model = None
|
|
143
111
|
self._counter = 0
|
|
144
|
-
|
|
145
|
-
|
|
112
|
+
self._kwargs = kwargs
|
|
113
|
+
if model_family.type == "unknown":
|
|
114
|
+
model_family.type = self._auto_detect_type(model_path)
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
@abstractmethod
|
|
118
|
+
def check_lib(cls) -> bool:
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
@classmethod
|
|
122
|
+
@abstractmethod
|
|
123
|
+
def match_json(
|
|
124
|
+
cls,
|
|
125
|
+
model_family: RerankModelFamilyV2,
|
|
126
|
+
model_spec: RerankSpecV1,
|
|
127
|
+
quantization: str,
|
|
128
|
+
) -> bool:
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
def match(
|
|
133
|
+
cls,
|
|
134
|
+
model_family: RerankModelFamilyV2,
|
|
135
|
+
model_spec: RerankSpecV1,
|
|
136
|
+
quantization: str,
|
|
137
|
+
):
|
|
138
|
+
"""
|
|
139
|
+
Return if the model_spec can be matched.
|
|
140
|
+
"""
|
|
141
|
+
if not cls.check_lib():
|
|
142
|
+
return False
|
|
143
|
+
return cls.match_json(model_family, model_spec, quantization)
|
|
146
144
|
|
|
147
145
|
@staticmethod
|
|
148
146
|
def _get_tokenizer(model_path):
|
|
@@ -171,145 +169,10 @@ class RerankModel:
|
|
|
171
169
|
return "normal"
|
|
172
170
|
return rerank_type
|
|
173
171
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
|
|
177
|
-
if (
|
|
178
|
-
self._auto_detect_type(self._model_path) != "normal"
|
|
179
|
-
and flash_attn_installed
|
|
180
|
-
):
|
|
181
|
-
logger.warning(
|
|
182
|
-
"flash_attn can only support fp16 and bf16, "
|
|
183
|
-
"will force set `use_fp16` to True"
|
|
184
|
-
)
|
|
185
|
-
self._use_fp16 = True
|
|
186
|
-
|
|
187
|
-
if (
|
|
188
|
-
self._model_spec.type == "normal"
|
|
189
|
-
and "qwen3" not in self._model_spec.model_name.lower()
|
|
190
|
-
):
|
|
191
|
-
try:
|
|
192
|
-
import sentence_transformers
|
|
193
|
-
from sentence_transformers.cross_encoder import CrossEncoder
|
|
194
|
-
|
|
195
|
-
if sentence_transformers.__version__ < "3.1.0":
|
|
196
|
-
raise ValueError(
|
|
197
|
-
"The sentence_transformers version must be greater than 3.1.0. "
|
|
198
|
-
"Please upgrade your version via `pip install -U sentence_transformers` or refer to "
|
|
199
|
-
"https://github.com/UKPLab/sentence-transformers"
|
|
200
|
-
)
|
|
201
|
-
except ImportError:
|
|
202
|
-
error_message = "Failed to import module 'sentence-transformers'"
|
|
203
|
-
installation_guide = [
|
|
204
|
-
"Please make sure 'sentence-transformers' is installed. ",
|
|
205
|
-
"You can install it by `pip install sentence-transformers`\n",
|
|
206
|
-
]
|
|
207
|
-
|
|
208
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
209
|
-
self._model = CrossEncoder(
|
|
210
|
-
self._model_path,
|
|
211
|
-
device=self._device,
|
|
212
|
-
trust_remote_code=True,
|
|
213
|
-
max_length=getattr(self._model_spec, "max_tokens"),
|
|
214
|
-
**self._model_config,
|
|
215
|
-
)
|
|
216
|
-
if self._use_fp16:
|
|
217
|
-
self._model.model.half()
|
|
218
|
-
elif "qwen3" in self._model_spec.model_name.lower():
|
|
219
|
-
# qwen3-reranker
|
|
220
|
-
# now we use transformers
|
|
221
|
-
# TODO: support engines for rerank models
|
|
222
|
-
try:
|
|
223
|
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
224
|
-
except ImportError:
|
|
225
|
-
error_message = "Failed to import module 'transformers'"
|
|
226
|
-
installation_guide = [
|
|
227
|
-
"Please make sure 'transformers' is installed. ",
|
|
228
|
-
"You can install it by `pip install transformers`\n",
|
|
229
|
-
]
|
|
230
|
-
|
|
231
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
232
|
-
|
|
233
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
234
|
-
self._model_path, padding_side="left"
|
|
235
|
-
)
|
|
236
|
-
enable_flash_attn = self._model_config.pop(
|
|
237
|
-
"enable_flash_attn", is_device_available("cuda")
|
|
238
|
-
)
|
|
239
|
-
model_kwargs = {"device_map": "auto"}
|
|
240
|
-
if flash_attn_installed and enable_flash_attn:
|
|
241
|
-
model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
242
|
-
model_kwargs["torch_dtype"] = torch.float16
|
|
243
|
-
model_kwargs.update(self._model_config)
|
|
244
|
-
logger.debug("Loading qwen3 rerank with kwargs %s", model_kwargs)
|
|
245
|
-
model = self._model = AutoModelForCausalLM.from_pretrained(
|
|
246
|
-
self._model_path, **model_kwargs
|
|
247
|
-
).eval()
|
|
248
|
-
max_length = getattr(self._model_spec, "max_tokens")
|
|
249
|
-
|
|
250
|
-
prefix = (
|
|
251
|
-
"<|im_start|>system\nJudge whether the Document meets the requirements based on the Query "
|
|
252
|
-
'and the Instruct provided. Note that the answer can only be "yes" or "no".'
|
|
253
|
-
"<|im_end|>\n<|im_start|>user\n"
|
|
254
|
-
)
|
|
255
|
-
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
|
256
|
-
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
|
|
257
|
-
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
|
|
258
|
-
|
|
259
|
-
def process_inputs(pairs):
|
|
260
|
-
inputs = tokenizer(
|
|
261
|
-
pairs,
|
|
262
|
-
padding=False,
|
|
263
|
-
truncation="longest_first",
|
|
264
|
-
return_attention_mask=False,
|
|
265
|
-
max_length=max_length - len(prefix_tokens) - len(suffix_tokens),
|
|
266
|
-
)
|
|
267
|
-
for i, ele in enumerate(inputs["input_ids"]):
|
|
268
|
-
inputs["input_ids"][i] = prefix_tokens + ele + suffix_tokens
|
|
269
|
-
inputs = tokenizer.pad(
|
|
270
|
-
inputs, padding=True, return_tensors="pt", max_length=max_length
|
|
271
|
-
)
|
|
272
|
-
for key in inputs:
|
|
273
|
-
inputs[key] = inputs[key].to(model.device)
|
|
274
|
-
return inputs
|
|
275
|
-
|
|
276
|
-
token_false_id = tokenizer.convert_tokens_to_ids("no")
|
|
277
|
-
token_true_id = tokenizer.convert_tokens_to_ids("yes")
|
|
278
|
-
|
|
279
|
-
@torch.no_grad()
|
|
280
|
-
def compute_logits(inputs, **kwargs):
|
|
281
|
-
batch_scores = model(**inputs).logits[:, -1, :]
|
|
282
|
-
true_vector = batch_scores[:, token_true_id]
|
|
283
|
-
false_vector = batch_scores[:, token_false_id]
|
|
284
|
-
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
|
285
|
-
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
|
286
|
-
scores = batch_scores[:, 1].exp().tolist()
|
|
287
|
-
return scores
|
|
288
|
-
|
|
289
|
-
self.process_inputs = process_inputs
|
|
290
|
-
self.compute_logits = compute_logits
|
|
291
|
-
else:
|
|
292
|
-
try:
|
|
293
|
-
if self._model_spec.type == "LLM-based":
|
|
294
|
-
from FlagEmbedding import FlagLLMReranker as FlagReranker
|
|
295
|
-
elif self._model_spec.type == "LLM-based layerwise":
|
|
296
|
-
from FlagEmbedding import LayerWiseFlagLLMReranker as FlagReranker
|
|
297
|
-
else:
|
|
298
|
-
raise RuntimeError(
|
|
299
|
-
f"Unsupported Rank model type: {self._model_spec.type}"
|
|
300
|
-
)
|
|
301
|
-
except ImportError:
|
|
302
|
-
error_message = "Failed to import module 'FlagEmbedding'"
|
|
303
|
-
installation_guide = [
|
|
304
|
-
"Please make sure 'FlagEmbedding' is installed. ",
|
|
305
|
-
"You can install it by `pip install FlagEmbedding`\n",
|
|
306
|
-
]
|
|
307
|
-
|
|
308
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
309
|
-
self._model = FlagReranker(self._model_path, use_fp16=self._use_fp16)
|
|
310
|
-
# Wrap transformers model to record number of tokens
|
|
311
|
-
self._model.model = _ModelWrapper(self._model.model)
|
|
172
|
+
@abstractmethod
|
|
173
|
+
def load(self): ...
|
|
312
174
|
|
|
175
|
+
@abstractmethod
|
|
313
176
|
def rerank(
|
|
314
177
|
self,
|
|
315
178
|
documents: List[str],
|
|
@@ -319,159 +182,41 @@ class RerankModel:
|
|
|
319
182
|
return_documents: Optional[bool],
|
|
320
183
|
return_len: Optional[bool],
|
|
321
184
|
**kwargs,
|
|
322
|
-
) -> Rerank:
|
|
323
|
-
assert self._model is not None
|
|
324
|
-
if max_chunks_per_doc is not None:
|
|
325
|
-
raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
|
|
326
|
-
logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model)
|
|
327
|
-
|
|
328
|
-
pre_query = preprocess_sentence(
|
|
329
|
-
query, kwargs.get("instruction", None), self._model_spec.model_name
|
|
330
|
-
)
|
|
331
|
-
sentence_combinations = [[pre_query, doc] for doc in documents]
|
|
332
|
-
# reset n tokens
|
|
333
|
-
self._model.model.n_tokens = 0
|
|
334
|
-
if (
|
|
335
|
-
self._model_spec.type == "normal"
|
|
336
|
-
and "qwen3" not in self._model_spec.model_name.lower()
|
|
337
|
-
):
|
|
338
|
-
logger.debug("Passing processed sentences: %s", sentence_combinations)
|
|
339
|
-
similarity_scores = self._model.predict(
|
|
340
|
-
sentence_combinations,
|
|
341
|
-
convert_to_numpy=False,
|
|
342
|
-
convert_to_tensor=True,
|
|
343
|
-
**kwargs,
|
|
344
|
-
).cpu()
|
|
345
|
-
if similarity_scores.dtype == torch.bfloat16:
|
|
346
|
-
similarity_scores = similarity_scores.float()
|
|
347
|
-
elif "qwen3" in self._model_spec.model_name.lower():
|
|
348
|
-
|
|
349
|
-
def format_instruction(instruction, query, doc):
|
|
350
|
-
if instruction is None:
|
|
351
|
-
instruction = "Given a web search query, retrieve relevant passages that answer the query"
|
|
352
|
-
output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
|
|
353
|
-
instruction=instruction, query=query, doc=doc
|
|
354
|
-
)
|
|
355
|
-
return output
|
|
356
|
-
|
|
357
|
-
# reduce memory usage.
|
|
358
|
-
micro_bs = 4
|
|
359
|
-
similarity_scores = []
|
|
360
|
-
for i in range(0, len(documents), micro_bs):
|
|
361
|
-
sub_docs = documents[i : i + micro_bs]
|
|
362
|
-
pairs = [
|
|
363
|
-
format_instruction(kwargs.get("instruction", None), query, doc)
|
|
364
|
-
for doc in sub_docs
|
|
365
|
-
]
|
|
366
|
-
# Tokenize the input texts
|
|
367
|
-
inputs = self.process_inputs(pairs)
|
|
368
|
-
similarity_scores.extend(self.compute_logits(inputs))
|
|
369
|
-
else:
|
|
370
|
-
# Related issue: https://github.com/xorbitsai/inference/issues/1775
|
|
371
|
-
similarity_scores = self._model.compute_score(
|
|
372
|
-
sentence_combinations, **kwargs
|
|
373
|
-
)
|
|
374
|
-
|
|
375
|
-
if not isinstance(similarity_scores, Sequence):
|
|
376
|
-
similarity_scores = [similarity_scores]
|
|
377
|
-
elif (
|
|
378
|
-
isinstance(similarity_scores, list)
|
|
379
|
-
and len(similarity_scores) > 0
|
|
380
|
-
and isinstance(similarity_scores[0], Sequence)
|
|
381
|
-
):
|
|
382
|
-
similarity_scores = similarity_scores[0]
|
|
383
|
-
|
|
384
|
-
sim_scores_argsort = list(reversed(np.argsort(similarity_scores)))
|
|
385
|
-
if top_n is not None:
|
|
386
|
-
sim_scores_argsort = sim_scores_argsort[:top_n]
|
|
387
|
-
if return_documents:
|
|
388
|
-
docs = [
|
|
389
|
-
DocumentObj(
|
|
390
|
-
index=int(arg),
|
|
391
|
-
relevance_score=float(similarity_scores[arg]),
|
|
392
|
-
document=Document(text=documents[arg]),
|
|
393
|
-
)
|
|
394
|
-
for arg in sim_scores_argsort
|
|
395
|
-
]
|
|
396
|
-
else:
|
|
397
|
-
docs = [
|
|
398
|
-
DocumentObj(
|
|
399
|
-
index=int(arg),
|
|
400
|
-
relevance_score=float(similarity_scores[arg]),
|
|
401
|
-
document=None,
|
|
402
|
-
)
|
|
403
|
-
for arg in sim_scores_argsort
|
|
404
|
-
]
|
|
405
|
-
if return_len:
|
|
406
|
-
input_len = self._model.model.n_tokens
|
|
407
|
-
# Rerank Model output is just score or documents
|
|
408
|
-
# while return_documents = True
|
|
409
|
-
output_len = input_len
|
|
410
|
-
|
|
411
|
-
# api_version, billed_units, warnings
|
|
412
|
-
# is for Cohere API compatibility, set to None
|
|
413
|
-
metadata = {
|
|
414
|
-
"api_version": None,
|
|
415
|
-
"billed_units": None,
|
|
416
|
-
"tokens": (
|
|
417
|
-
RerankTokens(input_tokens=input_len, output_tokens=output_len)
|
|
418
|
-
if return_len
|
|
419
|
-
else None
|
|
420
|
-
),
|
|
421
|
-
"warnings": None,
|
|
422
|
-
}
|
|
423
|
-
|
|
424
|
-
del similarity_scores
|
|
425
|
-
# clear cache if possible
|
|
426
|
-
self._counter += 1
|
|
427
|
-
if self._counter % RERANK_EMPTY_CACHE_COUNT == 0:
|
|
428
|
-
logger.debug("Empty rerank cache.")
|
|
429
|
-
gc.collect()
|
|
430
|
-
empty_cache()
|
|
431
|
-
|
|
432
|
-
return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)
|
|
185
|
+
) -> Rerank: ...
|
|
433
186
|
|
|
434
187
|
|
|
435
188
|
def create_rerank_model_instance(
|
|
436
189
|
model_uid: str,
|
|
437
190
|
model_name: str,
|
|
191
|
+
model_engine: Optional[str],
|
|
192
|
+
model_format: Optional[str] = None,
|
|
193
|
+
quantization: Optional[str] = None,
|
|
438
194
|
download_hub: Optional[
|
|
439
195
|
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
440
196
|
] = None,
|
|
441
197
|
model_path: Optional[str] = None,
|
|
442
198
|
**kwargs,
|
|
443
199
|
) -> RerankModel:
|
|
444
|
-
from
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
model_spec = None
|
|
450
|
-
for ud_spec in get_user_defined_reranks():
|
|
451
|
-
if ud_spec.model_name == model_name:
|
|
452
|
-
model_spec = ud_spec
|
|
453
|
-
break
|
|
454
|
-
|
|
455
|
-
if model_spec is None:
|
|
456
|
-
if model_name in BUILTIN_RERANK_MODELS:
|
|
457
|
-
model_specs = BUILTIN_RERANK_MODELS[model_name]
|
|
458
|
-
if download_hub == "modelscope" or download_from_modelscope():
|
|
459
|
-
model_spec = (
|
|
460
|
-
[x for x in model_specs if x.model_hub == "modelscope"]
|
|
461
|
-
+ [x for x in model_specs if x.model_hub == "huggingface"]
|
|
462
|
-
)[0]
|
|
463
|
-
else:
|
|
464
|
-
model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0]
|
|
465
|
-
else:
|
|
466
|
-
raise ValueError(
|
|
467
|
-
f"Rerank model {model_name} not found, available "
|
|
468
|
-
f"model: {BUILTIN_RERANK_MODELS.keys()}"
|
|
469
|
-
)
|
|
470
|
-
if not model_path:
|
|
471
|
-
cache_manager = CacheManager(model_spec)
|
|
200
|
+
from .cache_manager import RerankCacheManager
|
|
201
|
+
|
|
202
|
+
model_family = match_rerank(model_name, model_format, quantization, download_hub)
|
|
203
|
+
if model_path is None:
|
|
204
|
+
cache_manager = RerankCacheManager(model_family)
|
|
472
205
|
model_path = cache_manager.cache()
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
206
|
+
|
|
207
|
+
if model_engine is None:
|
|
208
|
+
# unlike LLM and for compatibility,
|
|
209
|
+
# we use sentence_transformers as the default engine for all models
|
|
210
|
+
model_engine = "sentence_transformers"
|
|
211
|
+
|
|
212
|
+
rerank_cls = check_engine_by_model_name_and_engine(
|
|
213
|
+
model_engine, model_name, model_format, quantization
|
|
214
|
+
)
|
|
215
|
+
model = rerank_cls(
|
|
216
|
+
model_uid,
|
|
217
|
+
model_path,
|
|
218
|
+
model_family,
|
|
219
|
+
quantization,
|
|
220
|
+
**kwargs,
|
|
476
221
|
)
|
|
477
222
|
return model
|