xinference 1.7.1__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (136) hide show
  1. xinference/_version.py +3 -3
  2. xinference/client/restful/async_restful_client.py +8 -13
  3. xinference/client/restful/restful_client.py +6 -2
  4. xinference/core/chat_interface.py +6 -4
  5. xinference/core/media_interface.py +5 -0
  6. xinference/core/model.py +1 -5
  7. xinference/core/supervisor.py +117 -68
  8. xinference/core/worker.py +49 -37
  9. xinference/deploy/test/test_cmdline.py +2 -6
  10. xinference/model/audio/__init__.py +26 -23
  11. xinference/model/audio/chattts.py +3 -2
  12. xinference/model/audio/core.py +49 -98
  13. xinference/model/audio/cosyvoice.py +3 -2
  14. xinference/model/audio/custom.py +28 -73
  15. xinference/model/audio/f5tts.py +3 -2
  16. xinference/model/audio/f5tts_mlx.py +3 -2
  17. xinference/model/audio/fish_speech.py +3 -2
  18. xinference/model/audio/funasr.py +17 -4
  19. xinference/model/audio/kokoro.py +3 -2
  20. xinference/model/audio/megatts.py +3 -2
  21. xinference/model/audio/melotts.py +3 -2
  22. xinference/model/audio/model_spec.json +572 -171
  23. xinference/model/audio/utils.py +0 -6
  24. xinference/model/audio/whisper.py +3 -2
  25. xinference/model/audio/whisper_mlx.py +3 -2
  26. xinference/model/cache_manager.py +141 -0
  27. xinference/model/core.py +6 -49
  28. xinference/model/custom.py +174 -0
  29. xinference/model/embedding/__init__.py +67 -56
  30. xinference/model/embedding/cache_manager.py +35 -0
  31. xinference/model/embedding/core.py +104 -84
  32. xinference/model/embedding/custom.py +55 -78
  33. xinference/model/embedding/embed_family.py +80 -31
  34. xinference/model/embedding/flag/core.py +21 -5
  35. xinference/model/embedding/llama_cpp/__init__.py +0 -0
  36. xinference/model/embedding/llama_cpp/core.py +234 -0
  37. xinference/model/embedding/model_spec.json +968 -103
  38. xinference/model/embedding/sentence_transformers/core.py +30 -20
  39. xinference/model/embedding/vllm/core.py +11 -5
  40. xinference/model/flexible/__init__.py +8 -2
  41. xinference/model/flexible/core.py +26 -119
  42. xinference/model/flexible/custom.py +69 -0
  43. xinference/model/flexible/launchers/image_process_launcher.py +1 -0
  44. xinference/model/flexible/launchers/modelscope_launcher.py +5 -1
  45. xinference/model/flexible/launchers/transformers_launcher.py +15 -3
  46. xinference/model/flexible/launchers/yolo_launcher.py +5 -1
  47. xinference/model/image/__init__.py +20 -20
  48. xinference/model/image/cache_manager.py +62 -0
  49. xinference/model/image/core.py +70 -182
  50. xinference/model/image/custom.py +28 -72
  51. xinference/model/image/model_spec.json +402 -119
  52. xinference/model/image/ocr/got_ocr2.py +3 -2
  53. xinference/model/image/stable_diffusion/core.py +22 -7
  54. xinference/model/image/stable_diffusion/mlx.py +6 -6
  55. xinference/model/image/utils.py +2 -2
  56. xinference/model/llm/__init__.py +71 -94
  57. xinference/model/llm/cache_manager.py +292 -0
  58. xinference/model/llm/core.py +37 -111
  59. xinference/model/llm/custom.py +88 -0
  60. xinference/model/llm/llama_cpp/core.py +5 -7
  61. xinference/model/llm/llm_family.json +16260 -8151
  62. xinference/model/llm/llm_family.py +138 -839
  63. xinference/model/llm/lmdeploy/core.py +5 -7
  64. xinference/model/llm/memory.py +3 -4
  65. xinference/model/llm/mlx/core.py +6 -8
  66. xinference/model/llm/reasoning_parser.py +3 -1
  67. xinference/model/llm/sglang/core.py +32 -14
  68. xinference/model/llm/transformers/chatglm.py +3 -7
  69. xinference/model/llm/transformers/core.py +49 -27
  70. xinference/model/llm/transformers/deepseek_v2.py +2 -2
  71. xinference/model/llm/transformers/gemma3.py +2 -2
  72. xinference/model/llm/transformers/multimodal/cogagent.py +2 -2
  73. xinference/model/llm/transformers/multimodal/deepseek_vl2.py +2 -2
  74. xinference/model/llm/transformers/multimodal/gemma3.py +2 -2
  75. xinference/model/llm/transformers/multimodal/glm4_1v.py +167 -0
  76. xinference/model/llm/transformers/multimodal/glm4v.py +2 -2
  77. xinference/model/llm/transformers/multimodal/intern_vl.py +2 -2
  78. xinference/model/llm/transformers/multimodal/minicpmv26.py +3 -3
  79. xinference/model/llm/transformers/multimodal/ovis2.py +2 -2
  80. xinference/model/llm/transformers/multimodal/qwen-omni.py +2 -2
  81. xinference/model/llm/transformers/multimodal/qwen2_audio.py +2 -2
  82. xinference/model/llm/transformers/multimodal/qwen2_vl.py +2 -2
  83. xinference/model/llm/transformers/opt.py +3 -7
  84. xinference/model/llm/utils.py +34 -49
  85. xinference/model/llm/vllm/core.py +77 -27
  86. xinference/model/llm/vllm/xavier/engine.py +5 -3
  87. xinference/model/llm/vllm/xavier/scheduler.py +10 -6
  88. xinference/model/llm/vllm/xavier/transfer.py +1 -1
  89. xinference/model/rerank/__init__.py +26 -25
  90. xinference/model/rerank/core.py +47 -87
  91. xinference/model/rerank/custom.py +25 -71
  92. xinference/model/rerank/model_spec.json +158 -33
  93. xinference/model/rerank/utils.py +2 -2
  94. xinference/model/utils.py +115 -54
  95. xinference/model/video/__init__.py +13 -17
  96. xinference/model/video/core.py +44 -102
  97. xinference/model/video/diffusers.py +4 -3
  98. xinference/model/video/model_spec.json +90 -21
  99. xinference/types.py +5 -3
  100. xinference/web/ui/build/asset-manifest.json +3 -3
  101. xinference/web/ui/build/index.html +1 -1
  102. xinference/web/ui/build/static/js/main.7d24df53.js +3 -0
  103. xinference/web/ui/build/static/js/main.7d24df53.js.map +1 -0
  104. xinference/web/ui/node_modules/.cache/babel-loader/2704ff66a5f73ca78b341eb3edec60154369df9d87fbc8c6dd60121abc5e1b0a.json +1 -0
  105. xinference/web/ui/node_modules/.cache/babel-loader/607dfef23d33e6b594518c0c6434567639f24f356b877c80c60575184ec50ed0.json +1 -0
  106. xinference/web/ui/node_modules/.cache/babel-loader/9be3d56173aacc3efd0b497bcb13c4f6365de30069176ee9403b40e717542326.json +1 -0
  107. xinference/web/ui/node_modules/.cache/babel-loader/9f9dd6c32c78a222d07da5987ae902effe16bcf20aac00774acdccc4de3c9ff2.json +1 -0
  108. xinference/web/ui/node_modules/.cache/babel-loader/b2ab5ee972c60d15eb9abf5845705f8ab7e1d125d324d9a9b1bcae5d6fd7ffb2.json +1 -0
  109. xinference/web/ui/src/locales/en.json +0 -1
  110. xinference/web/ui/src/locales/ja.json +0 -1
  111. xinference/web/ui/src/locales/ko.json +0 -1
  112. xinference/web/ui/src/locales/zh.json +0 -1
  113. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/METADATA +9 -11
  114. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/RECORD +119 -119
  115. xinference/model/audio/model_spec_modelscope.json +0 -231
  116. xinference/model/embedding/model_spec_modelscope.json +0 -293
  117. xinference/model/embedding/utils.py +0 -18
  118. xinference/model/image/model_spec_modelscope.json +0 -375
  119. xinference/model/llm/llama_cpp/memory.py +0 -457
  120. xinference/model/llm/llm_family_csghub.json +0 -56
  121. xinference/model/llm/llm_family_modelscope.json +0 -8700
  122. xinference/model/llm/llm_family_openmind_hub.json +0 -1019
  123. xinference/model/rerank/model_spec_modelscope.json +0 -85
  124. xinference/model/video/model_spec_modelscope.json +0 -184
  125. xinference/web/ui/build/static/js/main.9b12b7f9.js +0 -3
  126. xinference/web/ui/build/static/js/main.9b12b7f9.js.map +0 -1
  127. xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +0 -1
  128. xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +0 -1
  129. xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +0 -1
  130. xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +0 -1
  131. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +0 -1
  132. /xinference/web/ui/build/static/js/{main.9b12b7f9.js.LICENSE.txt → main.7d24df53.js.LICENSE.txt} +0 -0
  133. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/WHEEL +0 -0
  134. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/entry_points.txt +0 -0
  135. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/licenses/LICENSE +0 -0
  136. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/top_level.txt +0 -0
