xinference 0.16.2__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (60) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +62 -11
  3. xinference/client/restful/restful_client.py +8 -2
  4. xinference/conftest.py +0 -8
  5. xinference/constants.py +2 -0
  6. xinference/core/model.py +44 -5
  7. xinference/core/supervisor.py +13 -7
  8. xinference/core/utils.py +76 -12
  9. xinference/core/worker.py +5 -4
  10. xinference/deploy/cmdline.py +5 -0
  11. xinference/deploy/utils.py +7 -4
  12. xinference/model/audio/model_spec.json +2 -2
  13. xinference/model/image/stable_diffusion/core.py +5 -2
  14. xinference/model/llm/core.py +1 -3
  15. xinference/model/llm/llm_family.json +263 -4
  16. xinference/model/llm/llm_family_modelscope.json +302 -0
  17. xinference/model/llm/mlx/core.py +45 -2
  18. xinference/model/llm/vllm/core.py +2 -1
  19. xinference/model/rerank/core.py +11 -4
  20. xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
  21. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  22. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  23. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  24. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  25. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  26. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
  27. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  28. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  29. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
  30. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  31. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  32. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  33. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  34. xinference/thirdparty/fish_speech/tools/api.py +578 -75
  35. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  36. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  37. xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
  38. xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
  39. xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
  40. xinference/thirdparty/fish_speech/tools/schema.py +187 -0
  41. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  42. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  43. xinference/thirdparty/fish_speech/tools/webui.py +138 -75
  44. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/METADATA +26 -3
  45. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/RECORD +49 -56
  46. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/WHEEL +1 -1
  47. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  51. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  53. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  54. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  55. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  56. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  57. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  58. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/LICENSE +0 -0
  59. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/entry_points.txt +0 -0
  60. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/top_level.txt +0 -0
@@ -363,6 +363,97 @@
363
363
  "<|eom_id|>"
364
364
  ]
365
365
  },
366
+ {
367
+ "version": 1,
368
+ "context_length": 131072,
369
+ "model_name": "llama-3.2-vision-instruct",
370
+ "model_lang": [
371
+ "en",
372
+ "de",
373
+ "fr",
374
+ "it",
375
+ "pt",
376
+ "hi",
377
+ "es",
378
+ "th"
379
+ ],
380
+ "model_ability": [
381
+ "chat",
382
+ "vision"
383
+ ],
384
+ "model_description": "Llama 3.2-Vision instruction-tuned models are optimized for visual recognition, image reasoning, captioning, and answering general questions about an image...",
385
+ "model_specs": [
386
+ {
387
+ "model_format": "pytorch",
388
+ "model_size_in_billions": 11,
389
+ "quantizations": [
390
+ "none"
391
+ ],
392
+ "model_id": "LLM-Research/Llama-3.2-11B-Vision-Instruct",
393
+ "model_hub": "modelscope"
394
+ },
395
+ {
396
+ "model_format": "pytorch",
397
+ "model_size_in_billions": 90,
398
+ "quantizations": [
399
+ "none"
400
+ ],
401
+ "model_id": "LLM-Research/Llama-3.2-90B-Vision-Instruct",
402
+ "model_hub": "modelscope"
403
+ }
404
+ ],
405
+ "chat_template": "{% for message in messages %}{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}{{ '<|image|>' }}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{ '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
406
+ "stop_token_ids": [
407
+ 128001,
408
+ 128008,
409
+ 128009
410
+ ],
411
+ "stop": [
412
+ "<|end_of_text|>",
413
+ "<|eot_id|>",
414
+ "<|eom_id|>"
415
+ ]
416
+ },
417
+ {
418
+ "version": 1,
419
+ "context_length": 131072,
420
+ "model_name": "llama-3.2-vision",
421
+ "model_lang": [
422
+ "en",
423
+ "de",
424
+ "fr",
425
+ "it",
426
+ "pt",
427
+ "hi",
428
+ "es",
429
+ "th"
430
+ ],
431
+ "model_ability": [
432
+ "generate",
433
+ "vision"
434
+ ],
435
+ "model_description": "The Llama 3.2-Vision instruction-tuned models are optimized for visual recognition, image reasoning, captioning, and answering general questions about an image...",
436
+ "model_specs": [
437
+ {
438
+ "model_format": "pytorch",
439
+ "model_size_in_billions": 11,
440
+ "quantizations": [
441
+ "none"
442
+ ],
443
+ "model_id": "LLM-Research/Llama-3.2-11B-Vision",
444
+ "model_hub": "modelscope"
445
+ },
446
+ {
447
+ "model_format": "pytorch",
448
+ "model_size_in_billions": 90,
449
+ "quantizations": [
450
+ "none"
451
+ ],
452
+ "model_id": "LLM-Research/Llama-3.2-90B-Vision",
453
+ "model_hub": "modelscope"
454
+ }
455
+ ]
456
+ },
366
457
  {
367
458
  "version": 1,
368
459
  "context_length": 2048,
@@ -5816,6 +5907,18 @@
5816
5907
  ],
