xinference 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (95) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/auth_service.py +132 -0
  3. xinference/api/restful_api.py +282 -78
  4. xinference/client/handlers.py +3 -0
  5. xinference/client/restful/restful_client.py +108 -75
  6. xinference/constants.py +14 -4
  7. xinference/core/cache_tracker.py +102 -0
  8. xinference/core/chat_interface.py +10 -4
  9. xinference/core/event.py +56 -0
  10. xinference/core/model.py +44 -0
  11. xinference/core/resource.py +19 -12
  12. xinference/core/status_guard.py +4 -0
  13. xinference/core/supervisor.py +278 -87
  14. xinference/core/utils.py +68 -3
  15. xinference/core/worker.py +98 -8
  16. xinference/deploy/cmdline.py +6 -3
  17. xinference/deploy/local.py +2 -2
  18. xinference/deploy/supervisor.py +2 -2
  19. xinference/model/audio/__init__.py +27 -0
  20. xinference/model/audio/core.py +161 -0
  21. xinference/model/audio/model_spec.json +79 -0
  22. xinference/model/audio/utils.py +18 -0
  23. xinference/model/audio/whisper.py +132 -0
  24. xinference/model/core.py +18 -13
  25. xinference/model/embedding/__init__.py +27 -2
  26. xinference/model/embedding/core.py +43 -3
  27. xinference/model/embedding/model_spec.json +24 -0
  28. xinference/model/embedding/model_spec_modelscope.json +24 -0
  29. xinference/model/embedding/utils.py +18 -0
  30. xinference/model/image/__init__.py +12 -1
  31. xinference/model/image/core.py +63 -9
  32. xinference/model/image/utils.py +26 -0
  33. xinference/model/llm/__init__.py +20 -1
  34. xinference/model/llm/core.py +43 -2
  35. xinference/model/llm/ggml/chatglm.py +15 -6
  36. xinference/model/llm/llm_family.json +197 -6
  37. xinference/model/llm/llm_family.py +9 -7
  38. xinference/model/llm/llm_family_modelscope.json +189 -4
  39. xinference/model/llm/pytorch/chatglm.py +3 -3
  40. xinference/model/llm/pytorch/core.py +4 -2
  41. xinference/model/{multimodal → llm/pytorch}/qwen_vl.py +10 -8
  42. xinference/model/llm/pytorch/utils.py +21 -9
  43. xinference/model/llm/pytorch/yi_vl.py +246 -0
  44. xinference/model/llm/utils.py +57 -4
  45. xinference/model/llm/vllm/core.py +5 -4
  46. xinference/model/rerank/__init__.py +25 -2
  47. xinference/model/rerank/core.py +51 -9
  48. xinference/model/rerank/model_spec.json +6 -0
  49. xinference/model/rerank/model_spec_modelscope.json +7 -0
  50. xinference/{api/oauth2/common.py → model/rerank/utils.py} +6 -2
  51. xinference/model/utils.py +5 -3
  52. xinference/thirdparty/__init__.py +0 -0
  53. xinference/thirdparty/llava/__init__.py +1 -0
  54. xinference/thirdparty/llava/conversation.py +205 -0
  55. xinference/thirdparty/llava/mm_utils.py +122 -0
  56. xinference/thirdparty/llava/model/__init__.py +1 -0
  57. xinference/thirdparty/llava/model/clip_encoder/__init__.py +0 -0
  58. xinference/thirdparty/llava/model/clip_encoder/builder.py +11 -0
  59. xinference/thirdparty/llava/model/clip_encoder/clip_encoder.py +86 -0
  60. xinference/thirdparty/llava/model/constants.py +6 -0
  61. xinference/thirdparty/llava/model/llava_arch.py +385 -0
  62. xinference/thirdparty/llava/model/llava_llama.py +163 -0
  63. xinference/thirdparty/llava/model/multimodal_projector/__init__.py +0 -0
  64. xinference/thirdparty/llava/model/multimodal_projector/builder.py +64 -0
  65. xinference/types.py +1 -1
  66. xinference/web/ui/build/asset-manifest.json +3 -3
  67. xinference/web/ui/build/index.html +1 -1
  68. xinference/web/ui/build/static/js/main.15822aeb.js +3 -0
  69. xinference/web/ui/build/static/js/main.15822aeb.js.map +1 -0
  70. xinference/web/ui/node_modules/.cache/babel-loader/139e5e4adf436923107d2b02994c7ff6dba2aac1989e9b6638984f0dfe782c4a.json +1 -0
  71. xinference/web/ui/node_modules/.cache/babel-loader/52aa27272b4b9968f62666262b47661cb1992336a2aff3b13994cc36877b3ec3.json +1 -0
  72. xinference/web/ui/node_modules/.cache/babel-loader/64accc515dc6cd584a2873796cd7da6f93de57f7e465eb5423cca9a2f3fe3eff.json +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/65ca3ba225b8c8dac907210545b51f2fcdb2591f0feeb7195f1c037f2bc956a0.json +1 -0
  74. xinference/web/ui/node_modules/.cache/babel-loader/b80db1012318b97c329c4e3e72454f7512fb107e57c444b437dbe4ba1a3faa5a.json +1 -0
  75. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/METADATA +33 -23
  76. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/RECORD +81 -64
  77. xinference/api/oauth2/core.py +0 -93
  78. xinference/model/multimodal/__init__.py +0 -52
  79. xinference/model/multimodal/core.py +0 -467
  80. xinference/model/multimodal/model_spec.json +0 -43
  81. xinference/model/multimodal/model_spec_modelscope.json +0 -45
  82. xinference/web/ui/build/static/js/main.b83095c2.js +0 -3
  83. xinference/web/ui/build/static/js/main.b83095c2.js.map +0 -1
  84. xinference/web/ui/node_modules/.cache/babel-loader/101923c539819f26ad11fbcbd6f6e56436b285efbb090dcc7dd648c6e924c4a8.json +0 -1
  85. xinference/web/ui/node_modules/.cache/babel-loader/4942da6bc03bf7373af068e22f916341aabc5b5df855d73c1d348c696724ce37.json +0 -1
  86. xinference/web/ui/node_modules/.cache/babel-loader/52a6136cb2dbbf9c51d461724d9b283ebe74a73fb19d5df7ba8e13c42bd7174d.json +0 -1
  87. xinference/web/ui/node_modules/.cache/babel-loader/71493aadd34d568fbe605cacaba220aa69bd09273251ee4ba27930f8d01fccd8.json +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/8b071db2a5a9ef68dc14d5f606540bd23d9785e365a11997c510656764d2dccf.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/a4d72d3b806ba061919115f0c513738726872e3c79cf258f007519d3f91d1a16.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/f037ffef5992af0892d6d991053c1dace364cd39a3f11f1a41f92776e8a59459.json +0 -1
  91. /xinference/web/ui/build/static/js/{main.b83095c2.js.LICENSE.txt → main.15822aeb.js.LICENSE.txt} +0 -0
  92. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/LICENSE +0 -0
  93. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/WHEEL +0 -0
  94. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/entry_points.txt +0 -0
  95. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/top_level.txt +0 -0