@@ -50,9 +50,9 @@ from ....types import (
50
50
  CompletionUsage,
51
51
  LoRA,
52
52
  )
53
- from .. import LLM, LLMFamilyV1, LLMSpecV1
53
+ from .. import BUILTIN_LLM_FAMILIES, LLM, LLMFamilyV2, LLMSpecV1
54
54
  from ..core import chat_context_var
55
- from ..llm_family import CustomLLMFamilyV1, cache_model_tokenizer_and_config
55
+ from ..llm_family import CustomLLMFamilyV2, cache_model_tokenizer_and_config
56
56
  from ..utils import (
57
57
  DEEPSEEK_TOOL_CALL_FAMILY,
58
58
  QWEN_TOOL_CALL_FAMILY,
@@ -117,6 +117,11 @@ class VLLMGenerateConfig(TypedDict, total=False):
117
117
  try:
118
118
  import vllm # noqa: F401
119
119
 
120
+ if not getattr(vllm, "__version__", None):
121
+ raise ImportError(
122
+ "vllm not installed properly, or wrongly be found in sys.path"
123
+ )
124
+
120
125
  VLLM_INSTALLED = True
121
126
  except ImportError:
122
127
  VLLM_INSTALLED = False
@@ -257,14 +262,16 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.8.5":
257
262
  if VLLM_INSTALLED and vllm.__version__ >= "0.9.1":
258
263
  VLLM_SUPPORTED_CHAT_MODELS.append("minicpm4")
259
264
 
265
+ if VLLM_INSTALLED and vllm.__version__ >= "0.9.2":
266
+ VLLM_SUPPORTED_CHAT_MODELS.append("Ernie4.5")
267
+ VLLM_SUPPORTED_VISION_MODEL_LIST.append("glm-4.1v-thinking")
268
+
260
269
 
261
270
  class VLLMModel(LLM):
262
271
  def __init__(
263
272
  self,
264
273
  model_uid: str,
265
- model_family: "LLMFamilyV1",
266
- model_spec: "LLMSpecV1",
267
- quantization: str,
274
+ model_family: "LLMFamilyV2",
268
275
  model_path: str,
269
276
  model_config: Optional[VLLMModelConfig],
270
277
  peft_model: Optional[List[LoRA]] = None,
@@ -279,7 +286,7 @@ class VLLMModel(LLM):
279
286
  ]
280
287
 
281
288
  raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
282
- super().__init__(model_uid, model_family, model_spec, quantization, model_path)
289
+ super().__init__(model_uid, model_family, model_path)
283
290
  self._model_config = model_config
284
291
  self._engine = None
285
292
  self.lora_modules = peft_model
@@ -349,7 +356,7 @@ class VLLMModel(LLM):
349
356
 
350
357
  raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
351
358
 
352
- from ..llm_family import LlamaCppLLMSpecV1
359
+ from ..llm_family import LlamaCppLLMSpecV2
353
360
 
354
361
  if "0.3.1" <= vllm.__version__ <= "0.3.3":
355
362
  # from vllm v0.3.1 to v0.3.3, it uses cupy as NCCL backend
@@ -368,7 +375,7 @@ class VLLMModel(LLM):
368
375
  )
