xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (108) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +2 -1
  3. xinference/core/model.py +8 -4
  4. xinference/core/supervisor.py +2 -3
  5. xinference/core/worker.py +7 -5
  6. xinference/deploy/cmdline.py +2 -0
  7. xinference/deploy/local.py +5 -0
  8. xinference/deploy/test/test_cmdline.py +1 -1
  9. xinference/deploy/worker.py +6 -0
  10. xinference/model/audio/cosyvoice.py +0 -1
  11. xinference/model/audio/model_spec.json +44 -20
  12. xinference/model/core.py +3 -0
  13. xinference/model/embedding/flag/core.py +5 -0
  14. xinference/model/embedding/llama_cpp/core.py +22 -19
  15. xinference/model/embedding/sentence_transformers/core.py +18 -4
  16. xinference/model/embedding/vllm/core.py +36 -9
  17. xinference/model/image/cache_manager.py +56 -0
  18. xinference/model/image/core.py +9 -0
  19. xinference/model/image/model_spec.json +178 -1
  20. xinference/model/image/stable_diffusion/core.py +155 -23
  21. xinference/model/llm/cache_manager.py +17 -3
  22. xinference/model/llm/harmony.py +245 -0
  23. xinference/model/llm/llama_cpp/core.py +41 -40
  24. xinference/model/llm/llm_family.json +688 -11
  25. xinference/model/llm/llm_family.py +1 -1
  26. xinference/model/llm/sglang/core.py +108 -5
  27. xinference/model/llm/transformers/core.py +20 -18
  28. xinference/model/llm/transformers/gemma3.py +1 -1
  29. xinference/model/llm/transformers/gpt_oss.py +91 -0
  30. xinference/model/llm/transformers/multimodal/core.py +1 -1
  31. xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
  32. xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
  33. xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
  34. xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
  35. xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
  36. xinference/model/llm/transformers/utils.py +1 -33
  37. xinference/model/llm/utils.py +61 -7
  38. xinference/model/llm/vllm/core.py +44 -8
  39. xinference/model/rerank/__init__.py +66 -23
  40. xinference/model/rerank/cache_manager.py +35 -0
  41. xinference/model/rerank/core.py +87 -339
  42. xinference/model/rerank/custom.py +33 -8
  43. xinference/model/rerank/model_spec.json +251 -212
  44. xinference/model/rerank/rerank_family.py +137 -0
  45. xinference/model/rerank/sentence_transformers/__init__.py +13 -0
  46. xinference/model/rerank/sentence_transformers/core.py +337 -0
  47. xinference/model/rerank/vllm/__init__.py +13 -0
  48. xinference/model/rerank/vllm/core.py +156 -0
  49. xinference/model/utils.py +108 -0
  50. xinference/model/video/model_spec.json +95 -1
  51. xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
  52. xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
  53. xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
  54. xinference/thirdparty/cosyvoice/bin/train.py +23 -3
  55. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
  56. xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
  57. xinference/thirdparty/cosyvoice/cli/model.py +53 -75
  58. xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
  59. xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
  60. xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
  61. xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
  62. xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
  63. xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
  64. xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
  65. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
  66. xinference/thirdparty/cosyvoice/utils/common.py +20 -0
  67. xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
  68. xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
  69. xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
  70. xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
  71. xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
  72. xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
  73. xinference/types.py +2 -0
  74. xinference/ui/gradio/chat_interface.py +2 -0
  75. xinference/ui/gradio/media_interface.py +353 -7
  76. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  77. xinference/ui/web/ui/build/index.html +1 -1
  78. xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
  79. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
  80. xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
  81. xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
  82. xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
  83. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
  84. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
  85. xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
  86. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
  87. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
  88. xinference/ui/web/ui/src/locales/en.json +2 -0
  89. xinference/ui/web/ui/src/locales/ja.json +2 -0
  90. xinference/ui/web/ui/src/locales/ko.json +2 -0
  91. xinference/ui/web/ui/src/locales/zh.json +2 -0
  92. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
  93. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
  94. xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
  95. xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
  96. xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
  97. xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
  98. xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
  99. xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
  100. xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
  101. xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
  102. xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
  103. xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
  104. /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
  105. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
  106. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
  107. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
  108. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
@@ -67,6 +67,9 @@ QWEN_TOOL_CALL_FAMILY = [
67
67
  "qwen3",
68
68
  "HuatuoGPT-o1-Qwen2.5",
69
69
  "DianJin-R1",
70
+ "Qwen3-Thinking",
71
+ "Qwen3-Instruct",
72
+ "Qwen3-Coder",
70
73
  ]
