xinference 0.8.2__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (53) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +22 -7
  3. xinference/client/restful/restful_client.py +10 -0
  4. xinference/constants.py +14 -4
  5. xinference/core/chat_interface.py +8 -1
  6. xinference/core/resource.py +19 -12
  7. xinference/core/supervisor.py +94 -30
  8. xinference/core/utils.py +29 -1
  9. xinference/core/worker.py +18 -3
  10. xinference/deploy/local.py +2 -2
  11. xinference/deploy/supervisor.py +2 -2
  12. xinference/model/audio/model_spec.json +29 -1
  13. xinference/model/embedding/model_spec.json +24 -0
  14. xinference/model/embedding/model_spec_modelscope.json +24 -0
  15. xinference/model/llm/__init__.py +2 -0
  16. xinference/model/llm/core.py +2 -0
  17. xinference/model/llm/ggml/chatglm.py +15 -6
  18. xinference/model/llm/llm_family.json +56 -0
  19. xinference/model/llm/llm_family_modelscope.json +56 -0
  20. xinference/model/llm/pytorch/chatglm.py +3 -3
  21. xinference/model/llm/pytorch/core.py +1 -0
  22. xinference/model/llm/pytorch/utils.py +21 -9
  23. xinference/model/llm/pytorch/yi_vl.py +246 -0
  24. xinference/model/rerank/core.py +1 -1
  25. xinference/model/rerank/model_spec.json +6 -0
  26. xinference/model/rerank/model_spec_modelscope.json +7 -0
  27. xinference/thirdparty/__init__.py +0 -0
  28. xinference/thirdparty/llava/__init__.py +1 -0
  29. xinference/thirdparty/llava/conversation.py +205 -0
  30. xinference/thirdparty/llava/mm_utils.py +122 -0
  31. xinference/thirdparty/llava/model/__init__.py +1 -0
  32. xinference/thirdparty/llava/model/clip_encoder/__init__.py +0 -0
  33. xinference/thirdparty/llava/model/clip_encoder/builder.py +11 -0
  34. xinference/thirdparty/llava/model/clip_encoder/clip_encoder.py +86 -0
  35. xinference/thirdparty/llava/model/constants.py +6 -0
  36. xinference/thirdparty/llava/model/llava_arch.py +385 -0
  37. xinference/thirdparty/llava/model/llava_llama.py +163 -0
  38. xinference/thirdparty/llava/model/multimodal_projector/__init__.py +0 -0
  39. xinference/thirdparty/llava/model/multimodal_projector/builder.py +64 -0
  40. xinference/types.py +1 -1
  41. xinference/web/ui/build/asset-manifest.json +3 -3
  42. xinference/web/ui/build/index.html +1 -1
  43. xinference/web/ui/build/static/js/{main.abedc3c9.js → main.15822aeb.js} +3 -3
  44. xinference/web/ui/build/static/js/{main.abedc3c9.js.map → main.15822aeb.js.map} +1 -1
  45. xinference/web/ui/node_modules/.cache/babel-loader/65ca3ba225b8c8dac907210545b51f2fcdb2591f0feeb7195f1c037f2bc956a0.json +1 -0
  46. {xinference-0.8.2.dist-info → xinference-0.8.3.dist-info}/METADATA +21 -18
  47. {xinference-0.8.2.dist-info → xinference-0.8.3.dist-info}/RECORD +52 -38
  48. xinference/web/ui/node_modules/.cache/babel-loader/c157e34990b23834b7ad4c13c42962209942c60f8130978c1514f3d085cfaea0.json +0 -1
  49. /xinference/web/ui/build/static/js/{main.abedc3c9.js.LICENSE.txt → main.15822aeb.js.LICENSE.txt} +0 -0
  50. {xinference-0.8.2.dist-info → xinference-0.8.3.dist-info}/LICENSE +0 -0
  51. {xinference-0.8.2.dist-info → xinference-0.8.3.dist-info}/WHEEL +0 -0
  52. {xinference-0.8.2.dist-info → xinference-0.8.3.dist-info}/entry_points.txt +0 -0
  53. {xinference-0.8.2.dist-info → xinference-0.8.3.dist-info}/top_level.txt +0 -0
@@ -58,6 +58,7 @@ def _install():
58
58
  from .pytorch.llama_2 import LlamaPytorchChatModel, LlamaPytorchModel
59
59
  from .pytorch.qwen_vl import QwenVLChatModel