369
376
 
370
377
  if (
371
- isinstance(self.model_spec, LlamaCppLLMSpecV1)
378
+ isinstance(self.model_spec, LlamaCppLLMSpecV2)
372
379
  and self.model_spec.model_format == "ggufv2"
373
380
  ):
374
381
  # gguf
@@ -592,20 +599,25 @@ class VLLMModel(LLM):
592
599
 
593
600
  if "tokenizer" not in self._model_config:
594
601
  # find pytorch format without quantization
602
+ family = next(
603
+ family
604
+ for family in BUILTIN_LLM_FAMILIES
605
+ if family.model_name == self.model_family.model_name
606
+ ).copy()
595
607
  non_quant_spec = next(
596
608
  spec
597
- for spec in self.model_family.model_specs
598
- if spec.model_format == "pytorch"
599
- and "none" in spec.quantizations
609
+ for spec in family.model_specs
610
+ if spec.quantization == "none"
600
611
  and spec.model_size_in_billions
601
612
  == self.model_spec.model_size_in_billions
613
+ and spec.model_hub == self.model_spec.model_hub
602
614
  )
603
-
604
- path = cache_model_tokenizer_and_config(self.model_family, non_quant_spec)
615
+ family.model_specs = [non_quant_spec]
616
+ path = cache_model_tokenizer_and_config(family)
605
617
  # other than gguf file, vllm requires to provide tokenizer and hf_config_path
