xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +51 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +209 -40
- xinference/client/restful/restful_client.py +7 -26
- xinference/conftest.py +1 -1
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +1 -1
- xinference/core/chat_interface.py +8 -14
- xinference/core/event.py +1 -1
- xinference/core/image_interface.py +28 -0
- xinference/core/model.py +110 -31
- xinference/core/scheduler.py +37 -37
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +17 -10
- xinference/core/utils.py +80 -22
- xinference/core/worker.py +17 -16
- xinference/deploy/cmdline.py +8 -16
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/deploy/utils.py +1 -1
- xinference/deploy/worker.py +1 -1
- xinference/model/audio/cosyvoice.py +86 -41
- xinference/model/audio/fish_speech.py +9 -9
- xinference/model/audio/model_spec.json +9 -9
- xinference/model/audio/whisper.py +4 -1
- xinference/model/embedding/core.py +52 -31
- xinference/model/image/core.py +2 -1
- xinference/model/image/model_spec.json +16 -4
- xinference/model/image/model_spec_modelscope.json +16 -4
- xinference/model/image/sdapi.py +136 -0
- xinference/model/image/stable_diffusion/core.py +164 -19
- xinference/model/llm/__init__.py +29 -11
- xinference/model/llm/llama_cpp/core.py +16 -33
- xinference/model/llm/llm_family.json +1011 -1296
- xinference/model/llm/llm_family.py +34 -53
- xinference/model/llm/llm_family_csghub.json +18 -35
- xinference/model/llm/llm_family_modelscope.json +981 -1122
- xinference/model/llm/lmdeploy/core.py +56 -88
- xinference/model/llm/mlx/core.py +46 -69
- xinference/model/llm/sglang/core.py +36 -18
- xinference/model/llm/transformers/chatglm.py +168 -306
- xinference/model/llm/transformers/cogvlm2.py +36 -63
- xinference/model/llm/transformers/cogvlm2_video.py +33 -223
- xinference/model/llm/transformers/core.py +55 -50
- xinference/model/llm/transformers/deepseek_v2.py +340 -0
- xinference/model/llm/transformers/deepseek_vl.py +53 -96
- xinference/model/llm/transformers/glm4v.py +55 -111
- xinference/model/llm/transformers/intern_vl.py +39 -70
- xinference/model/llm/transformers/internlm2.py +32 -54
- xinference/model/llm/transformers/minicpmv25.py +22 -55
- xinference/model/llm/transformers/minicpmv26.py +158 -68
- xinference/model/llm/transformers/omnilmm.py +5 -28
- xinference/model/llm/transformers/qwen2_audio.py +168 -0
- xinference/model/llm/transformers/qwen2_vl.py +234 -0
- xinference/model/llm/transformers/qwen_vl.py +34 -86
- xinference/model/llm/transformers/utils.py +32 -38
- xinference/model/llm/transformers/yi_vl.py +32 -72
- xinference/model/llm/utils.py +280 -554
- xinference/model/llm/vllm/core.py +161 -100
- xinference/model/rerank/core.py +41 -8
- xinference/model/rerank/model_spec.json +7 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -1
- xinference/model/utils.py +1 -31
- xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
- xinference/thirdparty/cosyvoice/cli/model.py +139 -26
- xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
- xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
- xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
- xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
- xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
- xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
- xinference/thirdparty/cosyvoice/utils/common.py +36 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
- xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
- xinference/thirdparty/fish_speech/tools/api.py +79 -134
- xinference/thirdparty/fish_speech/tools/commons.py +35 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
- xinference/thirdparty/fish_speech/tools/file.py +17 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
- xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
- xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
- xinference/thirdparty/fish_speech/tools/webui.py +12 -146
- xinference/thirdparty/matcha/VERSION +1 -0
- xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
- xinference/thirdparty/matcha/hifigan/README.md +101 -0
- xinference/thirdparty/omnilmm/LICENSE +201 -0
- xinference/thirdparty/whisper/__init__.py +156 -0
- xinference/thirdparty/whisper/__main__.py +3 -0
- xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
- xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
- xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
- xinference/thirdparty/whisper/audio.py +157 -0
- xinference/thirdparty/whisper/decoding.py +826 -0
- xinference/thirdparty/whisper/model.py +314 -0
- xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
- xinference/thirdparty/whisper/normalizers/basic.py +76 -0
- xinference/thirdparty/whisper/normalizers/english.json +1741 -0
- xinference/thirdparty/whisper/normalizers/english.py +550 -0
- xinference/thirdparty/whisper/timing.py +386 -0
- xinference/thirdparty/whisper/tokenizer.py +395 -0
- xinference/thirdparty/whisper/transcribe.py +605 -0
- xinference/thirdparty/whisper/triton_ops.py +109 -0
- xinference/thirdparty/whisper/utils.py +316 -0
- xinference/thirdparty/whisper/version.py +1 -0
- xinference/types.py +14 -53
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
- xinference/web/ui/build/static/js/main.754740c0.js +3 -0
- xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
- xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +37 -0
- xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
- xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
- xinference/web/ui/node_modules/nunjucks/package.json +112 -0
- xinference/web/ui/package-lock.json +38 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
- xinference/model/llm/transformers/llama_2.py +0 -108
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
- xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
- xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
- xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
- xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
|
-
import json
|
|
17
16
|
import logging
|
|
18
17
|
import multiprocessing
|
|
19
18
|
import os
|
|
@@ -24,9 +23,9 @@ from typing import (
|
|
|
24
23
|
Any,
|
|
25
24
|
AsyncGenerator,
|
|
26
25
|
Dict,
|
|
27
|
-
Iterable,
|
|
28
26
|
List,
|
|
29
27
|
Optional,
|
|
28
|
+
Tuple,
|
|
30
29
|
TypedDict,
|
|
31
30
|
Union,
|
|
32
31
|
)
|
|
@@ -34,18 +33,20 @@ from typing import (
|
|
|
34
33
|
from ....types import (
|
|
35
34
|
ChatCompletion,
|
|
36
35
|
ChatCompletionChunk,
|
|
37
|
-
ChatCompletionMessage,
|
|
38
36
|
Completion,
|
|
39
37
|
CompletionChoice,
|
|
40
38
|
CompletionChunk,
|
|
41
39
|
CompletionUsage,
|
|
42
40
|
LoRA,
|
|
43
|
-
ToolCallFunction,
|
|
44
|
-
ToolCalls,
|
|
45
41
|
)
|
|
46
42
|
from .. import LLM, LLMFamilyV1, LLMSpecV1
|
|
47
43
|
from ..llm_family import CustomLLMFamilyV1
|
|
48
|
-
from ..utils import
|
|
44
|
+
from ..utils import (
|
|
45
|
+
QWEN_TOOL_CALL_FAMILY,
|
|
46
|
+
QWEN_TOOL_CALL_SYMBOLS,
|
|
47
|
+
ChatModelMixin,
|
|
48
|
+
generate_completion_chunk,
|
|
49
|
+
)
|
|
49
50
|
|
|
50
51
|
logger = logging.getLogger(__name__)
|
|
51
52
|
|
|
@@ -103,6 +104,7 @@ VLLM_SUPPORTED_MODELS = [
|
|
|
103
104
|
"code-llama-python",
|
|
104
105
|
"deepseek",
|
|
105
106
|
"deepseek-coder",
|
|
107
|
+
"yi-coder",
|
|
106
108
|
]
|
|
107
109
|
VLLM_SUPPORTED_CHAT_MODELS = [
|
|
108
110
|
"llama-2-chat",
|
|
@@ -129,6 +131,7 @@ VLLM_SUPPORTED_CHAT_MODELS = [
|
|
|
129
131
|
"codegeex4",
|
|
130
132
|
"deepseek-chat",
|
|
131
133
|
"deepseek-coder-instruct",
|
|
134
|
+
"yi-coder-chat",
|
|
132
135
|
]
|
|
133
136
|
if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
|
|
134
137
|
VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-chat")
|
|
@@ -148,6 +151,12 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.4.0":
|
|
|
148
151
|
VLLM_SUPPORTED_CHAT_MODELS.append("qwen2-moe-instruct")
|
|
149
152
|
VLLM_SUPPORTED_CHAT_MODELS.append("c4ai-command-r-v01")
|
|
150
153
|
|
|
154
|
+
if VLLM_INSTALLED and vllm.__version__ >= "0.5.1":
|
|
155
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("deepseek-v2-chat")
|
|
156
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("deepseek-v2-chat-0628")
|
|
157
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("deepseek-v2.5")
|
|
158
|
+
|
|
159
|
+
|
|
151
160
|
if VLLM_INSTALLED and vllm.__version__ >= "0.5.3":
|
|
152
161
|
VLLM_SUPPORTED_CHAT_MODELS.append("gemma-2-it")
|
|
153
162
|
VLLM_SUPPORTED_CHAT_MODELS.append("mistral-nemo-instruct")
|
|
@@ -363,23 +372,28 @@ class VLLMModel(LLM):
|
|
|
363
372
|
@staticmethod
|
|
364
373
|
def _convert_request_output_to_completion_chunk(
|
|
365
374
|
request_id: str, model: str, request_output: "RequestOutput"
|
|
366
|
-
) -> CompletionChunk:
|
|
375
|
+
) -> Tuple[CompletionChunk, Optional[str]]:
|
|
367
376
|
choices: List[CompletionChoice] = []
|
|
377
|
+
finish_reason = None
|
|
368
378
|
for output in request_output.outputs:
|
|
369
379
|
choices.append(
|
|
370
380
|
CompletionChoice(
|
|
371
381
|
text=output.text,
|
|
372
382
|
index=output.index,
|
|
373
383
|
logprobs=None, # TODO: support logprobs.
|
|
374
|
-
finish_reason=
|
|
384
|
+
finish_reason=None,
|
|
375
385
|
)
|
|
376
386
|
)
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
387
|
+
finish_reason = output.finish_reason
|
|
388
|
+
return (
|
|
389
|
+
CompletionChunk(
|
|
390
|
+
id=request_id,
|
|
391
|
+
object="text_completion",
|
|
392
|
+
created=int(time.time()),
|
|
393
|
+
model=model,
|
|
394
|
+
choices=choices,
|
|
395
|
+
),
|
|
396
|
+
finish_reason,
|
|
383
397
|
)
|
|
384
398
|
|
|
385
399
|
@staticmethod
|
|
@@ -420,6 +434,7 @@ class VLLMModel(LLM):
|
|
|
420
434
|
prompt: Union[str, Dict[str, Any]],
|
|
421
435
|
generate_config: Optional[Dict] = None,
|
|
422
436
|
tools: object = False,
|
|
437
|
+
request_id: Optional[str] = None,
|
|
423
438
|
) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]:
|
|
424
439
|
try:
|
|
425
440
|
from vllm.sampling_params import SamplingParams
|
|
@@ -454,7 +469,8 @@ class VLLMModel(LLM):
|
|
|
454
469
|
else False
|
|
455
470
|
)
|
|
456
471
|
sampling_params = SamplingParams(**sanitized_generate_config)
|
|
457
|
-
|
|
472
|
+
if not request_id:
|
|
473
|
+
request_id = str(uuid.uuid1())
|
|
458
474
|
|
|
459
475
|
assert self._engine is not None
|
|
460
476
|
results_generator = self._engine.generate(
|
|
@@ -463,10 +479,14 @@ class VLLMModel(LLM):
|
|
|
463
479
|
|
|
464
480
|
async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
|
|
465
481
|
previous_texts = [""] * sanitized_generate_config["n"]
|
|
466
|
-
tools_token_filter = ChatModelMixin._tools_token_filter(self.model_family)
|
|
467
482
|
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
|
483
|
+
complete_response = ""
|
|
484
|
+
match_tool_call_tmp_results = []
|
|
485
|
+
is_match_tool_call = False
|
|
486
|
+
chunk = None
|
|
487
|
+
finish_reason = None
|
|
468
488
|
async for _request_output in results_generator:
|
|
469
|
-
chunk = self._convert_request_output_to_completion_chunk(
|
|
489
|
+
chunk, finish_reason = self._convert_request_output_to_completion_chunk(
|
|
470
490
|
request_id=request_id,
|
|
471
491
|
model=self.model_uid,
|
|
472
492
|
request_output=_request_output,
|
|
@@ -476,40 +496,8 @@ class VLLMModel(LLM):
|
|
|
476
496
|
delta = choice["text"][len(previous_texts[i]) :]
|
|
477
497
|
previous_texts[i] = choice["text"]
|
|
478
498
|
choice["text"] = delta
|
|
499
|
+
complete_response += delta
|
|
479
500
|
|
|
480
|
-
if tools:
|
|
481
|
-
# only handle the first choice
|
|
482
|
-
choice = chunk["choices"][0]
|
|
483
|
-
if choice["finish_reason"] is not None:
|
|
484
|
-
# use previous text for evaluation temporarily
|
|
485
|
-
choice_delta = choice["text"]
|
|
486
|
-
choice["text"] = previous_texts[0]
|
|
487
|
-
_content, func, args = ChatModelMixin._eval_tool_arguments(
|
|
488
|
-
self.model_family, chunk, tools
|
|
489
|
-
)
|
|
490
|
-
choice["text"] = tools_token_filter(
|
|
491
|
-
tokens=previous_texts[0], delta=choice_delta
|
|
492
|
-
)
|
|
493
|
-
if func is not None:
|
|
494
|
-
choice["text"] = None
|
|
495
|
-
choice["finish_reason"] = "tool_calls"
|
|
496
|
-
choice["tool_calls"] = [
|
|
497
|
-
ToolCalls(
|
|
498
|
-
id=str(uuid.uuid4()),
|
|
499
|
-
type="function",
|
|
500
|
-
function=ToolCallFunction(
|
|
501
|
-
name=func,
|
|
502
|
-
arguments=json.dumps(args, ensure_ascii=False),
|
|
503
|
-
),
|
|
504
|
-
)
|
|
505
|
-
]
|
|
506
|
-
else:
|
|
507
|
-
# use a filter function to skip Qwen's react thought process
|
|
508
|
-
choice["text"] = tools_token_filter(
|
|
509
|
-
tokens=previous_texts[0], delta=choice["text"]
|
|
510
|
-
)
|
|
511
|
-
if not choice["text"]:
|
|
512
|
-
continue
|
|
513
501
|
prompt_tokens = len(_request_output.prompt_token_ids)
|
|
514
502
|
completion_tokens = sum(
|
|
515
503
|
len(output.token_ids) for output in _request_output.outputs
|
|
@@ -520,7 +508,59 @@ class VLLMModel(LLM):
|
|
|
520
508
|
completion_tokens=completion_tokens,
|
|
521
509
|
total_tokens=total_tokens,
|
|
522
510
|
)
|
|
511
|
+
|
|
512
|
+
if tools:
|
|
513
|
+
"""
|
|
514
|
+
The qwen2 tool call returns format like this:
|
|
515
|
+
<tool_call>
|
|
516
|
+
{...}
|
|
517
|
+
</tool_call>
|
|
518
|
+
Here is to match this.
|
|
519
|
+
"""
|
|
520
|
+
if (len(QWEN_TOOL_CALL_SYMBOLS[0]) > len(complete_response)) and (
|
|
521
|
+
not QWEN_TOOL_CALL_SYMBOLS[0].startswith(complete_response)
|
|
522
|
+
):
|
|
523
|
+
for c in match_tool_call_tmp_results:
|
|
524
|
+
yield c
|
|
525
|
+
match_tool_call_tmp_results.clear()
|
|
526
|
+
yield chunk
|
|
527
|
+
elif (len(QWEN_TOOL_CALL_SYMBOLS[0]) > len(complete_response)) and (
|
|
528
|
+
QWEN_TOOL_CALL_SYMBOLS[0].startswith(complete_response)
|
|
529
|
+
):
|
|
530
|
+
match_tool_call_tmp_results.append(chunk)
|
|
531
|
+
else:
|
|
532
|
+
assert len(QWEN_TOOL_CALL_SYMBOLS[0]) <= len(complete_response)
|
|
533
|
+
if not is_match_tool_call and complete_response.startswith(
|
|
534
|
+
QWEN_TOOL_CALL_SYMBOLS[0]
|
|
535
|
+
):
|
|
536
|
+
is_match_tool_call = True
|
|
537
|
+
match_tool_call_tmp_results.clear()
|
|
538
|
+
|
|
539
|
+
if not is_match_tool_call:
|
|
540
|
+
for c in match_tool_call_tmp_results:
|
|
541
|
+
yield c
|
|
542
|
+
match_tool_call_tmp_results.clear()
|
|
543
|
+
yield chunk
|
|
544
|
+
else:
|
|
545
|
+
chunk["choices"][0]["text"] = complete_response
|
|
546
|
+
else:
|
|
547
|
+
yield chunk
|
|
548
|
+
|
|
549
|
+
if is_match_tool_call:
|
|
550
|
+
assert chunk is not None
|
|
523
551
|
yield chunk
|
|
552
|
+
|
|
553
|
+
# match OpenAI API stream
|
|
554
|
+
yield generate_completion_chunk(
|
|
555
|
+
chunk_text="",
|
|
556
|
+
finish_reason=finish_reason,
|
|
557
|
+
chunk_id=request_id,
|
|
558
|
+
model_uid=self.model_uid,
|
|
559
|
+
prompt_tokens=prompt_tokens,
|
|
560
|
+
completion_tokens=completion_tokens,
|
|
561
|
+
total_tokens=total_tokens,
|
|
562
|
+
)
|
|
563
|
+
|
|
524
564
|
if include_usage:
|
|
525
565
|
chunk = CompletionChunk(
|
|
526
566
|
id=request_id,
|
|
@@ -586,59 +626,74 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
586
626
|
) -> Dict:
|
|
587
627
|
if not generate_config:
|
|
588
628
|
generate_config = {}
|
|
589
|
-
if self.model_family.
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
"stop_token_ids",
|
|
597
|
-
self.model_family.prompt_style.stop_token_ids.copy(),
|
|
598
|
-
)
|
|
629
|
+
if not generate_config.get("stop") and self.model_family.stop:
|
|
630
|
+
generate_config["stop"] = self.model_family.stop.copy()
|
|
631
|
+
if (
|
|
632
|
+
not generate_config.get("stop_token_ids")
|
|
633
|
+
and self.model_family.stop_token_ids
|
|
634
|
+
):
|
|
635
|
+
generate_config["stop_token_ids"] = self.model_family.stop_token_ids.copy()
|
|
599
636
|
return generate_config
|
|
600
637
|
|
|
638
|
+
@staticmethod
|
|
639
|
+
def is_tool_call_chunk(chunk):
|
|
640
|
+
return chunk["choices"][0]["text"].startswith(QWEN_TOOL_CALL_SYMBOLS[0])
|
|
641
|
+
|
|
642
|
+
async def _async_to_tool_completion_chunks(
|
|
643
|
+
self,
|
|
644
|
+
chunks: AsyncGenerator[CompletionChunk, None],
|
|
645
|
+
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
|
646
|
+
i = 0
|
|
647
|
+
async for chunk in chunks:
|
|
648
|
+
if i == 0:
|
|
649
|
+
yield self._get_first_chat_completion_chunk(chunk)
|
|
650
|
+
# usage
|
|
651
|
+
choices = chunk.get("choices")
|
|
652
|
+
if not choices:
|
|
653
|
+
yield self._get_final_chat_completion_chunk(chunk)
|
|
654
|
+
else:
|
|
655
|
+
if self.is_tool_call_chunk(chunk):
|
|
656
|
+
yield self._tool_calls_completion_chunk(
|
|
657
|
+
self.model_family, self.model_uid, chunk
|
|
658
|
+
)
|
|
659
|
+
else:
|
|
660
|
+
yield self._to_chat_completion_chunk(chunk)
|
|
661
|
+
i += 1
|
|
662
|
+
|
|
601
663
|
async def async_chat(
|
|
602
664
|
self,
|
|
603
|
-
|
|
604
|
-
system_prompt: Optional[str] = None,
|
|
605
|
-
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
665
|
+
messages: List[Dict],
|
|
606
666
|
generate_config: Optional[Dict] = None,
|
|
667
|
+
request_id: Optional[str] = None,
|
|
607
668
|
) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
|
|
608
|
-
assert self.model_family.prompt_style is not None
|
|
609
|
-
prompt_style = self.model_family.prompt_style.copy()
|
|
610
|
-
if system_prompt:
|
|
611
|
-
prompt_style.system_prompt = system_prompt
|
|
612
|
-
chat_history = chat_history or []
|
|
613
669
|
tools = generate_config.pop("tools", []) if generate_config else None
|
|
614
|
-
full_prompt = self.get_prompt(prompt, chat_history, prompt_style, tools=tools)
|
|
615
|
-
|
|
616
|
-
generate_config = self._sanitize_chat_config(generate_config)
|
|
617
|
-
# TODO(codingl2k1): qwen hacky to set stop for function call.
|
|
618
670
|
model_family = self.model_family.model_family or self.model_family.model_name
|
|
671
|
+
full_context_kwargs = {}
|
|
619
672
|
if tools and model_family in QWEN_TOOL_CALL_FAMILY:
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
generate_config["stop"] = list(stop) + ["Observation:"]
|
|
626
|
-
else:
|
|
627
|
-
generate_config["stop"] = "Observation:"
|
|
673
|
+
full_context_kwargs["tools"] = tools
|
|
674
|
+
assert self.model_family.chat_template is not None
|
|
675
|
+
full_prompt = self.get_full_context(
|
|
676
|
+
messages, self.model_family.chat_template, **full_context_kwargs
|
|
677
|
+
)
|
|
628
678
|
|
|
679
|
+
generate_config = self._sanitize_chat_config(generate_config)
|
|
629
680
|
stream = generate_config.get("stream", None)
|
|
630
681
|
|
|
631
682
|
if stream:
|
|
632
|
-
agen = await self.async_generate(
|
|
683
|
+
agen = await self.async_generate(
|
|
684
|
+
full_prompt, generate_config, tools, request_id=request_id
|
|
685
|
+
)
|
|
633
686
|
assert isinstance(agen, AsyncGenerator)
|
|
687
|
+
if tools:
|
|
688
|
+
return self._async_to_tool_completion_chunks(agen)
|
|
634
689
|
return self._async_to_chat_completion_chunks(agen)
|
|
635
690
|
else:
|
|
636
|
-
c = await self.async_generate(
|
|
691
|
+
c = await self.async_generate(
|
|
692
|
+
full_prompt, generate_config, request_id=request_id
|
|
693
|
+
)
|
|
637
694
|
assert not isinstance(c, AsyncGenerator)
|
|
638
695
|
if tools:
|
|
639
|
-
return self._tool_calls_completion(
|
|
640
|
-
self.model_family, self.model_uid, c, tools
|
|
641
|
-
)
|
|
696
|
+
return self._tool_calls_completion(self.model_family, self.model_uid, c)
|
|
642
697
|
return self._to_chat_completion(c)
|
|
643
698
|
|
|
644
699
|
|
|
@@ -666,28 +721,30 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
|
|
|
666
721
|
self,
|
|
667
722
|
generate_config: Optional[Dict] = None,
|
|
668
723
|
) -> Dict:
|
|
724
|
+
from ..utils import get_stop_token_ids_from_config_file
|
|
725
|
+
|
|
669
726
|
if not generate_config:
|
|
670
727
|
generate_config = {}
|
|
671
|
-
if
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
728
|
+
if generate_config.get("stop_token_ids", None) is None:
|
|
729
|
+
stop_token_ids = get_stop_token_ids_from_config_file(self.model_path)
|
|
730
|
+
if stop_token_ids is not None:
|
|
731
|
+
generate_config.setdefault("stop_token_ids", stop_token_ids)
|
|
732
|
+
else:
|
|
733
|
+
if self.model_family.stop_token_ids:
|
|
734
|
+
generate_config.setdefault(
|
|
735
|
+
"stop_token_ids", self.model_family.stop_token_ids.copy()
|
|
736
|
+
)
|
|
677
737
|
return generate_config
|
|
678
738
|
|
|
679
739
|
async def async_chat(
|
|
680
740
|
self,
|
|
681
|
-
|
|
682
|
-
system_prompt: Optional[str] = None,
|
|
683
|
-
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
741
|
+
messages: List[Dict],
|
|
684
742
|
generate_config: Optional[Dict] = None,
|
|
743
|
+
request_id: Optional[str] = None,
|
|
685
744
|
) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
|
|
686
745
|
# only support single image, waiting vllm support multi images
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
chat_history = chat_history or []
|
|
690
|
-
prompt, images = self.get_prompt(prompt, chat_history, prompt_style)
|
|
746
|
+
model_family = self.model_family.model_family or self.model_family.model_name
|
|
747
|
+
prompt, images = self.get_specific_prompt(model_family, messages)
|
|
691
748
|
|
|
692
749
|
if len(images) == 0:
|
|
693
750
|
inputs = {
|
|
@@ -703,10 +760,14 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
|
|
|
703
760
|
stream = generate_config.get("stream", None)
|
|
704
761
|
|
|
705
762
|
if stream:
|
|
706
|
-
agen = await self.async_generate(
|
|
763
|
+
agen = await self.async_generate(
|
|
764
|
+
inputs, generate_config, request_id=request_id
|
|
765
|
+
)
|
|
707
766
|
assert isinstance(agen, AsyncGenerator)
|
|
708
767
|
return self._async_to_chat_completion_chunks(agen)
|
|
709
768
|
else:
|
|
710
|
-
c = await self.async_generate(
|
|
769
|
+
c = await self.async_generate(
|
|
770
|
+
inputs, generate_config, request_id=request_id
|
|
771
|
+
)
|
|
711
772
|
assert not isinstance(c, AsyncGenerator)
|
|
712
773
|
return self._to_chat_completion(c)
|
xinference/model/rerank/core.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
import gc
|
|
16
16
|
import logging
|
|
17
17
|
import os
|
|
18
|
+
import threading
|
|
18
19
|
import uuid
|
|
19
20
|
from collections import defaultdict
|
|
20
21
|
from collections.abc import Sequence
|
|
@@ -22,6 +23,7 @@ from typing import Dict, List, Literal, Optional, Tuple
|
|
|
22
23
|
|
|
23
24
|
import numpy as np
|
|
24
25
|
import torch
|
|
26
|
+
import torch.nn as nn
|
|
25
27
|
|
|
26
28
|
from ...constants import XINFERENCE_CACHE_DIR
|
|
27
29
|
from ...device_utils import empty_cache
|
|
@@ -49,6 +51,7 @@ class RerankModelSpec(CacheableModelSpec):
|
|
|
49
51
|
model_name: str
|
|
50
52
|
language: List[str]
|
|
51
53
|
type: Optional[str] = "unknown"
|
|
54
|
+
max_tokens: Optional[int]
|
|
52
55
|
model_id: str
|
|
53
56
|
model_revision: Optional[str]
|
|
54
57
|
model_hub: str = "huggingface"
|
|
@@ -102,6 +105,30 @@ def generate_rerank_description(model_spec: RerankModelSpec) -> Dict[str, List[D
|
|
|
102
105
|
return res
|
|
103
106
|
|
|
104
107
|
|
|
108
|
+
class _ModelWrapper:
|
|
109
|
+
def __init__(self, module: nn.Module):
|
|
110
|
+
self._module = module
|
|
111
|
+
self._local_data = threading.local()
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def n_tokens(self):
|
|
115
|
+
return getattr(self._local_data, "n_tokens", 0)
|
|
116
|
+
|
|
117
|
+
@n_tokens.setter
|
|
118
|
+
def n_tokens(self, new_n_tokens):
|
|
119
|
+
self._local_data.n_tokens = new_n_tokens
|
|
120
|
+
|
|
121
|
+
def __getattr__(self, attr):
|
|
122
|
+
return getattr(self._module, attr)
|
|
123
|
+
|
|
124
|
+
def __call__(self, **kwargs):
|
|
125
|
+
attention_mask = kwargs["attention_mask"]
|
|
126
|
+
# when batching, the attention mask 1 means there is a token
|
|
127
|
+
# thus we just sum up it to get the total number of tokens
|
|
128
|
+
self.n_tokens += attention_mask.sum().item()
|
|
129
|
+
return self._module(**kwargs)
|
|
130
|
+
|
|
131
|
+
|
|
105
132
|
class RerankModel:
|
|
106
133
|
def __init__(
|
|
107
134
|
self,
|
|
@@ -166,6 +193,7 @@ class RerankModel:
|
|
|
166
193
|
self._model_path,
|
|
167
194
|
device=self._device,
|
|
168
195
|
trust_remote_code=True,
|
|
196
|
+
max_length=getattr(self._model_spec, "max_tokens"),
|
|
169
197
|
**self._model_config,
|
|
170
198
|
)
|
|
171
199
|
if self._use_fp16:
|
|
@@ -189,6 +217,8 @@ class RerankModel:
|
|
|
189
217
|
|
|
190
218
|
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
191
219
|
self._model = FlagReranker(self._model_path, use_fp16=self._use_fp16)
|
|
220
|
+
# Wrap transformers model to record number of tokens
|
|
221
|
+
self._model.model = _ModelWrapper(self._model.model)
|
|
192
222
|
|
|
193
223
|
def rerank(
|
|
194
224
|
self,
|
|
@@ -200,17 +230,14 @@ class RerankModel:
|
|
|
200
230
|
return_len: Optional[bool],
|
|
201
231
|
**kwargs,
|
|
202
232
|
) -> Rerank:
|
|
203
|
-
self._counter += 1
|
|
204
|
-
if self._counter % RERANK_EMPTY_CACHE_COUNT == 0:
|
|
205
|
-
logger.debug("Empty rerank cache.")
|
|
206
|
-
gc.collect()
|
|
207
|
-
empty_cache()
|
|
208
233
|
assert self._model is not None
|
|
209
234
|
if kwargs:
|
|
210
235
|
raise ValueError("rerank hasn't support extra parameter.")
|
|
211
236
|
if max_chunks_per_doc is not None:
|
|
212
237
|
raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
|
|
213
238
|
sentence_combinations = [[query, doc] for doc in documents]
|
|
239
|
+
# reset n tokens
|
|
240
|
+
self._model.model.n_tokens = 0
|
|
214
241
|
if self._model_spec.type == "normal":
|
|
215
242
|
similarity_scores = self._model.predict(
|
|
216
243
|
sentence_combinations, convert_to_numpy=False, convert_to_tensor=True
|
|
@@ -245,9 +272,7 @@ class RerankModel:
|
|
|
245
272
|
for arg in sim_scores_argsort
|
|
246
273
|
]
|
|
247
274
|
if return_len:
|
|
248
|
-
|
|
249
|
-
input_len = sum([len(tokenizer.tokenize(t)) for t in documents])
|
|
250
|
-
|
|
275
|
+
input_len = self._model.model.n_tokens
|
|
251
276
|
# Rerank Model output is just score or documents
|
|
252
277
|
# while return_documents = True
|
|
253
278
|
output_len = input_len
|
|
@@ -265,6 +290,14 @@ class RerankModel:
|
|
|
265
290
|
"warnings": None,
|
|
266
291
|
}
|
|
267
292
|
|
|
293
|
+
del similarity_scores
|
|
294
|
+
# clear cache if possible
|
|
295
|
+
self._counter += 1
|
|
296
|
+
if self._counter % RERANK_EMPTY_CACHE_COUNT == 0:
|
|
297
|
+
logger.debug("Empty rerank cache.")
|
|
298
|
+
gc.collect()
|
|
299
|
+
empty_cache()
|
|
300
|
+
|
|
268
301
|
return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)
|
|
269
302
|
|
|
270
303
|
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
"model_name": "bge-reranker-large",
|
|
4
4
|
"type": "normal",
|
|
5
5
|
"language": ["en", "zh"],
|
|
6
|
+
"max_tokens": 512,
|
|
6
7
|
"model_id": "BAAI/bge-reranker-large",
|
|
7
8
|
"model_revision": "27c9168d479987529781de8474dff94d69beca11"
|
|
8
9
|
},
|
|
@@ -10,6 +11,7 @@
|
|
|
10
11
|
"model_name": "bge-reranker-base",
|
|
11
12
|
"type": "normal",
|
|
12
13
|
"language": ["en", "zh"],
|
|
14
|
+
"max_tokens": 512,
|
|
13
15
|
"model_id": "BAAI/bge-reranker-base",
|
|
14
16
|
"model_revision": "465b4b7ddf2be0a020c8ad6e525b9bb1dbb708ae"
|
|
15
17
|
},
|
|
@@ -17,6 +19,7 @@
|
|
|
17
19
|
"model_name": "bce-reranker-base_v1",
|
|
18
20
|
"type": "normal",
|
|
19
21
|
"language": ["en", "zh"],
|
|
22
|
+
"max_tokens": 512,
|
|
20
23
|
"model_id": "maidalun1020/bce-reranker-base_v1",
|
|
21
24
|
"model_revision": "eaa31a577a0574e87a08959bd229ca14ce1b5496"
|
|
22
25
|
},
|
|
@@ -24,6 +27,7 @@
|
|
|
24
27
|
"model_name": "bge-reranker-v2-m3",
|
|
25
28
|
"type": "normal",
|
|
26
29
|
"language": ["en", "zh", "multilingual"],
|
|
30
|
+
"max_tokens": 8192,
|
|
27
31
|
"model_id": "BAAI/bge-reranker-v2-m3",
|
|
28
32
|
"model_revision": "12e974610ba9083ed95f3edf08d7e899581f4de4"
|
|
29
33
|
},
|
|
@@ -31,6 +35,7 @@
|
|
|
31
35
|
"model_name": "bge-reranker-v2-gemma",
|
|
32
36
|
"type": "LLM-based",
|
|
33
37
|
"language": ["en", "zh", "multilingual"],
|
|
38
|
+
"max_tokens": 8192,
|
|
34
39
|
"model_id": "BAAI/bge-reranker-v2-gemma",
|
|
35
40
|
"model_revision": "1787044f8b6fb740a9de4557c3a12377f84d9e17"
|
|
36
41
|
},
|
|
@@ -38,6 +43,7 @@
|
|
|
38
43
|
"model_name": "bge-reranker-v2-minicpm-layerwise",
|
|
39
44
|
"type": "LLM-based layerwise",
|
|
40
45
|
"language": ["en", "zh", "multilingual"],
|
|
46
|
+
"max_tokens": 2048,
|
|
41
47
|
"model_id": "BAAI/bge-reranker-v2-minicpm-layerwise",
|
|
42
48
|
"model_revision": "47b5332b296c4d8cb6ee2c60502cc62a0d708881"
|
|
43
49
|
},
|
|
@@ -45,6 +51,7 @@
|
|
|
45
51
|
"model_name": "jina-reranker-v2",
|
|
46
52
|
"type": "normal",
|
|
47
53
|
"language": ["en", "zh", "multilingual"],
|
|
54
|
+
"max_tokens": 1024,
|
|
48
55
|
"model_id": "jinaai/jina-reranker-v2-base-multilingual",
|
|
49
56
|
"model_revision": "298e48cada4a9318650d7fbd795f63827f884087"
|
|
50
57
|
}
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
"model_name": "bge-reranker-base",
|
|
4
4
|
"type": "normal",
|
|
5
5
|
"language": ["en", "zh"],
|
|
6
|
+
"max_tokens": 512,
|
|
6
7
|
"model_id": "Xorbits/bge-reranker-base",
|
|
7
8
|
"model_revision": "v0.0.1",
|
|
8
9
|
"model_hub": "modelscope"
|
|
@@ -11,6 +12,7 @@
|
|
|
11
12
|
"model_name": "bge-reranker-large",
|
|
12
13
|
"type": "normal",
|
|
13
14
|
"language": ["en", "zh"],
|
|
15
|
+
"max_tokens": 512,
|
|
14
16
|
"model_id": "Xorbits/bge-reranker-large",
|
|
15
17
|
"model_revision": "v0.0.1",
|
|
16
18
|
"model_hub": "modelscope"
|
|
@@ -19,6 +21,7 @@
|
|
|
19
21
|
"model_name": "bce-reranker-base_v1",
|
|
20
22
|
"type": "normal",
|
|
21
23
|
"language": ["en", "zh"],
|
|
24
|
+
"max_tokens": 512,
|
|
22
25
|
"model_id": "maidalun/bce-reranker-base_v1",
|
|
23
26
|
"model_revision": "v0.0.1",
|
|
24
27
|
"model_hub": "modelscope"
|
|
@@ -26,6 +29,7 @@
|
|
|
26
29
|
{
|
|
27
30
|
"model_name": "bge-reranker-v2-m3",
|
|
28
31
|
"type": "normal",
|
|
32
|
+
"max_tokens": 8192,
|
|
29
33
|
"language": ["en", "zh", "multilingual"],
|
|
30
34
|
"model_id": "AI-ModelScope/bge-reranker-v2-m3",
|
|
31
35
|
"model_hub": "modelscope"
|
|
@@ -34,6 +38,7 @@
|
|
|
34
38
|
"model_name": "bge-reranker-v2-gemma",
|
|
35
39
|
"type": "LLM-based",
|
|
36
40
|
"language": ["en", "zh", "multilingual"],
|
|
41
|
+
"max_tokens": 8192,
|
|
37
42
|
"model_id": "AI-ModelScope/bge-reranker-v2-gemma",
|
|
38
43
|
"model_hub": "modelscope"
|
|
39
44
|
},
|
|
@@ -41,7 +46,8 @@
|
|
|
41
46
|
"model_name": "bge-reranker-v2-minicpm-layerwise",
|
|
42
47
|
"type": "LLM-based layerwise",
|
|
43
48
|
"language": ["en", "zh", "multilingual"],
|
|
44
|
-
"
|
|
49
|
+
"max_tokens": 2048,
|
|
50
|
+
"model_id": "mirror013/bge-reranker-v2-minicpm-layerwise",
|
|
45
51
|
"model_hub": "modelscope"
|
|
46
52
|
}
|
|
47
53
|
]
|
xinference/model/utils.py
CHANGED
|
@@ -11,10 +11,6 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
import functools
|
|
16
|
-
import gc
|
|
17
|
-
import inspect
|
|
18
14
|
import json
|
|
19
15
|
import logging
|
|
20
16
|
import os
|
|
@@ -28,7 +24,7 @@ import numpy as np
|
|
|
28
24
|
import torch
|
|
29
25
|
|
|
30
26
|
from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_ENV_MODEL_SRC
|
|
31
|
-
from ..device_utils import
|
|
27
|
+
from ..device_utils import get_available_device, is_device_available
|
|
32
28
|
from .core import CacheableModelSpec
|
|
33
29
|
|
|
34
30
|
logger = logging.getLogger(__name__)
|
|
@@ -357,32 +353,6 @@ def convert_float_to_int_or_str(model_size: float) -> Union[int, str]:
|
|
|
357
353
|
return str(model_size)
|
|
358
354
|
|
|
359
355
|
|
|
360
|
-
def ensure_cache_cleared(func: Callable):
|
|
361
|
-
assert not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
|
|
362
|
-
func
|
|
363
|
-
)
|
|
364
|
-
if inspect.isgeneratorfunction(func):
|
|
365
|
-
|
|
366
|
-
@functools.wraps(func)
|
|
367
|
-
def inner(*args, **kwargs):
|
|
368
|
-
for obj in func(*args, **kwargs):
|
|
369
|
-
yield obj
|
|
370
|
-
gc.collect()
|
|
371
|
-
empty_cache()
|
|
372
|
-
|
|
373
|
-
else:
|
|
374
|
-
|
|
375
|
-
@functools.wraps(func)
|
|
376
|
-
def inner(*args, **kwargs):
|
|
377
|
-
try:
|
|
378
|
-
return func(*args, **kwargs)
|
|
379
|
-
finally:
|
|
380
|
-
gc.collect()
|
|
381
|
-
empty_cache()
|
|
382
|
-
|
|
383
|
-
return inner
|
|
384
|
-
|
|
385
|
-
|
|
386
356
|
def set_all_random_seed(seed: int):
|
|
387
357
|
random.seed(seed)
|
|
388
358
|
np.random.seed(seed)
|