60
60
  from .pytorch.vicuna import VicunaPytorchChatModel
61
+ from .pytorch.yi_vl import YiVLChatModel
61
62
  from .vllm.core import VLLMChatModel, VLLMModel
62
63
 
63
64
  # register llm classes.
@@ -90,6 +91,7 @@ def _install():
90
91
  FalconPytorchModel,
91
92
  Internlm2PytorchChatModel,
92
93
  QwenVLChatModel,
94
+ YiVLChatModel,
93
95
  PytorchModel,
94
96
  ]
95
97
  )
@@ -135,6 +135,8 @@ class LLMDescription(ModelDescription):
135
135
  "model_description": self._llm_family.model_description,
136
136
  "model_format": self._llm_spec.model_format,
137
137
  "model_size_in_billions": self._llm_spec.model_size_in_billions,
138
+ "model_family": self._llm_family.model_family
139
+ or self._llm_family.model_name,
138
140
  "quantization": self._quantization,
139
141
  "model_hub": self._llm_spec.model_hub,
140
142
  "revision": self._llm_spec.model_revision,
@@ -230,20 +230,28 @@ class ChatglmCppChatModel(LLM):
230
230
  ),
231
231
  }
232
232
 
233
+ @staticmethod
234
+ def _to_chatglm_chat_messages(history_list: List[Any]):
235
+ from chatglm_cpp import ChatMessage
236
+
237
+ return [ChatMessage(role=v["role"], content=v["content"]) for v in history_list]
238
+
233
239
  def chat(
234
240
  self,
235
241
  prompt: str,
242
+ system_prompt: Optional[str] = None,
236
243
  chat_history: Optional[List[ChatCompletionMessage]] = None,
237
244
  generate_config: Optional[ChatglmCppGenerateConfig] = None,
238
245
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
246
+ chat_history_list = []
247
+ if system_prompt is not None:
248
+ chat_history_list.append({"role": "system", "content": system_prompt})
239
249
  if chat_history is not None:
240
- chat_history_list = chat_history
241
- else:
242
- chat_history_list = []
250
+ chat_history_list.extend(chat_history) # type: ignore
243
251
 
244
252
  tool_message = self._handle_tools(generate_config)
245
253
  if tool_message is not None:
246
- chat_history_list.insert(0, tool_message)
254
+ chat_history_list.insert(0, tool_message) # type: ignore
247
255
 
248
256
  # We drop the message which contains tool calls to walkaround the issue:
249
257
  # https://github.com/li-plus/chatglm.cpp/issues/231
@@ -276,17 +284,18 @@ class ChatglmCppChatModel(LLM):
276
284
  params = {k: v for k, v in params.items() if v is not None}
277
285
 
278
286
  assert self._llm is not None
287
+ chat_history_messages = self._to_chatglm_chat_messages(chat_history_list)
279
288
 
280
289
  if generate_config["stream"]:
281
290
  it = self._llm.chat(
282
- chat_history_list,
291
+ chat_history_messages,
283
292
  **params,
284
293
  )
285
294
  assert not isinstance(it, str)
286
295
  return self._convert_raw_text_chunks_to_chat(it, self.model_uid)
287
296
  else:
288
297
  c = self._llm.chat(
289
- chat_history_list,
298
+ chat_history_messages,
290
299
  **params,
291
300
  )
292
301
  assert not isinstance(c, Iterator)
@@ -3346,5 +3346,61 @@
3346
3346
  "<unk>"
3347
3347
  ]
3348
3348
  }