606
- self._model_config["tokenizer"] = self._model_config[
607
- "hf_config_path"
608
- ] = path
618
+ self._model_config["tokenizer"] = self._model_config["hf_config_path"] = (
619
+ path
620
+ )
609
621
 
610
622
  if not os.path.isfile(self.model_path):
611
623
  self.model_path = os.path.realpath(
@@ -791,7 +803,7 @@ class VLLMModel(LLM):
791
803
 
792
804
  @classmethod
793
805
  def match_json(
794
- cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
806
+ cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
795
807
  ) -> bool:
796
808
  if not cls._has_cuda_device():
797
809
  return False
@@ -813,7 +825,7 @@ class VLLMModel(LLM):
813
825
  else:
814
826
  if "4" not in quantization:
815
827
  return False
816
- if isinstance(llm_family, CustomLLMFamilyV1):
828
+ if isinstance(llm_family, CustomLLMFamilyV2):
817
829
  if llm_family.model_family not in VLLM_SUPPORTED_MODELS:
818
830
  return False
819
831
  else:
@@ -1090,7 +1102,7 @@ class VLLMModel(LLM):
1090
1102
  class VLLMChatModel(VLLMModel, ChatModelMixin):
1091
1103
  @classmethod
1092
1104
  def match_json(
1093
- cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
1105
+ cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
1094
1106
  ) -> bool:
1095
1107
  if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "ggufv2"]:
1096
1108
  return False
@@ -1111,7 +1123,7 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
1111
1123
  if llm_spec.model_format == "ggufv2":
1112
1124
  if not (VLLM_INSTALLED and vllm.__version__ >= "0.8.2"):
1113
1125
  return False
1114
- if isinstance(llm_family, CustomLLMFamilyV1):
1126
+ if isinstance(llm_family, CustomLLMFamilyV2):
1115
1127
  if llm_family.model_family not in VLLM_SUPPORTED_CHAT_MODELS:
1116
1128
  return False
1117
1129
  else:
@@ -1137,9 +1149,9 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
1137
1149
  not generate_config.get("stop_token_ids")
1138
1150
  and self.model_family.stop_token_ids
1139
1151
  ):
1140
- generate_config[
1141
- "stop_token_ids"
1142
- ] = self.model_family.stop_token_ids.copy()
1152
+ generate_config["stop_token_ids"] = (
1153
+ self.model_family.stop_token_ids.copy()
1154
+ )
1143
1155
  return generate_config
1144
1156
 
1145
1157
  @staticmethod
@@ -1150,17 +1162,50 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
1150
1162
  def is_tool_call_chunk_end(chunk):
1151
1163
  return chunk["choices"][0]["text"].endswith(QWEN_TOOL_CALL_SYMBOLS[1])
1152
1164
 
1165
+ @staticmethod
1166
+ def prefill_messages(messages: List[Dict]) -> List[Dict]:
1167
+ """
1168
+ Preprocess messages to ensure content is not None
1169
+
1170
+ Args:
1171
+ messages: Original message list
1172
+
1173
+ Returns:
1174
+ Processed message list, where content is not None
1175
+ """
1176
+ processed_messages = []
1177
+
1178
+ for msg in messages:
1179
+ if isinstance(msg, dict):
1180
+ if msg.get("content") is None:
1181
+ msg_copy = msg.copy()
1182
+ msg_copy["content"] = "" # Replace None with empty string
1183
+ processed_messages.append(msg_copy)
1184
+ else:
1185
+ processed_messages.append(msg)
1186
+ else:
1187
+ processed_messages.append(msg)
1188
+
1189
+ return processed_messages
1190
+
1153
1191
  async def _async_to_tool_completion_chunks(
1154
1192
  self,
1155
1193
  chunks: AsyncGenerator[CompletionChunk, None],
1194
+ ctx: Optional[Dict[str, Any]] = {},
1156
1195
  ) -> AsyncGenerator[ChatCompletionChunk, None]:
1196
+ def set_context():
1197
+ if ctx:
1198
+ chat_context_var.set(ctx)
1199
+
1157
1200
  i = 0
1158
1201
  previous_texts = [""]
1159
1202
  tool_call = False
1160
1203
  tool_call_texts = [""]
1161
1204
  if self.reasoning_parser:
1205
+ set_context()
1162
1206
  chunks = self.reasoning_parser.prepare_reasoning_content_streaming(chunks)
1163
1207
  async for chunk in chunks:
1208
+ set_context()
1164
1209
  if i == 0:
