xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +2 -1
- xinference/core/model.py +8 -4
- xinference/core/supervisor.py +2 -3
- xinference/core/worker.py +7 -5
- xinference/deploy/cmdline.py +2 -0
- xinference/deploy/local.py +5 -0
- xinference/deploy/test/test_cmdline.py +1 -1
- xinference/deploy/worker.py +6 -0
- xinference/model/audio/cosyvoice.py +0 -1
- xinference/model/audio/model_spec.json +44 -20
- xinference/model/core.py +3 -0
- xinference/model/embedding/flag/core.py +5 -0
- xinference/model/embedding/llama_cpp/core.py +22 -19
- xinference/model/embedding/sentence_transformers/core.py +18 -4
- xinference/model/embedding/vllm/core.py +36 -9
- xinference/model/image/cache_manager.py +56 -0
- xinference/model/image/core.py +9 -0
- xinference/model/image/model_spec.json +178 -1
- xinference/model/image/stable_diffusion/core.py +155 -23
- xinference/model/llm/cache_manager.py +17 -3
- xinference/model/llm/harmony.py +245 -0
- xinference/model/llm/llama_cpp/core.py +41 -40
- xinference/model/llm/llm_family.json +688 -11
- xinference/model/llm/llm_family.py +1 -1
- xinference/model/llm/sglang/core.py +108 -5
- xinference/model/llm/transformers/core.py +20 -18
- xinference/model/llm/transformers/gemma3.py +1 -1
- xinference/model/llm/transformers/gpt_oss.py +91 -0
- xinference/model/llm/transformers/multimodal/core.py +1 -1
- xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
- xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
- xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
- xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
- xinference/model/llm/transformers/utils.py +1 -33
- xinference/model/llm/utils.py +61 -7
- xinference/model/llm/vllm/core.py +44 -8
- xinference/model/rerank/__init__.py +66 -23
- xinference/model/rerank/cache_manager.py +35 -0
- xinference/model/rerank/core.py +87 -339
- xinference/model/rerank/custom.py +33 -8
- xinference/model/rerank/model_spec.json +251 -212
- xinference/model/rerank/rerank_family.py +137 -0
- xinference/model/rerank/sentence_transformers/__init__.py +13 -0
- xinference/model/rerank/sentence_transformers/core.py +337 -0
- xinference/model/rerank/vllm/__init__.py +13 -0
- xinference/model/rerank/vllm/core.py +156 -0
- xinference/model/utils.py +108 -0
- xinference/model/video/model_spec.json +95 -1
- xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
- xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
- xinference/thirdparty/cosyvoice/bin/train.py +23 -3
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
- xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
- xinference/thirdparty/cosyvoice/cli/model.py +53 -75
- xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
- xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
- xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
- xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
- xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
- xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
- xinference/thirdparty/cosyvoice/utils/common.py +20 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
- xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
- xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
- xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
- xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
- xinference/types.py +2 -0
- xinference/ui/gradio/chat_interface.py +2 -0
- xinference/ui/gradio/media_interface.py +353 -7
- xinference/ui/web/ui/build/asset-manifest.json +3 -3
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
- xinference/ui/web/ui/src/locales/en.json +2 -0
- xinference/ui/web/ui/src/locales/ja.json +2 -0
- xinference/ui/web/ui/src/locales/ko.json +2 -0
- xinference/ui/web/ui/src/locales/zh.json +2 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
- xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
- xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
- /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
xinference/model/llm/utils.py
CHANGED
|
@@ -67,6 +67,9 @@ QWEN_TOOL_CALL_FAMILY = [
|
|
|
67
67
|
"qwen3",
|
|
68
68
|
"HuatuoGPT-o1-Qwen2.5",
|
|
69
69
|
"DianJin-R1",
|
|
70
|
+
"Qwen3-Thinking",
|
|
71
|
+
"Qwen3-Instruct",
|
|
72
|
+
"Qwen3-Coder",
|
|
70
73
|
]
|
|
71
74
|
|
|
72
75
|
GLM4_TOOL_CALL_FAMILY = [
|
|
@@ -79,9 +82,7 @@ LLAMA3_TOOL_CALL_FAMILY = [
|
|
|
79
82
|
"HuatuoGPT-o1-LLaMA-3.1",
|
|
80
83
|
]
|
|
81
84
|
|
|
82
|
-
DEEPSEEK_TOOL_CALL_FAMILY = [
|
|
83
|
-
"deepseek-v3",
|
|
84
|
-
]
|
|
85
|
+
DEEPSEEK_TOOL_CALL_FAMILY = ["deepseek-v3", "deepseek-r1-0528", "Deepseek-V3.1"]
|
|
85
86
|
|
|
86
87
|
TOOL_CALL_FAMILY = (
|
|
87
88
|
QWEN_TOOL_CALL_FAMILY
|
|
@@ -167,8 +168,7 @@ class ChatModelMixin:
|
|
|
167
168
|
return json.loads(kwargs)
|
|
168
169
|
except json.JSONDecodeError:
|
|
169
170
|
raise TypeError(
|
|
170
|
-
f"`chat_template_kwargs` should be json parsable, "
|
|
171
|
-
f"got: {kwargs}"
|
|
171
|
+
f"`chat_template_kwargs` should be json parsable, got: {kwargs}"
|
|
172
172
|
)
|
|
173
173
|
elif isinstance(kwargs, dict):
|
|
174
174
|
return kwargs
|
|
@@ -254,7 +254,7 @@ class ChatModelMixin:
|
|
|
254
254
|
ret += role + "\n" + text + intra_message_sep + "\n"
|
|
255
255
|
else:
|
|
256
256
|
placeholders = "\n".join(
|
|
257
|
-
f"Image-{i+1}: <image>\n"
|
|
257
|
+
f"Image-{i + 1}: <image>\n"
|
|
258
258
|
for i in range(
|
|
259
259
|
len(images) - len(image_futures), len(images)
|
|
260
260
|
)
|
|
@@ -463,6 +463,7 @@ class ChatModelMixin:
|
|
|
463
463
|
chat_context_var.set(ctx)
|
|
464
464
|
|
|
465
465
|
previous_texts = [""]
|
|
466
|
+
full_text = ""
|
|
466
467
|
# Process chunks
|
|
467
468
|
if reasoning_parser:
|
|
468
469
|
set_context()
|
|
@@ -474,10 +475,14 @@ class ChatModelMixin:
|
|
|
474
475
|
# usage
|
|
475
476
|
chat_chunk = cls._get_final_chat_completion_chunk(chunk)
|
|
476
477
|
else:
|
|
478
|
+
if choices[0].get("text"):
|
|
479
|
+
full_text += choices[0]["text"] # type: ignore
|
|
480
|
+
|
|
477
481
|
chat_chunk = cls._to_chat_completion_chunk(
|
|
478
482
|
chunk, reasoning_parser, previous_texts
|
|
479
483
|
)
|
|
480
484
|
yield chat_chunk
|
|
485
|
+
logger.debug("Chat finished, output: %s", full_text)
|
|
481
486
|
|
|
482
487
|
@staticmethod
|
|
483
488
|
def _to_chat_completion(
|
|
@@ -683,6 +688,52 @@ class ChatModelMixin:
|
|
|
683
688
|
|
|
684
689
|
return results
|
|
685
690
|
|
|
691
|
+
@classmethod
|
|
692
|
+
def _eval_deepseek_r1_arguments(cls, c) -> List[Tuple]:
|
|
693
|
+
"""
|
|
694
|
+
Parses tool calls from deepseek-r1 (0528) chat template format.
|
|
695
|
+
Returns:
|
|
696
|
+
List of (None, function_name, arguments_dict)
|
|
697
|
+
or (raw_content, None, None) if parsing fails.
|
|
698
|
+
"""
|
|
699
|
+
text = c["choices"][0]["text"]
|
|
700
|
+
pattern = (
|
|
701
|
+
r"<\|tool▁call▁begin|>function<\|tool▁sep|>([^\n]+)\n"
|
|
702
|
+
r"```json\n(.*?)\n```<\|tool▁call▁end|>"
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
matches = re.findall(pattern, text, re.DOTALL)
|
|
706
|
+
if not matches:
|
|
707
|
+
return [(text, None, None)]
|
|
708
|
+
|
|
709
|
+
tool_calls = set()
|
|
710
|
+
results = []
|
|
711
|
+
|
|
712
|
+
for func_name, raw_json in matches:
|
|
713
|
+
func_and_args = None
|
|
714
|
+
try:
|
|
715
|
+
func_and_args = json.loads(raw_json)
|
|
716
|
+
arguments_hashable = frozenset(func_and_args.items())
|
|
717
|
+
tool_call_tuple = (
|
|
718
|
+
None,
|
|
719
|
+
func_name,
|
|
720
|
+
func_and_args,
|
|
721
|
+
)
|
|
722
|
+
except Exception:
|
|
723
|
+
tool_call_tuple = (raw_json, None, None)
|
|
724
|
+
arguments_hashable = None
|
|
725
|
+
|
|
726
|
+
dedup_key = (
|
|
727
|
+
(func_name, arguments_hashable)
|
|
728
|
+
if func_and_args is not None
|
|
729
|
+
else raw_json
|
|
730
|
+
)
|
|
731
|
+
if dedup_key not in tool_calls:
|
|
732
|
+
tool_calls.add(dedup_key)
|
|
733
|
+
results.append(tool_call_tuple)
|
|
734
|
+
|
|
735
|
+
return results
|
|
736
|
+
|
|
686
737
|
@classmethod
|
|
687
738
|
def _eval_tool_arguments(
|
|
688
739
|
cls, model_family, c, tool_call_text: Optional[str] = None
|
|
@@ -695,7 +746,10 @@ class ChatModelMixin:
|
|
|
695
746
|
elif family in LLAMA3_TOOL_CALL_FAMILY:
|
|
696
747
|
result = cls._eval_llama3_chat_arguments(c)
|
|
697
748
|
elif family in DEEPSEEK_TOOL_CALL_FAMILY:
|
|
698
|
-
|
|
749
|
+
if family == "deepseek-r1-0528":
|
|
750
|
+
result = cls._eval_deepseek_r1_arguments(c)
|
|
751
|
+
else:
|
|
752
|
+
result = cls._eval_deepseek_chat_arguments(c)
|
|
699
753
|
else:
|
|
700
754
|
raise Exception(
|
|
701
755
|
f"Model {model_family.model_name} is not support tool calls."
|
|
@@ -89,6 +89,7 @@ class VLLMModelConfig(TypedDict, total=False):
|
|
|
89
89
|
mm_processor_kwargs: NotRequired[dict[str, Any]]
|
|
90
90
|
min_pixels: NotRequired[int]
|
|
91
91
|
max_pixels: NotRequired[int]
|
|
92
|
+
enable_expert_parallel: bool
|
|
92
93
|
|
|
93
94
|
|
|
94
95
|
class VLLMGenerateConfig(TypedDict, total=False):
|
|
@@ -272,9 +273,19 @@ if VLLM_INSTALLED and VLLM_VERSION >= version.parse("0.9.2"):
|
|
|
272
273
|
VLLM_SUPPORTED_CHAT_MODELS.append("Qwen3-Instruct")
|
|
273
274
|
VLLM_SUPPORTED_CHAT_MODELS.append("Qwen3-Thinking")
|
|
274
275
|
VLLM_SUPPORTED_CHAT_MODELS.append("Qwen3-Coder")
|
|
276
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("Deepseek-V3.1")
|
|
275
277
|
|
|
276
|
-
if VLLM_INSTALLED and VLLM_VERSION
|
|
278
|
+
if VLLM_INSTALLED and VLLM_VERSION >= version.parse("0.10.0"):
|
|
277
279
|
VLLM_SUPPORTED_CHAT_MODELS.append("glm-4.5")
|
|
280
|
+
VLLM_SUPPORTED_VISION_MODEL_LIST.append("glm-4.5v")
|
|
281
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("KAT-V1")
|
|
282
|
+
|
|
283
|
+
if VLLM_INSTALLED and VLLM_VERSION > version.parse("0.10.0"):
|
|
284
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("gpt-oss")
|
|
285
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("seed-oss")
|
|
286
|
+
|
|
287
|
+
if VLLM_INSTALLED and VLLM_VERSION > version.parse("0.10.1.1"):
|
|
288
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("seed-oss")
|
|
278
289
|
|
|
279
290
|
|
|
280
291
|
class VLLMModel(LLM):
|
|
@@ -557,7 +568,9 @@ class VLLMModel(LLM):
|
|
|
557
568
|
raise err.with_traceback(tb)
|
|
558
569
|
|
|
559
570
|
# set context length after engine inited
|
|
560
|
-
|
|
571
|
+
# if shard > 0, the engine will be inited in another process
|
|
572
|
+
if self._engine:
|
|
573
|
+
self._set_context_length()
|
|
561
574
|
|
|
562
575
|
def _set_context_length(self):
|
|
563
576
|
from vllm import envs
|
|
@@ -839,7 +852,7 @@ class VLLMModel(LLM):
|
|
|
839
852
|
return False
|
|
840
853
|
if not cls._is_linux():
|
|
841
854
|
return False
|
|
842
|
-
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
|
|
855
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "bnb"]:
|
|
843
856
|
return False
|
|
844
857
|
if llm_spec.model_format == "pytorch":
|
|
845
858
|
if quantization != "none" and not (quantization is None):
|
|
@@ -1187,7 +1200,14 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
1187
1200
|
def match_json(
|
|
1188
1201
|
cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
|
|
1189
1202
|
) -> bool:
|
|
1190
|
-
if llm_spec.model_format not in [
|
|
1203
|
+
if llm_spec.model_format not in [
|
|
1204
|
+
"pytorch",
|
|
1205
|
+
"gptq",
|
|
1206
|
+
"awq",
|
|
1207
|
+
"fp8",
|
|
1208
|
+
"bnb",
|
|
1209
|
+
"ggufv2",
|
|
1210
|
+
]:
|
|
1191
1211
|
return False
|
|
1192
1212
|
if llm_spec.model_format == "pytorch":
|
|
1193
1213
|
if quantization != "none" and not (quantization is None):
|
|
@@ -1284,6 +1304,7 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
1284
1304
|
previous_texts = [""]
|
|
1285
1305
|
tool_call = False
|
|
1286
1306
|
tool_call_texts = [""]
|
|
1307
|
+
full_text = ""
|
|
1287
1308
|
if self.reasoning_parser:
|
|
1288
1309
|
set_context()
|
|
1289
1310
|
chunks = self.reasoning_parser.prepare_reasoning_content_streaming(chunks)
|
|
@@ -1299,6 +1320,7 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
1299
1320
|
if not choices:
|
|
1300
1321
|
yield self._get_final_chat_completion_chunk(chunk)
|
|
1301
1322
|
else:
|
|
1323
|
+
full_text += chunk["choices"][0]["text"]
|
|
1302
1324
|
if self.is_tool_call_chunk_start(chunk):
|
|
1303
1325
|
tool_call = True
|
|
1304
1326
|
if tool_call:
|
|
@@ -1320,6 +1342,7 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
1320
1342
|
chunk, self.reasoning_parser, previous_texts
|
|
1321
1343
|
)
|
|
1322
1344
|
i += 1
|
|
1345
|
+
logger.debug("Chat finished, output: %s", full_text)
|
|
1323
1346
|
|
|
1324
1347
|
@vllm_check
|
|
1325
1348
|
async def async_chat(
|
|
@@ -1348,13 +1371,26 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
1348
1371
|
):
|
|
1349
1372
|
full_context_kwargs["tools"] = tools
|
|
1350
1373
|
assert self.model_family.chat_template is not None
|
|
1351
|
-
full_prompt = self.get_full_context(
|
|
1352
|
-
messages, self.model_family.chat_template, **full_context_kwargs
|
|
1353
|
-
)
|
|
1354
1374
|
|
|
1355
1375
|
generate_config = self._sanitize_chat_config(generate_config)
|
|
1356
1376
|
stream = generate_config.get("stream", None)
|
|
1357
1377
|
|
|
1378
|
+
lora_request = None
|
|
1379
|
+
lora_model = generate_config.get("lora_name")
|
|
1380
|
+
if lora_model is not None:
|
|
1381
|
+
for lora in self.lora_requests:
|
|
1382
|
+
if lora_model == lora.lora_name:
|
|
1383
|
+
lora_request = lora
|
|
1384
|
+
break
|
|
1385
|
+
tokenizer = await self._get_tokenizer(lora_request)
|
|
1386
|
+
|
|
1387
|
+
full_prompt = self.get_full_context(
|
|
1388
|
+
messages,
|
|
1389
|
+
self.model_family.chat_template,
|
|
1390
|
+
tokenizer=tokenizer,
|
|
1391
|
+
**full_context_kwargs,
|
|
1392
|
+
)
|
|
1393
|
+
|
|
1358
1394
|
if stream:
|
|
1359
1395
|
agen = await self.async_generate(
|
|
1360
1396
|
full_prompt, generate_config, tools, request_id=request_id
|
|
@@ -1386,7 +1422,7 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
|
|
|
1386
1422
|
return False
|
|
1387
1423
|
if not cls._is_linux():
|
|
1388
1424
|
return False
|
|
1389
|
-
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
|
|
1425
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "bnb"]:
|
|
1390
1426
|
return False
|
|
1391
1427
|
if llm_spec.model_format == "pytorch":
|
|
1392
1428
|
if quantization != "none" and not (quantization is None):
|
|
@@ -16,10 +16,10 @@ import codecs
|
|
|
16
16
|
import json
|
|
17
17
|
import os
|
|
18
18
|
import warnings
|
|
19
|
-
from typing import Dict, List
|
|
19
|
+
from typing import Any, Dict, List
|
|
20
20
|
|
|
21
21
|
from ...constants import XINFERENCE_MODEL_DIR
|
|
22
|
-
from ..utils import
|
|
22
|
+
from ..utils import flatten_quantizations
|
|
23
23
|
from .core import (
|
|
24
24
|
RERANK_MODEL_DESCRIPTIONS,
|
|
25
25
|
RerankModelFamilyV2,
|
|
@@ -32,8 +32,13 @@ from .custom import (
|
|
|
32
32
|
register_rerank,
|
|
33
33
|
unregister_rerank,
|
|
34
34
|
)
|
|
35
|
-
|
|
36
|
-
BUILTIN_RERANK_MODELS
|
|
35
|
+
from .rerank_family import (
|
|
36
|
+
BUILTIN_RERANK_MODELS,
|
|
37
|
+
RERANK_ENGINES,
|
|
38
|
+
SENTENCE_TRANSFORMER_CLASSES,
|
|
39
|
+
SUPPORTED_ENGINES,
|
|
40
|
+
VLLM_CLASSES,
|
|
41
|
+
)
|
|
37
42
|
|
|
38
43
|
|
|
39
44
|
def register_custom_model():
|
|
@@ -58,31 +63,69 @@ def register_custom_model():
|
|
|
58
63
|
warnings.warn(f"{user_defined_rerank_dir}/{f} has error, {e}")
|
|
59
64
|
|
|
60
65
|
|
|
61
|
-
def
|
|
62
|
-
|
|
66
|
+
def generate_engine_config_by_model_name(model_family: "RerankModelFamilyV2"):
|
|
67
|
+
model_name = model_family.model_name
|
|
68
|
+
engines: Dict[str, List[Dict[str, Any]]] = RERANK_ENGINES.get(
|
|
69
|
+
model_name, {}
|
|
70
|
+
) # structure for engine query
|
|
71
|
+
for spec in [x for x in model_family.model_specs if x.model_hub == "huggingface"]:
|
|
72
|
+
model_format = spec.model_format
|
|
73
|
+
quantization = spec.quantization
|
|
74
|
+
for engine in SUPPORTED_ENGINES:
|
|
75
|
+
CLASSES = SUPPORTED_ENGINES[engine]
|
|
76
|
+
for cls in CLASSES:
|
|
77
|
+
# Every engine needs to implement match method
|
|
78
|
+
if cls.match(model_family, spec, quantization):
|
|
79
|
+
# we only match the first class for an engine
|
|
80
|
+
if engine not in engines:
|
|
81
|
+
engines[engine] = [
|
|
82
|
+
{
|
|
83
|
+
"model_name": model_name,
|
|
84
|
+
"model_format": model_format,
|
|
85
|
+
"quantization": quantization,
|
|
86
|
+
"rerank_class": cls,
|
|
87
|
+
}
|
|
88
|
+
]
|
|
89
|
+
else:
|
|
90
|
+
engines[engine].append(
|
|
91
|
+
{
|
|
92
|
+
"model_name": model_name,
|
|
93
|
+
"model_format": model_format,
|
|
94
|
+
"quantization": quantization,
|
|
95
|
+
"rerank_class": cls,
|
|
96
|
+
}
|
|
97
|
+
)
|
|
98
|
+
break
|
|
99
|
+
RERANK_ENGINES[model_name] = engines
|
|
100
|
+
|
|
63
101
|
|
|
64
|
-
|
|
65
|
-
|
|
102
|
+
def _install():
|
|
103
|
+
_model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
|
|
104
|
+
for json_obj in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8")):
|
|
105
|
+
flattened = []
|
|
106
|
+
for spec in json_obj["model_specs"]:
|
|
107
|
+
flattened.extend(flatten_quantizations(spec))
|
|
108
|
+
json_obj["model_specs"] = flattened
|
|
109
|
+
BUILTIN_RERANK_MODELS[json_obj["model_name"]] = RerankModelFamilyV2(**json_obj)
|
|
110
|
+
|
|
111
|
+
for model_name, model_spec in BUILTIN_RERANK_MODELS.items():
|
|
66
112
|
if model_spec.model_name not in RERANK_MODEL_DESCRIPTIONS:
|
|
67
113
|
RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(model_spec))
|
|
68
114
|
|
|
69
|
-
|
|
115
|
+
from .sentence_transformers.core import SentenceTransformerRerankModel
|
|
116
|
+
from .vllm.core import VLLMRerankModel
|
|
70
117
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(ud_rerank))
|
|
118
|
+
SENTENCE_TRANSFORMER_CLASSES.extend([SentenceTransformerRerankModel])
|
|
119
|
+
VLLM_CLASSES.extend([VLLMRerankModel])
|
|
74
120
|
|
|
121
|
+
SUPPORTED_ENGINES["sentence_transformers"] = SENTENCE_TRANSFORMER_CLASSES
|
|
122
|
+
SUPPORTED_ENGINES["vllm"] = VLLM_CLASSES
|
|
75
123
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
flattened_model_specs = []
|
|
79
|
-
for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8")):
|
|
80
|
-
flattened_model_specs.extend(flatten_model_src(spec))
|
|
124
|
+
for model_spec in BUILTIN_RERANK_MODELS.values():
|
|
125
|
+
generate_engine_config_by_model_name(model_spec)
|
|
81
126
|
|
|
82
|
-
|
|
83
|
-
if spec["model_name"] not in target_families:
|
|
84
|
-
target_families[spec["model_name"]] = [RerankModelFamilyV2(**spec)]
|
|
85
|
-
else:
|
|
86
|
-
target_families[spec["model_name"]].append(RerankModelFamilyV2(**spec))
|
|
127
|
+
register_custom_model()
|
|
87
128
|
|
|
88
|
-
|
|
129
|
+
# register model description
|
|
130
|
+
for ud_rerank in get_user_defined_reranks():
|
|
131
|
+
RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(ud_rerank))
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from ..cache_manager import CacheManager
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from .core import RerankModelFamilyV2
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RerankCacheManager(CacheManager):
|
|
11
|
+
def __init__(self, model_family: "RerankModelFamilyV2"):
|
|
12
|
+
from ..llm.cache_manager import LLMCacheManager
|
|
13
|
+
|
|
14
|
+
super().__init__(model_family)
|
|
15
|
+
# Composition design mode for avoiding duplicate code
|
|
16
|
+
self.cache_helper = LLMCacheManager(model_family)
|
|
17
|
+
|
|
18
|
+
spec = self._model_family.model_specs[0]
|
|
19
|
+
model_dir_name = (
|
|
20
|
+
f"{self._model_family.model_name}-{spec.model_format}-{spec.quantization}"
|
|
21
|
+
)
|
|
22
|
+
self._cache_dir = os.path.join(self._v2_cache_dir_prefix, model_dir_name)
|
|
23
|
+
self.cache_helper._cache_dir = self._cache_dir
|
|
24
|
+
|
|
25
|
+
def cache(self) -> str:
|
|
26
|
+
spec = self._model_family.model_specs[0]
|
|
27
|
+
if spec.model_uri is not None:
|
|
28
|
+
return self.cache_helper.cache_uri()
|
|
29
|
+
else:
|
|
30
|
+
if spec.model_hub == "huggingface":
|
|
31
|
+
return self.cache_helper.cache_from_huggingface()
|
|
32
|
+
elif spec.model_hub == "modelscope":
|
|
33
|
+
return self.cache_helper.cache_from_modelscope()
|
|
34
|
+
else:
|
|
35
|
+
raise ValueError(f"Unknown model hub: {spec.model_hub}")
|