xinference 0.15.3__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -4
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +29 -2
- xinference/client/restful/restful_client.py +10 -0
- xinference/constants.py +7 -3
- xinference/core/image_interface.py +76 -23
- xinference/core/model.py +158 -46
- xinference/core/progress_tracker.py +187 -0
- xinference/core/scheduler.py +10 -7
- xinference/core/supervisor.py +11 -0
- xinference/core/utils.py +9 -0
- xinference/core/worker.py +1 -0
- xinference/deploy/supervisor.py +4 -0
- xinference/model/__init__.py +4 -0
- xinference/model/audio/chattts.py +2 -1
- xinference/model/audio/core.py +0 -2
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/image/core.py +6 -7
- xinference/model/image/scheduler/__init__.py +13 -0
- xinference/model/image/scheduler/flux.py +533 -0
- xinference/model/image/sdapi.py +35 -4
- xinference/model/image/stable_diffusion/core.py +215 -110
- xinference/model/image/utils.py +39 -3
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +185 -17
- xinference/model/llm/llm_family_modelscope.json +124 -12
- xinference/model/llm/transformers/chatglm.py +104 -0
- xinference/model/llm/transformers/cogvlm2.py +2 -1
- xinference/model/llm/transformers/cogvlm2_video.py +2 -0
- xinference/model/llm/transformers/core.py +43 -113
- xinference/model/llm/transformers/deepseek_v2.py +0 -226
- xinference/model/llm/transformers/deepseek_vl.py +2 -0
- xinference/model/llm/transformers/glm4v.py +2 -1
- xinference/model/llm/transformers/intern_vl.py +2 -0
- xinference/model/llm/transformers/internlm2.py +3 -95
- xinference/model/llm/transformers/minicpmv25.py +2 -0
- xinference/model/llm/transformers/minicpmv26.py +2 -0
- xinference/model/llm/transformers/omnilmm.py +2 -0
- xinference/model/llm/transformers/opt.py +68 -0
- xinference/model/llm/transformers/qwen2_audio.py +11 -4
- xinference/model/llm/transformers/qwen2_vl.py +2 -28
- xinference/model/llm/transformers/qwen_vl.py +2 -1
- xinference/model/llm/transformers/utils.py +36 -283
- xinference/model/llm/transformers/yi_vl.py +2 -0
- xinference/model/llm/utils.py +60 -16
- xinference/model/llm/vllm/core.py +68 -9
- xinference/model/llm/vllm/utils.py +0 -1
- xinference/model/utils.py +7 -4
- xinference/model/video/core.py +0 -2
- xinference/utils.py +2 -3
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.e51a356d.js → main.f7da0140.js} +3 -3
- xinference/web/ui/build/static/js/main.f7da0140.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/METADATA +38 -6
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/RECORD +63 -59
- xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
- /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.f7da0140.js.LICENSE.txt} +0 -0
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/LICENSE +0 -0
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/WHEEL +0 -0
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright 2022-2024 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from builtins import classmethod
|
|
15
|
+
from typing import List, Optional
|
|
16
|
+
|
|
17
|
+
from ....core.scheduler import InferenceRequest
|
|
18
|
+
from ....types import LoRA
|
|
19
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
20
|
+
from .core import PytorchModel, PytorchModelConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OptPytorchModel(PytorchModel):
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
model_uid: str,
|
|
27
|
+
model_family: "LLMFamilyV1",
|
|
28
|
+
model_spec: "LLMSpecV1",
|
|
29
|
+
quantization: str,
|
|
30
|
+
model_path: str,
|
|
31
|
+
pytorch_model_config: Optional[PytorchModelConfig] = None,
|
|
32
|
+
peft_model: Optional[List[LoRA]] = None,
|
|
33
|
+
):
|
|
34
|
+
super().__init__(
|
|
35
|
+
model_uid,
|
|
36
|
+
model_family,
|
|
37
|
+
model_spec,
|
|
38
|
+
quantization,
|
|
39
|
+
model_path,
|
|
40
|
+
pytorch_model_config=pytorch_model_config,
|
|
41
|
+
peft_model=peft_model,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def match(
|
|
46
|
+
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
47
|
+
) -> bool:
|
|
48
|
+
if llm_spec.model_format != "pytorch":
|
|
49
|
+
return False
|
|
50
|
+
model_family = llm_family.model_family or llm_family.model_name
|
|
51
|
+
if model_family != "opt":
|
|
52
|
+
return False
|
|
53
|
+
return True
|
|
54
|
+
|
|
55
|
+
def build_prefill_position_ids(
|
|
56
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
Mainly for UT.
|
|
60
|
+
Transformers code in `main` branch supports `position_ids` parameter (https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L1076),
|
|
61
|
+
while in release branch, it doesn't (https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/opt/modeling_opt.py#L886).
|
|
62
|
+
"""
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
def build_decode_position_ids(
|
|
66
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
67
|
+
):
|
|
68
|
+
return None
|
|
@@ -14,16 +14,22 @@
|
|
|
14
14
|
import logging
|
|
15
15
|
import uuid
|
|
16
16
|
from io import BytesIO
|
|
17
|
-
from typing import
|
|
17
|
+
from typing import Iterator, List, Optional, Union
|
|
18
18
|
from urllib.request import urlopen
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
|
|
22
22
|
from ....model.utils import select_device
|
|
23
|
-
from ....types import
|
|
23
|
+
from ....types import (
|
|
24
|
+
ChatCompletion,
|
|
25
|
+
ChatCompletionChunk,
|
|
26
|
+
ChatCompletionMessage,
|
|
27
|
+
CompletionChunk,
|
|
28
|
+
)
|
|
24
29
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
25
30
|
from ..utils import generate_chat_completion, generate_completion_chunk
|
|
26
31
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
32
|
+
from .utils import cache_clean
|
|
27
33
|
|
|
28
34
|
logger = logging.getLogger(__name__)
|
|
29
35
|
|
|
@@ -68,7 +74,7 @@ class Qwen2AudioChatModel(PytorchChatModel):
|
|
|
68
74
|
|
|
69
75
|
def _transform_messages(
|
|
70
76
|
self,
|
|
71
|
-
messages: List[
|
|
77
|
+
messages: List[ChatCompletionMessage],
|
|
72
78
|
):
|
|
73
79
|
import librosa
|
|
74
80
|
|
|
@@ -89,9 +95,10 @@ class Qwen2AudioChatModel(PytorchChatModel):
|
|
|
89
95
|
|
|
90
96
|
return text, audios
|
|
91
97
|
|
|
98
|
+
@cache_clean
|
|
92
99
|
def chat(
|
|
93
100
|
self,
|
|
94
|
-
messages: List[
|
|
101
|
+
messages: List[ChatCompletionMessage],
|
|
95
102
|
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
96
103
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
97
104
|
text, audios = self._transform_messages(messages)
|
|
@@ -27,6 +27,7 @@ from ....types import (
|
|
|
27
27
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
28
28
|
from ..utils import generate_chat_completion, generate_completion_chunk
|
|
29
29
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
30
|
+
from .utils import cache_clean
|
|
30
31
|
|
|
31
32
|
logger = logging.getLogger(__name__)
|
|
32
33
|
|
|
@@ -75,34 +76,7 @@ class Qwen2VLChatModel(PytorchChatModel):
|
|
|
75
76
|
self.model_path, device_map=device, trust_remote_code=True
|
|
76
77
|
).eval()
|
|
77
78
|
|
|
78
|
-
|
|
79
|
-
self,
|
|
80
|
-
messages: List[ChatCompletionMessage],
|
|
81
|
-
):
|
|
82
|
-
transformed_messages = []
|
|
83
|
-
for msg in messages:
|
|
84
|
-
new_content = []
|
|
85
|
-
role = msg["role"]
|
|
86
|
-
content = msg["content"]
|
|
87
|
-
if isinstance(content, str):
|
|
88
|
-
new_content.append({"type": "text", "text": content})
|
|
89
|
-
elif isinstance(content, List):
|
|
90
|
-
for item in content: # type: ignore
|
|
91
|
-
if "text" in item:
|
|
92
|
-
new_content.append({"type": "text", "text": item["text"]})
|
|
93
|
-
elif "image_url" in item:
|
|
94
|
-
new_content.append(
|
|
95
|
-
{"type": "image", "image": item["image_url"]["url"]}
|
|
96
|
-
)
|
|
97
|
-
elif "video_url" in item:
|
|
98
|
-
new_content.append(
|
|
99
|
-
{"type": "video", "video": item["video_url"]["url"]}
|
|
100
|
-
)
|
|
101
|
-
new_message = {"role": role, "content": new_content}
|
|
102
|
-
transformed_messages.append(new_message)
|
|
103
|
-
|
|
104
|
-
return transformed_messages
|
|
105
|
-
|
|
79
|
+
@cache_clean
|
|
106
80
|
def chat(
|
|
107
81
|
self,
|
|
108
82
|
messages: List[ChatCompletionMessage], # type: ignore
|
|
@@ -28,7 +28,7 @@ from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
|
|
|
28
28
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
29
29
|
from ..utils import generate_chat_completion, generate_completion_chunk
|
|
30
30
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
31
|
-
from .utils import pad_prefill_tokens
|
|
31
|
+
from .utils import cache_clean, pad_prefill_tokens
|
|
32
32
|
|
|
33
33
|
logger = logging.getLogger(__name__)
|
|
34
34
|
|
|
@@ -137,6 +137,7 @@ class QwenVLChatModel(PytorchChatModel):
|
|
|
137
137
|
prompt = self._message_content_to_qwen(messages[-1]["content"])
|
|
138
138
|
return prompt, qwen_history
|
|
139
139
|
|
|
140
|
+
@cache_clean
|
|
140
141
|
def chat(
|
|
141
142
|
self,
|
|
142
143
|
messages: List[Dict],
|
|
@@ -12,12 +12,12 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import
|
|
15
|
+
import asyncio
|
|
16
|
+
import functools
|
|
16
17
|
import logging
|
|
17
18
|
import os
|
|
18
19
|
import time
|
|
19
|
-
import
|
|
20
|
-
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Tuple
|
|
20
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
|
21
21
|
|
|
22
22
|
import torch
|
|
23
23
|
from transformers.cache_utils import DynamicCache
|
|
@@ -45,20 +45,6 @@ if TYPE_CHECKING:
|
|
|
45
45
|
logger = logging.getLogger(__name__)
|
|
46
46
|
|
|
47
47
|
|
|
48
|
-
def is_sentence_complete(output: str):
|
|
49
|
-
"""Check whether the output is a complete sentence."""
|
|
50
|
-
end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”")
|
|
51
|
-
return output.endswith(end_symbols)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def is_partial_stop(output: str, stop_str: str):
|
|
55
|
-
"""Check whether the output contains a partial stop str."""
|
|
56
|
-
for i in range(0, min(len(output), len(stop_str))):
|
|
57
|
-
if stop_str.startswith(output[-i:]):
|
|
58
|
-
return True
|
|
59
|
-
return False
|
|
60
|
-
|
|
61
|
-
|
|
62
48
|
def get_context_length(config) -> int:
|
|
63
49
|
"""Get the context length of a model from a huggingface model config."""
|
|
64
50
|
if (
|
|
@@ -98,272 +84,6 @@ def prepare_logits_processor(
|
|
|
98
84
|
return processor_list
|
|
99
85
|
|
|
100
86
|
|
|
101
|
-
@torch.inference_mode()
|
|
102
|
-
def generate_stream(
|
|
103
|
-
model_uid,
|
|
104
|
-
model,
|
|
105
|
-
tokenizer,
|
|
106
|
-
prompt,
|
|
107
|
-
device,
|
|
108
|
-
generate_config,
|
|
109
|
-
judge_sent_end=False,
|
|
110
|
-
) -> Iterator[Tuple[CompletionChunk, CompletionUsage]]:
|
|
111
|
-
context_len = get_context_length(model.config)
|
|
112
|
-
stream_interval = generate_config.get("stream_interval", 2)
|
|
113
|
-
stream = generate_config.get("stream", False)
|
|
114
|
-
stream_options = generate_config.pop("stream_options", None)
|
|
115
|
-
include_usage = (
|
|
116
|
-
stream_options["include_usage"] if isinstance(stream_options, dict) else False
|
|
117
|
-
)
|
|
118
|
-
|
|
119
|
-
len_prompt = len(prompt)
|
|
120
|
-
|
|
121
|
-
temperature = float(generate_config.get("temperature", 1.0))
|
|
122
|
-
repetition_penalty = float(generate_config.get("repetition_penalty", 1.0))
|
|
123
|
-
top_p = float(generate_config.get("top_p", 1.0))
|
|
124
|
-
top_k = int(generate_config.get("top_k", -1)) # -1 means disable
|
|
125
|
-
max_new_tokens = int(generate_config.get("max_tokens", max_tokens_field.default))
|
|
126
|
-
echo = bool(generate_config.get("echo", False))
|
|
127
|
-
stop_str = generate_config.get("stop", None)
|
|
128
|
-
stop_token_ids = generate_config.get("stop_token_ids", None) or []
|
|
129
|
-
stop_token_ids.append(tokenizer.eos_token_id)
|
|
130
|
-
chunk_id = str(uuid.uuid4())
|
|
131
|
-
|
|
132
|
-
logits_processor = prepare_logits_processor(
|
|
133
|
-
temperature, repetition_penalty, top_p, top_k
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
if ".modeling_qwen." in str(type(model)).lower():
|
|
137
|
-
# TODO: hacky
|
|
138
|
-
input_ids = tokenizer(prompt, allowed_special="all").input_ids
|
|
139
|
-
else:
|
|
140
|
-
input_ids = tokenizer(prompt).input_ids
|
|
141
|
-
output_ids = list(input_ids)
|
|
142
|
-
|
|
143
|
-
if model.config.is_encoder_decoder:
|
|
144
|
-
max_src_len = context_len
|
|
145
|
-
else:
|
|
146
|
-
max_src_len = context_len - max_new_tokens - 8
|
|
147
|
-
if max_src_len < 0:
|
|
148
|
-
raise ValueError("Max tokens exceeds model's max length")
|
|
149
|
-
|
|
150
|
-
input_ids = input_ids[-max_src_len:]
|
|
151
|
-
input_echo_len = len(input_ids)
|
|
152
|
-
|
|
153
|
-
if model.config.is_encoder_decoder:
|
|
154
|
-
encoder_output = model.encoder(
|
|
155
|
-
input_ids=torch.as_tensor([input_ids], device=device)
|
|
156
|
-
)[0]
|
|
157
|
-
start_ids = torch.as_tensor(
|
|
158
|
-
[[model.generation_config.decoder_start_token_id]],
|
|
159
|
-
dtype=torch.int64,
|
|
160
|
-
device=device,
|
|
161
|
-
)
|
|
162
|
-
|
|
163
|
-
start = time.time()
|
|
164
|
-
past_key_values = out = None
|
|
165
|
-
sent_interrupt = False
|
|
166
|
-
token = None
|
|
167
|
-
last_output_length = 0
|
|
168
|
-
for i in range(max_new_tokens):
|
|
169
|
-
if i == 0:
|
|
170
|
-
if model.config.is_encoder_decoder:
|
|
171
|
-
out = model.decoder(
|
|
172
|
-
input_ids=start_ids,
|
|
173
|
-
encoder_hidden_states=encoder_output,
|
|
174
|
-
use_cache=True,
|
|
175
|
-
)
|
|
176
|
-
logits = model.lm_head(out[0])
|
|
177
|
-
else:
|
|
178
|
-
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
|
|
179
|
-
logits = out.logits
|
|
180
|
-
past_key_values = out.past_key_values
|
|
181
|
-
else:
|
|
182
|
-
if model.config.is_encoder_decoder:
|
|
183
|
-
out = model.decoder(
|
|
184
|
-
input_ids=torch.as_tensor(
|
|
185
|
-
[[token] if not sent_interrupt else output_ids], device=device
|
|
186
|
-
),
|
|
187
|
-
encoder_hidden_states=encoder_output,
|
|
188
|
-
use_cache=True,
|
|
189
|
-
past_key_values=past_key_values if not sent_interrupt else None,
|
|
190
|
-
)
|
|
191
|
-
sent_interrupt = False
|
|
192
|
-
|
|
193
|
-
logits = model.lm_head(out[0])
|
|
194
|
-
else:
|
|
195
|
-
out = model(
|
|
196
|
-
input_ids=torch.as_tensor(
|
|
197
|
-
[[token] if not sent_interrupt else output_ids], device=device
|
|
198
|
-
),
|
|
199
|
-
use_cache=True,
|
|
200
|
-
past_key_values=past_key_values if not sent_interrupt else None,
|
|
201
|
-
)
|
|
202
|
-
sent_interrupt = False
|
|
203
|
-
logits = out.logits
|
|
204
|
-
past_key_values = out.past_key_values
|
|
205
|
-
|
|
206
|
-
if logits_processor:
|
|
207
|
-
if repetition_penalty > 1.0:
|
|
208
|
-
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
|
|
209
|
-
else:
|
|
210
|
-
tmp_output_ids = None
|
|
211
|
-
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
|
|
212
|
-
else:
|
|
213
|
-
last_token_logits = logits[0, -1, :]
|
|
214
|
-
|
|
215
|
-
if device == "mps":
|
|
216
|
-
# Switch to CPU by avoiding some bugs in mps backend.
|
|
217
|
-
last_token_logits = last_token_logits.float().to("cpu")
|
|
218
|
-
|
|
219
|
-
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
|
220
|
-
_, indices = torch.topk(last_token_logits, 2)
|
|
221
|
-
tokens = [int(index) for index in indices.tolist()]
|
|
222
|
-
else:
|
|
223
|
-
probs = torch.softmax(last_token_logits, dim=-1)
|
|
224
|
-
indices = torch.multinomial(probs, num_samples=2)
|
|
225
|
-
tokens = [int(token) for token in indices.tolist()]
|
|
226
|
-
token = tokens[0]
|
|
227
|
-
output_ids.append(token)
|
|
228
|
-
|
|
229
|
-
if token in stop_token_ids:
|
|
230
|
-
stopped = True
|
|
231
|
-
else:
|
|
232
|
-
stopped = False
|
|
233
|
-
|
|
234
|
-
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
|
235
|
-
if echo:
|
|
236
|
-
tmp_output_ids = output_ids
|
|
237
|
-
rfind_start = len_prompt
|
|
238
|
-
else:
|
|
239
|
-
tmp_output_ids = output_ids[input_echo_len:]
|
|
240
|
-
rfind_start = 0
|
|
241
|
-
|
|
242
|
-
output = tokenizer.decode(
|
|
243
|
-
tmp_output_ids,
|
|
244
|
-
skip_special_tokens=True,
|
|
245
|
-
spaces_between_special_tokens=False,
|
|
246
|
-
clean_up_tokenization_spaces=True,
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
# TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
|
|
250
|
-
if judge_sent_end and stopped and not is_sentence_complete(output):
|
|
251
|
-
if len(tokens) > 1:
|
|
252
|
-
token = tokens[1]
|
|
253
|
-
output_ids[-1] = token
|
|
254
|
-
else:
|
|
255
|
-
output_ids.pop()
|
|
256
|
-
stopped = False
|
|
257
|
-
sent_interrupt = True
|
|
258
|
-
|
|
259
|
-
partially_stopped = False
|
|
260
|
-
if stop_str:
|
|
261
|
-
if isinstance(stop_str, str):
|
|
262
|
-
pos = output.rfind(stop_str, rfind_start)
|
|
263
|
-
if pos != -1:
|
|
264
|
-
output = output[:pos]
|
|
265
|
-
stopped = True
|
|
266
|
-
else:
|
|
267
|
-
partially_stopped = is_partial_stop(output, stop_str)
|
|
268
|
-
elif isinstance(stop_str, Iterable):
|
|
269
|
-
for each_stop in stop_str:
|
|
270
|
-
pos = output.rfind(each_stop, rfind_start)
|
|
271
|
-
if pos != -1:
|
|
272
|
-
output = output[:pos]
|
|
273
|
-
stopped = True
|
|
274
|
-
break
|
|
275
|
-
else:
|
|
276
|
-
partially_stopped = is_partial_stop(output, each_stop)
|
|
277
|
-
if partially_stopped:
|
|
278
|
-
break
|
|
279
|
-
else:
|
|
280
|
-
raise ValueError("Invalid stop field type.")
|
|
281
|
-
|
|
282
|
-
if stream:
|
|
283
|
-
output = output.strip("�")
|
|
284
|
-
tmp_output_length = len(output)
|
|
285
|
-
output = output[last_output_length:]
|
|
286
|
-
last_output_length = tmp_output_length
|
|
287
|
-
|
|
288
|
-
# prevent yielding partial stop sequence
|
|
289
|
-
if not partially_stopped:
|
|
290
|
-
completion_choice = CompletionChoice(
|
|
291
|
-
text=output, index=0, logprobs=None, finish_reason=None
|
|
292
|
-
)
|
|
293
|
-
completion_chunk = CompletionChunk(
|
|
294
|
-
id=chunk_id,
|
|
295
|
-
object="text_completion",
|
|
296
|
-
created=int(time.time()),
|
|
297
|
-
model=model_uid,
|
|
298
|
-
choices=[completion_choice],
|
|
299
|
-
)
|
|
300
|
-
completion_usage = CompletionUsage(
|
|
301
|
-
prompt_tokens=input_echo_len,
|
|
302
|
-
completion_tokens=i,
|
|
303
|
-
total_tokens=(input_echo_len + i),
|
|
304
|
-
)
|
|
305
|
-
|
|
306
|
-
yield completion_chunk, completion_usage
|
|
307
|
-
|
|
308
|
-
if stopped:
|
|
309
|
-
break
|
|
310
|
-
|
|
311
|
-
elapsed_time = time.time() - start
|
|
312
|
-
logger.info(f"Average generation speed: {i / elapsed_time:.2f} tokens/s.")
|
|
313
|
-
|
|
314
|
-
# finish stream event, which contains finish reason
|
|
315
|
-
if stopped:
|
|
316
|
-
finish_reason = "stop"
|
|
317
|
-
elif i == max_new_tokens - 1:
|
|
318
|
-
finish_reason = "length"
|
|
319
|
-
else:
|
|
320
|
-
finish_reason = None
|
|
321
|
-
|
|
322
|
-
if stream:
|
|
323
|
-
completion_choice = CompletionChoice(
|
|
324
|
-
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
325
|
-
)
|
|
326
|
-
else:
|
|
327
|
-
completion_choice = CompletionChoice(
|
|
328
|
-
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
329
|
-
)
|
|
330
|
-
|
|
331
|
-
completion_chunk = CompletionChunk(
|
|
332
|
-
id=chunk_id,
|
|
333
|
-
object="text_completion",
|
|
334
|
-
created=int(time.time()),
|
|
335
|
-
model=model_uid,
|
|
336
|
-
choices=[completion_choice],
|
|
337
|
-
)
|
|
338
|
-
completion_usage = CompletionUsage(
|
|
339
|
-
prompt_tokens=input_echo_len,
|
|
340
|
-
completion_tokens=i,
|
|
341
|
-
total_tokens=(input_echo_len + i),
|
|
342
|
-
)
|
|
343
|
-
|
|
344
|
-
yield completion_chunk, completion_usage
|
|
345
|
-
|
|
346
|
-
if include_usage:
|
|
347
|
-
completion_chunk = CompletionChunk(
|
|
348
|
-
id=chunk_id,
|
|
349
|
-
object="text_completion",
|
|
350
|
-
created=int(time.time()),
|
|
351
|
-
model=model_uid,
|
|
352
|
-
choices=[],
|
|
353
|
-
)
|
|
354
|
-
completion_usage = CompletionUsage(
|
|
355
|
-
prompt_tokens=input_echo_len,
|
|
356
|
-
completion_tokens=i,
|
|
357
|
-
total_tokens=(input_echo_len + i),
|
|
358
|
-
)
|
|
359
|
-
yield completion_chunk, completion_usage
|
|
360
|
-
|
|
361
|
-
# clean
|
|
362
|
-
del past_key_values, out
|
|
363
|
-
gc.collect()
|
|
364
|
-
empty_cache()
|
|
365
|
-
|
|
366
|
-
|
|
367
87
|
def _get_token_from_logits(
|
|
368
88
|
req: InferenceRequest, i: int, logits, temperature, repetition_penalty, top_p, top_k
|
|
369
89
|
):
|
|
@@ -678,6 +398,7 @@ def _batch_inference_one_step_internal(
|
|
|
678
398
|
output = output.strip("�")
|
|
679
399
|
output = output[r.last_output_length :]
|
|
680
400
|
r.last_output_length += len(output)
|
|
401
|
+
r.outputs.append(output)
|
|
681
402
|
|
|
682
403
|
completion_chunk = generate_completion_chunk(
|
|
683
404
|
chunk_text=output,
|
|
@@ -702,6 +423,7 @@ def _batch_inference_one_step_internal(
|
|
|
702
423
|
)
|
|
703
424
|
r.completion.append(completion_chunk)
|
|
704
425
|
r.completion.append(eos_flag)
|
|
426
|
+
r.outputs.append(eos_flag)
|
|
705
427
|
|
|
706
428
|
# last round, handle stream result
|
|
707
429
|
# append usage information when enable `include_usage` for OPENAI API compatibility
|
|
@@ -776,3 +498,34 @@ def batch_inference_one_step(
|
|
|
776
498
|
for r in req_list:
|
|
777
499
|
r.stopped = True
|
|
778
500
|
r.error_msg = str(e)
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
def cache_clean(fn):
|
|
504
|
+
@functools.wraps(fn)
|
|
505
|
+
async def _async_wrapper(self, *args, **kwargs):
|
|
506
|
+
import gc
|
|
507
|
+
|
|
508
|
+
from ....device_utils import empty_cache
|
|
509
|
+
|
|
510
|
+
result = await fn(self, *args, **kwargs)
|
|
511
|
+
|
|
512
|
+
gc.collect()
|
|
513
|
+
empty_cache()
|
|
514
|
+
return result
|
|
515
|
+
|
|
516
|
+
@functools.wraps(fn)
|
|
517
|
+
def _wrapper(self, *args, **kwargs):
|
|
518
|
+
import gc
|
|
519
|
+
|
|
520
|
+
from ....device_utils import empty_cache
|
|
521
|
+
|
|
522
|
+
result = fn(self, *args, **kwargs)
|
|
523
|
+
|
|
524
|
+
gc.collect()
|
|
525
|
+
empty_cache()
|
|
526
|
+
return result
|
|
527
|
+
|
|
528
|
+
if asyncio.iscoroutinefunction(fn):
|
|
529
|
+
return _async_wrapper
|
|
530
|
+
else:
|
|
531
|
+
return _wrapper
|
|
@@ -29,6 +29,7 @@ from ..utils import (
|
|
|
29
29
|
parse_messages,
|
|
30
30
|
)
|
|
31
31
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
32
|
+
from .utils import cache_clean
|
|
32
33
|
|
|
33
34
|
logger = logging.getLogger(__name__)
|
|
34
35
|
|
|
@@ -99,6 +100,7 @@ class YiVLChatModel(PytorchChatModel):
|
|
|
99
100
|
raise RuntimeError("Only one image per message is supported by Yi VL.")
|
|
100
101
|
return content
|
|
101
102
|
|
|
103
|
+
@cache_clean
|
|
102
104
|
def chat(
|
|
103
105
|
self,
|
|
104
106
|
messages: List[Dict],
|