1165
1210
  for first_chunk in self._get_first_chat_completion_chunk(
1166
1211
  chunk, self.reasoning_parser
@@ -1200,6 +1245,9 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
1200
1245
  generate_config: Optional[Dict] = None,
1201
1246
  request_id: Optional[str] = None,
1202
1247
  ) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
1248
+ # Preprocess messages to ensure content is not None
1249
+ messages = self.prefill_messages(messages)
1250
+
1203
1251
  tools = generate_config.pop("tools", []) if generate_config else None
1204
1252
  model_family = self.model_family.model_family or self.model_family.model_name
1205
1253
  chat_template_kwargs = (
@@ -1230,8 +1278,10 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
1230
1278
  )
1231
1279
  assert isinstance(agen, AsyncGenerator)
1232
1280
  if tools:
1233
- return self._async_to_tool_completion_chunks(agen)
1234
- return self._async_to_chat_completion_chunks(agen, self.reasoning_parser)
1281
+ return self._async_to_tool_completion_chunks(agen, chat_template_kwargs)
1282
+ return self._async_to_chat_completion_chunks(
1283
+ agen, self.reasoning_parser, chat_template_kwargs
1284
+ )
1235
1285
  else:
1236
1286
  c = await self.async_generate(
1237
1287
  full_prompt, generate_config, request_id=request_id
@@ -1247,7 +1297,7 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
1247
1297
  class VLLMVisionModel(VLLMModel, ChatModelMixin):
1248
1298
  @classmethod
1249
1299
  def match_json(
1250
- cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
1300
+ cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
1251
1301
  ) -> bool:
1252
1302
  if not cls._has_cuda_device():
1253
1303
  return False
@@ -1269,7 +1319,7 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
1269
1319
  else:
1270
1320
  if "4" not in quantization:
1271
1321
  return False
1272
- if isinstance(llm_family, CustomLLMFamilyV1):
1322
+ if isinstance(llm_family, CustomLLMFamilyV2):
1273
1323
  if llm_family.model_family not in VLLM_SUPPORTED_VISION_MODEL_LIST:
1274
1324
  return False
1275
1325
  else:
@@ -39,9 +39,11 @@ class XavierInternalEngine(_AsyncLLMEngine):
39
39
  self.cache_config,
40
40
  self.lora_config,
41
41
  self.parallel_config.pipeline_parallel_size,
42
- self.async_callbacks[v_id]
43
- if self.model_config.use_async_output_proc
44
- else None,
42
+ (
43
+ self.async_callbacks[v_id]
44
+ if self.model_config.use_async_output_proc
45
+ else None
46
+ ),
45
47
  xavier_config=self._xavier_config,
46
48
  virtual_engine=v_id,
47
49
  )
@@ -352,12 +352,16 @@ class XavierScheduler(Scheduler):
352
352
  # between engine and worker.
353
353
  # the subsequent comms can still use delta, but
354
354
  # `multi_modal_data` will be None.
355
- multi_modal_data=seq_group.multi_modal_data
356
- if scheduler_outputs.num_prefill_groups > 0
357
- else None,
358
- multi_modal_placeholders=seq_group.multi_modal_placeholders
359
- if scheduler_outputs.num_prefill_groups > 0
360
- else None,
355
+ multi_modal_data=(
356
+ seq_group.multi_modal_data
357
+ if scheduler_outputs.num_prefill_groups > 0
358
+ else None
359
+ ),
360
+ multi_modal_placeholders=(
361
+ seq_group.multi_modal_placeholders
362
+ if scheduler_outputs.num_prefill_groups > 0
363
+ else None
364
+ ),
361
365
  mm_processor_kwargs=seq_group.mm_processor_kwargs,
362
366
  prompt_adapter_request=seq_group.prompt_adapter_request,
363
367
  )
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
30
30
 
31
31
  class BufferTransferMixin:
32
32
  def __init__(self):
33
- self.num_buffer: int = 0
33
+ self.num_buffer: int = 0 # type: ignore
34
34
  self.buffers: List[torch.Tensor] = [] # type: ignore
35
35
  self.buffer_queue: Optional[Queue] = None # type: ignore
36
36
  self.transfer_block_num = 0
@@ -16,38 +16,41 @@ import codecs
16
16
  import json
17
17
  import os
18
18
  import warnings
19
- from typing import Any, Dict
19
+ from typing import Dict, List
20
20
 
21
21
  from ...constants import XINFERENCE_MODEL_DIR
22
+ from ..utils import flatten_model_src
22
23
  from .core import (
23
- MODEL_NAME_TO_REVISION,
24
24
  RERANK_MODEL_DESCRIPTIONS,
25
- RerankModelSpec,
25
+ RerankModelFamilyV2,
26
26
  generate_rerank_description,
27
- get_cache_status,
28
27
  get_rerank_model_descriptions,
29
28
  )
