xinference 1.6.0.post1__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +79 -2
- xinference/client/restful/restful_client.py +65 -3
- xinference/conftest.py +0 -7
- xinference/core/media_interface.py +132 -8
- xinference/core/model.py +44 -6
- xinference/core/scheduler.py +1 -10
- xinference/core/supervisor.py +8 -17
- xinference/core/worker.py +5 -27
- xinference/deploy/cmdline.py +6 -2
- xinference/model/audio/chattts.py +24 -39
- xinference/model/audio/cosyvoice.py +18 -30
- xinference/model/audio/funasr.py +42 -0
- xinference/model/audio/model_spec.json +71 -1
- xinference/model/audio/model_spec_modelscope.json +76 -2
- xinference/model/audio/utils.py +75 -0
- xinference/model/core.py +1 -0
- xinference/model/embedding/__init__.py +74 -18
- xinference/model/embedding/core.py +98 -589
- xinference/model/embedding/embed_family.py +133 -0
- xinference/{thirdparty/omnilmm/train → model/embedding/flag}/__init__.py +1 -1
- xinference/model/embedding/flag/core.py +282 -0
- xinference/model/embedding/model_spec.json +24 -0
- xinference/model/embedding/model_spec_modelscope.json +24 -0
- xinference/model/embedding/sentence_transformers/__init__.py +13 -0
- xinference/model/embedding/sentence_transformers/core.py +399 -0
- xinference/model/embedding/vllm/core.py +95 -0
- xinference/model/image/model_spec.json +30 -3
- xinference/model/image/model_spec_modelscope.json +41 -2
- xinference/model/image/stable_diffusion/core.py +144 -53
- xinference/model/llm/__init__.py +6 -54
- xinference/model/llm/core.py +19 -5
- xinference/model/llm/llama_cpp/core.py +59 -3
- xinference/model/llm/llama_cpp/memory.py +457 -0
- xinference/model/llm/llm_family.json +247 -402
- xinference/model/llm/llm_family.py +88 -16
- xinference/model/llm/llm_family_modelscope.json +260 -421
- xinference/model/llm/llm_family_openmind_hub.json +0 -34
- xinference/model/llm/sglang/core.py +8 -0
- xinference/model/llm/transformers/__init__.py +27 -6
- xinference/model/llm/transformers/chatglm.py +4 -2
- xinference/model/llm/transformers/core.py +49 -28
- xinference/model/llm/transformers/deepseek_v2.py +6 -49
- xinference/model/llm/transformers/gemma3.py +119 -164
- xinference/model/llm/transformers/multimodal/__init__.py +13 -0
- xinference/model/llm/transformers/{cogagent.py → multimodal/cogagent.py} +58 -95
- xinference/model/llm/transformers/multimodal/core.py +205 -0
- xinference/model/llm/transformers/{deepseek_vl2.py → multimodal/deepseek_vl2.py} +59 -120
- xinference/model/llm/transformers/multimodal/gemma3.py +117 -0
- xinference/model/llm/transformers/{glm4v.py → multimodal/glm4v.py} +57 -93
- xinference/model/llm/transformers/multimodal/intern_vl.py +412 -0
- xinference/model/llm/transformers/{minicpmv26.py → multimodal/minicpmv26.py} +55 -102
- xinference/model/llm/transformers/{ovis2.py → multimodal/ovis2.py} +114 -175
- xinference/model/llm/transformers/{qwen-omni.py → multimodal/qwen-omni.py} +82 -167
- xinference/model/llm/transformers/multimodal/qwen2_audio.py +131 -0
- xinference/model/llm/transformers/{qwen2_vl.py → multimodal/qwen2_vl.py} +224 -256
- xinference/model/llm/transformers/opt.py +4 -2
- xinference/model/llm/transformers/utils.py +6 -37
- xinference/model/llm/utils.py +11 -0
- xinference/model/llm/vllm/core.py +7 -0
- xinference/model/rerank/core.py +91 -3
- xinference/model/rerank/model_spec.json +24 -0
- xinference/model/rerank/model_spec_modelscope.json +24 -0
- xinference/model/rerank/utils.py +20 -2
- xinference/model/utils.py +38 -1
- xinference/model/video/diffusers.py +65 -3
- xinference/model/video/model_spec.json +31 -4
- xinference/model/video/model_spec_modelscope.json +32 -4
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.013f296b.css +2 -0
- xinference/web/ui/build/static/css/main.013f296b.css.map +1 -0
- xinference/web/ui/build/static/js/main.8a9e3ba0.js +3 -0
- xinference/web/ui/build/static/js/main.8a9e3ba0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/34cfbfb7836e136ba3261cfd411cc554bf99ba24b35dcceebeaa4f008cb3c9dc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/567e49df411efb24425d289bb484758cb57067ca54f8b5c67fe4505f698deb96.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6595880facebca7ceace6f17cf21c3a5a9219a2f52fb0ba9f3cf1131eddbcf6b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/aa998bc2d9c11853add6b8a2e08f50327f56d8824ccaaec92d6dde1b305f0d85.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c748246b1d7bcebc16153be69f37e955bb2145526c47dd425aeeff70d3004dbc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e31234e95d60a5a7883fbcd70de2475dc1c88c90705df1a530abb68f86f80a51.json +1 -0
- xinference/web/ui/src/locales/en.json +21 -8
- xinference/web/ui/src/locales/ja.json +224 -0
- xinference/web/ui/src/locales/ko.json +224 -0
- xinference/web/ui/src/locales/zh.json +21 -8
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/METADATA +14 -11
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/RECORD +93 -100
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/WHEEL +1 -1
- xinference/model/llm/transformers/cogvlm2.py +0 -442
- xinference/model/llm/transformers/cogvlm2_video.py +0 -333
- xinference/model/llm/transformers/deepseek_vl.py +0 -280
- xinference/model/llm/transformers/glm_edge_v.py +0 -213
- xinference/model/llm/transformers/intern_vl.py +0 -526
- xinference/model/llm/transformers/internlm2.py +0 -94
- xinference/model/llm/transformers/minicpmv25.py +0 -193
- xinference/model/llm/transformers/omnilmm.py +0 -132
- xinference/model/llm/transformers/qwen2_audio.py +0 -179
- xinference/model/llm/transformers/qwen_vl.py +0 -360
- xinference/thirdparty/omnilmm/LICENSE +0 -201
- xinference/thirdparty/omnilmm/chat.py +0 -218
- xinference/thirdparty/omnilmm/constants.py +0 -4
- xinference/thirdparty/omnilmm/conversation.py +0 -332
- xinference/thirdparty/omnilmm/model/__init__.py +0 -1
- xinference/thirdparty/omnilmm/model/omnilmm.py +0 -595
- xinference/thirdparty/omnilmm/model/resampler.py +0 -166
- xinference/thirdparty/omnilmm/model/utils.py +0 -578
- xinference/thirdparty/omnilmm/train/train_utils.py +0 -150
- xinference/thirdparty/omnilmm/utils.py +0 -134
- xinference/web/ui/build/static/css/main.337afe76.css +0 -2
- xinference/web/ui/build/static/css/main.337afe76.css.map +0 -1
- xinference/web/ui/build/static/js/main.ae579a97.js +0 -3
- xinference/web/ui/build/static/js/main.ae579a97.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +0 -1
- /xinference/{thirdparty/omnilmm → model/embedding/vllm}/__init__.py +0 -0
- /xinference/web/ui/build/static/js/{main.ae579a97.js.LICENSE.txt → main.8a9e3ba0.js.LICENSE.txt} +0 -0
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
from threading import Thread
|
|
16
|
+
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
|
17
|
+
|
|
18
|
+
from .....model.utils import select_device
|
|
19
|
+
from .....types import PytorchModelConfig
|
|
20
|
+
from ...llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
|
|
21
|
+
from ..core import register_non_default_model
|
|
22
|
+
from .core import PytorchMultiModalModel
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@register_transformer
|
|
28
|
+
@register_non_default_model("gemma-3-it")
|
|
29
|
+
class Gemma3ChatModel(PytorchMultiModalModel):
|
|
30
|
+
@classmethod
|
|
31
|
+
def match_json(
|
|
32
|
+
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
33
|
+
) -> bool:
|
|
34
|
+
if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
35
|
+
return False
|
|
36
|
+
llm_family = model_family.model_family or model_family.model_name
|
|
37
|
+
if "gemma-3-it".lower() in llm_family.lower():
|
|
38
|
+
return True
|
|
39
|
+
return False
|
|
40
|
+
|
|
41
|
+
def _sanitize_model_config(
|
|
42
|
+
self, pytorch_model_config: Optional[PytorchModelConfig]
|
|
43
|
+
) -> PytorchModelConfig:
|
|
44
|
+
pytorch_model_config = super()._sanitize_model_config(pytorch_model_config)
|
|
45
|
+
assert pytorch_model_config is not None
|
|
46
|
+
pytorch_model_config.setdefault("min_pixels", 256 * 28 * 28)
|
|
47
|
+
pytorch_model_config.setdefault("max_pixels", 1280 * 28 * 28)
|
|
48
|
+
return pytorch_model_config
|
|
49
|
+
|
|
50
|
+
def decide_device(self):
|
|
51
|
+
device = self._pytorch_model_config.get("device", "auto")
|
|
52
|
+
device = select_device(device)
|
|
53
|
+
self._device = device
|
|
54
|
+
|
|
55
|
+
def load_processor(self):
|
|
56
|
+
from transformers import AutoProcessor
|
|
57
|
+
|
|
58
|
+
min_pixels = self._pytorch_model_config.get("min_pixels")
|
|
59
|
+
max_pixels = self._pytorch_model_config.get("max_pixels")
|
|
60
|
+
self._processor = AutoProcessor.from_pretrained(
|
|
61
|
+
self.model_path,
|
|
62
|
+
min_pixels=min_pixels,
|
|
63
|
+
max_pixels=max_pixels,
|
|
64
|
+
)
|
|
65
|
+
self._tokenizer = self._processor.tokenizer
|
|
66
|
+
|
|
67
|
+
def load_multimodal_model(self):
|
|
68
|
+
from transformers import Gemma3ForConditionalGeneration
|
|
69
|
+
|
|
70
|
+
kwargs = self.apply_bnb_quantization()
|
|
71
|
+
self._model = Gemma3ForConditionalGeneration.from_pretrained(
|
|
72
|
+
self.model_path, device_map="auto", torch_dtype="bfloat16", **kwargs
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def build_inputs_from_messages(
|
|
76
|
+
self,
|
|
77
|
+
messages: List[Dict],
|
|
78
|
+
generate_config: Dict,
|
|
79
|
+
):
|
|
80
|
+
messages = self._transform_messages(messages)
|
|
81
|
+
inputs = self._processor.apply_chat_template(
|
|
82
|
+
messages,
|
|
83
|
+
add_generation_prompt=True,
|
|
84
|
+
tokenize=True,
|
|
85
|
+
return_dict=True,
|
|
86
|
+
return_tensors="pt",
|
|
87
|
+
).to(self._device)
|
|
88
|
+
return inputs
|
|
89
|
+
|
|
90
|
+
def build_generate_kwargs(
|
|
91
|
+
self,
|
|
92
|
+
generate_config: Dict,
|
|
93
|
+
) -> Dict[str, Any]:
|
|
94
|
+
return dict(
|
|
95
|
+
max_new_tokens=generate_config.get("max_tokens", 512),
|
|
96
|
+
temperature=generate_config.get("temperature", 1),
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def build_streaming_iter(
|
|
100
|
+
self,
|
|
101
|
+
messages: List[Dict],
|
|
102
|
+
generate_config: Dict,
|
|
103
|
+
) -> Tuple[Iterator, int]:
|
|
104
|
+
from transformers import TextIteratorStreamer
|
|
105
|
+
|
|
106
|
+
inputs = self.build_inputs_from_messages(messages, generate_config)
|
|
107
|
+
configs = self.build_generate_kwargs(generate_config)
|
|
108
|
+
|
|
109
|
+
tokenizer = self._tokenizer
|
|
110
|
+
streamer = TextIteratorStreamer(
|
|
111
|
+
tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
gen_kwargs = {"streamer": streamer, **inputs, **configs}
|
|
115
|
+
t = Thread(target=self._model.generate, kwargs=gen_kwargs)
|
|
116
|
+
t.start()
|
|
117
|
+
return streamer, len(inputs.input_ids[0])
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022-
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -13,31 +13,28 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import logging
|
|
15
15
|
import typing
|
|
16
|
-
import uuid
|
|
17
16
|
from concurrent.futures import ThreadPoolExecutor
|
|
18
17
|
from threading import Thread
|
|
19
|
-
from typing import Dict, Iterator, List, Optional,
|
|
18
|
+
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
|
20
19
|
|
|
21
20
|
import torch
|
|
22
21
|
|
|
23
|
-
from
|
|
24
|
-
from
|
|
25
|
-
from
|
|
26
|
-
from
|
|
27
|
-
from
|
|
28
|
-
from
|
|
29
|
-
from
|
|
22
|
+
from .....core.model import register_batching_multimodal_models
|
|
23
|
+
from .....core.scheduler import InferenceRequest
|
|
24
|
+
from .....model.utils import select_device
|
|
25
|
+
from ...llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
|
|
26
|
+
from ...utils import _decode_image
|
|
27
|
+
from ..core import register_non_default_model
|
|
28
|
+
from ..utils import get_max_src_len
|
|
29
|
+
from .core import PytorchMultiModalModel
|
|
30
30
|
|
|
31
31
|
logger = logging.getLogger(__name__)
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
self._tokenizer = None
|
|
39
|
-
self._model = None
|
|
40
|
-
|
|
34
|
+
@register_batching_multimodal_models("glm-4v")
|
|
35
|
+
@register_transformer
|
|
36
|
+
@register_non_default_model("glm-4v")
|
|
37
|
+
class Glm4VModel(PytorchMultiModalModel):
|
|
41
38
|
@classmethod
|
|
42
39
|
def match_json(
|
|
43
40
|
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
@@ -47,19 +44,23 @@ class Glm4VModel(PytorchChatModel):
|
|
|
47
44
|
return True
|
|
48
45
|
return False
|
|
49
46
|
|
|
50
|
-
def
|
|
51
|
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
52
|
-
|
|
47
|
+
def decide_device(self):
|
|
53
48
|
device = self._pytorch_model_config.get("device", "auto")
|
|
54
49
|
self._device = select_device(device)
|
|
55
50
|
|
|
51
|
+
def load_processor(self):
|
|
52
|
+
from transformers import AutoTokenizer
|
|
53
|
+
|
|
54
|
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
55
|
+
self.model_path, trust_remote_code=True
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def load_multimodal_model(self):
|
|
59
|
+
from transformers import AutoModelForCausalLM
|
|
60
|
+
|
|
56
61
|
kwargs = {"device_map": self._device}
|
|
57
62
|
kwargs = self.apply_bnb_quantization(kwargs)
|
|
58
63
|
|
|
59
|
-
if self._check_tensorizer_integrity():
|
|
60
|
-
self._model, self._tokenizer = self._load_tensorizer()
|
|
61
|
-
return
|
|
62
|
-
|
|
63
64
|
model = AutoModelForCausalLM.from_pretrained(
|
|
64
65
|
self.model_path,
|
|
65
66
|
low_cpu_mem_usage=True,
|
|
@@ -69,12 +70,6 @@ class Glm4VModel(PytorchChatModel):
|
|
|
69
70
|
)
|
|
70
71
|
self._model = model.eval()
|
|
71
72
|
|
|
72
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
73
|
-
self.model_path, trust_remote_code=True
|
|
74
|
-
)
|
|
75
|
-
self._tokenizer = tokenizer
|
|
76
|
-
self._save_tensorizer()
|
|
77
|
-
|
|
78
73
|
@staticmethod
|
|
79
74
|
def _get_processed_msgs(messages: List[Dict]) -> List[Dict]:
|
|
80
75
|
res = []
|
|
@@ -111,20 +106,12 @@ class Glm4VModel(PytorchChatModel):
|
|
|
111
106
|
res.append({"role": role, "content": text})
|
|
112
107
|
return res
|
|
113
108
|
|
|
114
|
-
|
|
115
|
-
def chat(
|
|
109
|
+
def build_inputs_from_messages(
|
|
116
110
|
self,
|
|
117
111
|
messages: List[Dict],
|
|
118
|
-
generate_config:
|
|
119
|
-
)
|
|
120
|
-
from transformers import TextIteratorStreamer
|
|
121
|
-
|
|
122
|
-
if not generate_config:
|
|
123
|
-
generate_config = {}
|
|
124
|
-
|
|
125
|
-
stream = generate_config.get("stream", False)
|
|
112
|
+
generate_config: Dict,
|
|
113
|
+
):
|
|
126
114
|
msgs = self._get_processed_msgs(messages)
|
|
127
|
-
|
|
128
115
|
inputs = self._tokenizer.apply_chat_template(
|
|
129
116
|
msgs,
|
|
130
117
|
add_generation_prompt=True,
|
|
@@ -133,68 +120,45 @@ class Glm4VModel(PytorchChatModel):
|
|
|
133
120
|
return_dict=True,
|
|
134
121
|
) # chat mode
|
|
135
122
|
inputs = inputs.to(self._model.device)
|
|
123
|
+
return inputs
|
|
136
124
|
|
|
137
|
-
|
|
138
|
-
|
|
125
|
+
def build_generate_kwargs(
|
|
126
|
+
self,
|
|
127
|
+
generate_config: Dict,
|
|
128
|
+
) -> Dict[str, Any]:
|
|
129
|
+
return {
|
|
139
130
|
"eos_token_id": [151329, 151336, 151338],
|
|
140
131
|
"do_sample": True,
|
|
141
132
|
"max_length": generate_config.get("max_tokens", 2048),
|
|
142
133
|
"temperature": generate_config.get("temperature", 0.7),
|
|
143
134
|
}
|
|
144
|
-
stop_str = "<|endoftext|>"
|
|
145
|
-
|
|
146
|
-
if stream:
|
|
147
|
-
streamer = TextIteratorStreamer(
|
|
148
|
-
tokenizer=self._tokenizer,
|
|
149
|
-
timeout=60,
|
|
150
|
-
skip_prompt=True,
|
|
151
|
-
skip_special_tokens=True,
|
|
152
|
-
)
|
|
153
|
-
generate_kwargs = {
|
|
154
|
-
**generate_kwargs,
|
|
155
|
-
"streamer": streamer,
|
|
156
|
-
}
|
|
157
|
-
t = Thread(target=self._model.generate, kwargs=generate_kwargs)
|
|
158
|
-
t.start()
|
|
159
135
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
else:
|
|
163
|
-
with torch.no_grad():
|
|
164
|
-
outputs = self._model.generate(**generate_kwargs)
|
|
165
|
-
outputs = outputs[:, inputs["input_ids"].shape[1] :]
|
|
166
|
-
response = self._tokenizer.decode(outputs[0])
|
|
167
|
-
if response.endswith(stop_str):
|
|
168
|
-
response = response[: -len(stop_str)]
|
|
169
|
-
return generate_chat_completion(self.model_uid, response)
|
|
136
|
+
def get_stop_strs(self) -> List[str]:
|
|
137
|
+
return ["<|endoftext|>"]
|
|
170
138
|
|
|
171
|
-
def
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
finish_reason=None,
|
|
178
|
-
chunk_id=completion_id,
|
|
179
|
-
model_uid=self.model_uid,
|
|
180
|
-
prompt_tokens=-1,
|
|
181
|
-
completion_tokens=-1,
|
|
182
|
-
total_tokens=-1,
|
|
183
|
-
has_choice=True,
|
|
184
|
-
has_content=True,
|
|
185
|
-
)
|
|
139
|
+
def build_streaming_iter(
|
|
140
|
+
self,
|
|
141
|
+
messages: List[Dict],
|
|
142
|
+
generate_config: Dict,
|
|
143
|
+
) -> Tuple[Iterator, int]:
|
|
144
|
+
from transformers import TextIteratorStreamer
|
|
186
145
|
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
total_tokens=-1,
|
|
195
|
-
has_choice=True,
|
|
196
|
-
has_content=False,
|
|
146
|
+
generate_kwargs = self.build_generate_kwargs(generate_config)
|
|
147
|
+
inputs = self.build_inputs_from_messages(messages, generate_config)
|
|
148
|
+
streamer = TextIteratorStreamer(
|
|
149
|
+
tokenizer=self._tokenizer,
|
|
150
|
+
timeout=60,
|
|
151
|
+
skip_prompt=True,
|
|
152
|
+
skip_special_tokens=True,
|
|
197
153
|
)
|
|
154
|
+
kwargs = {
|
|
155
|
+
**inputs,
|
|
156
|
+
**generate_kwargs,
|
|
157
|
+
"streamer": streamer,
|
|
158
|
+
}
|
|
159
|
+
t = Thread(target=self._model.generate, kwargs=kwargs)
|
|
160
|
+
t.start()
|
|
161
|
+
return streamer, len(inputs.input_ids[0])
|
|
198
162
|
|
|
199
163
|
def _get_full_prompt(self, messages, tools, generate_config: dict):
|
|
200
164
|
msgs = self._get_processed_msgs(messages)
|