5817
5908
  "model_description": "Qwen2.5-Coder is the latest series of Code-Specific Qwen large language models (formerly known as CodeQwen).",
5818
5909
  "model_specs": [
5910
+ {
5911
+ "model_format": "pytorch",
5912
+ "model_size_in_billions": "0_5",
5913
+ "quantizations": [
5914
+ "4-bit",
5915
+ "8-bit",
5916
+ "none"
5917
+ ],
5918
+ "model_id": "qwen/Qwen2.5-Coder-0.5B",
5919
+ "model_revision": "master",
5920
+ "model_hub": "modelscope"
5921
+ },
5819
5922
  {
5820
5923
  "model_format": "pytorch",
5821
5924
  "model_size_in_billions": "1_5",
@@ -5828,6 +5931,18 @@
5828
5931
  "model_revision": "master",
5829
5932
  "model_hub": "modelscope"
5830
5933
  },
5934
+ {
5935
+ "model_format": "pytorch",
5936
+ "model_size_in_billions": "3",
5937
+ "quantizations": [
5938
+ "4-bit",
5939
+ "8-bit",
5940
+ "none"
5941
+ ],
5942
+ "model_id": "qwen/Qwen2.5-Coder-3B",
5943
+ "model_revision": "master",
5944
+ "model_hub": "modelscope"
5945
+ },
5831
5946
  {
5832
5947
  "model_format": "pytorch",
5833
5948
  "model_size_in_billions": 7,
@@ -5839,6 +5954,30 @@
5839
5954
  "model_id": "qwen/Qwen2.5-Coder-7B",
5840
5955
  "model_revision": "master",
5841
5956
  "model_hub": "modelscope"
5957
+ },
5958
+ {
5959
+ "model_format": "pytorch",
5960
+ "model_size_in_billions": 14,
5961
+ "quantizations": [
5962
+ "4-bit",
5963
+ "8-bit",
5964
+ "none"
5965
+ ],
5966
+ "model_id": "qwen/Qwen2.5-Coder-14B",
5967
+ "model_revision": "master",
5968
+ "model_hub": "modelscope"
5969
+ },
5970
+ {
5971
+ "model_format": "pytorch",
5972
+ "model_size_in_billions": 32,
5973
+ "quantizations": [
5974
+ "4-bit",
5975
+ "8-bit",
5976
+ "none"
5977
+ ],
5978
+ "model_id": "qwen/Qwen2.5-Coder-32B",
5979
+ "model_revision": "master",
5980
+ "model_hub": "modelscope"
5842
5981
  }
5843
5982
  ]
5844
5983
  },
@@ -5856,6 +5995,18 @@
5856
5995
  ],
5857
5996
  "model_description": "Qwen2.5-Coder is the latest series of Code-Specific Qwen large language models (formerly known as CodeQwen).",