30
29
  from .custom import (
31
- CustomRerankModelSpec,
30
+ CustomRerankModelFamilyV2,
32
31
  get_user_defined_reranks,
33
32
  register_rerank,
34
33
  unregister_rerank,
35
34
  )
36
35
 
37
- BUILTIN_RERANK_MODELS: Dict[str, Any] = {}
38
- MODELSCOPE_RERANK_MODELS: Dict[str, Any] = {}
36
+ BUILTIN_RERANK_MODELS: Dict[str, List["RerankModelFamilyV2"]] = {}
39
37
 
40
38
 
41
39
  def register_custom_model():
40
+ from ..custom import migrate_from_v1_to_v2
41
+
42
+ # migrate from v1 to v2 first
43
+ migrate_from_v1_to_v2("rerank", CustomRerankModelFamilyV2)
44
+
42
45
  # if persist=True, load them when init
43
- user_defined_rerank_dir = os.path.join(XINFERENCE_MODEL_DIR, "rerank")
46
+ user_defined_rerank_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "rerank")
44
47
  if os.path.isdir(user_defined_rerank_dir):
45
48
  for f in os.listdir(user_defined_rerank_dir):
46
49
  try:
47
50
  with codecs.open(
48
51
  os.path.join(user_defined_rerank_dir, f), encoding="utf-8"
49
52
  ) as fd:
50
- user_defined_rerank_spec = CustomRerankModelSpec.parse_obj(
53
+ user_defined_rerank_spec = CustomRerankModelFamilyV2.parse_obj(
51
54
  json.load(fd)
52
55
  )
53
56
  register_rerank(user_defined_rerank_spec, persist=False)
@@ -57,15 +60,11 @@ def register_custom_model():
57
60
 
58
61
  def _install():
59
62
  load_model_family_from_json("model_spec.json", BUILTIN_RERANK_MODELS)
60
- load_model_family_from_json("model_spec_modelscope.json", MODELSCOPE_RERANK_MODELS)
61
63
 
62
- # register model description after recording model revision
63
- for model_spec_info in [BUILTIN_RERANK_MODELS, MODELSCOPE_RERANK_MODELS]:
64
- for model_name, model_spec in model_spec_info.items():
65
- if model_spec.model_name not in RERANK_MODEL_DESCRIPTIONS:
66
- RERANK_MODEL_DESCRIPTIONS.update(
67
- generate_rerank_description(model_spec)
68
- )
64
+ for model_name, model_specs in BUILTIN_RERANK_MODELS.items():
65
+ model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0]
66
+ if model_spec.model_name not in RERANK_MODEL_DESCRIPTIONS:
67
+ RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(model_spec))
69
68
 
70
69
  register_custom_model()
71
70
 
@@ -76,12 +75,14 @@ def _install():
76
75
 
77
76
  def load_model_family_from_json(json_filename, target_families):
78
77
  _model_spec_json = os.path.join(os.path.dirname(__file__), json_filename)
79
- target_families.update(
80
- dict(
81
- (spec["model_name"], RerankModelSpec(**spec))
82
- for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
83
- )
84
- )
85
- for model_name, model_spec in target_families.items():
86
- MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
78
+ flattened_model_specs = []
79
+ for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8")):
80
+ flattened_model_specs.extend(flatten_model_src(spec))
81
+
82
+ for spec in flattened_model_specs:
83
+ if spec["model_name"] not in target_families:
84
+ target_families[spec["model_name"]] = [RerankModelFamilyV2(**spec)]
85
+ else:
86
+ target_families[spec["model_name"]].append(RerankModelFamilyV2(**spec))
87
+
87
88
  del _model_spec_json
@@ -21,24 +21,22 @@ import threading
21
21
  import uuid
22
22
  from collections import defaultdict
23
23
  from collections.abc import Sequence
24
- from typing import Dict, List, Literal, Optional, Tuple
24
+ from typing import Dict, List, Literal, Optional
25
25
 
26
26
  import numpy as np
27
27
  import torch
28
28
  import torch.nn as nn
29
29
 
30
- from ...constants import XINFERENCE_CACHE_DIR
31
- from ...device_utils import empty_cache
30
+ from ...device_utils import empty_cache, is_device_available
32
31
  from ...types import Document, DocumentObj, Rerank, RerankTokens