3349
+ },
3350
+ {
3351
+ "version": 1,
3352
+ "context_length": 204800,
3353
+ "model_name": "yi-vl-chat",
3354
+ "model_lang": [
3355
+ "en",
3356
+ "zh"
3357
+ ],
3358
+ "model_ability": [
3359
+ "chat",
3360
+ "vision"
3361
+ ],
3362
+ "model_description": "Yi Vision Language (Yi-VL) model is the open-source, multimodal version of the Yi Large Language Model (LLM) series, enabling content comprehension, recognition, and multi-round conversations about images.",
3363
+ "model_specs": [
3364
+ {
3365
+ "model_format": "pytorch",
3366
+ "model_size_in_billions": 6,
3367
+ "quantizations": [
3368
+ "none"
3369
+ ],
3370
+ "model_id": "01-ai/Yi-VL-6B",
3371
+ "model_revision": "897c938da1ec860330e2ba2d425ab3004495ba38"
3372
+ },
3373
+ {
3374
+ "model_format": "pytorch",
3375
+ "model_size_in_billions": 34,
3376
+ "quantizations": [
3377
+ "none"
3378
+ ],
3379
+ "model_id": "01-ai/Yi-VL-34B",
3380
+ "model_revision": "ea29a9a430f27893e780366dae81d4ca5ebab561"
3381
+ }
3382
+ ],
3383
+ "prompt_style": {
3384
+ "style_name": "CHATML",
3385
+ "system_prompt": "",
3386
+ "roles": [
3387
+ "<|im_start|>user",
3388
+ "<|im_start|>assistant"
3389
+ ],
3390
+ "intra_message_sep": "<|im_end|>",
3391
+ "inter_message_sep": "",
3392
+ "stop_token_ids": [
3393
+ 2,
3394
+ 6,
3395
+ 7,
3396
+ 8
3397
+ ],
3398
+ "stop": [
3399
+ "<|endoftext|>",
3400
+ "<|im_start|>",
3401
+ "<|im_end|>",
3402
+ "<|im_sep|>"
3403
+ ]
3404
+ }
3349
3405
  }
3350
3406
  ]
@@ -1957,5 +1957,61 @@
1957
1957
  "<unk>"
1958
1958
  ]
1959
1959
  }
1960
+ },
1961
+ {
1962
+ "version": 1,
1963
+ "context_length": 204800,
1964
+ "model_name": "yi-vl-chat",
1965
+ "model_lang": [
1966
+ "en",
1967
+ "zh"
1968
+ ],
1969
+ "model_ability": [
1970
+ "chat",
1971
+ "vision"
1972
+ ],
1973
+ "model_description": "Yi Vision Language (Yi-VL) model is the open-source, multimodal version of the Yi Large Language Model (LLM) series, enabling content comprehension, recognition, and multi-round conversations about images.",
1974
+ "model_specs": [
1975
+ {
1976
+ "model_format": "pytorch",
1977
+ "model_size_in_billions": 6,
1978
+ "quantizations": [
1979
+ "none"
1980
+ ],
1981
+ "model_hub": "modelscope",
1982
+ "model_id": "01ai/Yi-VL-6B"
1983
+ },
1984
+ {
1985
+ "model_format": "pytorch",
1986
+ "model_size_in_billions": 34,
1987
+ "quantizations": [
1988
+ "none"
1989
+ ],
1990
+ "model_hub": "modelscope",
1991
+ "model_id": "01ai/Yi-VL-34B"
1992
+ }
1993
+ ],
1994
+ "prompt_style": {
1995
+ "style_name": "CHATML",
1996
+ "system_prompt": "",
1997
+ "roles": [
1998
+ "<|im_start|>user",
1999
+ "<|im_start|>assistant"
2000
+ ],
2001
+ "intra_message_sep": "<|im_end|>",
2002
+ "inter_message_sep": "",
2003
+ "stop_token_ids": [
2004
+ 2,
2005
+ 6,
2006
+ 7,
2007
+ 8
2008
+ ],
2009
+ "stop": [
2010
+ "<|endoftext|>",
2011
+ "<|im_start|>",
2012
+ "<|im_end|>",
2013
+ "<|im_sep|>"
2014
+ ]
2015
+ }
1960
2016
  }
1961
2017
  ]
@@ -120,9 +120,9 @@ class ChatglmPytorchChatModel(PytorchChatModel):
120
120
  top_p = generate_config.get("top_p")
121
121
  if top_p is not None:
122
122
  kwargs["top_p"] = float(top_p)
123
- max_length = generate_config.get("max_tokens")
124
- if max_length is not None:
125
- kwargs["max_length"] = int(max_length)
123
+ max_new_tokens = generate_config.get("max_tokens")
124
+ if max_new_tokens is not None:
125
+ kwargs["max_new_tokens"] = int(max_new_tokens)
126
126
  # Tool calls only works for non stream, so we call chat directly.
127
127
  if prompt == SPECIAL_TOOL_PROMPT and chat_history:
128
128
  tool_message = chat_history.pop()
@@ -423,6 +423,7 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
423
423
  "llama-2-chat",
424
424
  "internlm2-chat",
425
425
  "qwen-vl-chat",
426
+ "yi-vl-chat",
426
427
  ]:
427
428
  return False
428
429
  if "chat" not in llm_family.model_ability:
@@ -29,7 +29,12 @@ from transformers.generation.logits_process import (
29
29
  TopPLogitsWarper,
30
30
  )