5858
5997
  "model_specs": [
5998
+ {
5999
+ "model_format": "pytorch",
6000
+ "model_size_in_billions": "0_5",
6001
+ "quantizations": [
6002
+ "4-bit",
6003
+ "8-bit",
6004
+ "none"
6005
+ ],
6006
+ "model_id": "qwen/Qwen2.5-Coder-0.5B-Instruct",
6007
+ "model_revision": "master",
6008
+ "model_hub": "modelscope"
6009
+ },
5859
6010
  {
5860
6011
  "model_format": "pytorch",
5861
6012
  "model_size_in_billions": "1_5",
@@ -5867,6 +6018,17 @@
5867
6018
  "model_id": "qwen/Qwen2.5-Coder-1.5B-Instruct",
5868
6019
  "model_revision": "master",
5869
6020
  "model_hub": "modelscope"
6021
+ }, {
6022
+ "model_format": "pytorch",
6023
+ "model_size_in_billions": "3",
6024
+ "quantizations": [
6025
+ "4-bit",
6026
+ "8-bit",
6027
+ "none"
6028
+ ],
6029
+ "model_id": "qwen/Qwen2.5-Coder-3B-Instruct",
6030
+ "model_revision": "master",
6031
+ "model_hub": "modelscope"
5870
6032
  },
5871
6033
  {
5872
6034
  "model_format": "pytorch",
@@ -5880,6 +6042,63 @@
5880
6042
  "model_revision": "master",
5881
6043
  "model_hub": "modelscope"
5882
6044
  },
6045
+ {
6046
+ "model_format": "pytorch",
6047
+ "model_size_in_billions": 14,
6048
+ "quantizations": [
6049
+ "4-bit",
6050
+ "8-bit",
6051
+ "none"
6052
+ ],
6053
+ "model_id": "qwen/Qwen2.5-Coder-14B-Instruct",
6054
+ "model_revision": "master",
6055
+ "model_hub": "modelscope"
6056
+ },
6057
+ {
6058
+ "model_format": "pytorch",
6059
+ "model_size_in_billions": 32,
6060
+ "quantizations": [
6061
+ "4-bit",
6062
+ "8-bit",
6063
+ "none"
6064
+ ],
6065
+ "model_id": "qwen/Qwen2.5-Coder-32B-Instruct",
6066
+ "model_revision": "master",
6067
+ "model_hub": "modelscope"
6068
+ },
6069
+ {
6070
+ "model_format": "gptq",
6071
+ "model_size_in_billions": "0_5",
6072
+ "quantizations": [
6073
+ "Int4",
6074
+ "Int8"
6075
+ ],
6076
+ "model_id": "qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-{quantization}",
6077
+ "model_revision": "master",
6078
+ "model_hub": "modelscope"
6079
+ },
6080
+ {
6081
+ "model_format": "gptq",
6082
+ "model_size_in_billions": "1_5",
6083
+ "quantizations": [
6084
+ "Int4",
6085
+ "Int8"
6086
+ ],
6087
+ "model_id": "qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-{quantization}",
6088
+ "model_revision": "master",
6089
+ "model_hub": "modelscope"
6090
+ },
6091
+ {
6092
+ "model_format": "gptq",
6093
+ "model_size_in_billions": 3,
6094
+ "quantizations": [
6095
+ "Int4",
6096
+ "Int8"
6097
+ ],
6098
+ "model_id": "qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-{quantization}",
6099
+ "model_revision": "master",
6100
+ "model_hub": "modelscope"
6101
+ },
5883
6102
  {
5884
6103
  "model_format": "gptq",
5885
6104
  "model_size_in_billions": 7,
@@ -5891,6 +6110,89 @@
5891
6110
  "model_revision": "master",
5892
6111
  "model_hub": "modelscope"
5893
6112
  },
6113
+ {
6114
+ "model_format": "gptq",
6115
+ "model_size_in_billions": 14,
6116
+ "quantizations": [
6117
+ "Int4",
6118
+ "Int8"
6119
+ ],
6120
+ "model_id": "qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-{quantization}",
6121
+ "model_revision": "master",
6122
+ "model_hub": "modelscope"
6123
+ },
6124
+ {
6125
+ "model_format": "gptq",
6126
+ "model_size_in_billions": 32,
6127
+ "quantizations": [
6128
+ "Int4",
6129
+ "Int8"
6130
+ ],
6131
+ "model_id": "qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-{quantization}",
6132
+ "model_revision": "master",
6133
+ "model_hub": "modelscope"
6134
+ },
6135
+ {
6136
+ "model_format": "awq",
6137
+ "model_size_in_billions": "0_5",
6138
+ "quantizations": [
6139
+ "Int4"
6140
+ ],
6141
+ "model_id": "qwen/Qwen2.5-Coder-0.5B-Instruct-AWQ",
6142
+ "model_revision": "master",
6143
+ "model_hub": "modelscope"
6144
+ },
6145
+ {
6146
+ "model_format": "awq",
6147
+ "model_size_in_billions": "1_5",
6148
+ "quantizations": [
6149
+ "Int4"
6150
+ ],
6151
+ "model_id": "qwen/Qwen2.5-Coder-1.5B-Instruct-AWQ",
6152
+ "model_revision": "master",
6153
+ "model_hub": "modelscope"
6154
+ },
6155
+ {
6156
+ "model_format": "awq",
6157
+ "model_size_in_billions": 3,
6158
+ "quantizations": [
6159
+ "Int4"
6160
+ ],
6161
+ "model_id": "qwen/Qwen2.5-Coder-3B-Instruct-AWQ",
6162
+ "model_revision": "master",
6163
+ "model_hub": "modelscope"
6164
+ },
6165
+ {
6166
+ "model_format": "awq",
6167
+ "model_size_in_billions": 7,
6168
+ "quantizations": [
6169
+ "Int4"
6170
+ ],
6171
+ "model_id": "qwen/Qwen2.5-Coder-7B-Instruct-AWQ",
6172
+ "model_revision": "master",
6173
+ "model_hub": "modelscope"
6174
+ },
6175
+ {
6176
+ "model_format": "awq",
6177
+ "model_size_in_billions": 14,
6178
+ "quantizations": [
6179
+ "Int4"
6180
+ ],
6181
+ "model_id": "qwen/Qwen2.5-Coder-14B-Instruct-AWQ",
6182
+ "model_revision": "master",
6183
+ "model_hub": "modelscope"
6184
+ },
6185
+ {
6186
+ "model_format": "awq",
6187
+ "model_size_in_billions": 32,
6188
+ "quantizations": [
6189
+ "Int4"
6190
+ ],
6191
+ "model_id": "qwen/Qwen2.5-Coder-32B-Instruct-AWQ",
6192
+ "model_revision": "master",
6193
+ "model_hub": "modelscope"
6194
+ },
6195
+
5894
6196
  {
5895
6197
  "model_format": "ggufv2",
5896
6198
  "model_size_in_billions": "1_5",
@@ -17,7 +17,8 @@ import platform
17
17
  import sys
18
18
  import time
19
19
  import uuid
20
- from typing import Dict, Iterator, List, Optional, TypedDict, Union
20
+ from dataclasses import dataclass, field
21
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, TypedDict, Union
21
22
 
22
23
  from ....fields import max_tokens_field
23
24
  from ....types import (
@@ -53,6 +54,14 @@ class MLXGenerateConfig(TypedDict, total=False):
53
54
  stream: bool
54
55
  stream_options: Optional[Union[dict, None]]
55
56
  tools: Optional[List[Dict]]
57
+ lora_name: Optional[str]
58
+
59
+
60
+ @dataclass
61
+ class PromptCache:
62
+ cache: List[Any] = field(default_factory=list)
63
+ model_key: Tuple[str, Optional[str]] = ("", None)
64
+ tokens: List[int] = field(default_factory=list)
56
65
 
57
66
 
58
67
  class MLXModel(LLM):
@@ -69,6 +78,8 @@ class MLXModel(LLM):
69
78
  super().__init__(model_uid, model_family, model_spec, quantization, model_path)
70
79
  self._use_fast_tokenizer = True
71
80
  self._model_config: MLXModelConfig = self._sanitize_model_config(model_config)
81
+ self._max_kv_size = None
82
+ self._prompt_cache = None
72
83
  if peft_model is not None:
73
84
  raise ValueError("MLX engine has not supported lora yet")
74
85
 
@@ -127,6 +138,9 @@ class MLXModel(LLM):
127
138
  logger.debug(f"Setting cache limit to {cache_limit_gb} GB")
128
139
  mx.metal.set_cache_limit(cache_limit_gb * 1024 * 1024 * 1024)
129
140
 
141
+ self._max_kv_size = kwargs.get("max_kv_size", None)
142
+ self._prompt_cache = PromptCache()
143
+
130
144
  return load(
131
145
  self.model_path,
132
146
  tokenizer_config=tokenizer_config,
@@ -156,6 +170,27 @@ class MLXModel(LLM):
156
170
  return False
157
171
  return True
158
172
 
173
+ def _get_prompt_cache(self, prompt, lora_name: Optional[str] = None):
174
+ from mlx_lm.models.cache import make_prompt_cache
175
+
176
+ assert self._prompt_cache is not None
177
+ cache_len = len(self._prompt_cache.tokens)
178
+ model_key = (self.model_path, lora_name)
179
+ if (
180
+ self._prompt_cache.model_key != model_key
181
+ or cache_len >= len(prompt)
182
+ or self._prompt_cache.tokens != prompt[:cache_len]
183
+ ):
184
+ self._prompt_cache.model_key = model_key
185
+ self._prompt_cache.cache = make_prompt_cache(self._model, self._max_kv_size)
186
+ self._prompt_cache.tokens = []
187
+ logger.debug("Making new prompt cache for %s", self.model_uid)
188
+ else:
189
+ prompt = prompt[cache_len:]
190
+ logger.debug("Cache hit for %s", self.model_uid)
191
+ self._prompt_cache.tokens.extend(prompt)
192
+ return prompt
193
+
159
194
  def _generate_stream(self, prompt: str, kwargs: MLXGenerateConfig):
160
195
  import mlx.core as mx
161
196
  from mlx_lm.utils import generate_step
@@ -167,6 +202,7 @@ class MLXModel(LLM):
167
202
  chunk_id = str(uuid.uuid4())
168
203
  stop_token_ids = kwargs.get("stop_token_ids", [])
169
204
  stream = kwargs.get("stream", False)
205
+ lora_name = kwargs.get("lora_name")
170
206
  stream_options = kwargs.pop("stream_options", None)
171
207
  include_usage = (
172
208
  stream_options["include_usage"]
@@ -174,12 +210,15 @@ class MLXModel(LLM):
174
210
  else False
175
211
  )
176
212
 
177
- prompt_tokens = mx.array(tokenizer.encode(prompt))
213
+ prompt_token_ids = tokenizer.encode(prompt)
214
+ prompt_token_ids = self._get_prompt_cache(prompt_token_ids, lora_name)
215
+ prompt_tokens = mx.array(prompt_token_ids)
178
216
  input_echo_len = len(prompt_tokens)
179
217
 
180
218
  i = 0
181
219
  start = time.time()
182
220
  output = ""
221
+ tokens = []
183
222
  for (token, _), i in zip(
184
223
  generate_step(
185
224
  prompt_tokens,
@@ -189,9 +228,11 @@ class MLXModel(LLM):
189
228
  repetition_context_size=kwargs["repetition_context_size"],
190
229
  top_p=kwargs["top_p"],
191
230
  logit_bias=kwargs["logit_bias"],
231
+ prompt_cache=self._prompt_cache.cache, # type: ignore
192
232
  ),
193
233
  range(max_tokens),
194
234
  ):
235
+ tokens.append(token)
195
236
  if token == tokenizer.eos_token_id or token in stop_token_ids: # type: ignore
196
237
  break
197
238
 
@@ -230,6 +271,8 @@ class MLXModel(LLM):
230
271
  f"Average generation speed: {i / (time.time() - start):.2f} tokens/s."
231
272
  )
232
273
 
274
+ self._prompt_cache.tokens.extend(tokens) # type: ignore
275
+
233
276
  if i == max_tokens - 1:
234
277
  finish_reason = "length"
235
278
  else:
@@ -163,7 +163,6 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.5.1":
163
163
  VLLM_SUPPORTED_CHAT_MODELS.append("deepseek-v2-chat-0628")
164
164
  VLLM_SUPPORTED_CHAT_MODELS.append("deepseek-v2.5")
165
165
 
166
-
167
166
  if VLLM_INSTALLED and vllm.__version__ >= "0.5.3":
168
167
  VLLM_SUPPORTED_CHAT_MODELS.append("gemma-2-it")
169
168
  VLLM_SUPPORTED_CHAT_MODELS.append("mistral-nemo-instruct")
@@ -177,6 +176,8 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.6.1":
177
176
  VLLM_SUPPORTED_VISION_MODEL_LIST.append("internvl2")
178
177
 
179
178
  if VLLM_INSTALLED and vllm.__version__ >= "0.6.3":
179
+ VLLM_SUPPORTED_MODELS.append("llama-3.2-vision")
180
+ VLLM_SUPPORTED_VISION_MODEL_LIST.append("llama-3.2-vision-instruct")
180
181
  VLLM_SUPPORTED_VISION_MODEL_LIST.append("qwen2-vl-instruct")
181
182
 
182
183
 
@@ -179,6 +179,7 @@ class RerankModel:
179
179
  return rerank_type
180
180
 
181
181
  def load(self):
182
+ logger.info("Loading rerank model: %s", self._model_path)
182
183
  flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
183
184
  if (
184
185
  self._auto_detect_type(self._model_path) != "normal"
@@ -189,6 +190,7 @@ class RerankModel:
189
190
  "will force set `use_fp16` to True"
190
191
  )
191
192
  self._use_fp16 = True
193
+
192
194
  if self._model_spec.type == "normal":
193
195
  try:
194
196
  import sentence_transformers
@@ -250,22 +252,27 @@ class RerankModel:
250
252
  **kwargs,
251
253
  ) -> Rerank:
252
254
  assert self._model is not None
253
- if kwargs:
254
- raise ValueError("rerank hasn't support extra parameter.")
255
255
  if max_chunks_per_doc is not None:
256
256
  raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
257
+ logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model)
257
258
  sentence_combinations = [[query, doc] for doc in documents]
258
259
  # reset n tokens
259
260
  self._model.model.n_tokens = 0
260
261
  if self._model_spec.type == "normal":
261
262
  similarity_scores = self._model.predict(
262
- sentence_combinations, convert_to_numpy=False, convert_to_tensor=True
263
+ sentence_combinations,
264
+ convert_to_numpy=False,
265
+ convert_to_tensor=True,
266
+ **kwargs,
263
267
  ).cpu()
264
268
  if similarity_scores.dtype == torch.bfloat16:
265
269
  similarity_scores = similarity_scores.float()
266
270
  else:
267
271
  # Related issue: https://github.com/xorbitsai/inference/issues/1775
268
- similarity_scores = self._model.compute_score(sentence_combinations)
272
+ similarity_scores = self._model.compute_score(
273
+ sentence_combinations, **kwargs
274
+ )
275
+
269
276
  if not isinstance(similarity_scores, Sequence):
270
277
  similarity_scores = [similarity_scores]
271
278
  elif (