33
- from ..core import CacheableModelSpec, ModelDescription, VirtualEnvSettings
34
- from ..utils import is_model_cached
32
+ from ..core import CacheableModelSpec, VirtualEnvSettings
33
+ from ..utils import ModelInstanceInfoMixin
35
34
  from .utils import preprocess_sentence
36
35
 
37
36
  logger = logging.getLogger(__name__)
38
37
 
39
38
  # Used for check whether the model is cached.
40
39
  # Init when registering all the builtin models.
41
- MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
42
40
  RERANK_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
43
41
  RERANK_EMPTY_CACHE_COUNT = int(os.getenv("XINFERENCE_RERANK_EMPTY_CACHE_COUNT", "10"))
44
42
  assert RERANK_EMPTY_CACHE_COUNT > 0
@@ -50,7 +48,8 @@ def get_rerank_model_descriptions():
50
48
  return copy.deepcopy(RERANK_MODEL_DESCRIPTIONS)
51
49
 
52
50
 
53
- class RerankModelSpec(CacheableModelSpec):
51
+ class RerankModelFamilyV2(CacheableModelSpec, ModelInstanceInfoMixin):
52
+ version: Literal[2]
54
53
  model_name: str
55
54
  language: List[str]
56
55
  type: Optional[str] = "unknown"
@@ -60,56 +59,37 @@ class RerankModelSpec(CacheableModelSpec):
60
59
  model_hub: str = "huggingface"
61
60
  virtualenv: Optional[VirtualEnvSettings]
62
61
 
62
+ class Config:
63
+ extra = "allow"
63
64
 
64
- class RerankModelDescription(ModelDescription):
65
- def __init__(
66
- self,
67
- address: Optional[str],
68
- devices: Optional[List[str]],
69
- model_spec: RerankModelSpec,
70
- model_path: Optional[str] = None,
71
- ):
72
- super().__init__(address, devices, model_path=model_path)
73
- self._model_spec = model_spec
74
-
75
- @property
76
- def spec(self):
77
- return self._model_spec
78
-
79
- def to_dict(self):
65
+ def to_description(self):
80
66
  return {
81
67
  "model_type": "rerank",
82
- "address": self.address,
83
- "accelerators": self.devices,
84
- "type": self._model_spec.type,
85
- "model_name": self._model_spec.model_name,
86
- "language": self._model_spec.language,
87
- "model_revision": self._model_spec.model_revision,
68
+ "address": getattr(self, "address", None),
69
+ "accelerators": getattr(self, "accelerators", None),
70
+ "type": self.type,
71
+ "model_name": self.model_name,
72
+ "language": self.language,
73
+ "model_revision": self.model_revision,
88
74
  }
89
75
 
90
76
  def to_version_info(self):
91
- from .utils import get_model_version
92
-
93
- if self._model_path is None:
94
- is_cached = get_cache_status(self._model_spec)
95
- file_location = get_cache_dir(self._model_spec)
96
- else:
97
- is_cached = True
98
- file_location = self._model_path
77
+ from ..cache_manager import CacheManager
99
78
 
79
+ cache_manager = CacheManager(self)
100
80
  return {
101
- "model_version": get_model_version(self._model_spec),
102
- "model_file_location": file_location,
103
- "cache_status": is_cached,
104
- "language": self._model_spec.language,
81
+ "model_version": self.model_name,
82
+ "model_file_location": cache_manager.get_cache_dir(),
83
+ "cache_status": cache_manager.get_cache_status(),
84
+ "language": self.language,
105
85
  }
106
86
 
107
87
 
108
- def generate_rerank_description(model_spec: RerankModelSpec) -> Dict[str, List[Dict]]:
88
+ def generate_rerank_description(
89
+ model_spec: RerankModelFamilyV2,
90
+ ) -> Dict[str, List[Dict]]:
109
91
  res = defaultdict(list)
110
- res[model_spec.model_name].append(
111
- RerankModelDescription(None, None, model_spec).to_version_info()
112
- )
92
+ res[model_spec.model_name].append(model_spec.to_version_info())
113
93
  return res
114
94
 
115
95
 
@@ -145,13 +125,14 @@ class _ModelWrapper(nn.Module):
145
125
  class RerankModel:
146
126
  def __init__(
147
127
  self,
148
- model_spec: RerankModelSpec,
128
+ model_spec: RerankModelFamilyV2,
149
129
  model_uid: str,
150
130
  model_path: Optional[str] = None,
151
131
  device: Optional[str] = None,
152
132
  use_fp16: bool = False,
153
133
  model_config: Optional[Dict] = None,
154
134
  ):
135
+ self.model_family = model_spec
155
136
  self._model_spec = model_spec
156
137
  self._model_uid = model_uid
157
138
  self._model_path = model_path
