xinference 1.7.1__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/client/restful/async_restful_client.py +8 -13
- xinference/client/restful/restful_client.py +6 -2
- xinference/core/chat_interface.py +6 -4
- xinference/core/media_interface.py +5 -0
- xinference/core/model.py +1 -5
- xinference/core/supervisor.py +117 -68
- xinference/core/worker.py +49 -37
- xinference/deploy/test/test_cmdline.py +2 -6
- xinference/model/audio/__init__.py +26 -23
- xinference/model/audio/chattts.py +3 -2
- xinference/model/audio/core.py +49 -98
- xinference/model/audio/cosyvoice.py +3 -2
- xinference/model/audio/custom.py +28 -73
- xinference/model/audio/f5tts.py +3 -2
- xinference/model/audio/f5tts_mlx.py +3 -2
- xinference/model/audio/fish_speech.py +3 -2
- xinference/model/audio/funasr.py +17 -4
- xinference/model/audio/kokoro.py +3 -2
- xinference/model/audio/megatts.py +3 -2
- xinference/model/audio/melotts.py +3 -2
- xinference/model/audio/model_spec.json +572 -171
- xinference/model/audio/utils.py +0 -6
- xinference/model/audio/whisper.py +3 -2
- xinference/model/audio/whisper_mlx.py +3 -2
- xinference/model/cache_manager.py +141 -0
- xinference/model/core.py +6 -49
- xinference/model/custom.py +174 -0
- xinference/model/embedding/__init__.py +67 -56
- xinference/model/embedding/cache_manager.py +35 -0
- xinference/model/embedding/core.py +104 -84
- xinference/model/embedding/custom.py +55 -78
- xinference/model/embedding/embed_family.py +80 -31
- xinference/model/embedding/flag/core.py +21 -5
- xinference/model/embedding/llama_cpp/__init__.py +0 -0
- xinference/model/embedding/llama_cpp/core.py +234 -0
- xinference/model/embedding/model_spec.json +968 -103
- xinference/model/embedding/sentence_transformers/core.py +30 -20
- xinference/model/embedding/vllm/core.py +11 -5
- xinference/model/flexible/__init__.py +8 -2
- xinference/model/flexible/core.py +26 -119
- xinference/model/flexible/custom.py +69 -0
- xinference/model/flexible/launchers/image_process_launcher.py +1 -0
- xinference/model/flexible/launchers/modelscope_launcher.py +5 -1
- xinference/model/flexible/launchers/transformers_launcher.py +15 -3
- xinference/model/flexible/launchers/yolo_launcher.py +5 -1
- xinference/model/image/__init__.py +20 -20
- xinference/model/image/cache_manager.py +62 -0
- xinference/model/image/core.py +70 -182
- xinference/model/image/custom.py +28 -72
- xinference/model/image/model_spec.json +402 -119
- xinference/model/image/ocr/got_ocr2.py +3 -2
- xinference/model/image/stable_diffusion/core.py +22 -7
- xinference/model/image/stable_diffusion/mlx.py +6 -6
- xinference/model/image/utils.py +2 -2
- xinference/model/llm/__init__.py +71 -94
- xinference/model/llm/cache_manager.py +292 -0
- xinference/model/llm/core.py +37 -111
- xinference/model/llm/custom.py +88 -0
- xinference/model/llm/llama_cpp/core.py +5 -7
- xinference/model/llm/llm_family.json +16260 -8151
- xinference/model/llm/llm_family.py +138 -839
- xinference/model/llm/lmdeploy/core.py +5 -7
- xinference/model/llm/memory.py +3 -4
- xinference/model/llm/mlx/core.py +6 -8
- xinference/model/llm/reasoning_parser.py +3 -1
- xinference/model/llm/sglang/core.py +32 -14
- xinference/model/llm/transformers/chatglm.py +3 -7
- xinference/model/llm/transformers/core.py +49 -27
- xinference/model/llm/transformers/deepseek_v2.py +2 -2
- xinference/model/llm/transformers/gemma3.py +2 -2
- xinference/model/llm/transformers/multimodal/cogagent.py +2 -2
- xinference/model/llm/transformers/multimodal/deepseek_vl2.py +2 -2
- xinference/model/llm/transformers/multimodal/gemma3.py +2 -2
- xinference/model/llm/transformers/multimodal/glm4_1v.py +167 -0
- xinference/model/llm/transformers/multimodal/glm4v.py +2 -2
- xinference/model/llm/transformers/multimodal/intern_vl.py +2 -2
- xinference/model/llm/transformers/multimodal/minicpmv26.py +3 -3
- xinference/model/llm/transformers/multimodal/ovis2.py +2 -2
- xinference/model/llm/transformers/multimodal/qwen-omni.py +2 -2
- xinference/model/llm/transformers/multimodal/qwen2_audio.py +2 -2
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +2 -2
- xinference/model/llm/transformers/opt.py +3 -7
- xinference/model/llm/utils.py +34 -49
- xinference/model/llm/vllm/core.py +77 -27
- xinference/model/llm/vllm/xavier/engine.py +5 -3
- xinference/model/llm/vllm/xavier/scheduler.py +10 -6
- xinference/model/llm/vllm/xavier/transfer.py +1 -1
- xinference/model/rerank/__init__.py +26 -25
- xinference/model/rerank/core.py +47 -87
- xinference/model/rerank/custom.py +25 -71
- xinference/model/rerank/model_spec.json +158 -33
- xinference/model/rerank/utils.py +2 -2
- xinference/model/utils.py +115 -54
- xinference/model/video/__init__.py +13 -17
- xinference/model/video/core.py +44 -102
- xinference/model/video/diffusers.py +4 -3
- xinference/model/video/model_spec.json +90 -21
- xinference/types.py +5 -3
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.7d24df53.js +3 -0
- xinference/web/ui/build/static/js/main.7d24df53.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2704ff66a5f73ca78b341eb3edec60154369df9d87fbc8c6dd60121abc5e1b0a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/607dfef23d33e6b594518c0c6434567639f24f356b877c80c60575184ec50ed0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9be3d56173aacc3efd0b497bcb13c4f6365de30069176ee9403b40e717542326.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9f9dd6c32c78a222d07da5987ae902effe16bcf20aac00774acdccc4de3c9ff2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b2ab5ee972c60d15eb9abf5845705f8ab7e1d125d324d9a9b1bcae5d6fd7ffb2.json +1 -0
- xinference/web/ui/src/locales/en.json +0 -1
- xinference/web/ui/src/locales/ja.json +0 -1
- xinference/web/ui/src/locales/ko.json +0 -1
- xinference/web/ui/src/locales/zh.json +0 -1
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/METADATA +9 -11
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/RECORD +119 -119
- xinference/model/audio/model_spec_modelscope.json +0 -231
- xinference/model/embedding/model_spec_modelscope.json +0 -293
- xinference/model/embedding/utils.py +0 -18
- xinference/model/image/model_spec_modelscope.json +0 -375
- xinference/model/llm/llama_cpp/memory.py +0 -457
- xinference/model/llm/llm_family_csghub.json +0 -56
- xinference/model/llm/llm_family_modelscope.json +0 -8700
- xinference/model/llm/llm_family_openmind_hub.json +0 -1019
- xinference/model/rerank/model_spec_modelscope.json +0 -85
- xinference/model/video/model_spec_modelscope.json +0 -184
- xinference/web/ui/build/static/js/main.9b12b7f9.js +0 -3
- xinference/web/ui/build/static/js/main.9b12b7f9.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +0 -1
- /xinference/web/ui/build/static/js/{main.9b12b7f9.js.LICENSE.txt → main.7d24df53.js.LICENSE.txt} +0 -0
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/WHEEL +0 -0
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/top_level.txt +0 -0
|
@@ -50,9 +50,9 @@ from ....types import (
|
|
|
50
50
|
CompletionUsage,
|
|
51
51
|
LoRA,
|
|
52
52
|
)
|
|
53
|
-
from .. import LLM,
|
|
53
|
+
from .. import BUILTIN_LLM_FAMILIES, LLM, LLMFamilyV2, LLMSpecV1
|
|
54
54
|
from ..core import chat_context_var
|
|
55
|
-
from ..llm_family import
|
|
55
|
+
from ..llm_family import CustomLLMFamilyV2, cache_model_tokenizer_and_config
|
|
56
56
|
from ..utils import (
|
|
57
57
|
DEEPSEEK_TOOL_CALL_FAMILY,
|
|
58
58
|
QWEN_TOOL_CALL_FAMILY,
|
|
@@ -117,6 +117,11 @@ class VLLMGenerateConfig(TypedDict, total=False):
|
|
|
117
117
|
try:
|
|
118
118
|
import vllm # noqa: F401
|
|
119
119
|
|
|
120
|
+
if not getattr(vllm, "__version__", None):
|
|
121
|
+
raise ImportError(
|
|
122
|
+
"vllm not installed properly, or wrongly be found in sys.path"
|
|
123
|
+
)
|
|
124
|
+
|
|
120
125
|
VLLM_INSTALLED = True
|
|
121
126
|
except ImportError:
|
|
122
127
|
VLLM_INSTALLED = False
|
|
@@ -257,14 +262,16 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.8.5":
|
|
|
257
262
|
if VLLM_INSTALLED and vllm.__version__ >= "0.9.1":
|
|
258
263
|
VLLM_SUPPORTED_CHAT_MODELS.append("minicpm4")
|
|
259
264
|
|
|
265
|
+
if VLLM_INSTALLED and vllm.__version__ >= "0.9.2":
|
|
266
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("Ernie4.5")
|
|
267
|
+
VLLM_SUPPORTED_VISION_MODEL_LIST.append("glm-4.1v-thinking")
|
|
268
|
+
|
|
260
269
|
|
|
261
270
|
class VLLMModel(LLM):
|
|
262
271
|
def __init__(
|
|
263
272
|
self,
|
|
264
273
|
model_uid: str,
|
|
265
|
-
model_family: "
|
|
266
|
-
model_spec: "LLMSpecV1",
|
|
267
|
-
quantization: str,
|
|
274
|
+
model_family: "LLMFamilyV2",
|
|
268
275
|
model_path: str,
|
|
269
276
|
model_config: Optional[VLLMModelConfig],
|
|
270
277
|
peft_model: Optional[List[LoRA]] = None,
|
|
@@ -279,7 +286,7 @@ class VLLMModel(LLM):
|
|
|
279
286
|
]
|
|
280
287
|
|
|
281
288
|
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
282
|
-
super().__init__(model_uid, model_family,
|
|
289
|
+
super().__init__(model_uid, model_family, model_path)
|
|
283
290
|
self._model_config = model_config
|
|
284
291
|
self._engine = None
|
|
285
292
|
self.lora_modules = peft_model
|
|
@@ -349,7 +356,7 @@ class VLLMModel(LLM):
|
|
|
349
356
|
|
|
350
357
|
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
351
358
|
|
|
352
|
-
from ..llm_family import
|
|
359
|
+
from ..llm_family import LlamaCppLLMSpecV2
|
|
353
360
|
|
|
354
361
|
if "0.3.1" <= vllm.__version__ <= "0.3.3":
|
|
355
362
|
# from vllm v0.3.1 to v0.3.3, it uses cupy as NCCL backend
|
|
@@ -368,7 +375,7 @@ class VLLMModel(LLM):
|
|
|
368
375
|
)
|
|
369
376
|
|
|
370
377
|
if (
|
|
371
|
-
isinstance(self.model_spec,
|
|
378
|
+
isinstance(self.model_spec, LlamaCppLLMSpecV2)
|
|
372
379
|
and self.model_spec.model_format == "ggufv2"
|
|
373
380
|
):
|
|
374
381
|
# gguf
|
|
@@ -592,20 +599,25 @@ class VLLMModel(LLM):
|
|
|
592
599
|
|
|
593
600
|
if "tokenizer" not in self._model_config:
|
|
594
601
|
# find pytorch format without quantization
|
|
602
|
+
family = next(
|
|
603
|
+
family
|
|
604
|
+
for family in BUILTIN_LLM_FAMILIES
|
|
605
|
+
if family.model_name == self.model_family.model_name
|
|
606
|
+
).copy()
|
|
595
607
|
non_quant_spec = next(
|
|
596
608
|
spec
|
|
597
|
-
for spec in
|
|
598
|
-
if spec.
|
|
599
|
-
and "none" in spec.quantizations
|
|
609
|
+
for spec in family.model_specs
|
|
610
|
+
if spec.quantization == "none"
|
|
600
611
|
and spec.model_size_in_billions
|
|
601
612
|
== self.model_spec.model_size_in_billions
|
|
613
|
+
and spec.model_hub == self.model_spec.model_hub
|
|
602
614
|
)
|
|
603
|
-
|
|
604
|
-
path = cache_model_tokenizer_and_config(
|
|
615
|
+
family.model_specs = [non_quant_spec]
|
|
616
|
+
path = cache_model_tokenizer_and_config(family)
|
|
605
617
|
# other than gguf file, vllm requires to provide tokenizer and hf_config_path
|
|
606
|
-
self._model_config["tokenizer"] = self._model_config[
|
|
607
|
-
|
|
608
|
-
|
|
618
|
+
self._model_config["tokenizer"] = self._model_config["hf_config_path"] = (
|
|
619
|
+
path
|
|
620
|
+
)
|
|
609
621
|
|
|
610
622
|
if not os.path.isfile(self.model_path):
|
|
611
623
|
self.model_path = os.path.realpath(
|
|
@@ -791,7 +803,7 @@ class VLLMModel(LLM):
|
|
|
791
803
|
|
|
792
804
|
@classmethod
|
|
793
805
|
def match_json(
|
|
794
|
-
cls, llm_family: "
|
|
806
|
+
cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
|
|
795
807
|
) -> bool:
|
|
796
808
|
if not cls._has_cuda_device():
|
|
797
809
|
return False
|
|
@@ -813,7 +825,7 @@ class VLLMModel(LLM):
|
|
|
813
825
|
else:
|
|
814
826
|
if "4" not in quantization:
|
|
815
827
|
return False
|
|
816
|
-
if isinstance(llm_family,
|
|
828
|
+
if isinstance(llm_family, CustomLLMFamilyV2):
|
|
817
829
|
if llm_family.model_family not in VLLM_SUPPORTED_MODELS:
|
|
818
830
|
return False
|
|
819
831
|
else:
|
|
@@ -1090,7 +1102,7 @@ class VLLMModel(LLM):
|
|
|
1090
1102
|
class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
1091
1103
|
@classmethod
|
|
1092
1104
|
def match_json(
|
|
1093
|
-
cls, llm_family: "
|
|
1105
|
+
cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
|
|
1094
1106
|
) -> bool:
|
|
1095
1107
|
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "ggufv2"]:
|
|
1096
1108
|
return False
|
|
@@ -1111,7 +1123,7 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
1111
1123
|
if llm_spec.model_format == "ggufv2":
|
|
1112
1124
|
if not (VLLM_INSTALLED and vllm.__version__ >= "0.8.2"):
|
|
1113
1125
|
return False
|
|
1114
|
-
if isinstance(llm_family,
|
|
1126
|
+
if isinstance(llm_family, CustomLLMFamilyV2):
|
|
1115
1127
|
if llm_family.model_family not in VLLM_SUPPORTED_CHAT_MODELS:
|
|
1116
1128
|
return False
|
|
1117
1129
|
else:
|
|
@@ -1137,9 +1149,9 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
1137
1149
|
not generate_config.get("stop_token_ids")
|
|
1138
1150
|
and self.model_family.stop_token_ids
|
|
1139
1151
|
):
|
|
1140
|
-
generate_config[
|
|
1141
|
-
|
|
1142
|
-
|
|
1152
|
+
generate_config["stop_token_ids"] = (
|
|
1153
|
+
self.model_family.stop_token_ids.copy()
|
|
1154
|
+
)
|
|
1143
1155
|
return generate_config
|
|
1144
1156
|
|
|
1145
1157
|
@staticmethod
|
|
@@ -1150,17 +1162,50 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
1150
1162
|
def is_tool_call_chunk_end(chunk):
|
|
1151
1163
|
return chunk["choices"][0]["text"].endswith(QWEN_TOOL_CALL_SYMBOLS[1])
|
|
1152
1164
|
|
|
1165
|
+
@staticmethod
|
|
1166
|
+
def prefill_messages(messages: List[Dict]) -> List[Dict]:
|
|
1167
|
+
"""
|
|
1168
|
+
Preprocess messages to ensure content is not None
|
|
1169
|
+
|
|
1170
|
+
Args:
|
|
1171
|
+
messages: Original message list
|
|
1172
|
+
|
|
1173
|
+
Returns:
|
|
1174
|
+
Processed message list, where content is not None
|
|
1175
|
+
"""
|
|
1176
|
+
processed_messages = []
|
|
1177
|
+
|
|
1178
|
+
for msg in messages:
|
|
1179
|
+
if isinstance(msg, dict):
|
|
1180
|
+
if msg.get("content") is None:
|
|
1181
|
+
msg_copy = msg.copy()
|
|
1182
|
+
msg_copy["content"] = "" # Replace None with empty string
|
|
1183
|
+
processed_messages.append(msg_copy)
|
|
1184
|
+
else:
|
|
1185
|
+
processed_messages.append(msg)
|
|
1186
|
+
else:
|
|
1187
|
+
processed_messages.append(msg)
|
|
1188
|
+
|
|
1189
|
+
return processed_messages
|
|
1190
|
+
|
|
1153
1191
|
async def _async_to_tool_completion_chunks(
|
|
1154
1192
|
self,
|
|
1155
1193
|
chunks: AsyncGenerator[CompletionChunk, None],
|
|
1194
|
+
ctx: Optional[Dict[str, Any]] = {},
|
|
1156
1195
|
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
|
1196
|
+
def set_context():
|
|
1197
|
+
if ctx:
|
|
1198
|
+
chat_context_var.set(ctx)
|
|
1199
|
+
|
|
1157
1200
|
i = 0
|
|
1158
1201
|
previous_texts = [""]
|
|
1159
1202
|
tool_call = False
|
|
1160
1203
|
tool_call_texts = [""]
|
|
1161
1204
|
if self.reasoning_parser:
|
|
1205
|
+
set_context()
|
|
1162
1206
|
chunks = self.reasoning_parser.prepare_reasoning_content_streaming(chunks)
|
|
1163
1207
|
async for chunk in chunks:
|
|
1208
|
+
set_context()
|
|
1164
1209
|
if i == 0:
|
|
1165
1210
|
for first_chunk in self._get_first_chat_completion_chunk(
|
|
1166
1211
|
chunk, self.reasoning_parser
|
|
@@ -1200,6 +1245,9 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
1200
1245
|
generate_config: Optional[Dict] = None,
|
|
1201
1246
|
request_id: Optional[str] = None,
|
|
1202
1247
|
) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
|
|
1248
|
+
# Preprocess messages to ensure content is not None
|
|
1249
|
+
messages = self.prefill_messages(messages)
|
|
1250
|
+
|
|
1203
1251
|
tools = generate_config.pop("tools", []) if generate_config else None
|
|
1204
1252
|
model_family = self.model_family.model_family or self.model_family.model_name
|
|
1205
1253
|
chat_template_kwargs = (
|
|
@@ -1230,8 +1278,10 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
1230
1278
|
)
|
|
1231
1279
|
assert isinstance(agen, AsyncGenerator)
|
|
1232
1280
|
if tools:
|
|
1233
|
-
return self._async_to_tool_completion_chunks(agen)
|
|
1234
|
-
return self._async_to_chat_completion_chunks(
|
|
1281
|
+
return self._async_to_tool_completion_chunks(agen, chat_template_kwargs)
|
|
1282
|
+
return self._async_to_chat_completion_chunks(
|
|
1283
|
+
agen, self.reasoning_parser, chat_template_kwargs
|
|
1284
|
+
)
|
|
1235
1285
|
else:
|
|
1236
1286
|
c = await self.async_generate(
|
|
1237
1287
|
full_prompt, generate_config, request_id=request_id
|
|
@@ -1247,7 +1297,7 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
1247
1297
|
class VLLMVisionModel(VLLMModel, ChatModelMixin):
|
|
1248
1298
|
@classmethod
|
|
1249
1299
|
def match_json(
|
|
1250
|
-
cls, llm_family: "
|
|
1300
|
+
cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
|
|
1251
1301
|
) -> bool:
|
|
1252
1302
|
if not cls._has_cuda_device():
|
|
1253
1303
|
return False
|
|
@@ -1269,7 +1319,7 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
|
|
|
1269
1319
|
else:
|
|
1270
1320
|
if "4" not in quantization:
|
|
1271
1321
|
return False
|
|
1272
|
-
if isinstance(llm_family,
|
|
1322
|
+
if isinstance(llm_family, CustomLLMFamilyV2):
|
|
1273
1323
|
if llm_family.model_family not in VLLM_SUPPORTED_VISION_MODEL_LIST:
|
|
1274
1324
|
return False
|
|
1275
1325
|
else:
|
|
@@ -39,9 +39,11 @@ class XavierInternalEngine(_AsyncLLMEngine):
|
|
|
39
39
|
self.cache_config,
|
|
40
40
|
self.lora_config,
|
|
41
41
|
self.parallel_config.pipeline_parallel_size,
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
42
|
+
(
|
|
43
|
+
self.async_callbacks[v_id]
|
|
44
|
+
if self.model_config.use_async_output_proc
|
|
45
|
+
else None
|
|
46
|
+
),
|
|
45
47
|
xavier_config=self._xavier_config,
|
|
46
48
|
virtual_engine=v_id,
|
|
47
49
|
)
|
|
@@ -352,12 +352,16 @@ class XavierScheduler(Scheduler):
|
|
|
352
352
|
# between engine and worker.
|
|
353
353
|
# the subsequent comms can still use delta, but
|
|
354
354
|
# `multi_modal_data` will be None.
|
|
355
|
-
multi_modal_data=
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
355
|
+
multi_modal_data=(
|
|
356
|
+
seq_group.multi_modal_data
|
|
357
|
+
if scheduler_outputs.num_prefill_groups > 0
|
|
358
|
+
else None
|
|
359
|
+
),
|
|
360
|
+
multi_modal_placeholders=(
|
|
361
|
+
seq_group.multi_modal_placeholders
|
|
362
|
+
if scheduler_outputs.num_prefill_groups > 0
|
|
363
|
+
else None
|
|
364
|
+
),
|
|
361
365
|
mm_processor_kwargs=seq_group.mm_processor_kwargs,
|
|
362
366
|
prompt_adapter_request=seq_group.prompt_adapter_request,
|
|
363
367
|
)
|
|
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|
|
30
30
|
|
|
31
31
|
class BufferTransferMixin:
|
|
32
32
|
def __init__(self):
|
|
33
|
-
self.num_buffer: int = 0
|
|
33
|
+
self.num_buffer: int = 0 # type: ignore
|
|
34
34
|
self.buffers: List[torch.Tensor] = [] # type: ignore
|
|
35
35
|
self.buffer_queue: Optional[Queue] = None # type: ignore
|
|
36
36
|
self.transfer_block_num = 0
|
|
@@ -16,38 +16,41 @@ import codecs
|
|
|
16
16
|
import json
|
|
17
17
|
import os
|
|
18
18
|
import warnings
|
|
19
|
-
from typing import
|
|
19
|
+
from typing import Dict, List
|
|
20
20
|
|
|
21
21
|
from ...constants import XINFERENCE_MODEL_DIR
|
|
22
|
+
from ..utils import flatten_model_src
|
|
22
23
|
from .core import (
|
|
23
|
-
MODEL_NAME_TO_REVISION,
|
|
24
24
|
RERANK_MODEL_DESCRIPTIONS,
|
|
25
|
-
|
|
25
|
+
RerankModelFamilyV2,
|
|
26
26
|
generate_rerank_description,
|
|
27
|
-
get_cache_status,
|
|
28
27
|
get_rerank_model_descriptions,
|
|
29
28
|
)
|
|
30
29
|
from .custom import (
|
|
31
|
-
|
|
30
|
+
CustomRerankModelFamilyV2,
|
|
32
31
|
get_user_defined_reranks,
|
|
33
32
|
register_rerank,
|
|
34
33
|
unregister_rerank,
|
|
35
34
|
)
|
|
36
35
|
|
|
37
|
-
BUILTIN_RERANK_MODELS: Dict[str,
|
|
38
|
-
MODELSCOPE_RERANK_MODELS: Dict[str, Any] = {}
|
|
36
|
+
BUILTIN_RERANK_MODELS: Dict[str, List["RerankModelFamilyV2"]] = {}
|
|
39
37
|
|
|
40
38
|
|
|
41
39
|
def register_custom_model():
|
|
40
|
+
from ..custom import migrate_from_v1_to_v2
|
|
41
|
+
|
|
42
|
+
# migrate from v1 to v2 first
|
|
43
|
+
migrate_from_v1_to_v2("rerank", CustomRerankModelFamilyV2)
|
|
44
|
+
|
|
42
45
|
# if persist=True, load them when init
|
|
43
|
-
user_defined_rerank_dir = os.path.join(XINFERENCE_MODEL_DIR, "rerank")
|
|
46
|
+
user_defined_rerank_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "rerank")
|
|
44
47
|
if os.path.isdir(user_defined_rerank_dir):
|
|
45
48
|
for f in os.listdir(user_defined_rerank_dir):
|
|
46
49
|
try:
|
|
47
50
|
with codecs.open(
|
|
48
51
|
os.path.join(user_defined_rerank_dir, f), encoding="utf-8"
|
|
49
52
|
) as fd:
|
|
50
|
-
user_defined_rerank_spec =
|
|
53
|
+
user_defined_rerank_spec = CustomRerankModelFamilyV2.parse_obj(
|
|
51
54
|
json.load(fd)
|
|
52
55
|
)
|
|
53
56
|
register_rerank(user_defined_rerank_spec, persist=False)
|
|
@@ -57,15 +60,11 @@ def register_custom_model():
|
|
|
57
60
|
|
|
58
61
|
def _install():
|
|
59
62
|
load_model_family_from_json("model_spec.json", BUILTIN_RERANK_MODELS)
|
|
60
|
-
load_model_family_from_json("model_spec_modelscope.json", MODELSCOPE_RERANK_MODELS)
|
|
61
63
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
RERANK_MODEL_DESCRIPTIONS.update(
|
|
67
|
-
generate_rerank_description(model_spec)
|
|
68
|
-
)
|
|
64
|
+
for model_name, model_specs in BUILTIN_RERANK_MODELS.items():
|
|
65
|
+
model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0]
|
|
66
|
+
if model_spec.model_name not in RERANK_MODEL_DESCRIPTIONS:
|
|
67
|
+
RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(model_spec))
|
|
69
68
|
|
|
70
69
|
register_custom_model()
|
|
71
70
|
|
|
@@ -76,12 +75,14 @@ def _install():
|
|
|
76
75
|
|
|
77
76
|
def load_model_family_from_json(json_filename, target_families):
|
|
78
77
|
_model_spec_json = os.path.join(os.path.dirname(__file__), json_filename)
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
78
|
+
flattened_model_specs = []
|
|
79
|
+
for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8")):
|
|
80
|
+
flattened_model_specs.extend(flatten_model_src(spec))
|
|
81
|
+
|
|
82
|
+
for spec in flattened_model_specs:
|
|
83
|
+
if spec["model_name"] not in target_families:
|
|
84
|
+
target_families[spec["model_name"]] = [RerankModelFamilyV2(**spec)]
|
|
85
|
+
else:
|
|
86
|
+
target_families[spec["model_name"]].append(RerankModelFamilyV2(**spec))
|
|
87
|
+
|
|
87
88
|
del _model_spec_json
|
xinference/model/rerank/core.py
CHANGED
|
@@ -21,24 +21,22 @@ import threading
|
|
|
21
21
|
import uuid
|
|
22
22
|
from collections import defaultdict
|
|
23
23
|
from collections.abc import Sequence
|
|
24
|
-
from typing import Dict, List, Literal, Optional
|
|
24
|
+
from typing import Dict, List, Literal, Optional
|
|
25
25
|
|
|
26
26
|
import numpy as np
|
|
27
27
|
import torch
|
|
28
28
|
import torch.nn as nn
|
|
29
29
|
|
|
30
|
-
from ...
|
|
31
|
-
from ...device_utils import empty_cache
|
|
30
|
+
from ...device_utils import empty_cache, is_device_available
|
|
32
31
|
from ...types import Document, DocumentObj, Rerank, RerankTokens
|
|
33
|
-
from ..core import CacheableModelSpec,
|
|
34
|
-
from ..utils import
|
|
32
|
+
from ..core import CacheableModelSpec, VirtualEnvSettings
|
|
33
|
+
from ..utils import ModelInstanceInfoMixin
|
|
35
34
|
from .utils import preprocess_sentence
|
|
36
35
|
|
|
37
36
|
logger = logging.getLogger(__name__)
|
|
38
37
|
|
|
39
38
|
# Used for check whether the model is cached.
|
|
40
39
|
# Init when registering all the builtin models.
|
|
41
|
-
MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
|
|
42
40
|
RERANK_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
|
|
43
41
|
RERANK_EMPTY_CACHE_COUNT = int(os.getenv("XINFERENCE_RERANK_EMPTY_CACHE_COUNT", "10"))
|
|
44
42
|
assert RERANK_EMPTY_CACHE_COUNT > 0
|
|
@@ -50,7 +48,8 @@ def get_rerank_model_descriptions():
|
|
|
50
48
|
return copy.deepcopy(RERANK_MODEL_DESCRIPTIONS)
|
|
51
49
|
|
|
52
50
|
|
|
53
|
-
class
|
|
51
|
+
class RerankModelFamilyV2(CacheableModelSpec, ModelInstanceInfoMixin):
|
|
52
|
+
version: Literal[2]
|
|
54
53
|
model_name: str
|
|
55
54
|
language: List[str]
|
|
56
55
|
type: Optional[str] = "unknown"
|
|
@@ -60,56 +59,37 @@ class RerankModelSpec(CacheableModelSpec):
|
|
|
60
59
|
model_hub: str = "huggingface"
|
|
61
60
|
virtualenv: Optional[VirtualEnvSettings]
|
|
62
61
|
|
|
62
|
+
class Config:
|
|
63
|
+
extra = "allow"
|
|
63
64
|
|
|
64
|
-
|
|
65
|
-
def __init__(
|
|
66
|
-
self,
|
|
67
|
-
address: Optional[str],
|
|
68
|
-
devices: Optional[List[str]],
|
|
69
|
-
model_spec: RerankModelSpec,
|
|
70
|
-
model_path: Optional[str] = None,
|
|
71
|
-
):
|
|
72
|
-
super().__init__(address, devices, model_path=model_path)
|
|
73
|
-
self._model_spec = model_spec
|
|
74
|
-
|
|
75
|
-
@property
|
|
76
|
-
def spec(self):
|
|
77
|
-
return self._model_spec
|
|
78
|
-
|
|
79
|
-
def to_dict(self):
|
|
65
|
+
def to_description(self):
|
|
80
66
|
return {
|
|
81
67
|
"model_type": "rerank",
|
|
82
|
-
"address": self
|
|
83
|
-
"accelerators": self
|
|
84
|
-
"type": self.
|
|
85
|
-
"model_name": self.
|
|
86
|
-
"language": self.
|
|
87
|
-
"model_revision": self.
|
|
68
|
+
"address": getattr(self, "address", None),
|
|
69
|
+
"accelerators": getattr(self, "accelerators", None),
|
|
70
|
+
"type": self.type,
|
|
71
|
+
"model_name": self.model_name,
|
|
72
|
+
"language": self.language,
|
|
73
|
+
"model_revision": self.model_revision,
|
|
88
74
|
}
|
|
89
75
|
|
|
90
76
|
def to_version_info(self):
|
|
91
|
-
from
|
|
92
|
-
|
|
93
|
-
if self._model_path is None:
|
|
94
|
-
is_cached = get_cache_status(self._model_spec)
|
|
95
|
-
file_location = get_cache_dir(self._model_spec)
|
|
96
|
-
else:
|
|
97
|
-
is_cached = True
|
|
98
|
-
file_location = self._model_path
|
|
77
|
+
from ..cache_manager import CacheManager
|
|
99
78
|
|
|
79
|
+
cache_manager = CacheManager(self)
|
|
100
80
|
return {
|
|
101
|
-
"model_version":
|
|
102
|
-
"model_file_location":
|
|
103
|
-
"cache_status":
|
|
104
|
-
"language": self.
|
|
81
|
+
"model_version": self.model_name,
|
|
82
|
+
"model_file_location": cache_manager.get_cache_dir(),
|
|
83
|
+
"cache_status": cache_manager.get_cache_status(),
|
|
84
|
+
"language": self.language,
|
|
105
85
|
}
|
|
106
86
|
|
|
107
87
|
|
|
108
|
-
def generate_rerank_description(
|
|
88
|
+
def generate_rerank_description(
|
|
89
|
+
model_spec: RerankModelFamilyV2,
|
|
90
|
+
) -> Dict[str, List[Dict]]:
|
|
109
91
|
res = defaultdict(list)
|
|
110
|
-
res[model_spec.model_name].append(
|
|
111
|
-
RerankModelDescription(None, None, model_spec).to_version_info()
|
|
112
|
-
)
|
|
92
|
+
res[model_spec.model_name].append(model_spec.to_version_info())
|
|
113
93
|
return res
|
|
114
94
|
|
|
115
95
|
|
|
@@ -145,13 +125,14 @@ class _ModelWrapper(nn.Module):
|
|
|
145
125
|
class RerankModel:
|
|
146
126
|
def __init__(
|
|
147
127
|
self,
|
|
148
|
-
model_spec:
|
|
128
|
+
model_spec: RerankModelFamilyV2,
|
|
149
129
|
model_uid: str,
|
|
150
130
|
model_path: Optional[str] = None,
|
|
151
131
|
device: Optional[str] = None,
|
|
152
132
|
use_fp16: bool = False,
|
|
153
133
|
model_config: Optional[Dict] = None,
|
|
154
134
|
):
|
|
135
|
+
self.model_family = model_spec
|
|
155
136
|
self._model_spec = model_spec
|
|
156
137
|
self._model_uid = model_uid
|
|
157
138
|
self._model_path = model_path
|
|
@@ -252,7 +233,9 @@ class RerankModel:
|
|
|
252
233
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
253
234
|
self._model_path, padding_side="left"
|
|
254
235
|
)
|
|
255
|
-
enable_flash_attn = self._model_config.
|
|
236
|
+
enable_flash_attn = self._model_config.pop(
|
|
237
|
+
"enable_flash_attn", is_device_available("cuda")
|
|
238
|
+
)
|
|
256
239
|
model_kwargs = {"device_map": "auto"}
|
|
257
240
|
if flash_attn_installed and enable_flash_attn:
|
|
258
241
|
model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
@@ -448,25 +431,7 @@ class RerankModel:
|
|
|
448
431
|
return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)
|
|
449
432
|
|
|
450
433
|
|
|
451
|
-
def get_cache_dir(model_spec: RerankModelSpec):
|
|
452
|
-
return os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name))
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
def get_cache_status(
|
|
456
|
-
model_spec: RerankModelSpec,
|
|
457
|
-
) -> bool:
|
|
458
|
-
return is_model_cached(model_spec, MODEL_NAME_TO_REVISION)
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
def cache(model_spec: RerankModelSpec):
|
|
462
|
-
from ..utils import cache
|
|
463
|
-
|
|
464
|
-
return cache(model_spec, RerankModelDescription)
|
|
465
|
-
|
|
466
|
-
|
|
467
434
|
def create_rerank_model_instance(
|
|
468
|
-
subpool_addr: str,
|
|
469
|
-
devices: List[str],
|
|
470
435
|
model_uid: str,
|
|
471
436
|
model_name: str,
|
|
472
437
|
download_hub: Optional[
|
|
@@ -474,9 +439,10 @@ def create_rerank_model_instance(
|
|
|
474
439
|
] = None,
|
|
475
440
|
model_path: Optional[str] = None,
|
|
476
441
|
**kwargs,
|
|
477
|
-
) ->
|
|
442
|
+
) -> RerankModel:
|
|
443
|
+
from ..cache_manager import CacheManager
|
|
478
444
|
from ..utils import download_from_modelscope
|
|
479
|
-
from . import BUILTIN_RERANK_MODELS
|
|
445
|
+
from . import BUILTIN_RERANK_MODELS
|
|
480
446
|
from .custom import get_user_defined_reranks
|
|
481
447
|
|
|
482
448
|
model_spec = None
|
|
@@ -486,31 +452,25 @@ def create_rerank_model_instance(
|
|
|
486
452
|
break
|
|
487
453
|
|
|
488
454
|
if model_spec is None:
|
|
489
|
-
if
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
elif model_name in BUILTIN_RERANK_MODELS:
|
|
499
|
-
logger.debug(f"Rerank model {model_name} found in Huggingface.")
|
|
500
|
-
model_spec = BUILTIN_RERANK_MODELS[model_name]
|
|
455
|
+
if model_name in BUILTIN_RERANK_MODELS:
|
|
456
|
+
model_specs = BUILTIN_RERANK_MODELS[model_name]
|
|
457
|
+
if download_hub == "modelscope" or download_from_modelscope():
|
|
458
|
+
model_spec = (
|
|
459
|
+
[x for x in model_specs if x.model_hub == "modelscope"]
|
|
460
|
+
+ [x for x in model_specs if x.model_hub == "huggingface"]
|
|
461
|
+
)[0]
|
|
462
|
+
else:
|
|
463
|
+
model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0]
|
|
501
464
|
else:
|
|
502
465
|
raise ValueError(
|
|
503
|
-
f"Rerank model {model_name} not found, available"
|
|
504
|
-
f"
|
|
505
|
-
f"ModelScope: {MODELSCOPE_RERANK_MODELS.keys()}"
|
|
466
|
+
f"Rerank model {model_name} not found, available "
|
|
467
|
+
f"model: {BUILTIN_RERANK_MODELS.keys()}"
|
|
506
468
|
)
|
|
507
469
|
if not model_path:
|
|
508
|
-
|
|
470
|
+
cache_manager = CacheManager(model_spec)
|
|
471
|
+
model_path = cache_manager.cache()
|
|
509
472
|
use_fp16 = kwargs.pop("use_fp16", False)
|
|
510
473
|
model = RerankModel(
|
|
511
474
|
model_spec, model_uid, model_path, use_fp16=use_fp16, model_config=kwargs
|
|
512
475
|
)
|
|
513
|
-
|
|
514
|
-
subpool_addr, devices, model_spec, model_path=model_path
|
|
515
|
-
)
|
|
516
|
-
return model, model_description
|
|
476
|
+
return model
|