31
31
 
32
- from ....types import CompletionChoice, CompletionChunk, CompletionUsage
32
+ from ....types import (
33
+ CompletionChoice,
34
+ CompletionChunk,
35
+ CompletionUsage,
36
+ max_tokens_field,
37
+ )
33
38
 
34
39
  logger = logging.getLogger(__name__)
35
40
 
@@ -54,16 +59,21 @@ def get_context_length(config):
54
59
  hasattr(config, "max_sequence_length")
55
60
  and config.max_sequence_length is not None
56
61
  ):
57
- return config.max_sequence_length
58
- elif hasattr(config, "seq_length") and config.seq_length is not None:
59
- return config.seq_length
60
- elif (
62
+ max_sequence_length = config.max_sequence_length
63
+ else:
64
+ max_sequence_length = 2048
65
+ if hasattr(config, "seq_length") and config.seq_length is not None:
66
+ seq_length = config.seq_length
67
+ else:
68
+ seq_length = 2048
69
+ if (
61
70
  hasattr(config, "max_position_embeddings")
62
71
  and config.max_position_embeddings is not None
63
72
  ):
64
- return config.max_position_embeddings
73
+ max_position_embeddings = config.max_position_embeddings
65
74
  else:
66
- return 2048
75
+ max_position_embeddings = 2048
76
+ return max(max_sequence_length, seq_length, max_position_embeddings)
67
77
 
68
78
 
69
79
  def prepare_logits_processor(
@@ -102,7 +112,7 @@ def generate_stream(
102
112
  repetition_penalty = float(generate_config.get("repetition_penalty", 1.0))
103
113
  top_p = float(generate_config.get("top_p", 1.0))
104
114
  top_k = int(generate_config.get("top_k", -1)) # -1 means disable
105
- max_new_tokens = int(generate_config.get("max_tokens", 256))
115
+ max_new_tokens = int(generate_config.get("max_tokens", max_tokens_field.default))
106
116
  echo = bool(generate_config.get("echo", False))
107
117
  stop_str = generate_config.get("stop", None)
108
118
  stop_token_ids = generate_config.get("stop_token_ids", None) or []
@@ -123,6 +133,8 @@ def generate_stream(
123
133
  max_src_len = context_len
124
134
  else:
125
135
  max_src_len = context_len - max_new_tokens - 8
136
+ if max_src_len < 0:
137
+ raise ValueError("Max tokens exceeds model's max length")
126
138
 
127
139
  input_ids = input_ids[-max_src_len:]
128
140
  input_echo_len = len(input_ids)
@@ -346,7 +358,7 @@ def generate_stream_falcon(
346
358
  repetition_penalty = float(generate_config.get("repetition_penalty", 1.0))
347
359
  top_p = float(generate_config.get("top_p", 1.0))
348
360
  top_k = int(generate_config.get("top_k", 50)) # -1 means disable
349
- max_new_tokens = int(generate_config.get("max_tokens", 256))
361
+ max_new_tokens = int(generate_config.get("max_tokens", max_tokens_field.default))
350
362
  echo = bool(generate_config.get("echo", False))
351
363
  stop_str = generate_config.get("stop", None)
352
364
  stop_token_ids = generate_config.get("stop_token_ids", None) or []
@@ -0,0 +1,246 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import base64
15
+ import logging
16
+ import time
17
+ import uuid
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ from io import BytesIO
20
+ from threading import Thread
21
+ from typing import Dict, Iterator, List, Optional, Union
22
+
23
+ import requests
24
+ import torch
25
+ from PIL import Image
26
+
27
+ from ....model.utils import select_device
28
+ from ....types import (
29
+ ChatCompletion,
30
+ ChatCompletionChoice,
31
+ ChatCompletionChunk,
32
+ ChatCompletionMessage,
33
+ CompletionUsage,
34
+ )
35
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
36
+ from .core import PytorchChatModel, PytorchGenerateConfig
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ class YiVLChatModel(PytorchChatModel):
42
+ def __init__(self, *args, **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+ self._tokenizer = None
45
+ self._model = None
46
+ self._image_processor = None
47
+
48
+ @classmethod
49
+ def match(
50
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
51
+ ) -> bool:
52
+ if "yi" in model_family.model_name:
53
+ return True
54
+ return False
55
+
56
+ def load(self):
57
+ from ....thirdparty.llava.mm_utils import load_pretrained_model
58
+ from ....thirdparty.llava.model.constants import key_info
59
+
60
+ device = self._pytorch_model_config.get("device", "auto")
61
+ device = select_device(device)
62
+
63
+ key_info["model_path"] = self.model_path
64
+ (
65
+ self._tokenizer,
66
+ self._model,
67
+ self._image_processor,
68
+ _,
69
+ ) = load_pretrained_model(self.model_path, device_map=device)
70
+
71
+ @staticmethod
72
+ def _message_content_to_yi(content) -> Union[str, tuple]:
73
+ def _load_image(_url):
74
+ if _url.startswith("data:"):
75
+ logging.info("Parse url by base64 decoder.")
76
+ # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
77
+ # e.g. f"data:image/jpeg;base64,{base64_image}"
78
+ _type, data = _url.split(";")
79
+ _, ext = _type.split("/")
80
+ data = data[len("base64,") :]
81
+ data = base64.b64decode(data.encode("utf-8"))
82
+
83
+ return Image.open(BytesIO(data))
84
+ else:
85
+ try:
86
+ response = requests.get(_url)
87
+ except requests.exceptions.MissingSchema:
88
+ return Image.open(_url)
89
+ else:
90
+ return Image.open(BytesIO(response.content))
91
+
92
+ if not isinstance(content, str):
93
+ from ....thirdparty.llava.model.constants import DEFAULT_IMAGE_TOKEN
94
+
95
+ texts = []
96
+ image_urls = []
97
+ for c in content:
98
+ c_type = c.get("type")
99
+ if c_type == "text":
100
+ texts.append(c["text"])
101
+ elif c_type == "image_url":
102
+ image_urls.append(c["image_url"]["url"])
103
+ image_futures = []
104
+ with ThreadPoolExecutor() as executor:
105
+ for image_url in image_urls:
106
+ fut = executor.submit(_load_image, image_url)
107
+ image_futures.append(fut)
108
+ images = [fut.result() for fut in image_futures]
109
+ text = " ".join(texts)
110
+ if DEFAULT_IMAGE_TOKEN not in text:
111
+ text = DEFAULT_IMAGE_TOKEN + "\n" + text
112
+ if len(images) == 0:
113
+ return text
114
+ elif len(images) == 1:
115
+ return text, images[0], "Pad"
116
+ else:
117
+ raise RuntimeError("Only one image per message is supported by Yi VL.")
118
+ return content
119
+
120
+ @staticmethod
121
+ def _parse_text(text):
122
+ lines = text.split("\n")
123
+ lines = [line for line in lines if line != ""]
124
+ count = 0
125
+ for i, line in enumerate(lines):
126
+ if "```" in line:
127
+ count += 1
128
+ items = line.split("`")
129
+ if count % 2 == 1:
130
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
131
+ else:
132
+ lines[i] = f"<br></code></pre>"
133
+ else:
134
+ if i > 0:
135
+ if count % 2 == 1:
136
+ line = line.replace("`", r"\`")
137
+ line = line.replace("<", "&lt;")
138
+ line = line.replace(">", "&gt;")
139
+ line = line.replace(" ", "&nbsp;")
140
+ line = line.replace("*", "&ast;")
141
+ line = line.replace("_", "&lowbar;")
142
+ line = line.replace("-", "&#45;")
143
+ line = line.replace(".", "&#46;")
144
+ line = line.replace("!", "&#33;")
145
+ line = line.replace("(", "&#40;")
146
+ line = line.replace(")", "&#41;")
147
+ line = line.replace("$", "&#36;")
148
+ lines[i] = "<br>" + line
149
+ text = "".join(lines)
150
+ return text
151
+
152
+ def chat(
153
+ self,
154
+ prompt: Union[str, List[Dict]],
155
+ system_prompt: Optional[str] = None,
156
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
157
+ generate_config: Optional[PytorchGenerateConfig] = None,
158
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
159
+ from transformers import TextIteratorStreamer
160
+
161
+ # TODO(codingl2k1): implement stream mode.
162
+ if generate_config and generate_config.get("stream"):
163
+ raise Exception(
164
+ f"Chat with model {self.model_family.model_name} does not support stream."
165
+ )
166
+ if not generate_config:
167
+ generate_config = {}
168
+ from ....thirdparty.llava.conversation import conv_templates
169
+ from ....thirdparty.llava.mm_utils import (
170
+ KeywordsStoppingCriteria,
171
+ tokenizer_image_token,
172
+ )
173
+ from ....thirdparty.llava.model.constants import IMAGE_TOKEN_INDEX
174
+
175
+ # Convert chat history to llava state
176
+ state = conv_templates["mm_default"].copy()
177
+ for message in chat_history or []:
178
+ content = self._message_content_to_yi(message["content"])
179
+ state.append_message(message["role"], content)
180
+ state.append_message(state.roles[0], self._message_content_to_yi(prompt))
181
+ state.append_message(state.roles[1], None)
182
+
183
+ prompt = state.get_prompt()
184
+
185
+ input_ids = (
186
+ tokenizer_image_token(
187
+ prompt, self._tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
188
+ )
189
+ .unsqueeze(0)
190
+ .cuda()
191
+ )
192
+
193
+ images = state.get_images(return_pil=True)
194
+ image = images[0]
195
+
196
+ image_tensor = self._image_processor.preprocess(image, return_tensors="pt")[
197
+ "pixel_values"
198
+ ][0]
199
+
200
+ stop_str = state.sep
201
+ keywords = [stop_str]
202
+ stopping_criteria = KeywordsStoppingCriteria(
203
+ keywords, self._tokenizer, input_ids
204
+ )
205
+ streamer = TextIteratorStreamer(
206
+ self._tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
207
+ )
208
+ top_p = generate_config.get("top_p", 0.7)
209
+ temperature = generate_config.get("temperature", 0.2)
210
+ max_new_tokens = generate_config.get("max_tokens", 512)
211
+ generate_kwargs = {
212
+ "input_ids": input_ids,
213
+ "images": image_tensor.unsqueeze(0).to(dtype=torch.bfloat16).cuda(),
214
+ "streamer": streamer,
215
+ "do_sample": True,
216
+ "top_p": float(top_p),
217
+ "temperature": float(temperature),
218
+ "stopping_criteria": [stopping_criteria],
219
+ "use_cache": True,
220
+ "max_new_tokens": min(int(max_new_tokens), 1536),
221
+ }
222
+ t = Thread(target=self._model.generate, kwargs=generate_kwargs)
223
+ t.start()
224
+
225
+ generated_text = ""
226
+ for new_text in streamer:
227
+ generated_text += new_text
228
+ if generated_text.endswith(stop_str):
229
+ generated_text = generated_text[: -len(stop_str)]
230
+ r = self._parse_text(generated_text)
231
+ return ChatCompletion(
232
+ id="chat" + str(uuid.uuid1()),
233
+ object="chat.completion",
234
+ created=int(time.time()),
235
+ model=self.model_uid,
236
+ choices=[
237
+ ChatCompletionChoice(
238
+ index=0,
239
+ message={"role": "assistant", "content": r},
240
+ finish_reason="stop",
241
+ )
242
+ ],
243
+ usage=CompletionUsage(
244
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
245
+ ),
246
+ )
@@ -128,7 +128,7 @@ class RerankModel:
128
128
 
129
129
  raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
130
130
  self._model = CrossEncoder(
131
- self._model_path, device=self._device, automodel_args=self._model_config
131
+ self._model_path, device=self._device, **self._model_config
132
132
  )
133
133
  if self._use_fp16:
134
134
  self._model.model.half()
@@ -10,5 +10,11 @@
10
10
  "language": ["en", "zh"],
11
11
  "model_id": "BAAI/bge-reranker-base",
12
12
  "model_revision": "465b4b7ddf2be0a020c8ad6e525b9bb1dbb708ae"
13
+ },
14
+ {
15
+ "model_name": "bce-reranker-base_v1",
16
+ "language": ["en", "zh"],
17
+ "model_id": "maidalun1020/bce-reranker-base_v1",
18
+ "model_revision": "eaa31a577a0574e87a08959bd229ca14ce1b5496"
13
19
  }
14
20
  ]
@@ -12,5 +12,12 @@
12
12
  "model_id": "Xorbits/bge-reranker-large",
13
13
  "model_revision": "v0.0.1",
14
14
  "model_hub": "modelscope"
15
+ },
16
+ {
17
+ "model_name": "bce-reranker-base_v1",
18
+ "language": ["en", "zh"],
19
+ "model_id": "maidalun/bce-reranker-base_v1",
20
+ "model_revision": "v0.0.1",
21
+ "model_hub": "modelscope"
15
22
  }
16
23
  ]
File without changes
@@ -0,0 +1 @@
1
+ from .model import LlavaLlamaForCausalLM