xinference 0.14.4.post1__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +51 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +5 -39
- xinference/client/restful/restful_client.py +3 -24
- xinference/conftest.py +1 -1
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +1 -1
- xinference/core/chat_interface.py +8 -14
- xinference/core/event.py +1 -1
- xinference/core/model.py +82 -31
- xinference/core/scheduler.py +37 -37
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +11 -10
- xinference/core/utils.py +80 -22
- xinference/core/worker.py +17 -16
- xinference/deploy/cmdline.py +8 -16
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/deploy/utils.py +1 -1
- xinference/deploy/worker.py +1 -1
- xinference/model/audio/cosyvoice.py +86 -41
- xinference/model/embedding/core.py +52 -31
- xinference/model/image/stable_diffusion/core.py +18 -1
- xinference/model/llm/__init__.py +21 -11
- xinference/model/llm/llama_cpp/core.py +16 -33
- xinference/model/llm/llm_family.json +619 -1297
- xinference/model/llm/llm_family.py +31 -52
- xinference/model/llm/llm_family_csghub.json +18 -35
- xinference/model/llm/llm_family_modelscope.json +573 -1119
- xinference/model/llm/lmdeploy/core.py +56 -88
- xinference/model/llm/mlx/core.py +46 -69
- xinference/model/llm/sglang/core.py +33 -18
- xinference/model/llm/transformers/chatglm.py +167 -305
- xinference/model/llm/transformers/cogvlm2.py +36 -63
- xinference/model/llm/transformers/cogvlm2_video.py +33 -223
- xinference/model/llm/transformers/core.py +49 -50
- xinference/model/llm/transformers/deepseek_vl.py +53 -96
- xinference/model/llm/transformers/glm4v.py +55 -111
- xinference/model/llm/transformers/intern_vl.py +39 -70
- xinference/model/llm/transformers/internlm2.py +32 -54
- xinference/model/llm/transformers/minicpmv25.py +22 -55
- xinference/model/llm/transformers/minicpmv26.py +158 -68
- xinference/model/llm/transformers/omnilmm.py +5 -28
- xinference/model/llm/transformers/qwen2_vl.py +208 -0
- xinference/model/llm/transformers/qwen_vl.py +34 -86
- xinference/model/llm/transformers/utils.py +32 -38
- xinference/model/llm/transformers/yi_vl.py +32 -72
- xinference/model/llm/utils.py +195 -489
- xinference/model/llm/vllm/core.py +153 -100
- xinference/model/rerank/core.py +41 -8
- xinference/model/rerank/model_spec.json +7 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -1
- xinference/model/utils.py +1 -31
- xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
- xinference/thirdparty/cosyvoice/cli/model.py +139 -26
- xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
- xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
- xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
- xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
- xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
- xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
- xinference/thirdparty/cosyvoice/utils/common.py +36 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
- xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +34 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
- xinference/thirdparty/matcha/VERSION +1 -0
- xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
- xinference/thirdparty/matcha/hifigan/README.md +101 -0
- xinference/thirdparty/omnilmm/LICENSE +201 -0
- xinference/thirdparty/whisper/__init__.py +156 -0
- xinference/thirdparty/whisper/__main__.py +3 -0
- xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
- xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
- xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
- xinference/thirdparty/whisper/audio.py +157 -0
- xinference/thirdparty/whisper/decoding.py +826 -0
- xinference/thirdparty/whisper/model.py +314 -0
- xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
- xinference/thirdparty/whisper/normalizers/basic.py +76 -0
- xinference/thirdparty/whisper/normalizers/english.json +1741 -0
- xinference/thirdparty/whisper/normalizers/english.py +550 -0
- xinference/thirdparty/whisper/timing.py +386 -0
- xinference/thirdparty/whisper/tokenizer.py +395 -0
- xinference/thirdparty/whisper/transcribe.py +605 -0
- xinference/thirdparty/whisper/triton_ops.py +109 -0
- xinference/thirdparty/whisper/utils.py +316 -0
- xinference/thirdparty/whisper/version.py +1 -0
- xinference/types.py +7 -49
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.4bafd904.css → main.632e9148.css} +2 -2
- xinference/web/ui/build/static/css/main.632e9148.css.map +1 -0
- xinference/web/ui/build/static/js/main.9cfafbd6.js +3 -0
- xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.9cfafbd6.js.LICENSE.txt} +2 -0
- xinference/web/ui/build/static/js/main.9cfafbd6.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/01d6d198156bacbd436c51435edbd4b2cacd47a79db929105eba30f74b67d48d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/59eb25f514afcc4fefd1b309d192b2455f1e0aec68a9de598ca4b2333fe2c774.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +37 -0
- xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
- xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
- xinference/web/ui/node_modules/nunjucks/package.json +112 -0
- xinference/web/ui/package-lock.json +38 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/METADATA +8 -8
- {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/RECORD +141 -87
- xinference/model/llm/transformers/llama_2.py +0 -108
- xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
- xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
- xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
- {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/LICENSE +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/WHEEL +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/top_level.txt +0 -0
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
|
-
import json
|
|
17
16
|
import logging
|
|
18
17
|
import multiprocessing
|
|
19
18
|
import os
|
|
@@ -24,9 +23,9 @@ from typing import (
|
|
|
24
23
|
Any,
|
|
25
24
|
AsyncGenerator,
|
|
26
25
|
Dict,
|
|
27
|
-
Iterable,
|
|
28
26
|
List,
|
|
29
27
|
Optional,
|
|
28
|
+
Tuple,
|
|
30
29
|
TypedDict,
|
|
31
30
|
Union,
|
|
32
31
|
)
|
|
@@ -34,18 +33,20 @@ from typing import (
|
|
|
34
33
|
from ....types import (
|
|
35
34
|
ChatCompletion,
|
|
36
35
|
ChatCompletionChunk,
|
|
37
|
-
ChatCompletionMessage,
|
|
38
36
|
Completion,
|
|
39
37
|
CompletionChoice,
|
|
40
38
|
CompletionChunk,
|
|
41
39
|
CompletionUsage,
|
|
42
40
|
LoRA,
|
|
43
|
-
ToolCallFunction,
|
|
44
|
-
ToolCalls,
|
|
45
41
|
)
|
|
46
42
|
from .. import LLM, LLMFamilyV1, LLMSpecV1
|
|
47
43
|
from ..llm_family import CustomLLMFamilyV1
|
|
48
|
-
from ..utils import
|
|
44
|
+
from ..utils import (
|
|
45
|
+
QWEN_TOOL_CALL_FAMILY,
|
|
46
|
+
QWEN_TOOL_CALL_SYMBOLS,
|
|
47
|
+
ChatModelMixin,
|
|
48
|
+
generate_completion_chunk,
|
|
49
|
+
)
|
|
49
50
|
|
|
50
51
|
logger = logging.getLogger(__name__)
|
|
51
52
|
|
|
@@ -363,23 +364,28 @@ class VLLMModel(LLM):
|
|
|
363
364
|
@staticmethod
|
|
364
365
|
def _convert_request_output_to_completion_chunk(
|
|
365
366
|
request_id: str, model: str, request_output: "RequestOutput"
|
|
366
|
-
) -> CompletionChunk:
|
|
367
|
+
) -> Tuple[CompletionChunk, Optional[str]]:
|
|
367
368
|
choices: List[CompletionChoice] = []
|
|
369
|
+
finish_reason = None
|
|
368
370
|
for output in request_output.outputs:
|
|
369
371
|
choices.append(
|
|
370
372
|
CompletionChoice(
|
|
371
373
|
text=output.text,
|
|
372
374
|
index=output.index,
|
|
373
375
|
logprobs=None, # TODO: support logprobs.
|
|
374
|
-
finish_reason=
|
|
376
|
+
finish_reason=None,
|
|
375
377
|
)
|
|
376
378
|
)
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
379
|
+
finish_reason = output.finish_reason
|
|
380
|
+
return (
|
|
381
|
+
CompletionChunk(
|
|
382
|
+
id=request_id,
|
|
383
|
+
object="text_completion",
|
|
384
|
+
created=int(time.time()),
|
|
385
|
+
model=model,
|
|
386
|
+
choices=choices,
|
|
387
|
+
),
|
|
388
|
+
finish_reason,
|
|
383
389
|
)
|
|
384
390
|
|
|
385
391
|
@staticmethod
|
|
@@ -420,6 +426,7 @@ class VLLMModel(LLM):
|
|
|
420
426
|
prompt: Union[str, Dict[str, Any]],
|
|
421
427
|
generate_config: Optional[Dict] = None,
|
|
422
428
|
tools: object = False,
|
|
429
|
+
request_id: Optional[str] = None,
|
|
423
430
|
) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]:
|
|
424
431
|
try:
|
|
425
432
|
from vllm.sampling_params import SamplingParams
|
|
@@ -454,7 +461,8 @@ class VLLMModel(LLM):
|
|
|
454
461
|
else False
|
|
455
462
|
)
|
|
456
463
|
sampling_params = SamplingParams(**sanitized_generate_config)
|
|
457
|
-
|
|
464
|
+
if not request_id:
|
|
465
|
+
request_id = str(uuid.uuid1())
|
|
458
466
|
|
|
459
467
|
assert self._engine is not None
|
|
460
468
|
results_generator = self._engine.generate(
|
|
@@ -463,10 +471,14 @@ class VLLMModel(LLM):
|
|
|
463
471
|
|
|
464
472
|
async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
|
|
465
473
|
previous_texts = [""] * sanitized_generate_config["n"]
|
|
466
|
-
tools_token_filter = ChatModelMixin._tools_token_filter(self.model_family)
|
|
467
474
|
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
|
475
|
+
complete_response = ""
|
|
476
|
+
match_tool_call_tmp_results = []
|
|
477
|
+
is_match_tool_call = False
|
|
478
|
+
chunk = None
|
|
479
|
+
finish_reason = None
|
|
468
480
|
async for _request_output in results_generator:
|
|
469
|
-
chunk = self._convert_request_output_to_completion_chunk(
|
|
481
|
+
chunk, finish_reason = self._convert_request_output_to_completion_chunk(
|
|
470
482
|
request_id=request_id,
|
|
471
483
|
model=self.model_uid,
|
|
472
484
|
request_output=_request_output,
|
|
@@ -476,40 +488,8 @@ class VLLMModel(LLM):
|
|
|
476
488
|
delta = choice["text"][len(previous_texts[i]) :]
|
|
477
489
|
previous_texts[i] = choice["text"]
|
|
478
490
|
choice["text"] = delta
|
|
491
|
+
complete_response += delta
|
|
479
492
|
|
|
480
|
-
if tools:
|
|
481
|
-
# only handle the first choice
|
|
482
|
-
choice = chunk["choices"][0]
|
|
483
|
-
if choice["finish_reason"] is not None:
|
|
484
|
-
# use previous text for evaluation temporarily
|
|
485
|
-
choice_delta = choice["text"]
|
|
486
|
-
choice["text"] = previous_texts[0]
|
|
487
|
-
_content, func, args = ChatModelMixin._eval_tool_arguments(
|
|
488
|
-
self.model_family, chunk, tools
|
|
489
|
-
)
|
|
490
|
-
choice["text"] = tools_token_filter(
|
|
491
|
-
tokens=previous_texts[0], delta=choice_delta
|
|
492
|
-
)
|
|
493
|
-
if func is not None:
|
|
494
|
-
choice["text"] = None
|
|
495
|
-
choice["finish_reason"] = "tool_calls"
|
|
496
|
-
choice["tool_calls"] = [
|
|
497
|
-
ToolCalls(
|
|
498
|
-
id=str(uuid.uuid4()),
|
|
499
|
-
type="function",
|
|
500
|
-
function=ToolCallFunction(
|
|
501
|
-
name=func,
|
|
502
|
-
arguments=json.dumps(args, ensure_ascii=False),
|
|
503
|
-
),
|
|
504
|
-
)
|
|
505
|
-
]
|
|
506
|
-
else:
|
|
507
|
-
# use a filter function to skip Qwen's react thought process
|
|
508
|
-
choice["text"] = tools_token_filter(
|
|
509
|
-
tokens=previous_texts[0], delta=choice["text"]
|
|
510
|
-
)
|
|
511
|
-
if not choice["text"]:
|
|
512
|
-
continue
|
|
513
493
|
prompt_tokens = len(_request_output.prompt_token_ids)
|
|
514
494
|
completion_tokens = sum(
|
|
515
495
|
len(output.token_ids) for output in _request_output.outputs
|
|
@@ -520,7 +500,59 @@ class VLLMModel(LLM):
|
|
|
520
500
|
completion_tokens=completion_tokens,
|
|
521
501
|
total_tokens=total_tokens,
|
|
522
502
|
)
|
|
503
|
+
|
|
504
|
+
if tools:
|
|
505
|
+
"""
|
|
506
|
+
The qwen2 tool call returns format like this:
|
|
507
|
+
<tool_call>
|
|
508
|
+
{...}
|
|
509
|
+
</tool_call>
|
|
510
|
+
Here is to match this.
|
|
511
|
+
"""
|
|
512
|
+
if (len(QWEN_TOOL_CALL_SYMBOLS[0]) > len(complete_response)) and (
|
|
513
|
+
not QWEN_TOOL_CALL_SYMBOLS[0].startswith(complete_response)
|
|
514
|
+
):
|
|
515
|
+
for c in match_tool_call_tmp_results:
|
|
516
|
+
yield c
|
|
517
|
+
match_tool_call_tmp_results.clear()
|
|
518
|
+
yield chunk
|
|
519
|
+
elif (len(QWEN_TOOL_CALL_SYMBOLS[0]) > len(complete_response)) and (
|
|
520
|
+
QWEN_TOOL_CALL_SYMBOLS[0].startswith(complete_response)
|
|
521
|
+
):
|
|
522
|
+
match_tool_call_tmp_results.append(chunk)
|
|
523
|
+
else:
|
|
524
|
+
assert len(QWEN_TOOL_CALL_SYMBOLS[0]) <= len(complete_response)
|
|
525
|
+
if not is_match_tool_call and complete_response.startswith(
|
|
526
|
+
QWEN_TOOL_CALL_SYMBOLS[0]
|
|
527
|
+
):
|
|
528
|
+
is_match_tool_call = True
|
|
529
|
+
match_tool_call_tmp_results.clear()
|
|
530
|
+
|
|
531
|
+
if not is_match_tool_call:
|
|
532
|
+
for c in match_tool_call_tmp_results:
|
|
533
|
+
yield c
|
|
534
|
+
match_tool_call_tmp_results.clear()
|
|
535
|
+
yield chunk
|
|
536
|
+
else:
|
|
537
|
+
chunk["choices"][0]["text"] = complete_response
|
|
538
|
+
else:
|
|
539
|
+
yield chunk
|
|
540
|
+
|
|
541
|
+
if is_match_tool_call:
|
|
542
|
+
assert chunk is not None
|
|
523
543
|
yield chunk
|
|
544
|
+
|
|
545
|
+
# match OpenAI API stream
|
|
546
|
+
yield generate_completion_chunk(
|
|
547
|
+
chunk_text="",
|
|
548
|
+
finish_reason=finish_reason,
|
|
549
|
+
chunk_id=request_id,
|
|
550
|
+
model_uid=self.model_uid,
|
|
551
|
+
prompt_tokens=prompt_tokens,
|
|
552
|
+
completion_tokens=completion_tokens,
|
|
553
|
+
total_tokens=total_tokens,
|
|
554
|
+
)
|
|
555
|
+
|
|
524
556
|
if include_usage:
|
|
525
557
|
chunk = CompletionChunk(
|
|
526
558
|
id=request_id,
|
|
@@ -586,59 +618,74 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
586
618
|
) -> Dict:
|
|
587
619
|
if not generate_config:
|
|
588
620
|
generate_config = {}
|
|
589
|
-
if self.model_family.
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
"stop_token_ids",
|
|
597
|
-
self.model_family.prompt_style.stop_token_ids.copy(),
|
|
598
|
-
)
|
|
621
|
+
if not generate_config.get("stop") and self.model_family.stop:
|
|
622
|
+
generate_config["stop"] = self.model_family.stop.copy()
|
|
623
|
+
if (
|
|
624
|
+
not generate_config.get("stop_token_ids")
|
|
625
|
+
and self.model_family.stop_token_ids
|
|
626
|
+
):
|
|
627
|
+
generate_config["stop_token_ids"] = self.model_family.stop_token_ids.copy()
|
|
599
628
|
return generate_config
|
|
600
629
|
|
|
630
|
+
@staticmethod
|
|
631
|
+
def is_tool_call_chunk(chunk):
|
|
632
|
+
return chunk["choices"][0]["text"].startswith(QWEN_TOOL_CALL_SYMBOLS[0])
|
|
633
|
+
|
|
634
|
+
async def _async_to_tool_completion_chunks(
|
|
635
|
+
self,
|
|
636
|
+
chunks: AsyncGenerator[CompletionChunk, None],
|
|
637
|
+
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
|
638
|
+
i = 0
|
|
639
|
+
async for chunk in chunks:
|
|
640
|
+
if i == 0:
|
|
641
|
+
yield self._get_first_chat_completion_chunk(chunk)
|
|
642
|
+
# usage
|
|
643
|
+
choices = chunk.get("choices")
|
|
644
|
+
if not choices:
|
|
645
|
+
yield self._get_final_chat_completion_chunk(chunk)
|
|
646
|
+
else:
|
|
647
|
+
if self.is_tool_call_chunk(chunk):
|
|
648
|
+
yield self._tool_calls_completion_chunk(
|
|
649
|
+
self.model_family, self.model_uid, chunk
|
|
650
|
+
)
|
|
651
|
+
else:
|
|
652
|
+
yield self._to_chat_completion_chunk(chunk)
|
|
653
|
+
i += 1
|
|
654
|
+
|
|
601
655
|
async def async_chat(
|
|
602
656
|
self,
|
|
603
|
-
|
|
604
|
-
system_prompt: Optional[str] = None,
|
|
605
|
-
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
657
|
+
messages: List[Dict],
|
|
606
658
|
generate_config: Optional[Dict] = None,
|
|
659
|
+
request_id: Optional[str] = None,
|
|
607
660
|
) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
|
|
608
|
-
assert self.model_family.prompt_style is not None
|
|
609
|
-
prompt_style = self.model_family.prompt_style.copy()
|
|
610
|
-
if system_prompt:
|
|
611
|
-
prompt_style.system_prompt = system_prompt
|
|
612
|
-
chat_history = chat_history or []
|
|
613
661
|
tools = generate_config.pop("tools", []) if generate_config else None
|
|
614
|
-
full_prompt = self.get_prompt(prompt, chat_history, prompt_style, tools=tools)
|
|
615
|
-
|
|
616
|
-
generate_config = self._sanitize_chat_config(generate_config)
|
|
617
|
-
# TODO(codingl2k1): qwen hacky to set stop for function call.
|
|
618
662
|
model_family = self.model_family.model_family or self.model_family.model_name
|
|
663
|
+
full_context_kwargs = {}
|
|
619
664
|
if tools and model_family in QWEN_TOOL_CALL_FAMILY:
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
generate_config["stop"] = list(stop) + ["Observation:"]
|
|
626
|
-
else:
|
|
627
|
-
generate_config["stop"] = "Observation:"
|
|
665
|
+
full_context_kwargs["tools"] = tools
|
|
666
|
+
assert self.model_family.chat_template is not None
|
|
667
|
+
full_prompt = self.get_full_context(
|
|
668
|
+
messages, self.model_family.chat_template, **full_context_kwargs
|
|
669
|
+
)
|
|
628
670
|
|
|
671
|
+
generate_config = self._sanitize_chat_config(generate_config)
|
|
629
672
|
stream = generate_config.get("stream", None)
|
|
630
673
|
|
|
631
674
|
if stream:
|
|
632
|
-
agen = await self.async_generate(
|
|
675
|
+
agen = await self.async_generate(
|
|
676
|
+
full_prompt, generate_config, tools, request_id=request_id
|
|
677
|
+
)
|
|
633
678
|
assert isinstance(agen, AsyncGenerator)
|
|
679
|
+
if tools:
|
|
680
|
+
return self._async_to_tool_completion_chunks(agen)
|
|
634
681
|
return self._async_to_chat_completion_chunks(agen)
|
|
635
682
|
else:
|
|
636
|
-
c = await self.async_generate(
|
|
683
|
+
c = await self.async_generate(
|
|
684
|
+
full_prompt, generate_config, request_id=request_id
|
|
685
|
+
)
|
|
637
686
|
assert not isinstance(c, AsyncGenerator)
|
|
638
687
|
if tools:
|
|
639
|
-
return self._tool_calls_completion(
|
|
640
|
-
self.model_family, self.model_uid, c, tools
|
|
641
|
-
)
|
|
688
|
+
return self._tool_calls_completion(self.model_family, self.model_uid, c)
|
|
642
689
|
return self._to_chat_completion(c)
|
|
643
690
|
|
|
644
691
|
|
|
@@ -666,28 +713,30 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
|
|
|
666
713
|
self,
|
|
667
714
|
generate_config: Optional[Dict] = None,
|
|
668
715
|
) -> Dict:
|
|
716
|
+
from ..utils import get_stop_token_ids_from_config_file
|
|
717
|
+
|
|
669
718
|
if not generate_config:
|
|
670
719
|
generate_config = {}
|
|
671
|
-
if
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
720
|
+
if generate_config.get("stop_token_ids", None) is None:
|
|
721
|
+
stop_token_ids = get_stop_token_ids_from_config_file(self.model_path)
|
|
722
|
+
if stop_token_ids is not None:
|
|
723
|
+
generate_config.setdefault("stop_token_ids", stop_token_ids)
|
|
724
|
+
else:
|
|
725
|
+
if self.model_family.stop_token_ids:
|
|
726
|
+
generate_config.setdefault(
|
|
727
|
+
"stop_token_ids", self.model_family.stop_token_ids.copy()
|
|
728
|
+
)
|
|
677
729
|
return generate_config
|
|
678
730
|
|
|
679
731
|
async def async_chat(
|
|
680
732
|
self,
|
|
681
|
-
|
|
682
|
-
system_prompt: Optional[str] = None,
|
|
683
|
-
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
733
|
+
messages: List[Dict],
|
|
684
734
|
generate_config: Optional[Dict] = None,
|
|
735
|
+
request_id: Optional[str] = None,
|
|
685
736
|
) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
|
|
686
737
|
# only support single image, waiting vllm support multi images
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
chat_history = chat_history or []
|
|
690
|
-
prompt, images = self.get_prompt(prompt, chat_history, prompt_style)
|
|
738
|
+
model_family = self.model_family.model_family or self.model_family.model_name
|
|
739
|
+
prompt, images = self.get_specific_prompt(model_family, messages)
|
|
691
740
|
|
|
692
741
|
if len(images) == 0:
|
|
693
742
|
inputs = {
|
|
@@ -703,10 +752,14 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
|
|
|
703
752
|
stream = generate_config.get("stream", None)
|
|
704
753
|
|
|
705
754
|
if stream:
|
|
706
|
-
agen = await self.async_generate(
|
|
755
|
+
agen = await self.async_generate(
|
|
756
|
+
inputs, generate_config, request_id=request_id
|
|
757
|
+
)
|
|
707
758
|
assert isinstance(agen, AsyncGenerator)
|
|
708
759
|
return self._async_to_chat_completion_chunks(agen)
|
|
709
760
|
else:
|
|
710
|
-
c = await self.async_generate(
|
|
761
|
+
c = await self.async_generate(
|
|
762
|
+
inputs, generate_config, request_id=request_id
|
|
763
|
+
)
|
|
711
764
|
assert not isinstance(c, AsyncGenerator)
|
|
712
765
|
return self._to_chat_completion(c)
|
xinference/model/rerank/core.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
import gc
|
|
16
16
|
import logging
|
|
17
17
|
import os
|
|
18
|
+
import threading
|
|
18
19
|
import uuid
|
|
19
20
|
from collections import defaultdict
|
|
20
21
|
from collections.abc import Sequence
|
|
@@ -22,6 +23,7 @@ from typing import Dict, List, Literal, Optional, Tuple
|
|
|
22
23
|
|
|
23
24
|
import numpy as np
|
|
24
25
|
import torch
|
|
26
|
+
import torch.nn as nn
|
|
25
27
|
|
|
26
28
|
from ...constants import XINFERENCE_CACHE_DIR
|
|
27
29
|
from ...device_utils import empty_cache
|
|
@@ -49,6 +51,7 @@ class RerankModelSpec(CacheableModelSpec):
|
|
|
49
51
|
model_name: str
|
|
50
52
|
language: List[str]
|
|
51
53
|
type: Optional[str] = "unknown"
|
|
54
|
+
max_tokens: Optional[int]
|
|
52
55
|
model_id: str
|
|
53
56
|
model_revision: Optional[str]
|
|
54
57
|
model_hub: str = "huggingface"
|
|
@@ -102,6 +105,30 @@ def generate_rerank_description(model_spec: RerankModelSpec) -> Dict[str, List[D
|
|
|
102
105
|
return res
|
|
103
106
|
|
|
104
107
|
|
|
108
|
+
class _ModelWrapper:
|
|
109
|
+
def __init__(self, module: nn.Module):
|
|
110
|
+
self._module = module
|
|
111
|
+
self._local_data = threading.local()
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def n_tokens(self):
|
|
115
|
+
return getattr(self._local_data, "n_tokens", 0)
|
|
116
|
+
|
|
117
|
+
@n_tokens.setter
|
|
118
|
+
def n_tokens(self, new_n_tokens):
|
|
119
|
+
self._local_data.n_tokens = new_n_tokens
|
|
120
|
+
|
|
121
|
+
def __getattr__(self, attr):
|
|
122
|
+
return getattr(self._module, attr)
|
|
123
|
+
|
|
124
|
+
def __call__(self, **kwargs):
|
|
125
|
+
attention_mask = kwargs["attention_mask"]
|
|
126
|
+
# when batching, the attention mask 1 means there is a token
|
|
127
|
+
# thus we just sum up it to get the total number of tokens
|
|
128
|
+
self.n_tokens += attention_mask.sum().item()
|
|
129
|
+
return self._module(**kwargs)
|
|
130
|
+
|
|
131
|
+
|
|
105
132
|
class RerankModel:
|
|
106
133
|
def __init__(
|
|
107
134
|
self,
|
|
@@ -166,6 +193,7 @@ class RerankModel:
|
|
|
166
193
|
self._model_path,
|
|
167
194
|
device=self._device,
|
|
168
195
|
trust_remote_code=True,
|
|
196
|
+
max_length=getattr(self._model_spec, "max_tokens"),
|
|
169
197
|
**self._model_config,
|
|
170
198
|
)
|
|
171
199
|
if self._use_fp16:
|
|
@@ -189,6 +217,8 @@ class RerankModel:
|
|
|
189
217
|
|
|
190
218
|
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
191
219
|
self._model = FlagReranker(self._model_path, use_fp16=self._use_fp16)
|
|
220
|
+
# Wrap transformers model to record number of tokens
|
|
221
|
+
self._model.model = _ModelWrapper(self._model.model)
|
|
192
222
|
|
|
193
223
|
def rerank(
|
|
194
224
|
self,
|
|
@@ -200,17 +230,14 @@ class RerankModel:
|
|
|
200
230
|
return_len: Optional[bool],
|
|
201
231
|
**kwargs,
|
|
202
232
|
) -> Rerank:
|
|
203
|
-
self._counter += 1
|
|
204
|
-
if self._counter % RERANK_EMPTY_CACHE_COUNT == 0:
|
|
205
|
-
logger.debug("Empty rerank cache.")
|
|
206
|
-
gc.collect()
|
|
207
|
-
empty_cache()
|
|
208
233
|
assert self._model is not None
|
|
209
234
|
if kwargs:
|
|
210
235
|
raise ValueError("rerank hasn't support extra parameter.")
|
|
211
236
|
if max_chunks_per_doc is not None:
|
|
212
237
|
raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
|
|
213
238
|
sentence_combinations = [[query, doc] for doc in documents]
|
|
239
|
+
# reset n tokens
|
|
240
|
+
self._model.model.n_tokens = 0
|
|
214
241
|
if self._model_spec.type == "normal":
|
|
215
242
|
similarity_scores = self._model.predict(
|
|
216
243
|
sentence_combinations, convert_to_numpy=False, convert_to_tensor=True
|
|
@@ -245,9 +272,7 @@ class RerankModel:
|
|
|
245
272
|
for arg in sim_scores_argsort
|
|
246
273
|
]
|
|
247
274
|
if return_len:
|
|
248
|
-
|
|
249
|
-
input_len = sum([len(tokenizer.tokenize(t)) for t in documents])
|
|
250
|
-
|
|
275
|
+
input_len = self._model.model.n_tokens
|
|
251
276
|
# Rerank Model output is just score or documents
|
|
252
277
|
# while return_documents = True
|
|
253
278
|
output_len = input_len
|
|
@@ -265,6 +290,14 @@ class RerankModel:
|
|
|
265
290
|
"warnings": None,
|
|
266
291
|
}
|
|
267
292
|
|
|
293
|
+
del similarity_scores
|
|
294
|
+
# clear cache if possible
|
|
295
|
+
self._counter += 1
|
|
296
|
+
if self._counter % RERANK_EMPTY_CACHE_COUNT == 0:
|
|
297
|
+
logger.debug("Empty rerank cache.")
|
|
298
|
+
gc.collect()
|
|
299
|
+
empty_cache()
|
|
300
|
+
|
|
268
301
|
return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)
|
|
269
302
|
|
|
270
303
|
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
"model_name": "bge-reranker-large",
|
|
4
4
|
"type": "normal",
|
|
5
5
|
"language": ["en", "zh"],
|
|
6
|
+
"max_tokens": 512,
|
|
6
7
|
"model_id": "BAAI/bge-reranker-large",
|
|
7
8
|
"model_revision": "27c9168d479987529781de8474dff94d69beca11"
|
|
8
9
|
},
|
|
@@ -10,6 +11,7 @@
|
|
|
10
11
|
"model_name": "bge-reranker-base",
|
|
11
12
|
"type": "normal",
|
|
12
13
|
"language": ["en", "zh"],
|
|
14
|
+
"max_tokens": 512,
|
|
13
15
|
"model_id": "BAAI/bge-reranker-base",
|
|
14
16
|
"model_revision": "465b4b7ddf2be0a020c8ad6e525b9bb1dbb708ae"
|
|
15
17
|
},
|
|
@@ -17,6 +19,7 @@
|
|
|
17
19
|
"model_name": "bce-reranker-base_v1",
|
|
18
20
|
"type": "normal",
|
|
19
21
|
"language": ["en", "zh"],
|
|
22
|
+
"max_tokens": 512,
|
|
20
23
|
"model_id": "maidalun1020/bce-reranker-base_v1",
|
|
21
24
|
"model_revision": "eaa31a577a0574e87a08959bd229ca14ce1b5496"
|
|
22
25
|
},
|
|
@@ -24,6 +27,7 @@
|
|
|
24
27
|
"model_name": "bge-reranker-v2-m3",
|
|
25
28
|
"type": "normal",
|
|
26
29
|
"language": ["en", "zh", "multilingual"],
|
|
30
|
+
"max_tokens": 8192,
|
|
27
31
|
"model_id": "BAAI/bge-reranker-v2-m3",
|
|
28
32
|
"model_revision": "12e974610ba9083ed95f3edf08d7e899581f4de4"
|
|
29
33
|
},
|
|
@@ -31,6 +35,7 @@
|
|
|
31
35
|
"model_name": "bge-reranker-v2-gemma",
|
|
32
36
|
"type": "LLM-based",
|
|
33
37
|
"language": ["en", "zh", "multilingual"],
|
|
38
|
+
"max_tokens": 8192,
|
|
34
39
|
"model_id": "BAAI/bge-reranker-v2-gemma",
|
|
35
40
|
"model_revision": "1787044f8b6fb740a9de4557c3a12377f84d9e17"
|
|
36
41
|
},
|
|
@@ -38,6 +43,7 @@
|
|
|
38
43
|
"model_name": "bge-reranker-v2-minicpm-layerwise",
|
|
39
44
|
"type": "LLM-based layerwise",
|
|
40
45
|
"language": ["en", "zh", "multilingual"],
|
|
46
|
+
"max_tokens": 2048,
|
|
41
47
|
"model_id": "BAAI/bge-reranker-v2-minicpm-layerwise",
|
|
42
48
|
"model_revision": "47b5332b296c4d8cb6ee2c60502cc62a0d708881"
|
|
43
49
|
},
|
|
@@ -45,6 +51,7 @@
|
|
|
45
51
|
"model_name": "jina-reranker-v2",
|
|
46
52
|
"type": "normal",
|
|
47
53
|
"language": ["en", "zh", "multilingual"],
|
|
54
|
+
"max_tokens": 1024,
|
|
48
55
|
"model_id": "jinaai/jina-reranker-v2-base-multilingual",
|
|
49
56
|
"model_revision": "298e48cada4a9318650d7fbd795f63827f884087"
|
|
50
57
|
}
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
"model_name": "bge-reranker-base",
|
|
4
4
|
"type": "normal",
|
|
5
5
|
"language": ["en", "zh"],
|
|
6
|
+
"max_tokens": 512,
|
|
6
7
|
"model_id": "Xorbits/bge-reranker-base",
|
|
7
8
|
"model_revision": "v0.0.1",
|
|
8
9
|
"model_hub": "modelscope"
|
|
@@ -11,6 +12,7 @@
|
|
|
11
12
|
"model_name": "bge-reranker-large",
|
|
12
13
|
"type": "normal",
|
|
13
14
|
"language": ["en", "zh"],
|
|
15
|
+
"max_tokens": 512,
|
|
14
16
|
"model_id": "Xorbits/bge-reranker-large",
|
|
15
17
|
"model_revision": "v0.0.1",
|
|
16
18
|
"model_hub": "modelscope"
|
|
@@ -19,6 +21,7 @@
|
|
|
19
21
|
"model_name": "bce-reranker-base_v1",
|
|
20
22
|
"type": "normal",
|
|
21
23
|
"language": ["en", "zh"],
|
|
24
|
+
"max_tokens": 512,
|
|
22
25
|
"model_id": "maidalun/bce-reranker-base_v1",
|
|
23
26
|
"model_revision": "v0.0.1",
|
|
24
27
|
"model_hub": "modelscope"
|
|
@@ -26,6 +29,7 @@
|
|
|
26
29
|
{
|
|
27
30
|
"model_name": "bge-reranker-v2-m3",
|
|
28
31
|
"type": "normal",
|
|
32
|
+
"max_tokens": 8192,
|
|
29
33
|
"language": ["en", "zh", "multilingual"],
|
|
30
34
|
"model_id": "AI-ModelScope/bge-reranker-v2-m3",
|
|
31
35
|
"model_hub": "modelscope"
|
|
@@ -34,6 +38,7 @@
|
|
|
34
38
|
"model_name": "bge-reranker-v2-gemma",
|
|
35
39
|
"type": "LLM-based",
|
|
36
40
|
"language": ["en", "zh", "multilingual"],
|
|
41
|
+
"max_tokens": 8192,
|
|
37
42
|
"model_id": "AI-ModelScope/bge-reranker-v2-gemma",
|
|
38
43
|
"model_hub": "modelscope"
|
|
39
44
|
},
|
|
@@ -41,7 +46,8 @@
|
|
|
41
46
|
"model_name": "bge-reranker-v2-minicpm-layerwise",
|
|
42
47
|
"type": "LLM-based layerwise",
|
|
43
48
|
"language": ["en", "zh", "multilingual"],
|
|
44
|
-
"
|
|
49
|
+
"max_tokens": 2048,
|
|
50
|
+
"model_id": "mirror013/bge-reranker-v2-minicpm-layerwise",
|
|
45
51
|
"model_hub": "modelscope"
|
|
46
52
|
}
|
|
47
53
|
]
|
xinference/model/utils.py
CHANGED
|
@@ -11,10 +11,6 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
import functools
|
|
16
|
-
import gc
|
|
17
|
-
import inspect
|
|
18
14
|
import json
|
|
19
15
|
import logging
|
|
20
16
|
import os
|
|
@@ -28,7 +24,7 @@ import numpy as np
|
|
|
28
24
|
import torch
|
|
29
25
|
|
|
30
26
|
from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_ENV_MODEL_SRC
|
|
31
|
-
from ..device_utils import
|
|
27
|
+
from ..device_utils import get_available_device, is_device_available
|
|
32
28
|
from .core import CacheableModelSpec
|
|
33
29
|
|
|
34
30
|
logger = logging.getLogger(__name__)
|
|
@@ -357,32 +353,6 @@ def convert_float_to_int_or_str(model_size: float) -> Union[int, str]:
|
|
|
357
353
|
return str(model_size)
|
|
358
354
|
|
|
359
355
|
|
|
360
|
-
def ensure_cache_cleared(func: Callable):
|
|
361
|
-
assert not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
|
|
362
|
-
func
|
|
363
|
-
)
|
|
364
|
-
if inspect.isgeneratorfunction(func):
|
|
365
|
-
|
|
366
|
-
@functools.wraps(func)
|
|
367
|
-
def inner(*args, **kwargs):
|
|
368
|
-
for obj in func(*args, **kwargs):
|
|
369
|
-
yield obj
|
|
370
|
-
gc.collect()
|
|
371
|
-
empty_cache()
|
|
372
|
-
|
|
373
|
-
else:
|
|
374
|
-
|
|
375
|
-
@functools.wraps(func)
|
|
376
|
-
def inner(*args, **kwargs):
|
|
377
|
-
try:
|
|
378
|
-
return func(*args, **kwargs)
|
|
379
|
-
finally:
|
|
380
|
-
gc.collect()
|
|
381
|
-
empty_cache()
|
|
382
|
-
|
|
383
|
-
return inner
|
|
384
|
-
|
|
385
|
-
|
|
386
356
|
def set_all_random_seed(seed: int):
|
|
387
357
|
random.seed(seed)
|
|
388
358
|
np.random.seed(seed)
|