xinference 1.3.1.post1__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +1 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +4 -0
- xinference/core/chat_interface.py +1 -1
- xinference/core/model.py +23 -3
- xinference/core/supervisor.py +6 -0
- xinference/core/worker.py +54 -11
- xinference/model/llm/__init__.py +7 -2
- xinference/model/llm/core.py +1 -0
- xinference/model/llm/llama_cpp/core.py +50 -15
- xinference/model/llm/llm_family.json +388 -13
- xinference/model/llm/llm_family_modelscope.json +373 -14
- xinference/model/llm/mlx/core.py +15 -11
- xinference/model/llm/reasoning_parser.py +17 -9
- xinference/model/llm/sglang/core.py +112 -12
- xinference/model/llm/transformers/core.py +4 -2
- xinference/model/llm/transformers/deepseek_vl.py +1 -1
- xinference/model/llm/transformers/deepseek_vl2.py +287 -0
- xinference/model/llm/transformers/gemma3.py +185 -0
- xinference/model/llm/transformers/intern_vl.py +0 -2
- xinference/model/llm/utils.py +62 -42
- xinference/model/llm/vllm/core.py +157 -11
- xinference/model/llm/vllm/distributed_executor.py +314 -0
- xinference/model/rerank/core.py +16 -11
- xinference/thirdparty/deepseek_vl2/__init__.py +31 -0
- xinference/thirdparty/deepseek_vl2/models/__init__.py +26 -0
- xinference/thirdparty/deepseek_vl2/models/configuration_deepseek.py +210 -0
- xinference/thirdparty/deepseek_vl2/models/conversation.py +310 -0
- xinference/thirdparty/deepseek_vl2/models/modeling_deepseek.py +1975 -0
- xinference/thirdparty/deepseek_vl2/models/modeling_deepseek_vl_v2.py +697 -0
- xinference/thirdparty/deepseek_vl2/models/processing_deepseek_vl_v2.py +675 -0
- xinference/thirdparty/deepseek_vl2/models/siglip_vit.py +661 -0
- xinference/thirdparty/deepseek_vl2/serve/__init__.py +0 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/__init__.py +0 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/gradio_utils.py +83 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/overwrites.py +81 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/presets.py +115 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/utils.py +333 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/simsun.ttc +0 -0
- xinference/thirdparty/deepseek_vl2/serve/inference.py +197 -0
- xinference/thirdparty/deepseek_vl2/utils/__init__.py +18 -0
- xinference/thirdparty/deepseek_vl2/utils/io.py +80 -0
- xinference/types.py +2 -2
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.b494ae7e.css +2 -0
- xinference/web/ui/build/static/css/main.b494ae7e.css.map +1 -0
- xinference/web/ui/build/static/js/main.5ca4eea1.js +3 -0
- xinference/web/ui/build/static/js/main.5ca4eea1.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/cc97b49285d7717c63374766c789141a4329a04582ab32756d7e0e614d4c5c7f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f199e8173f6409a5802ed44acb95f218388131136504b2e9132129e150c92f9a.json +1 -0
- xinference/web/ui/src/locales/en.json +2 -2
- xinference/web/ui/src/locales/zh.json +1 -1
- {xinference-1.3.1.post1.dist-info → xinference-1.4.1.dist-info}/METADATA +4 -4
- {xinference-1.3.1.post1.dist-info → xinference-1.4.1.dist-info}/RECORD +67 -41
- xinference/web/ui/build/static/css/main.f8177338.css +0 -2
- xinference/web/ui/build/static/css/main.f8177338.css.map +0 -1
- xinference/web/ui/build/static/js/main.55b70cb7.js +0 -3
- xinference/web/ui/build/static/js/main.55b70cb7.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2deac8d5636974533e3714f34e94fc754f9153a07c6ee11e72846cb8eae47e4b.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e23d476fcbf6fd69c8986bf82133d257d28aa8fc9a5cab231d81c1c75c58cd99.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e7a8c37fda8725cab69c7ef8c627060bd7fc806adc67e00fe628ba148cb86d7f.json +0 -1
- /xinference/web/ui/build/static/js/{main.55b70cb7.js.LICENSE.txt → main.5ca4eea1.js.LICENSE.txt} +0 -0
- {xinference-1.3.1.post1.dist-info → xinference-1.4.1.dist-info}/LICENSE +0 -0
- {xinference-1.3.1.post1.dist-info → xinference-1.4.1.dist-info}/WHEEL +0 -0
- {xinference-1.3.1.post1.dist-info → xinference-1.4.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.3.1.post1.dist-info → xinference-1.4.1.dist-info}/top_level.txt +0 -0
|
@@ -25,6 +25,7 @@ from xoscar.utils import get_next_port
|
|
|
25
25
|
from ....types import (
|
|
26
26
|
ChatCompletion,
|
|
27
27
|
ChatCompletionChunk,
|
|
28
|
+
ChatCompletionMessage,
|
|
28
29
|
Completion,
|
|
29
30
|
CompletionChoice,
|
|
30
31
|
CompletionChunk,
|
|
@@ -94,6 +95,7 @@ SGLANG_SUPPORTED_CHAT_MODELS = [
|
|
|
94
95
|
"mixtral-instruct-v0.1",
|
|
95
96
|
"gemma-it",
|
|
96
97
|
"gemma-2-it",
|
|
98
|
+
"gemma-3-1b-it",
|
|
97
99
|
"deepseek-v2.5",
|
|
98
100
|
"deepseek-v2-chat",
|
|
99
101
|
"deepseek-v2-chat-0628",
|
|
@@ -106,6 +108,12 @@ SGLANG_SUPPORTED_CHAT_MODELS = [
|
|
|
106
108
|
"deepseek-v3",
|
|
107
109
|
"deepseek-r1",
|
|
108
110
|
]
|
|
111
|
+
SGLANG_SUPPORTED_VISION_MODEL_LIST = [
|
|
112
|
+
"qwen2.5-vl-instruct",
|
|
113
|
+
"gemma-3-it",
|
|
114
|
+
"MiniCPM-V",
|
|
115
|
+
"llama-3.2-vision-instruct",
|
|
116
|
+
]
|
|
109
117
|
|
|
110
118
|
|
|
111
119
|
class SGLANGModel(LLM):
|
|
@@ -301,10 +309,6 @@ class SGLANGModel(LLM):
|
|
|
301
309
|
if llm_spec.model_format == "pytorch":
|
|
302
310
|
if quantization != "none" and not (quantization is None):
|
|
303
311
|
return False
|
|
304
|
-
if llm_spec.model_format in ["gptq", "awq"]:
|
|
305
|
-
# Currently, only 4-bit weight quantization is supported for GPTQ, but got 8 bits.
|
|
306
|
-
if "4" not in quantization:
|
|
307
|
-
return False
|
|
308
312
|
if isinstance(llm_family, CustomLLMFamilyV1):
|
|
309
313
|
if llm_family.model_family not in SGLANG_SUPPORTED_MODELS:
|
|
310
314
|
return False
|
|
@@ -369,12 +373,18 @@ class SGLANGModel(LLM):
|
|
|
369
373
|
sampling_params.pop("lora_name", None)
|
|
370
374
|
return sampling_params
|
|
371
375
|
|
|
372
|
-
async def _stream_generate(
|
|
376
|
+
async def _stream_generate(
|
|
377
|
+
self,
|
|
378
|
+
prompt: str,
|
|
379
|
+
image_data: Optional[Union[List[str], str]] = None,
|
|
380
|
+
**sampling_params,
|
|
381
|
+
):
|
|
373
382
|
import aiohttp
|
|
374
383
|
|
|
375
384
|
sampling_params = self._filter_sampling_params(sampling_params)
|
|
376
385
|
json_data = {
|
|
377
386
|
"text": prompt,
|
|
387
|
+
"image_data": image_data,
|
|
378
388
|
"sampling_params": sampling_params,
|
|
379
389
|
"stream": True,
|
|
380
390
|
}
|
|
@@ -402,12 +412,18 @@ class SGLANGModel(LLM):
|
|
|
402
412
|
if need_stop:
|
|
403
413
|
break
|
|
404
414
|
|
|
405
|
-
async def _non_stream_generate(
|
|
415
|
+
async def _non_stream_generate(
|
|
416
|
+
self,
|
|
417
|
+
prompt: str,
|
|
418
|
+
image_data: Optional[Union[List[str], str]] = None,
|
|
419
|
+
**sampling_params,
|
|
420
|
+
) -> dict:
|
|
406
421
|
import aiohttp
|
|
407
422
|
|
|
408
423
|
sampling_params = self._filter_sampling_params(sampling_params)
|
|
409
424
|
json_data = {
|
|
410
425
|
"text": prompt,
|
|
426
|
+
"image_data": image_data,
|
|
411
427
|
"sampling_params": sampling_params,
|
|
412
428
|
}
|
|
413
429
|
async with aiohttp.ClientSession(trust_env=True) as session:
|
|
@@ -419,6 +435,7 @@ class SGLANGModel(LLM):
|
|
|
419
435
|
async def async_generate(
|
|
420
436
|
self,
|
|
421
437
|
prompt: str,
|
|
438
|
+
image_data: Optional[Union[List[str], str]] = None,
|
|
422
439
|
generate_config: Optional[SGLANGGenerateConfig] = None,
|
|
423
440
|
request_id: Optional[str] = None,
|
|
424
441
|
) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]:
|
|
@@ -437,7 +454,9 @@ class SGLANGModel(LLM):
|
|
|
437
454
|
if not request_id:
|
|
438
455
|
request_id = str(uuid.uuid1())
|
|
439
456
|
if not stream:
|
|
440
|
-
state = await self._non_stream_generate(
|
|
457
|
+
state = await self._non_stream_generate(
|
|
458
|
+
prompt, image_data, **sanitized_generate_config
|
|
459
|
+
)
|
|
441
460
|
return self._convert_state_to_completion(
|
|
442
461
|
request_id,
|
|
443
462
|
model=self.model_uid,
|
|
@@ -450,7 +469,7 @@ class SGLANGModel(LLM):
|
|
|
450
469
|
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
|
451
470
|
finish_reason = None
|
|
452
471
|
async for meta_info, out in self._stream_generate(
|
|
453
|
-
prompt, **sanitized_generate_config
|
|
472
|
+
prompt, image_data, **sanitized_generate_config
|
|
454
473
|
):
|
|
455
474
|
chunk = self._convert_state_to_completion_chunk(
|
|
456
475
|
request_id, self.model_uid, output_text=out
|
|
@@ -513,10 +532,6 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
|
|
|
513
532
|
if llm_spec.model_format == "pytorch":
|
|
514
533
|
if quantization != "none" and not (quantization is None):
|
|
515
534
|
return False
|
|
516
|
-
if llm_spec.model_format in ["gptq", "awq"]:
|
|
517
|
-
# Currently, only 4-bit weight quantization is supported for GPTQ, but got 8 bits.
|
|
518
|
-
if "4" not in quantization:
|
|
519
|
-
return False
|
|
520
535
|
if isinstance(llm_family, CustomLLMFamilyV1):
|
|
521
536
|
if llm_family.model_family not in SGLANG_SUPPORTED_CHAT_MODELS:
|
|
522
537
|
return False
|
|
@@ -557,3 +572,88 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
|
|
|
557
572
|
c = await self.async_generate(full_prompt, generate_config) # type: ignore
|
|
558
573
|
assert not isinstance(c, AsyncGenerator)
|
|
559
574
|
return self._to_chat_completion(c, self.reasoning_parser)
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
class SGLANGVisionModel(SGLANGModel, ChatModelMixin):
|
|
578
|
+
@classmethod
|
|
579
|
+
def match(
|
|
580
|
+
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
581
|
+
) -> bool:
|
|
582
|
+
if not cls._has_cuda_device():
|
|
583
|
+
return False
|
|
584
|
+
if not cls._is_linux():
|
|
585
|
+
return False
|
|
586
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
|
|
587
|
+
return False
|
|
588
|
+
if llm_spec.model_format == "pytorch":
|
|
589
|
+
if quantization != "none" and not (quantization is None):
|
|
590
|
+
return False
|
|
591
|
+
if isinstance(llm_family, CustomLLMFamilyV1):
|
|
592
|
+
if llm_family.model_family not in SGLANG_SUPPORTED_VISION_MODEL_LIST:
|
|
593
|
+
return False
|
|
594
|
+
else:
|
|
595
|
+
if llm_family.model_name not in SGLANG_SUPPORTED_VISION_MODEL_LIST:
|
|
596
|
+
return False
|
|
597
|
+
if "vision" not in llm_family.model_ability:
|
|
598
|
+
return False
|
|
599
|
+
return SGLANG_INSTALLED
|
|
600
|
+
|
|
601
|
+
def _sanitize_chat_config(
|
|
602
|
+
self,
|
|
603
|
+
generate_config: Optional[Dict] = None,
|
|
604
|
+
) -> Dict:
|
|
605
|
+
if not generate_config:
|
|
606
|
+
generate_config = {}
|
|
607
|
+
if self.model_family.stop:
|
|
608
|
+
if (not generate_config.get("stop")) and self.model_family.stop:
|
|
609
|
+
generate_config["stop"] = self.model_family.stop.copy()
|
|
610
|
+
return generate_config
|
|
611
|
+
|
|
612
|
+
async def async_chat(
|
|
613
|
+
self,
|
|
614
|
+
messages: List[ChatCompletionMessage], # type: ignore
|
|
615
|
+
generate_config: Optional[Dict] = None,
|
|
616
|
+
request_id: Optional[str] = None,
|
|
617
|
+
) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
|
|
618
|
+
import base64
|
|
619
|
+
from io import BytesIO
|
|
620
|
+
|
|
621
|
+
from PIL import Image
|
|
622
|
+
from qwen_vl_utils import process_vision_info
|
|
623
|
+
|
|
624
|
+
messages = self._transform_messages(messages)
|
|
625
|
+
|
|
626
|
+
chat_template: str = (
|
|
627
|
+
self.model_family.chat_template if self.model_family.chat_template else ""
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
prompt = self.get_full_context(messages, chat_template)
|
|
631
|
+
images, video_inputs = process_vision_info(messages)
|
|
632
|
+
if video_inputs:
|
|
633
|
+
raise ValueError("Not support video input now.")
|
|
634
|
+
|
|
635
|
+
base64_images: Optional[List[str]] = None
|
|
636
|
+
if images:
|
|
637
|
+
base64_images = []
|
|
638
|
+
for image in images:
|
|
639
|
+
if isinstance(image, Image.Image):
|
|
640
|
+
buffered = BytesIO()
|
|
641
|
+
image.save(buffered, format="JPEG", quality=100)
|
|
642
|
+
base64_images.append(base64.b64encode(buffered.getvalue()).decode())
|
|
643
|
+
elif isinstance(image, str):
|
|
644
|
+
base64_images.append(image)
|
|
645
|
+
else:
|
|
646
|
+
raise ValueError(
|
|
647
|
+
f"Unsupported image type: {type(image)}, only support PIL.Image and base64 string"
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
generate_config = self._sanitize_chat_config(generate_config)
|
|
651
|
+
stream = generate_config.get("stream", None)
|
|
652
|
+
if stream:
|
|
653
|
+
agen = await self.async_generate(prompt, base64_images, generate_config) # type: ignore
|
|
654
|
+
assert isinstance(agen, AsyncGenerator)
|
|
655
|
+
return self._async_to_chat_completion_chunks(agen, self.reasoning_parser)
|
|
656
|
+
else:
|
|
657
|
+
c = await self.async_generate(prompt, base64_images, generate_config) # type: ignore
|
|
658
|
+
assert not isinstance(c, AsyncGenerator)
|
|
659
|
+
return self._to_chat_completion(c, self.reasoning_parser)
|
|
@@ -79,6 +79,9 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
|
|
|
79
79
|
"glm-edge-v",
|
|
80
80
|
"QvQ-72B-Preview",
|
|
81
81
|
"cogagent",
|
|
82
|
+
"gemma-3-1b-it",
|
|
83
|
+
"gemma-3-it",
|
|
84
|
+
"deepseek-vl2",
|
|
82
85
|
]
|
|
83
86
|
|
|
84
87
|
|
|
@@ -691,10 +694,9 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
691
694
|
tools
|
|
692
695
|
and model_family in QWEN_TOOL_CALL_FAMILY
|
|
693
696
|
or model_family in LLAMA3_TOOL_CALL_FAMILY
|
|
697
|
+
or model_family in DEEPSEEK_TOOL_CALL_FAMILY
|
|
694
698
|
):
|
|
695
699
|
full_context_kwargs["tools"] = tools
|
|
696
|
-
elif tools and model_family in DEEPSEEK_TOOL_CALL_FAMILY:
|
|
697
|
-
self._tools_to_messages_for_deepseek(messages, tools)
|
|
698
700
|
assert self.model_family.chat_template is not None
|
|
699
701
|
full_prompt = self.get_full_context(
|
|
700
702
|
messages,
|
|
@@ -46,7 +46,7 @@ class DeepSeekVLChatModel(PytorchChatModel):
|
|
|
46
46
|
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
47
47
|
) -> bool:
|
|
48
48
|
llm_family = model_family.model_family or model_family.model_name
|
|
49
|
-
if "deepseek-vl"
|
|
49
|
+
if "deepseek-vl" == llm_family.lower():
|
|
50
50
|
return True
|
|
51
51
|
return False
|
|
52
52
|
|
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import base64
|
|
15
|
+
import logging
|
|
16
|
+
import os.path
|
|
17
|
+
import tempfile
|
|
18
|
+
import uuid
|
|
19
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
20
|
+
from io import BytesIO
|
|
21
|
+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
|
22
|
+
|
|
23
|
+
import requests
|
|
24
|
+
import torch
|
|
25
|
+
|
|
26
|
+
from ....model.utils import select_device
|
|
27
|
+
from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
|
|
28
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
29
|
+
from ..utils import generate_chat_completion, generate_completion_chunk
|
|
30
|
+
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
31
|
+
from .utils import cache_clean
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class DeepSeekVL2ChatModel(PytorchChatModel):
|
|
37
|
+
def __init__(self, *args, **kwargs):
|
|
38
|
+
super().__init__(*args, **kwargs)
|
|
39
|
+
self._tokenizer = None
|
|
40
|
+
self._model = None
|
|
41
|
+
self._vl_chat_processor = None
|
|
42
|
+
self._type = None
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def match(
|
|
46
|
+
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
47
|
+
) -> bool:
|
|
48
|
+
llm_family = model_family.model_family or model_family.model_name
|
|
49
|
+
if "deepseek-vl2" == llm_family.lower():
|
|
50
|
+
return True
|
|
51
|
+
return False
|
|
52
|
+
|
|
53
|
+
def load(self):
|
|
54
|
+
from transformers import AutoModelForCausalLM
|
|
55
|
+
|
|
56
|
+
from ....thirdparty.deepseek_vl2.models import (
|
|
57
|
+
DeepseekVLV2ForCausalLM,
|
|
58
|
+
DeepseekVLV2Processor,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self._device = self._pytorch_model_config.get("device", "auto")
|
|
62
|
+
self._device = select_device(self._device)
|
|
63
|
+
self._type = torch.float16 if self._device == "mps" else torch.bfloat16
|
|
64
|
+
|
|
65
|
+
# specify the path to the model
|
|
66
|
+
self._vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained( # type: ignore
|
|
67
|
+
self.model_path
|
|
68
|
+
)
|
|
69
|
+
self._tokenizer = self._vl_chat_processor.tokenizer
|
|
70
|
+
|
|
71
|
+
vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained( # type: ignore
|
|
72
|
+
self.model_path, trust_remote_code=True, device_map=self._device
|
|
73
|
+
)
|
|
74
|
+
self._model = vl_gpt.to(torch.bfloat16).cuda().eval()
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def _message_content_to_deepseek(content) -> Tuple[str, List[str]]:
|
|
78
|
+
def _ensure_url(_url):
|
|
79
|
+
if _url.startswith("data:"):
|
|
80
|
+
logging.info("Parse url by base64 decoder.")
|
|
81
|
+
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
82
|
+
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
83
|
+
_type, data = _url.split(";")
|
|
84
|
+
_, ext = _type.split("/")
|
|
85
|
+
data = data[len("base64,") :]
|
|
86
|
+
data = base64.b64decode(data.encode("utf-8"))
|
|
87
|
+
|
|
88
|
+
with tempfile.NamedTemporaryFile(suffix=f".{ext}", delete=False) as f:
|
|
89
|
+
f.write(data)
|
|
90
|
+
logging.info("Dump base64 data to %s", f.name)
|
|
91
|
+
return f.name
|
|
92
|
+
else:
|
|
93
|
+
if len(_url) > 2048:
|
|
94
|
+
raise Exception(f"Image url is too long, {len(_url)} > 2048.")
|
|
95
|
+
|
|
96
|
+
return _url
|
|
97
|
+
|
|
98
|
+
def _download(_images):
|
|
99
|
+
local_images = []
|
|
100
|
+
|
|
101
|
+
# To make requests.get works
|
|
102
|
+
headers = {
|
|
103
|
+
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"
|
|
104
|
+
}
|
|
105
|
+
with ThreadPoolExecutor() as executor:
|
|
106
|
+
for url in images:
|
|
107
|
+
try:
|
|
108
|
+
if os.path.exists(url):
|
|
109
|
+
local_images.append(url)
|
|
110
|
+
continue
|
|
111
|
+
except Exception as e:
|
|
112
|
+
logger.debug("Image is remote: %s, e: %s", url, e)
|
|
113
|
+
pass
|
|
114
|
+
# Append a placeholder
|
|
115
|
+
local_images.append(None)
|
|
116
|
+
|
|
117
|
+
def _fill_placeholder(_url, _index):
|
|
118
|
+
response = requests.get(url, headers=headers)
|
|
119
|
+
local_images[_index] = BytesIO(response.content)
|
|
120
|
+
|
|
121
|
+
executor.submit(_fill_placeholder, url, len(local_images) - 1)
|
|
122
|
+
return local_images
|
|
123
|
+
|
|
124
|
+
if not isinstance(content, str):
|
|
125
|
+
# TODO(codingl2k1): Optimize _ensure_url
|
|
126
|
+
|
|
127
|
+
images = []
|
|
128
|
+
new_content = []
|
|
129
|
+
for c in content:
|
|
130
|
+
c_type = c.get("type")
|
|
131
|
+
if c_type == "image_url":
|
|
132
|
+
images.append(_ensure_url(c["image_url"]["url"]))
|
|
133
|
+
elif c_type == "text":
|
|
134
|
+
new_content.append(c["text"])
|
|
135
|
+
if images:
|
|
136
|
+
new_content.insert(0, "<image_placeholder>")
|
|
137
|
+
images = _download(images)
|
|
138
|
+
return "".join(new_content), images
|
|
139
|
+
return content, []
|
|
140
|
+
|
|
141
|
+
@cache_clean
|
|
142
|
+
def chat(
|
|
143
|
+
self,
|
|
144
|
+
messages: List[Dict],
|
|
145
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
146
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
147
|
+
if not generate_config:
|
|
148
|
+
generate_config = {}
|
|
149
|
+
|
|
150
|
+
stream = generate_config.get("stream", False)
|
|
151
|
+
stream_options = generate_config.pop("stream_options", None)
|
|
152
|
+
include_usage = (
|
|
153
|
+
stream_options["include_usage"]
|
|
154
|
+
if isinstance(stream_options, dict)
|
|
155
|
+
else False
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
prompt = ""
|
|
159
|
+
deepseek_messages = []
|
|
160
|
+
for i, message in enumerate(messages):
|
|
161
|
+
role = message["role"]
|
|
162
|
+
content = message["content"]
|
|
163
|
+
if role == "user":
|
|
164
|
+
if isinstance(content, str):
|
|
165
|
+
deepseek_messages.append(
|
|
166
|
+
{
|
|
167
|
+
"role": "<|User|>",
|
|
168
|
+
"content": "<image>\n<|ref|>" + content + "<|/ref|>",
|
|
169
|
+
}
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
content, images = self._message_content_to_deepseek(content)
|
|
173
|
+
msg: Dict[str, Any] = {
|
|
174
|
+
"role": "<|User|>",
|
|
175
|
+
"content": "<image>\n<|ref|>" + content + "<|/ref|>",
|
|
176
|
+
}
|
|
177
|
+
if images:
|
|
178
|
+
msg["images"] = images
|
|
179
|
+
deepseek_messages.append(msg)
|
|
180
|
+
deepseek_messages.append({"role": "<|Assistant|>", "content": ""})
|
|
181
|
+
if i == len(messages) - 1:
|
|
182
|
+
prompt = "<image>\n<|ref|>" + content + "<|/ref|>"
|
|
183
|
+
elif role == "assistant":
|
|
184
|
+
deepseek_messages.append({"role": "<|Assistant|>", "content": content})
|
|
185
|
+
else:
|
|
186
|
+
logger.error(
|
|
187
|
+
f"Unexpected message in messages: role: {role}, message: {message}"
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
from ....thirdparty.deepseek_vl2.utils.io import load_pil_images
|
|
191
|
+
|
|
192
|
+
# load images and prepare for inputs
|
|
193
|
+
pil_images = load_pil_images(deepseek_messages)
|
|
194
|
+
prepare_inputs = self._vl_chat_processor(
|
|
195
|
+
conversations=deepseek_messages,
|
|
196
|
+
images=pil_images,
|
|
197
|
+
force_batchify=True,
|
|
198
|
+
system_prompt="",
|
|
199
|
+
).to(self._model.device, self._model.dtype)
|
|
200
|
+
|
|
201
|
+
# run image encoder to get the image embeddings
|
|
202
|
+
inputs_embeds = self._model.prepare_inputs_embeds(**prepare_inputs)
|
|
203
|
+
|
|
204
|
+
max_new_tokens = generate_config.get("max_tokens", 512)
|
|
205
|
+
conversation = self._vl_chat_processor.new_chat_template()
|
|
206
|
+
stop_str = conversation.sep2
|
|
207
|
+
|
|
208
|
+
streamer = self._model.language.generate(
|
|
209
|
+
inputs_embeds=inputs_embeds,
|
|
210
|
+
attention_mask=prepare_inputs.attention_mask,
|
|
211
|
+
pad_token_id=self._tokenizer.eos_token_id,
|
|
212
|
+
bos_token_id=self._tokenizer.bos_token_id,
|
|
213
|
+
eos_token_id=self._tokenizer.eos_token_id,
|
|
214
|
+
max_new_tokens=max_new_tokens,
|
|
215
|
+
do_sample=False,
|
|
216
|
+
use_cache=True,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
if stream:
|
|
220
|
+
it = self._generate_stream(streamer, stop_str, include_usage, prompt)
|
|
221
|
+
return self._to_chat_completion_chunks(it)
|
|
222
|
+
else:
|
|
223
|
+
return self._generate(streamer, stop_str)
|
|
224
|
+
|
|
225
|
+
def _generate(self, streamer, stop_str) -> ChatCompletion:
|
|
226
|
+
generated_text = ""
|
|
227
|
+
|
|
228
|
+
for new_text in streamer:
|
|
229
|
+
if isinstance(new_text, torch.Tensor):
|
|
230
|
+
new_text = self._tokenizer.decode(
|
|
231
|
+
new_text.cpu().tolist(), skip_special_tokens=True
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
if new_text.endswith(stop_str):
|
|
235
|
+
new_text = new_text[: -len(stop_str)]
|
|
236
|
+
|
|
237
|
+
generated_text += new_text
|
|
238
|
+
|
|
239
|
+
return generate_chat_completion(self.model_uid, generated_text)
|
|
240
|
+
|
|
241
|
+
def _generate_stream(
|
|
242
|
+
self, streamer, stop_str, include_usage, prompt
|
|
243
|
+
) -> Iterator[CompletionChunk]:
|
|
244
|
+
completion_id = str(uuid.uuid1())
|
|
245
|
+
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
|
246
|
+
input_ids = self._tokenizer(prompt).input_ids
|
|
247
|
+
prompt_tokens = len(input_ids)
|
|
248
|
+
for i, new_text in enumerate(streamer):
|
|
249
|
+
if new_text.endswith(stop_str):
|
|
250
|
+
new_text = new_text[: -len(stop_str)]
|
|
251
|
+
completion_tokens = i
|
|
252
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
253
|
+
yield generate_completion_chunk(
|
|
254
|
+
chunk_text=new_text,
|
|
255
|
+
finish_reason=None,
|
|
256
|
+
chunk_id=completion_id,
|
|
257
|
+
model_uid=self.model_uid,
|
|
258
|
+
prompt_tokens=prompt_tokens,
|
|
259
|
+
completion_tokens=completion_tokens,
|
|
260
|
+
total_tokens=total_tokens,
|
|
261
|
+
has_choice=True,
|
|
262
|
+
has_content=True,
|
|
263
|
+
)
|
|
264
|
+
yield generate_completion_chunk(
|
|
265
|
+
chunk_text=None,
|
|
266
|
+
finish_reason="stop",
|
|
267
|
+
chunk_id=completion_id,
|
|
268
|
+
model_uid=self.model_uid,
|
|
269
|
+
prompt_tokens=prompt_tokens,
|
|
270
|
+
completion_tokens=completion_tokens,
|
|
271
|
+
total_tokens=total_tokens,
|
|
272
|
+
has_choice=True,
|
|
273
|
+
has_content=False,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
if include_usage:
|
|
277
|
+
yield generate_completion_chunk(
|
|
278
|
+
chunk_text=None,
|
|
279
|
+
finish_reason=None,
|
|
280
|
+
chunk_id=completion_id,
|
|
281
|
+
model_uid=self.model_uid,
|
|
282
|
+
prompt_tokens=prompt_tokens,
|
|
283
|
+
completion_tokens=completion_tokens,
|
|
284
|
+
total_tokens=total_tokens,
|
|
285
|
+
has_choice=False,
|
|
286
|
+
has_content=False,
|
|
287
|
+
)
|