@@ -19,19 +19,21 @@ import time
19
19
  import uuid
20
20
  from typing import Dict, Iterator, List, Optional, Union
21
21
 
22
- from ...types import (
22
+ from ....model.utils import select_device
23
+ from ....types import (
23
24
  ChatCompletion,
24
25
  ChatCompletionChoice,
25
26
  ChatCompletionChunk,
27
+ ChatCompletionMessage,
26
28
  CompletionUsage,
27
29
  )
28
- from ..utils import select_device
29
- from .core import LVLM, LVLMFamilyV1, LVLMSpecV1
30
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
31
+ from .core import PytorchChatModel, PytorchGenerateConfig
30
32
 
31
33
  logger = logging.getLogger(__name__)
32
34
 
33
35
 
34
- class QwenVLChat(LVLM):
36
+ class QwenVLChatModel(PytorchChatModel):
35
37
  def __init__(self, *args, **kwargs):
36
38
  super().__init__(*args, **kwargs)
37
39
  self._tokenizer = None
@@ -39,7 +41,7 @@ class QwenVLChat(LVLM):
39
41
 
40
42
  @classmethod
41
43
  def match(
42
- cls, model_family: "LVLMFamilyV1", model_spec: "LVLMSpecV1", quantization: str
44
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
43
45
  ) -> bool:
44
46
  if "qwen" in model_family.model_name:
45
47
  return True
@@ -49,7 +51,7 @@ class QwenVLChat(LVLM):
49
51
  from transformers import AutoModelForCausalLM, AutoTokenizer
50
52
  from transformers.generation import GenerationConfig
51
53
 
52
- device = self.kwargs.get("device", "auto")
54
+ device = self._pytorch_model_config.get("device", "auto")
53
55
  device = select_device(device)
54
56
 
55
57
  self._tokenizer = AutoTokenizer.from_pretrained(
@@ -106,8 +108,8 @@ class QwenVLChat(LVLM):
106
108
  self,
107
109
  prompt: Union[str, List[Dict]],
108
110
  system_prompt: Optional[str] = None,
109
- chat_history: Optional[List[Dict]] = None,
110
- generate_config: Optional[Dict] = None,
111
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
112
+ generate_config: Optional[PytorchGenerateConfig] = None,
111
113
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
112
114
  if generate_config and generate_config.get("stream"):
113
115
  raise Exception(
@@ -29,7 +29,12 @@ from transformers.generation.logits_process import (
29
29
  TopPLogitsWarper,
30
30
  )
31
31
 
32
- from ....types import CompletionChoice, CompletionChunk, CompletionUsage
32
+ from ....types import (
33
+ CompletionChoice,
34
+ CompletionChunk,
35
+ CompletionUsage,
36
+ max_tokens_field,
37
+ )
33
38
 
34
39
  logger = logging.getLogger(__name__)
35
40
 
@@ -54,16 +59,21 @@ def get_context_length(config):
54
59
  hasattr(config, "max_sequence_length")
55
60
  and config.max_sequence_length is not None
56
61
  ):
57
- return config.max_sequence_length
58
- elif hasattr(config, "seq_length") and config.seq_length is not None:
59
- return config.seq_length
60
- elif (
62
+ max_sequence_length = config.max_sequence_length
63
+ else:
64
+ max_sequence_length = 2048
65
+ if hasattr(config, "seq_length") and config.seq_length is not None:
66
+ seq_length = config.seq_length
67
+ else:
68
+ seq_length = 2048
69
+ if (
61
70
  hasattr(config, "max_position_embeddings")
62
71
  and config.max_position_embeddings is not None
63
72
  ):
64
- return config.max_position_embeddings
73
+ max_position_embeddings = config.max_position_embeddings
65
74
  else:
66
- return 2048
75
+ max_position_embeddings = 2048
76
+ return max(max_sequence_length, seq_length, max_position_embeddings)
67
77
 
68
78
 
69
79
  def prepare_logits_processor(
@@ -102,7 +112,7 @@ def generate_stream(
102
112
  repetition_penalty = float(generate_config.get("repetition_penalty", 1.0))
103
113
  top_p = float(generate_config.get("top_p", 1.0))
104
114
  top_k = int(generate_config.get("top_k", -1)) # -1 means disable
105
- max_new_tokens = int(generate_config.get("max_tokens", 256))
115
+ max_new_tokens = int(generate_config.get("max_tokens", max_tokens_field.default))
106
116
  echo = bool(generate_config.get("echo", False))
107
117
  stop_str = generate_config.get("stop", None)
108
118
  stop_token_ids = generate_config.get("stop_token_ids", None) or []
@@ -123,6 +133,8 @@ def generate_stream(
123
133
  max_src_len = context_len
124
134
  else:
125
135
  max_src_len = context_len - max_new_tokens - 8
136
+ if max_src_len < 0:
137
+ raise ValueError("Max tokens exceeds model's max length")
126
138
 
127
139
  input_ids = input_ids[-max_src_len:]
128
140
  input_echo_len = len(input_ids)
@@ -346,7 +358,7 @@ def generate_stream_falcon(
346
358
  repetition_penalty = float(generate_config.get("repetition_penalty", 1.0))
347
359
  top_p = float(generate_config.get("top_p", 1.0))
348
360
  top_k = int(generate_config.get("top_k", 50)) # -1 means disable
349
- max_new_tokens = int(generate_config.get("max_tokens", 256))
361
+ max_new_tokens = int(generate_config.get("max_tokens", max_tokens_field.default))
350
362
  echo = bool(generate_config.get("echo", False))
351
363
  stop_str = generate_config.get("stop", None)
352
364
  stop_token_ids = generate_config.get("stop_token_ids", None) or []
@@ -0,0 +1,246 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import base64
15
+ import logging
16
+ import time
17
+ import uuid
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ from io import BytesIO
20
+ from threading import Thread
21
+ from typing import Dict, Iterator, List, Optional, Union
22
+
23
+ import requests
24
+ import torch
25
+ from PIL import Image
26
+
27
+ from ....model.utils import select_device
28
+ from ....types import (
29
+ ChatCompletion,
30
+ ChatCompletionChoice,
31
+ ChatCompletionChunk,
32
+ ChatCompletionMessage,
33
+ CompletionUsage,
34
+ )
35
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
36
+ from .core import PytorchChatModel, PytorchGenerateConfig
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ class YiVLChatModel(PytorchChatModel):
42
+ def __init__(self, *args, **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+ self._tokenizer = None
45
+ self._model = None
46
+ self._image_processor = None
47
+
48
+ @classmethod
49
+ def match(
50
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
51
+ ) -> bool:
52
+ if "yi" in model_family.model_name:
53
+ return True
54
+ return False
55
+
56
+ def load(self):
57
+ from ....thirdparty.llava.mm_utils import load_pretrained_model
58
+ from ....thirdparty.llava.model.constants import key_info
59
+
60
+ device = self._pytorch_model_config.get("device", "auto")
61
+ device = select_device(device)
62
+
63
+ key_info["model_path"] = self.model_path
64
+ (
65
+ self._tokenizer,
66
+ self._model,
67
+ self._image_processor,
68
+ _,
69
+ ) = load_pretrained_model(self.model_path, device_map=device)
70
+
71
+ @staticmethod
72
+ def _message_content_to_yi(content) -> Union[str, tuple]:
73
+ def _load_image(_url):
74
+ if _url.startswith("data:"):
75
+ logging.info("Parse url by base64 decoder.")
76
+ # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
77
+ # e.g. f"data:image/jpeg;base64,{base64_image}"
78
+ _type, data = _url.split(";")
79
+ _, ext = _type.split("/")
80
+ data = data[len("base64,") :]
81
+ data = base64.b64decode(data.encode("utf-8"))
82
+
83
+ return Image.open(BytesIO(data))
84
+ else:
85
+ try:
86
+ response = requests.get(_url)
87
+ except requests.exceptions.MissingSchema:
88
+ return Image.open(_url)
89
+ else:
90
+ return Image.open(BytesIO(response.content))
91
+
92
+ if not isinstance(content, str):
93
+ from ....thirdparty.llava.model.constants import DEFAULT_IMAGE_TOKEN
94
+
95
+ texts = []
96
+ image_urls = []
97
+ for c in content:
98
+ c_type = c.get("type")
99
+ if c_type == "text":
100
+ texts.append(c["text"])
101
+ elif c_type == "image_url":
102
+ image_urls.append(c["image_url"]["url"])
103
+ image_futures = []
104
+ with ThreadPoolExecutor() as executor:
105
+ for image_url in image_urls:
106
+ fut = executor.submit(_load_image, image_url)
107
+ image_futures.append(fut)
108
+ images = [fut.result() for fut in image_futures]
109
+ text = " ".join(texts)
110
+ if DEFAULT_IMAGE_TOKEN not in text:
111
+ text = DEFAULT_IMAGE_TOKEN + "\n" + text
112
+ if len(images) == 0:
113
+ return text
114
+ elif len(images) == 1:
115
+ return text, images[0], "Pad"
116
+ else:
117
+ raise RuntimeError("Only one image per message is supported by Yi VL.")
118
+ return content
119
+
120
+ @staticmethod
121
+ def _parse_text(text):
122
+ lines = text.split("\n")
123
+ lines = [line for line in lines if line != ""]
124
+ count = 0
125
+ for i, line in enumerate(lines):
126
+ if "```" in line:
127
+ count += 1
128
+ items = line.split("`")
129
+ if count % 2 == 1:
130
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
131
+ else:
132
+ lines[i] = f"<br></code></pre>"
133
+ else:
134
+ if i > 0:
135
+ if count % 2 == 1:
136
+ line = line.replace("`", r"\`")
137
+ line = line.replace("<", "&lt;")
138
+ line = line.replace(">", "&gt;")
139
+ line = line.replace(" ", "&nbsp;")
140
+ line = line.replace("*", "&ast;")
141
+ line = line.replace("_", "&lowbar;")
142
+ line = line.replace("-", "&#45;")
143
+ line = line.replace(".", "&#46;")
144
+ line = line.replace("!", "&#33;")
145
+ line = line.replace("(", "&#40;")
146
+ line = line.replace(")", "&#41;")
147
+ line = line.replace("$", "&#36;")
148
+ lines[i] = "<br>" + line
149
+ text = "".join(lines)
150
+ return text
151
+
152
+ def chat(
153
+ self,
154
+ prompt: Union[str, List[Dict]],
155
+ system_prompt: Optional[str] = None,
156
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
157
+ generate_config: Optional[PytorchGenerateConfig] = None,
158
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
159
+ from transformers import TextIteratorStreamer
160
+
161
+ # TODO(codingl2k1): implement stream mode.
162
+ if generate_config and generate_config.get("stream"):
163
+ raise Exception(
164
+ f"Chat with model {self.model_family.model_name} does not support stream."
165
+ )
166
+ if not generate_config:
167
+ generate_config = {}
168
+ from ....thirdparty.llava.conversation import conv_templates
169
+ from ....thirdparty.llava.mm_utils import (
170
+ KeywordsStoppingCriteria,
171
+ tokenizer_image_token,
172
+ )
173
+ from ....thirdparty.llava.model.constants import IMAGE_TOKEN_INDEX
174
+
175
+ # Convert chat history to llava state
176
+ state = conv_templates["mm_default"].copy()
177
+ for message in chat_history or []:
178
+ content = self._message_content_to_yi(message["content"])
179
+ state.append_message(message["role"], content)
180
+ state.append_message(state.roles[0], self._message_content_to_yi(prompt))
181
+ state.append_message(state.roles[1], None)
182
+
183
+ prompt = state.get_prompt()
184
+
185
+ input_ids = (
186
+ tokenizer_image_token(
187
+ prompt, self._tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
188
+ )
189
+ .unsqueeze(0)
190
+ .cuda()
191
+ )
192
+
193
+ images = state.get_images(return_pil=True)
194
+ image = images[0]
195
+
196
+ image_tensor = self._image_processor.preprocess(image, return_tensors="pt")[
197
+ "pixel_values"
198
+ ][0]
199
+
200
+ stop_str = state.sep
201
+ keywords = [stop_str]
202
+ stopping_criteria = KeywordsStoppingCriteria(
203
+ keywords, self._tokenizer, input_ids
204
+ )
205
+ streamer = TextIteratorStreamer(
206
+ self._tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
207
+ )
208
+ top_p = generate_config.get("top_p", 0.7)
209
+ temperature = generate_config.get("temperature", 0.2)
210
+ max_new_tokens = generate_config.get("max_tokens", 512)
211
+ generate_kwargs = {
212
+ "input_ids": input_ids,
213
+ "images": image_tensor.unsqueeze(0).to(dtype=torch.bfloat16).cuda(),
214
+ "streamer": streamer,
215
+ "do_sample": True,
216
+ "top_p": float(top_p),
217
+ "temperature": float(temperature),
218
+ "stopping_criteria": [stopping_criteria],
219
+ "use_cache": True,
220
+ "max_new_tokens": min(int(max_new_tokens), 1536),
221
+ }
222
+ t = Thread(target=self._model.generate, kwargs=generate_kwargs)
223
+ t.start()
224
+
225
+ generated_text = ""
226
+ for new_text in streamer:
227
+ generated_text += new_text
228
+ if generated_text.endswith(stop_str):
229
+ generated_text = generated_text[: -len(stop_str)]
230
+ r = self._parse_text(generated_text)
231
+ return ChatCompletion(
232
+ id="chat" + str(uuid.uuid1()),
233
+ object="chat.completion",
234
+ created=int(time.time()),
235
+ model=self.model_uid,
236
+ choices=[
237
+ ChatCompletionChoice(
238
+ index=0,
239
+ message={"role": "assistant", "content": r},
240
+ finish_reason="stop",
241
+ )
242
+ ],
243
+ usage=CompletionUsage(
244
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
245
+ ),
246
+ )
@@ -14,11 +14,10 @@
14
14
  import functools
15
15
  import json
16
16
  import logging
17
+ import os
17
18
  import time
18
19
  import uuid
19
- from typing import AsyncGenerator, Dict, Iterator, List, Optional, cast
20
-
21
- from xinference.model.llm.llm_family import PromptStyleV1
20
+ from typing import AsyncGenerator, Dict, Iterator, List, Optional, Tuple, cast
22
21
 
23
22
  from ...types import (
24
23
  SPECIAL_TOOL_PROMPT,
@@ -28,6 +27,14 @@ from ...types import (
28
27
  Completion,
29
28
  CompletionChunk,
30
29
  )
30
+ from .llm_family import (
31
+ GgmlLLMSpecV1,
32
+ LLMFamilyV1,
33
+ LLMSpecV1,
34
+ PromptStyleV1,
35
+ _get_cache_dir,
36
+ get_cache_status,
37
+ )
31
38
 
32
39
  logger = logging.getLogger(__name__)
33
40
 
@@ -303,7 +310,7 @@ Begin!"""
303
310
  ret = (
304
311
  "<s>"
305
312
  if prompt_style.system_prompt == ""
306
- else "<s>[UNUSED_TOKEN_146]system\n"
313
+ else "<s><|im_start|>system\n"
307
314
  + prompt_style.system_prompt
308
315
  + prompt_style.intra_message_sep
309
316
  + "\n"
@@ -373,6 +380,20 @@ Begin!"""
373
380
  return f"USER: <<question>> {prompt} <<function>> {tools_string}\nASSISTANT: "
374
381
  else:
375
382
  return f"USER: <<question>> {prompt}\nASSISTANT: "
383
+ elif prompt_style.style_name == "orion":
384
+ ret = "<s>"
385
+ for i, message in enumerate(chat_history):
386
+ content = message["content"]
387
+ role = message["role"]
388
+ if i % 2 == 0: # Human
389
+ assert content is not None
390
+ ret += role + ": " + content + "\n\n"
391
+ else: # Assistant
392
+ if content:
393
+ ret += role + ": </s>" + content + "</s>"
394
+ else:
395
+ ret += role + ": </s>"
396
+ return ret
376
397
  else:
377
398
  raise ValueError(f"Invalid prompt style: {prompt_style.style_name}")
378
399
 
@@ -573,3 +594,35 @@ Begin!"""
573
594
  "total_tokens": -1,
574
595
  },
575
596
  }
597
+
598
+
599
+ def get_file_location(
600
+ llm_family: LLMFamilyV1, spec: LLMSpecV1, quantization: str
601
+ ) -> Tuple[str, bool]:
602
+ cache_dir = _get_cache_dir(llm_family, spec, create_if_not_exist=False)
603
+ cache_status = get_cache_status(llm_family, spec)
604
+ if isinstance(cache_status, list):
605
+ is_cached = None
606
+ for q, cs in zip(spec.quantizations, cache_status):
607
+ if q == quantization:
608
+ is_cached = cs
609
+ break
610
+ else:
611
+ is_cached = cache_status
612
+ assert isinstance(is_cached, bool)
613
+
614
+ if spec.model_format in ["pytorch", "gptq", "awq"]:
615
+ return cache_dir, is_cached
616
+ elif spec.model_format in ["ggmlv3", "ggufv2"]:
617
+ assert isinstance(spec, GgmlLLMSpecV1)
618
+ filename = spec.model_file_name_template.format(quantization=quantization)
619
+ model_path = os.path.join(cache_dir, filename)
620
+ return model_path, is_cached
621
+ else:
622
+ raise ValueError(f"Not supported model format {spec.model_format}")
623
+
624
+
625
+ def get_model_version(
626
+ llm_family: LLMFamilyV1, llm_spec: LLMSpecV1, quantization: str
627
+ ) -> str:
628
+ return f"{llm_family.model_name}--{llm_spec.model_size_in_billions}B--{llm_spec.model_format}--{quantization}"
@@ -95,6 +95,7 @@ VLLM_SUPPORTED_CHAT_MODELS = [
95
95
  "code-llama-instruct",
96
96
  "mistral-instruct-v0.1",
97
97
  "mistral-instruct-v0.2",
98
+ "mixtral-instruct-v0.1",
98
99
  "chatglm3",
99
100
  ]
100
101
 
@@ -190,12 +191,12 @@ class VLLMModel(LLM):
190
191
  return False
191
192
  if not cls._is_linux():
192
193
  return False
193
- if llm_spec.model_format not in ["pytorch", "gptq"]:
194
+ if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
194
195
  return False
195
196
  if llm_spec.model_format == "pytorch":
196
197
  if quantization != "none" and not (quantization is None):
197
198
  return False
198
- if llm_spec.model_format == "gptq":
199
+ if llm_spec.model_format in ["gptq", "awq"]:
199
200
  # Currently, only 4-bit weight quantization is supported for GPTQ, but got 8 bits.
200
201
  if "4" not in quantization:
201
202
  return False
@@ -336,12 +337,12 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
336
337
  ) -> bool:
337
338
  if XINFERENCE_DISABLE_VLLM:
338
339
  return False
339
- if llm_spec.model_format not in ["pytorch", "gptq"]:
340
+ if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
340
341
  return False
341
342
  if llm_spec.model_format == "pytorch":
342
343
  if quantization != "none" and not (quantization is None):
343
344
  return False
344
- if llm_spec.model_format == "gptq":
345
+ if llm_spec.model_format in ["gptq", "awq"]:
345
346
  # Currently, only 4-bit weight quantization is supported for GPTQ, but got 8 bits.
346
347
  if "4" not in quantization:
347
348
  return False
@@ -17,8 +17,20 @@ import json
17
17
  import os
18
18
 
19
19
  from ...constants import XINFERENCE_MODEL_DIR
20
- from .core import MODEL_NAME_TO_REVISION, RerankModelSpec, get_cache_status
21
- from .custom import CustomRerankModelSpec, register_rerank
20
+ from .core import (
21
+ MODEL_NAME_TO_REVISION,
22
+ RERANK_MODEL_DESCRIPTIONS,
23
+ RerankModelSpec,
24
+ generate_rerank_description,
25
+ get_cache_status,
26
+ get_rerank_model_descriptions,
27
+ )
28
+ from .custom import (
29
+ CustomRerankModelSpec,
30
+ get_user_defined_reranks,
31
+ register_rerank,
32
+ unregister_rerank,
33
+ )
22
34
 
23
35
  _model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
24
36
  _model_spec_modelscope_json = os.path.join(
@@ -30,6 +42,7 @@ BUILTIN_RERANK_MODELS = dict(
30
42
  )
31
43
  for model_name, model_spec in BUILTIN_RERANK_MODELS.items():
32
44
  MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
45
+
33
46
  MODELSCOPE_RERANK_MODELS = dict(
34
47
  (spec["model_name"], RerankModelSpec(**spec))
35
48
  for spec in json.load(
@@ -39,6 +52,12 @@ MODELSCOPE_RERANK_MODELS = dict(
39
52
  for model_name, model_spec in MODELSCOPE_RERANK_MODELS.items():
40
53
  MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
41
54
 
55
+ # register model description after recording model revision
56
+ for model_spec_info in [BUILTIN_RERANK_MODELS, MODELSCOPE_RERANK_MODELS]:
57
+ for model_name, model_spec in model_spec_info.items():
58
+ if model_spec.model_name not in RERANK_MODEL_DESCRIPTIONS:
59
+ RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(model_spec))
60
+
42
61
  # if persist=True, load them when init
43
62
  user_defined_rerank_dir = os.path.join(XINFERENCE_MODEL_DIR, "rerank")
44
63
  if os.path.isdir(user_defined_rerank_dir):
@@ -49,5 +68,9 @@ if os.path.isdir(user_defined_rerank_dir):
49
68
  user_defined_rerank_spec = CustomRerankModelSpec.parse_obj(json.load(fd))
50
69
  register_rerank(user_defined_rerank_spec, persist=False)
51
70
 
71
+ # register model description
72
+ for ud_rerank in get_user_defined_reranks():
73
+ RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(ud_rerank))
74
+
52
75
  del _model_spec_json
53
76
  del _model_spec_modelscope_json
@@ -36,6 +36,15 @@ MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
36
36
  SUPPORTED_SCHEMES = ["s3"]
37
37
 
38
38
 
39
+ RERANK_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
40
+
41
+
42
+ def get_rerank_model_descriptions():
43
+ import copy
44
+
45
+ return copy.deepcopy(RERANK_MODEL_DESCRIPTIONS)
46
+
47
+
39
48
  class RerankModelSpec(BaseModel):
40
49
  model_name: str
41
50
  language: List[str]
@@ -50,8 +59,9 @@ class RerankModelDescription(ModelDescription):
50
59
  address: Optional[str],
51
60
  devices: Optional[List[str]],
52
61
  model_spec: RerankModelSpec,
62
+ model_path: Optional[str] = None,
53
63
  ):
54
- super().__init__(address, devices)
64
+ super().__init__(address, devices, model_path=model_path)
55
65
  self._model_spec = model_spec
56
66
 
57
67
  def to_dict(self):
@@ -64,6 +74,31 @@ class RerankModelDescription(ModelDescription):
64
74
  "model_revision": self._model_spec.model_revision,
65
75
  }
66
76
 
77
+ def to_version_info(self):
78
+ from .utils import get_model_version
79
+
80
+ if self._model_path is None:
81
+ is_cached = get_cache_status(self._model_spec)
82
+ file_location = get_cache_dir(self._model_spec)
83
+ else:
84
+ is_cached = True
85
+ file_location = self._model_path
86
+
87
+ return {
88
+ "model_version": get_model_version(self._model_spec),
89
+ "model_file_location": file_location,
90
+ "cache_status": is_cached,
91
+ "language": self._model_spec.language,
92
+ }
93
+
94
+
95
+ def generate_rerank_description(model_spec: RerankModelSpec) -> Dict[str, List[Dict]]:
96
+ res = defaultdict(list)
97
+ res[model_spec.model_name].append(
98
+ RerankModelDescription(None, None, model_spec).to_version_info()
99
+ )
100
+ return res
101
+
67
102
 
68
103
  class RerankModel:
69
104
  def __init__(
@@ -71,12 +106,14 @@ class RerankModel:
71
106
  model_uid: str,
72
107
  model_path: str,
73
108
  device: Optional[str] = None,
109
+ use_fp16: bool = False,
74
110
  model_config: Optional[Dict] = None,
75
111
  ):
76
112
  self._model_uid = model_uid
77
113
  self._model_path = model_path
78
114
  self._device = device
79
115
  self._model_config = model_config or dict()
116
+ self._use_fp16 = use_fp16
80
117
  self._model = None
81
118
 
82
119
  def load(self):
@@ -93,6 +130,8 @@ class RerankModel:
93
130
  self._model = CrossEncoder(
94
131
  self._model_path, device=self._device, **self._model_config
95
132
  )
133
+ if self._use_fp16:
134
+ self._model.model.half()
96
135
 
97
136
  def rerank(
98
137
  self,
@@ -131,6 +170,10 @@ class RerankModel:
131
170
  return Rerank(id=str(uuid.uuid1()), results=docs)
132
171
 
133
172
 
173
+ def get_cache_dir(model_spec: RerankModelSpec):
174
+ return os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name))
175
+
176
+
134
177
  def get_cache_status(
135
178
  model_spec: RerankModelSpec,
136
179
  ) -> bool:
@@ -145,9 +188,7 @@ def cache_from_uri(
145
188
 
146
189
  from ..utils import copy_from_src_to_dst, parse_uri
147
190
 
148
- cache_dir = os.path.realpath(
149
- os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
150
- )
191
+ cache_dir = get_cache_dir(model_spec)
151
192
  if os.path.exists(cache_dir):
152
193
  logger.info(f"Rerank cache {cache_dir} exists")
153
194
  return cache_dir
@@ -227,9 +268,7 @@ def cache(model_spec: RerankModelSpec):
227
268
  logger.info(f"Rerank model caching from URI: {model_spec.model_uri}")
228
269
  return cache_from_uri(model_spec=model_spec)
229
270
 
230
- cache_dir = os.path.realpath(
231
- os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
232
- )
271
+ cache_dir = get_cache_dir(model_spec)
233
272
  if not os.path.exists(cache_dir):
234
273
  os.makedirs(cache_dir, exist_ok=True)
235
274
  meta_path = os.path.join(cache_dir, "__valid_download")
@@ -312,6 +351,9 @@ def create_rerank_model_instance(
312
351
  )
313
352
 
314
353
  model_path = cache(model_spec)
315
- model = RerankModel(model_uid, model_path, **kwargs)
316
- model_description = RerankModelDescription(subpool_addr, devices, model_spec)
354
+ use_fp16 = kwargs.pop("use_fp16", False)
355
+ model = RerankModel(model_uid, model_path, use_fp16=use_fp16, model_config=kwargs)
356
+ model_description = RerankModelDescription(
357
+ subpool_addr, devices, model_spec, model_path=model_path
358
+ )
317
359
  return model, model_description
@@ -10,5 +10,11 @@
10
10
  "language": ["en", "zh"],
11
11
  "model_id": "BAAI/bge-reranker-base",
12
12
  "model_revision": "465b4b7ddf2be0a020c8ad6e525b9bb1dbb708ae"
13
+ },
14
+ {
15
+ "model_name": "bce-reranker-base_v1",
16
+ "language": ["en", "zh"],
17
+ "model_id": "maidalun1020/bce-reranker-base_v1",
18
+ "model_revision": "eaa31a577a0574e87a08959bd229ca14ce1b5496"
13
19
  }
14
20
  ]