@@ -252,7 +233,9 @@ class RerankModel:
252
233
  tokenizer = AutoTokenizer.from_pretrained(
253
234
  self._model_path, padding_side="left"
254
235
  )
255
- enable_flash_attn = self._model_config.get("enable_flash_attn", True)
236
+ enable_flash_attn = self._model_config.pop(
237
+ "enable_flash_attn", is_device_available("cuda")
238
+ )
256
239
  model_kwargs = {"device_map": "auto"}
257
240
  if flash_attn_installed and enable_flash_attn:
258
241
  model_kwargs["attn_implementation"] = "flash_attention_2"
@@ -448,25 +431,7 @@ class RerankModel:
448
431
  return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)
449
432
 
450
433
 
451
- def get_cache_dir(model_spec: RerankModelSpec):
452
- return os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name))
453
-
454
-
455
- def get_cache_status(
456
- model_spec: RerankModelSpec,
457
- ) -> bool:
458
- return is_model_cached(model_spec, MODEL_NAME_TO_REVISION)
459
-
460
-
461
- def cache(model_spec: RerankModelSpec):
462
- from ..utils import cache
463
-
464
- return cache(model_spec, RerankModelDescription)
465
-
466
-
467
434
  def create_rerank_model_instance(
468
- subpool_addr: str,
469
- devices: List[str],
470
435
  model_uid: str,
471
436
  model_name: str,
472
437
  download_hub: Optional[
@@ -474,9 +439,10 @@ def create_rerank_model_instance(
474
439
  ] = None,
475
440
  model_path: Optional[str] = None,
476
441
  **kwargs,
477
- ) -> Tuple[RerankModel, RerankModelDescription]:
442
+ ) -> RerankModel:
443
+ from ..cache_manager import CacheManager
478
444
  from ..utils import download_from_modelscope
479
- from . import BUILTIN_RERANK_MODELS, MODELSCOPE_RERANK_MODELS
445
+ from . import BUILTIN_RERANK_MODELS
480
446
  from .custom import get_user_defined_reranks
481
447
 
482
448
  model_spec = None
@@ -486,31 +452,25 @@ def create_rerank_model_instance(
486
452
  break
487
453
 
488
454
  if model_spec is None:
489
- if download_hub == "huggingface" and model_name in BUILTIN_RERANK_MODELS:
490
- logger.debug(f"Rerank model {model_name} found in Huggingface.")
491
- model_spec = BUILTIN_RERANK_MODELS[model_name]
492
- elif download_hub == "modelscope" and model_name in MODELSCOPE_RERANK_MODELS:
493
- logger.debug(f"Rerank model {model_name} found in ModelScope.")
494
- model_spec = MODELSCOPE_RERANK_MODELS[model_name]
495
- elif download_from_modelscope() and model_name in MODELSCOPE_RERANK_MODELS:
496
- logger.debug(f"Rerank model {model_name} found in ModelScope.")
497
- model_spec = MODELSCOPE_RERANK_MODELS[model_name]
498
- elif model_name in BUILTIN_RERANK_MODELS:
499
- logger.debug(f"Rerank model {model_name} found in Huggingface.")
500
- model_spec = BUILTIN_RERANK_MODELS[model_name]
455
+ if model_name in BUILTIN_RERANK_MODELS:
456
+ model_specs = BUILTIN_RERANK_MODELS[model_name]
457
+ if download_hub == "modelscope" or download_from_modelscope():
458
+ model_spec = (
459
+ [x for x in model_specs if x.model_hub == "modelscope"]
460
+ + [x for x in model_specs if x.model_hub == "huggingface"]
461
+ )[0]
462
+ else:
463
+ model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0]
501
464
  else:
502
465
  raise ValueError(
503
- f"Rerank model {model_name} not found, available"
504
- f"Huggingface: {BUILTIN_RERANK_MODELS.keys()}"
505
- f"ModelScope: {MODELSCOPE_RERANK_MODELS.keys()}"
466
+ f"Rerank model {model_name} not found, available "
467
+ f"model: {BUILTIN_RERANK_MODELS.keys()}"
506
468
  )
507
469
  if not model_path:
508
- model_path = cache(model_spec)
470
+ cache_manager = CacheManager(model_spec)
471
+ model_path = cache_manager.cache()
509
472
  use_fp16 = kwargs.pop("use_fp16", False)
510
473
  model = RerankModel(
511
474
  model_spec, model_uid, model_path, use_fp16=use_fp16, model_config=kwargs
512
475
  )
513
- model_description = RerankModelDescription(
514
- subpool_addr, devices, model_spec, model_path=model_path
515
- )
516
- return model, model_description
476
+ return model