71
74
 
72
75
  GLM4_TOOL_CALL_FAMILY = [
@@ -79,9 +82,7 @@ LLAMA3_TOOL_CALL_FAMILY = [
79
82
  "HuatuoGPT-o1-LLaMA-3.1",
80
83
  ]
81
84
 
82
- DEEPSEEK_TOOL_CALL_FAMILY = [
83
- "deepseek-v3",
84
- ]
85
+ DEEPSEEK_TOOL_CALL_FAMILY = ["deepseek-v3", "deepseek-r1-0528", "Deepseek-V3.1"]
85
86
 
86
87
  TOOL_CALL_FAMILY = (
87
88
  QWEN_TOOL_CALL_FAMILY
@@ -167,8 +168,7 @@ class ChatModelMixin:
167
168
  return json.loads(kwargs)
168
169
  except json.JSONDecodeError:
169
170
  raise TypeError(
170
- f"`chat_template_kwargs` should be json parsable, "
171
- f"got: {kwargs}"
171
+ f"`chat_template_kwargs` should be json parsable, got: {kwargs}"
172
172
  )
173
173
  elif isinstance(kwargs, dict):
174
174
  return kwargs
@@ -254,7 +254,7 @@ class ChatModelMixin:
254
254
  ret += role + "\n" + text + intra_message_sep + "\n"
255
255
  else:
256
256
  placeholders = "\n".join(
257
- f"Image-{i+1}: <image>\n"
257
+ f"Image-{i + 1}: <image>\n"
258
258
  for i in range(
259
259
  len(images) - len(image_futures), len(images)
260
260
  )
@@ -463,6 +463,7 @@ class ChatModelMixin:
463
463
  chat_context_var.set(ctx)
464
464
 
465
465
  previous_texts = [""]
466
+ full_text = ""
466
467
  # Process chunks
467
468
  if reasoning_parser:
468
469
  set_context()
@@ -474,10 +475,14 @@ class ChatModelMixin:
474
475
  # usage
475
476
  chat_chunk = cls._get_final_chat_completion_chunk(chunk)
476
477
  else:
478
+ if choices[0].get("text"):
479
+ full_text += choices[0]["text"] # type: ignore
480
+
477
481
  chat_chunk = cls._to_chat_completion_chunk(
478
482
  chunk, reasoning_parser, previous_texts
479
483
  )
480
484
  yield chat_chunk
485
+ logger.debug("Chat finished, output: %s", full_text)
481
486
 
482
487
  @staticmethod
483
488
  def _to_chat_completion(
@@ -683,6 +688,52 @@ class ChatModelMixin:
683
688
 
684
689
  return results
685
690
 
691
+ @classmethod
692
+ def _eval_deepseek_r1_arguments(cls, c) -> List[Tuple]:
693
+ """
694
+ Parses tool calls from deepseek-r1 (0528) chat template format.
695
+ Returns:
696
+ List of (None, function_name, arguments_dict)
697
+ or (raw_content, None, None) if parsing fails.
698
+ """
699
+ text = c["choices"][0]["text"]
700
+ pattern = (
701
+ r"<\|tool▁call▁begin|>function<\|tool▁sep|>([^\n]+)\n"
702
+ r"```json\n(.*?)\n```<\|tool▁call▁end|>"
703
+ )
704
+
705
+ matches = re.findall(pattern, text, re.DOTALL)
706
+ if not matches:
707
+ return [(text, None, None)]
708
+
709
+ tool_calls = set()
710
+ results = []
711
+
712
+ for func_name, raw_json in matches:
713
+ func_and_args = None
714
+ try:
715
+ func_and_args = json.loads(raw_json)
716
+ arguments_hashable = frozenset(func_and_args.items())
717
+ tool_call_tuple = (
718
+ None,
719
+ func_name,
720
+ func_and_args,
721
+ )
722
+ except Exception:
723
+ tool_call_tuple = (raw_json, None, None)
724
+ arguments_hashable = None
725
+
726
+ dedup_key = (
727
+ (func_name, arguments_hashable)
728
+ if func_and_args is not None
729
+ else raw_json
730
+ )
731
+ if dedup_key not in tool_calls:
732
+ tool_calls.add(dedup_key)
733
+ results.append(tool_call_tuple)
734
+
735
+ return results
736
+
686
737
  @classmethod
687
738
  def _eval_tool_arguments(
688
739
  cls, model_family, c, tool_call_text: Optional[str] = None
@@ -695,7 +746,10 @@ class ChatModelMixin:
695
746
  elif family in LLAMA3_TOOL_CALL_FAMILY:
696
747
  result = cls._eval_llama3_chat_arguments(c)
697
748
  elif family in DEEPSEEK_TOOL_CALL_FAMILY:
698
- result = cls._eval_deepseek_chat_arguments(c)
749
+ if family == "deepseek-r1-0528":
750
+ result = cls._eval_deepseek_r1_arguments(c)
751
+ else:
752
+ result = cls._eval_deepseek_chat_arguments(c)
699
753
  else:
700
754
  raise Exception(
701
755
  f"Model {model_family.model_name} is not support tool calls."
@@ -89,6 +89,7 @@ class VLLMModelConfig(TypedDict, total=False):
89
89
  mm_processor_kwargs: NotRequired[dict[str, Any]]
90
90
  min_pixels: NotRequired[int]
91
91
  max_pixels: NotRequired[int]
92
+ enable_expert_parallel: bool
92
93
 
93
94
 
94
95
  class VLLMGenerateConfig(TypedDict, total=False):
@@ -272,9 +273,19 @@ if VLLM_INSTALLED and VLLM_VERSION >= version.parse("0.9.2"):
272
273
  VLLM_SUPPORTED_CHAT_MODELS.append("Qwen3-Instruct")
273
274
  VLLM_SUPPORTED_CHAT_MODELS.append("Qwen3-Thinking")
274
275
  VLLM_SUPPORTED_CHAT_MODELS.append("Qwen3-Coder")
276
+ VLLM_SUPPORTED_CHAT_MODELS.append("Deepseek-V3.1")
275
277
 
276
- if VLLM_INSTALLED and VLLM_VERSION > version.parse("0.10.0"):
278
+ if VLLM_INSTALLED and VLLM_VERSION >= version.parse("0.10.0"):
277
279
  VLLM_SUPPORTED_CHAT_MODELS.append("glm-4.5")
280
+ VLLM_SUPPORTED_VISION_MODEL_LIST.append("glm-4.5v")
281
+ VLLM_SUPPORTED_CHAT_MODELS.append("KAT-V1")
282
+
283
+ if VLLM_INSTALLED and VLLM_VERSION > version.parse("0.10.0"):
284
+ VLLM_SUPPORTED_CHAT_MODELS.append("gpt-oss")
285
+ VLLM_SUPPORTED_CHAT_MODELS.append("seed-oss")
286
+
287
+ if VLLM_INSTALLED and VLLM_VERSION > version.parse("0.10.1.1"):
288
+ VLLM_SUPPORTED_CHAT_MODELS.append("seed-oss")
278
289
 
279
290
 
280
291
  class VLLMModel(LLM):
@@ -557,7 +568,9 @@ class VLLMModel(LLM):
557
568
  raise err.with_traceback(tb)
558
569
 
559
570
  # set context length after engine inited
560
- self._set_context_length()
571
+ # if shard > 0, the engine will be inited in another process
572
+ if self._engine:
573
+ self._set_context_length()
561
574
 
562
575
  def _set_context_length(self):
563
576
  from vllm import envs
@@ -839,7 +852,7 @@ class VLLMModel(LLM):
839
852
  return False
840
853
  if not cls._is_linux():
841
854
  return False
842
- if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
855
+ if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "bnb"]:
843
856
  return False
844
857
  if llm_spec.model_format == "pytorch":
845
858
  if quantization != "none" and not (quantization is None):
@@ -1187,7 +1200,14 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
1187
1200
  def match_json(
1188
1201
  cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
1189
1202
  ) -> bool:
1190
- if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "ggufv2"]:
1203
+ if llm_spec.model_format not in [
1204
+ "pytorch",
1205
+ "gptq",
1206
+ "awq",
1207
+ "fp8",
1208
+ "bnb",
1209
+ "ggufv2",
1210
+ ]:
1191
1211
  return False
1192
1212
  if llm_spec.model_format == "pytorch":
1193
1213
  if quantization != "none" and not (quantization is None):
@@ -1284,6 +1304,7 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
1284
1304
  previous_texts = [""]
1285
1305
  tool_call = False
1286
1306
  tool_call_texts = [""]
1307
+ full_text = ""
1287
1308
  if self.reasoning_parser:
1288
1309
  set_context()
1289
1310
  chunks = self.reasoning_parser.prepare_reasoning_content_streaming(chunks)
@@ -1299,6 +1320,7 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
1299
1320
  if not choices:
1300
1321
  yield self._get_final_chat_completion_chunk(chunk)
1301
1322
  else:
1323
+ full_text += chunk["choices"][0]["text"]
1302
1324
  if self.is_tool_call_chunk_start(chunk):
1303
1325
  tool_call = True
1304
1326
  if tool_call:
@@ -1320,6 +1342,7 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
1320
1342
  chunk, self.reasoning_parser, previous_texts
1321
1343
  )
1322
1344
  i += 1
1345
+ logger.debug("Chat finished, output: %s", full_text)
1323
1346
 
1324
1347
  @vllm_check
1325
1348
  async def async_chat(
@@ -1348,13 +1371,26 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
1348
1371
  ):
1349
1372
  full_context_kwargs["tools"] = tools
1350
1373
  assert self.model_family.chat_template is not None
1351
- full_prompt = self.get_full_context(
1352
- messages, self.model_family.chat_template, **full_context_kwargs
1353
- )
1354
1374
 
1355
1375
  generate_config = self._sanitize_chat_config(generate_config)
1356
1376
  stream = generate_config.get("stream", None)
1357
1377
 
1378
+ lora_request = None
1379
+ lora_model = generate_config.get("lora_name")
1380
+ if lora_model is not None:
1381
+ for lora in self.lora_requests:
1382
+ if lora_model == lora.lora_name:
1383
+ lora_request = lora
1384
+ break
1385
+ tokenizer = await self._get_tokenizer(lora_request)
1386
+
1387
+ full_prompt = self.get_full_context(
1388
+ messages,
1389
+ self.model_family.chat_template,
1390
+ tokenizer=tokenizer,
1391
+ **full_context_kwargs,
1392
+ )
1393
+
1358
1394
  if stream:
1359
1395
  agen = await self.async_generate(
1360
1396
  full_prompt, generate_config, tools, request_id=request_id
@@ -1386,7 +1422,7 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
1386
1422
  return False
1387
1423
  if not cls._is_linux():
1388
1424
  return False
1389
- if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
1425
+ if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "bnb"]:
1390
1426
  return False
1391
1427
  if llm_spec.model_format == "pytorch":
1392
1428
  if quantization != "none" and not (quantization is None):
@@ -16,10 +16,10 @@ import codecs
16
16
  import json
17
17
  import os
18
18
  import warnings
19
- from typing import Dict, List
19
+ from typing import Any, Dict, List
20
20
 
21
21
  from ...constants import XINFERENCE_MODEL_DIR
22
- from ..utils import flatten_model_src
22
+ from ..utils import flatten_quantizations
23
23
  from .core import (
24
24
  RERANK_MODEL_DESCRIPTIONS,
25
25
  RerankModelFamilyV2,
@@ -32,8 +32,13 @@ from .custom import (
32
32
  register_rerank,
33
33
  unregister_rerank,
34
34
  )
35
-
36
- BUILTIN_RERANK_MODELS: Dict[str, List["RerankModelFamilyV2"]] = {}
35
+ from .rerank_family import (
36
+ BUILTIN_RERANK_MODELS,
37
+ RERANK_ENGINES,
38
+ SENTENCE_TRANSFORMER_CLASSES,
39
+ SUPPORTED_ENGINES,
40
+ VLLM_CLASSES,
41
+ )
37
42
 
38
43
 
39
44
  def register_custom_model():
@@ -58,31 +63,69 @@ def register_custom_model():
58
63
  warnings.warn(f"{user_defined_rerank_dir}/{f} has error, {e}")
59
64
 
60
65
 
61
- def _install():
62
- load_model_family_from_json("model_spec.json", BUILTIN_RERANK_MODELS)
66
+ def generate_engine_config_by_model_name(model_family: "RerankModelFamilyV2"):
67
+ model_name = model_family.model_name
68
+ engines: Dict[str, List[Dict[str, Any]]] = RERANK_ENGINES.get(
69
+ model_name, {}
70
+ ) # structure for engine query
71
+ for spec in [x for x in model_family.model_specs if x.model_hub == "huggingface"]:
72
+ model_format = spec.model_format
73
+ quantization = spec.quantization
74
+ for engine in SUPPORTED_ENGINES:
75
+ CLASSES = SUPPORTED_ENGINES[engine]
76
+ for cls in CLASSES:
77
+ # Every engine needs to implement match method
78
+ if cls.match(model_family, spec, quantization):
79
+ # we only match the first class for an engine
80
+ if engine not in engines:
81
+ engines[engine] = [
82
+ {
83
+ "model_name": model_name,
84
+ "model_format": model_format,
85
+ "quantization": quantization,
86
+ "rerank_class": cls,
87
+ }
88
+ ]
89
+ else:
90
+ engines[engine].append(
91
+ {
92
+ "model_name": model_name,
93
+ "model_format": model_format,
94
+ "quantization": quantization,
95
+ "rerank_class": cls,
96
+ }
97
+ )
98
+ break
99
+ RERANK_ENGINES[model_name] = engines
100
+
63
101
 
64
- for model_name, model_specs in BUILTIN_RERANK_MODELS.items():
65
- model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0]
102
+ def _install():
103
+ _model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
104
+ for json_obj in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8")):
105
+ flattened = []
106
+ for spec in json_obj["model_specs"]:
107
+ flattened.extend(flatten_quantizations(spec))
108
+ json_obj["model_specs"] = flattened
109
+ BUILTIN_RERANK_MODELS[json_obj["model_name"]] = RerankModelFamilyV2(**json_obj)
110
+
111
+ for model_name, model_spec in BUILTIN_RERANK_MODELS.items():
66
112
  if model_spec.model_name not in RERANK_MODEL_DESCRIPTIONS:
67
113
  RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(model_spec))
68
114
 
69
- register_custom_model()
115
+ from .sentence_transformers.core import SentenceTransformerRerankModel
116
+ from .vllm.core import VLLMRerankModel
70
117
 
71
- # register model description
72
- for ud_rerank in get_user_defined_reranks():
73
- RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(ud_rerank))
118
+ SENTENCE_TRANSFORMER_CLASSES.extend([SentenceTransformerRerankModel])
119
+ VLLM_CLASSES.extend([VLLMRerankModel])
74
120
 
121
+ SUPPORTED_ENGINES["sentence_transformers"] = SENTENCE_TRANSFORMER_CLASSES
122
+ SUPPORTED_ENGINES["vllm"] = VLLM_CLASSES
75
123
 
76
- def load_model_family_from_json(json_filename, target_families):
77
- _model_spec_json = os.path.join(os.path.dirname(__file__), json_filename)
78
- flattened_model_specs = []
79
- for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8")):
80
- flattened_model_specs.extend(flatten_model_src(spec))
124
+ for model_spec in BUILTIN_RERANK_MODELS.values():
125
+ generate_engine_config_by_model_name(model_spec)
81
126
 
82
- for spec in flattened_model_specs:
83
- if spec["model_name"] not in target_families:
84
- target_families[spec["model_name"]] = [RerankModelFamilyV2(**spec)]
85
- else:
86
- target_families[spec["model_name"]].append(RerankModelFamilyV2(**spec))
127
+ register_custom_model()
87
128
 
88
- del _model_spec_json
129
+ # register model description
130
+ for ud_rerank in get_user_defined_reranks():
131
+ RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(ud_rerank))
@@ -0,0 +1,35 @@
1
+ import os
2
+ from typing import TYPE_CHECKING
3
+
4
+ from ..cache_manager import CacheManager
5
+
6
+ if TYPE_CHECKING:
7
+ from .core import RerankModelFamilyV2
8
+
9
+
10
+ class RerankCacheManager(CacheManager):
11
+ def __init__(self, model_family: "RerankModelFamilyV2"):
12
+ from ..llm.cache_manager import LLMCacheManager
13
+
14
+ super().__init__(model_family)
15
+ # Composition design mode for avoiding duplicate code
16
+ self.cache_helper = LLMCacheManager(model_family)
17
+
18
+ spec = self._model_family.model_specs[0]
19
+ model_dir_name = (
20
+ f"{self._model_family.model_name}-{spec.model_format}-{spec.quantization}"
21
+ )
22
+ self._cache_dir = os.path.join(self._v2_cache_dir_prefix, model_dir_name)
23
+ self.cache_helper._cache_dir = self._cache_dir
24
+
25
+ def cache(self) -> str:
26
+ spec = self._model_family.model_specs[0]
27
+ if spec.model_uri is not None:
28
+ return self.cache_helper.cache_uri()
29
+ else:
30
+ if spec.model_hub == "huggingface":
31
+ return self.cache_helper.cache_from_huggingface()
32
+ elif spec.model_hub == "modelscope":
33
+ return self.cache_helper.cache_from_modelscope()
34
+ else:
35
+ raise ValueError(f"Unknown model hub: {spec.model_hub}")