xinference 0.10.2.post1__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/oauth2/auth_service.py +1 -1
- xinference/api/restful_api.py +53 -61
- xinference/client/restful/restful_client.py +52 -57
- xinference/conftest.py +1 -1
- xinference/core/cache_tracker.py +1 -1
- xinference/core/event.py +1 -1
- xinference/core/model.py +15 -4
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +58 -72
- xinference/core/worker.py +73 -102
- xinference/deploy/cmdline.py +175 -6
- xinference/deploy/test/test_cmdline.py +2 -0
- xinference/deploy/utils.py +1 -1
- xinference/device_utils.py +29 -3
- xinference/fields.py +5 -1
- xinference/model/audio/model_spec.json +8 -1
- xinference/model/audio/whisper.py +88 -12
- xinference/model/core.py +2 -2
- xinference/model/embedding/core.py +13 -0
- xinference/model/image/__init__.py +29 -0
- xinference/model/image/core.py +6 -0
- xinference/model/image/custom.py +109 -0
- xinference/model/llm/__init__.py +92 -32
- xinference/model/llm/core.py +57 -102
- xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +2 -2
- xinference/model/llm/llm_family.json +446 -2
- xinference/model/llm/llm_family.py +45 -41
- xinference/model/llm/llm_family_modelscope.json +208 -1
- xinference/model/llm/pytorch/deepseek_vl.py +89 -33
- xinference/model/llm/pytorch/qwen_vl.py +67 -12
- xinference/model/llm/pytorch/yi_vl.py +62 -45
- xinference/model/llm/utils.py +45 -15
- xinference/model/llm/vllm/core.py +21 -4
- xinference/model/rerank/core.py +48 -20
- xinference/thirdparty/omnilmm/chat.py +2 -1
- xinference/thirdparty/omnilmm/model/omnilmm.py +2 -1
- xinference/types.py +2 -0
- xinference/web/ui/build/asset-manifest.json +6 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.54bca460.css +2 -0
- xinference/web/ui/build/static/css/main.54bca460.css.map +1 -0
- xinference/web/ui/build/static/js/main.8e44da4b.js +3 -0
- xinference/web/ui/build/static/js/{main.26fdbfbe.js.LICENSE.txt → main.8e44da4b.js.LICENSE.txt} +7 -0
- xinference/web/ui/build/static/js/main.8e44da4b.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/29dda700ab913cf7f2cfabe450ddabfb283e96adfa3ec9d315b2fa6c63cd375c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2c63e940b945fd5817157e08a42b889b30d668ea4c91332f48ef2b1b9d26f520.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4135fe8745434cbce6438d1ebfa47422e0c77d884db4edc75c8bf32ea1d50621.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4de0a71074f9cbe1e7862750dcdd08cbc1bae7d9d9849a78b1783ca670017b3c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/53f6c0c0afb51265cd8fb940daeb65523501879ac2a8c03a1ead22b9793c5041.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8ccbb839002bc5bc03e0a0e7612362bf92f6ae64f87e094f8682d6a6fe4619bb.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/97ed30d6e22cf76f0733651e2c18364689a01665d0b5fe811c1b7ca3eb713c82.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9c0c70f1838913aaa792a0d2260f17f90fd177b95698ed46b7bc3050eb712c1c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9cfd33238ca43e5bf9fc7e442690e8cc6027c73553db36de87e3597ed524ee4b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/ada71518a429f821a9b1dea38bc951447f03c8db509887e0980b893acac938f3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b6c9558d28b5972bb8b2691c5a76a2c8814a815eb3443126da9f49f7d6a0c118.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bb0f721c084a4d85c09201c984f02ee8437d3b6c5c38a57cb4a101f653daef1b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/ddaec68b88e5eff792df1e39a4b4b8b737bfc832293c015660c3c69334e3cf5c.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +33 -0
- xinference/web/ui/node_modules/clipboard/.babelrc.json +11 -0
- xinference/web/ui/node_modules/clipboard/.eslintrc.json +24 -0
- xinference/web/ui/node_modules/clipboard/.prettierrc.json +9 -0
- xinference/web/ui/node_modules/clipboard/bower.json +18 -0
- xinference/web/ui/node_modules/clipboard/composer.json +25 -0
- xinference/web/ui/node_modules/clipboard/package.json +63 -0
- xinference/web/ui/node_modules/delegate/package.json +31 -0
- xinference/web/ui/node_modules/good-listener/bower.json +11 -0
- xinference/web/ui/node_modules/good-listener/package.json +35 -0
- xinference/web/ui/node_modules/select/bower.json +13 -0
- xinference/web/ui/node_modules/select/package.json +29 -0
- xinference/web/ui/node_modules/tiny-emitter/package.json +53 -0
- xinference/web/ui/package-lock.json +34 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/METADATA +14 -13
- {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/RECORD +81 -60
- xinference/client/oscar/__init__.py +0 -13
- xinference/client/oscar/actor_client.py +0 -611
- xinference/model/llm/pytorch/spec_decoding_utils.py +0 -531
- xinference/model/llm/pytorch/spec_model.py +0 -186
- xinference/web/ui/build/static/js/main.26fdbfbe.js +0 -3
- xinference/web/ui/build/static/js/main.26fdbfbe.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/63a4c48f0326d071c7772c46598215c006ae41fd3d4ff3577fe717de66ad6e89.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/de0299226173b0662b573f49e3992220f6611947073bd66ac079728a8bc8837d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e9b52d171223bb59fb918316297a051cdfd42dd453e8260fd918e90bc0a4ebdf.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f4d5d1a41892a754c1ee0237450d804b20612d1b657945b59e564161ea47aa7a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/fad4cd70de36ef6e6d5f8fd74a10ded58d964a8a91ef7681693fbb8376552da7.json +0 -1
- {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/LICENSE +0 -0
- {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/WHEEL +0 -0
- {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/top_level.txt +0 -0
|
@@ -1,531 +0,0 @@
|
|
|
1
|
-
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
import gc
|
|
15
|
-
import logging
|
|
16
|
-
import time
|
|
17
|
-
import uuid
|
|
18
|
-
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple
|
|
19
|
-
|
|
20
|
-
from ....device_utils import empty_cache
|
|
21
|
-
|
|
22
|
-
try:
|
|
23
|
-
import torch
|
|
24
|
-
from torch.nn import functional as F
|
|
25
|
-
except ImportError:
|
|
26
|
-
raise ImportError(
|
|
27
|
-
f"Failed to import module 'torch'. Please make sure 'torch' is installed.\n\n"
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
try:
|
|
31
|
-
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
32
|
-
from transformers.generation.logits_process import (
|
|
33
|
-
LogitsProcessorList,
|
|
34
|
-
TemperatureLogitsWarper,
|
|
35
|
-
TopKLogitsWarper,
|
|
36
|
-
TopPLogitsWarper,
|
|
37
|
-
)
|
|
38
|
-
except ImportError:
|
|
39
|
-
error_message = "Failed to import module 'transformers'"
|
|
40
|
-
installation_guide = [
|
|
41
|
-
"Please make sure 'transformers' is installed. ",
|
|
42
|
-
"You can install it by `pip install transformers`\n",
|
|
43
|
-
]
|
|
44
|
-
|
|
45
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
from ....types import CompletionChoice, CompletionChunk, CompletionUsage
|
|
49
|
-
|
|
50
|
-
logger = logging.getLogger(__name__)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def prepare_logits_processor(
|
|
54
|
-
temperature: float, top_p: float, top_k: int
|
|
55
|
-
) -> LogitsProcessorList:
|
|
56
|
-
processor_list = LogitsProcessorList()
|
|
57
|
-
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases.
|
|
58
|
-
if temperature >= 1e-5 and temperature != 1.0:
|
|
59
|
-
processor_list.append(TemperatureLogitsWarper(temperature))
|
|
60
|
-
if 1e-8 <= top_p < 1.0:
|
|
61
|
-
processor_list.append(TopPLogitsWarper(top_p))
|
|
62
|
-
if top_k > 0:
|
|
63
|
-
processor_list.append(TopKLogitsWarper(top_k))
|
|
64
|
-
return processor_list
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def get_context_length(config):
|
|
68
|
-
"""Get the context length of a model from a huggingface model config."""
|
|
69
|
-
if (
|
|
70
|
-
hasattr(config, "max_sequence_length")
|
|
71
|
-
and config.max_sequence_length is not None
|
|
72
|
-
):
|
|
73
|
-
return config.max_sequence_length
|
|
74
|
-
elif hasattr(config, "seq_length") and config.seq_length is not None:
|
|
75
|
-
return config.seq_length
|
|
76
|
-
elif (
|
|
77
|
-
hasattr(config, "max_position_embeddings")
|
|
78
|
-
and config.max_position_embeddings is not None
|
|
79
|
-
):
|
|
80
|
-
return config.max_position_embeddings
|
|
81
|
-
else:
|
|
82
|
-
return 2048
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
def normalize_logits(
|
|
86
|
-
logits_processor: LogitsProcessorList,
|
|
87
|
-
input_ids: List[int],
|
|
88
|
-
logits: torch.FloatTensor, # [1, n_seq, n_vocab]
|
|
89
|
-
) -> torch.Tensor:
|
|
90
|
-
"""
|
|
91
|
-
Parameters
|
|
92
|
-
----------
|
|
93
|
-
logits : torch.Tensor
|
|
94
|
-
Logits of shape `(n_batch, n_seq, n_vocab)`.
|
|
95
|
-
|
|
96
|
-
Returns
|
|
97
|
-
-------
|
|
98
|
-
torch.Tensor
|
|
99
|
-
Normalized logits of shape `(n_batch, n_seq, n_vocab)`.
|
|
100
|
-
"""
|
|
101
|
-
|
|
102
|
-
def _helper(
|
|
103
|
-
_input_ids: torch.LongTensor, _logits: torch.FloatTensor # [1, n_vocab]
|
|
104
|
-
) -> torch.Tensor:
|
|
105
|
-
if logits_processor:
|
|
106
|
-
last_token_logits = logits_processor(
|
|
107
|
-
_input_ids,
|
|
108
|
-
_logits,
|
|
109
|
-
)[0]
|
|
110
|
-
else:
|
|
111
|
-
return _logits[0]
|
|
112
|
-
|
|
113
|
-
return last_token_logits # [n_vocab,]
|
|
114
|
-
|
|
115
|
-
input_ids = torch.as_tensor([input_ids], device=logits.device).long()
|
|
116
|
-
for i in range(logits.shape[1]):
|
|
117
|
-
normalized = _helper(
|
|
118
|
-
input_ids[
|
|
119
|
-
: -logits.shape[1] + i
|
|
120
|
-
], # input_ids may not equal logits.shape[1]
|
|
121
|
-
logits[:, i, :],
|
|
122
|
-
)
|
|
123
|
-
logits[:, i, :] = normalized.clone()
|
|
124
|
-
return F.softmax(logits, dim=-1)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def sample(
|
|
128
|
-
last_token_logits: torch.FloatTensor, temperature: float, top_p: float
|
|
129
|
-
) -> int:
|
|
130
|
-
"""
|
|
131
|
-
Parameters
|
|
132
|
-
----------
|
|
133
|
-
last_token_logits : torch.FloatTensor
|
|
134
|
-
Last token logits of shape [n_vocab,]
|
|
135
|
-
|
|
136
|
-
Returns
|
|
137
|
-
-------
|
|
138
|
-
int
|
|
139
|
-
Token ID.
|
|
140
|
-
"""
|
|
141
|
-
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
|
142
|
-
_, indices = torch.topk(last_token_logits, 2)
|
|
143
|
-
tokens = [int(index) for index in indices.tolist()]
|
|
144
|
-
else:
|
|
145
|
-
indices = torch.multinomial(last_token_logits, num_samples=2)
|
|
146
|
-
tokens = [int(token) for token in indices.tolist()]
|
|
147
|
-
return tokens[0]
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
def rollback_kv_cache(
|
|
151
|
-
kv_cache: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], n: int
|
|
152
|
-
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
|
153
|
-
ret = []
|
|
154
|
-
for k_cache, v_cache in kv_cache:
|
|
155
|
-
k_cache = k_cache[:, :, :-n, :] # [1, n_head, n_seq - n, n_dim]
|
|
156
|
-
v_cache = v_cache[:, :, :-n, :]
|
|
157
|
-
|
|
158
|
-
assert isinstance(k_cache, torch.Tensor)
|
|
159
|
-
assert isinstance(v_cache, torch.Tensor)
|
|
160
|
-
ret.append((k_cache, v_cache))
|
|
161
|
-
|
|
162
|
-
return tuple(ret)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
def rollback_logits(logits: torch.Tensor, n: int):
|
|
166
|
-
return logits[:, :-n, :] # [1, n_seq, n_vocab]
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
def is_partial_stop(output: str, stop_str: str):
|
|
170
|
-
"""Check whether the output contains a partial stop str."""
|
|
171
|
-
for i in range(0, min(len(output), len(stop_str))):
|
|
172
|
-
if stop_str.startswith(output[-i:]):
|
|
173
|
-
return True
|
|
174
|
-
return False
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
def draft(
|
|
178
|
-
input_ids: List[int],
|
|
179
|
-
kv_cache: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]],
|
|
180
|
-
logits: Optional[torch.FloatTensor],
|
|
181
|
-
draft_model: "PreTrainedModel",
|
|
182
|
-
gamma: int,
|
|
183
|
-
logits_processor: LogitsProcessorList,
|
|
184
|
-
temperature: float,
|
|
185
|
-
top_p: float,
|
|
186
|
-
):
|
|
187
|
-
"""
|
|
188
|
-
Parameters
|
|
189
|
-
----------
|
|
190
|
-
input_ids : List[int]
|
|
191
|
-
On the prefill stage, `input_ids` are the prompt tokens.
|
|
192
|
-
|
|
193
|
-
On the decode stage. It includes the prompt tokens, the token generated by the original model
|
|
194
|
-
at the end of each full iteration, or the token generated by the draft model draft
|
|
195
|
-
iteration.
|
|
196
|
-
|
|
197
|
-
Returns
|
|
198
|
-
-------
|
|
199
|
-
int
|
|
200
|
-
The number of generated draft tokens.
|
|
201
|
-
List[int]
|
|
202
|
-
Outputs, including the draft tokens.
|
|
203
|
-
Tuple[Tuple[torch.Tensor, torch.Tensor], ...]
|
|
204
|
-
KV cache.
|
|
205
|
-
torch.FloatTensor
|
|
206
|
-
Logits.
|
|
207
|
-
"""
|
|
208
|
-
draft_output_ids = input_ids.copy()
|
|
209
|
-
|
|
210
|
-
if kv_cache is not None:
|
|
211
|
-
input_ids = draft_output_ids[-2:]
|
|
212
|
-
|
|
213
|
-
num_draft_tokens = 0
|
|
214
|
-
while num_draft_tokens < gamma:
|
|
215
|
-
if kv_cache is None:
|
|
216
|
-
# prefill.
|
|
217
|
-
draft_model_out = draft_model(
|
|
218
|
-
torch.as_tensor([input_ids], device=draft_model.device),
|
|
219
|
-
use_cache=True,
|
|
220
|
-
)
|
|
221
|
-
logits = normalize_logits(
|
|
222
|
-
logits_processor, input_ids, draft_model_out.logits
|
|
223
|
-
)
|
|
224
|
-
else:
|
|
225
|
-
draft_model_out = draft_model(
|
|
226
|
-
torch.as_tensor([input_ids], device=draft_model.device),
|
|
227
|
-
use_cache=True,
|
|
228
|
-
past_key_values=kv_cache,
|
|
229
|
-
)
|
|
230
|
-
normalized = normalize_logits(
|
|
231
|
-
logits_processor, draft_output_ids, draft_model_out.logits
|
|
232
|
-
)
|
|
233
|
-
assert logits is not None
|
|
234
|
-
logits = torch.cat((logits, normalized), dim=1)
|
|
235
|
-
kv_cache = draft_model_out.past_key_values
|
|
236
|
-
draft_token = sample(
|
|
237
|
-
logits[0, -1, :],
|
|
238
|
-
temperature,
|
|
239
|
-
top_p,
|
|
240
|
-
)
|
|
241
|
-
draft_output_ids.append(draft_token)
|
|
242
|
-
input_ids = [draft_token]
|
|
243
|
-
num_draft_tokens += 1
|
|
244
|
-
|
|
245
|
-
assert kv_cache is not None
|
|
246
|
-
return num_draft_tokens, draft_output_ids, kv_cache, logits
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
@torch.inference_mode()
|
|
250
|
-
def speculative_generate_stream(
|
|
251
|
-
model_uid: str,
|
|
252
|
-
draft_model: "PreTrainedModel",
|
|
253
|
-
model: "PreTrainedModel",
|
|
254
|
-
tokenizer: "PreTrainedTokenizer",
|
|
255
|
-
prompt: str,
|
|
256
|
-
generate_config: Dict[str, Any],
|
|
257
|
-
) -> Iterator[Tuple[CompletionChunk, CompletionUsage]]:
|
|
258
|
-
logger.debug(
|
|
259
|
-
f"Enter speculative_generate_stream, prompt: {prompt}, generate_config: {generate_config}"
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
# TODO: currently, repetition penalty leads to garbled outputs.
|
|
263
|
-
if float(generate_config.get("repetition_penalty", 1.0)) != 1.0:
|
|
264
|
-
raise ValueError(
|
|
265
|
-
"repetition penalty is not supported by speculative decoding yet"
|
|
266
|
-
)
|
|
267
|
-
|
|
268
|
-
gamma = generate_config.get("gamma", 4)
|
|
269
|
-
stream = generate_config.get("stream", False)
|
|
270
|
-
temperature = float(generate_config.get("temperature", 1.0))
|
|
271
|
-
top_p = float(generate_config.get("top_p", 1.0))
|
|
272
|
-
top_k = int(generate_config.get("top_k", -1)) # -1 means disable
|
|
273
|
-
max_new_tokens = int(generate_config.get("max_tokens", 256))
|
|
274
|
-
echo = bool(generate_config.get("echo", False))
|
|
275
|
-
stop_str = generate_config.get("stop", None)
|
|
276
|
-
stop_token_ids = generate_config.get("stop_token_ids", None) or []
|
|
277
|
-
stop_token_ids.append(tokenizer.eos_token_id)
|
|
278
|
-
|
|
279
|
-
logits_processor = prepare_logits_processor(temperature, top_p, top_k)
|
|
280
|
-
request_id = str(uuid.uuid1())
|
|
281
|
-
|
|
282
|
-
if "qwen" in str(type(model)).lower():
|
|
283
|
-
# TODO: hacky.
|
|
284
|
-
input_ids = tokenizer(prompt, allowed_special="all").input_ids
|
|
285
|
-
else:
|
|
286
|
-
input_ids = tokenizer(prompt).input_ids
|
|
287
|
-
|
|
288
|
-
num_prompt_tokens = len(input_ids)
|
|
289
|
-
output_ids = list(input_ids)
|
|
290
|
-
|
|
291
|
-
# internal states.
|
|
292
|
-
draft_kv_cache = None
|
|
293
|
-
draft_logits = None
|
|
294
|
-
kv_cache = None
|
|
295
|
-
logits = None
|
|
296
|
-
next_token = (
|
|
297
|
-
None # the token generated by the original model at each full iteration.
|
|
298
|
-
)
|
|
299
|
-
last_output_length = 0
|
|
300
|
-
finish_reason = "stop"
|
|
301
|
-
|
|
302
|
-
# performance stats.
|
|
303
|
-
total_seconds_on_drafting = 0.0
|
|
304
|
-
total_seconds_on_eval = 0.0
|
|
305
|
-
total_seconds_on_accepting = 0.0
|
|
306
|
-
total_num_draft_tokens = 0
|
|
307
|
-
total_num_accepted_tokens = 0
|
|
308
|
-
|
|
309
|
-
while len(output_ids) < max_new_tokens + num_prompt_tokens:
|
|
310
|
-
# allow the draft model to generate more than max_tokens since some of the generated
|
|
311
|
-
# tokens could be rejected.
|
|
312
|
-
start = time.time()
|
|
313
|
-
num_draft_tokens, output_ids, draft_kv_cache, draft_logits = draft(
|
|
314
|
-
input_ids=output_ids,
|
|
315
|
-
kv_cache=draft_kv_cache,
|
|
316
|
-
logits=draft_logits,
|
|
317
|
-
draft_model=draft_model,
|
|
318
|
-
gamma=gamma,
|
|
319
|
-
logits_processor=logits_processor,
|
|
320
|
-
temperature=temperature
|
|
321
|
-
* 0.5, # make the draft model outputs less random for better quality.
|
|
322
|
-
top_p=top_p,
|
|
323
|
-
)
|
|
324
|
-
total_seconds_on_drafting += time.time() - start
|
|
325
|
-
total_num_draft_tokens += num_draft_tokens
|
|
326
|
-
|
|
327
|
-
# eval stage.
|
|
328
|
-
start = time.time()
|
|
329
|
-
if kv_cache is None:
|
|
330
|
-
# prefill.
|
|
331
|
-
out = model(
|
|
332
|
-
torch.as_tensor([output_ids], device=model.device), use_cache=True
|
|
333
|
-
)
|
|
334
|
-
logits = normalize_logits(logits_processor, output_ids, out.logits)
|
|
335
|
-
else:
|
|
336
|
-
out = model(
|
|
337
|
-
torch.as_tensor(
|
|
338
|
-
[[next_token] + output_ids[-num_draft_tokens:]], device=model.device
|
|
339
|
-
),
|
|
340
|
-
use_cache=True,
|
|
341
|
-
past_key_values=kv_cache,
|
|
342
|
-
)
|
|
343
|
-
normalized = normalize_logits(logits_processor, output_ids, out.logits)
|
|
344
|
-
logits = torch.cat((logits, normalized), dim=1)
|
|
345
|
-
kv_cache = out.past_key_values
|
|
346
|
-
total_seconds_on_eval += time.time() - start
|
|
347
|
-
|
|
348
|
-
# accepting stage.
|
|
349
|
-
start = time.time()
|
|
350
|
-
assert draft_logits is not None
|
|
351
|
-
assert draft_kv_cache is not None
|
|
352
|
-
accepted = 0
|
|
353
|
-
stopped = False
|
|
354
|
-
for draft_token_idx in range(-num_draft_tokens, 0):
|
|
355
|
-
r = torch.rand(1, device=logits.device)
|
|
356
|
-
draft_token = output_ids[draft_token_idx]
|
|
357
|
-
token_logits = logits[:, draft_token_idx - 1, :] # [1, n_vocab,]
|
|
358
|
-
draft_token_logits = draft_logits[:, draft_token_idx, :].to(
|
|
359
|
-
logits.device
|
|
360
|
-
) # [1, n_vocab,]
|
|
361
|
-
if token_logits[0, draft_token] / draft_token_logits[0, draft_token] > r:
|
|
362
|
-
accepted += 1
|
|
363
|
-
total_num_accepted_tokens += 1
|
|
364
|
-
if draft_token in stop_token_ids:
|
|
365
|
-
stopped = True
|
|
366
|
-
else:
|
|
367
|
-
if logger.getEffectiveLevel() <= logging.DEBUG:
|
|
368
|
-
logger.debug(
|
|
369
|
-
f"Accepted ({accepted}/{num_draft_tokens}): '{tokenizer.decode(output_ids[-num_draft_tokens: draft_token_idx])}'"
|
|
370
|
-
)
|
|
371
|
-
logger.debug(
|
|
372
|
-
f"Rejected: '{tokenizer.decode(output_ids[draft_token_idx:])}'"
|
|
373
|
-
)
|
|
374
|
-
# rollback.
|
|
375
|
-
output_ids = output_ids[:draft_token_idx]
|
|
376
|
-
draft_kv_cache = rollback_kv_cache(
|
|
377
|
-
draft_kv_cache, num_draft_tokens - accepted
|
|
378
|
-
)
|
|
379
|
-
kv_cache = rollback_kv_cache(kv_cache, num_draft_tokens - accepted)
|
|
380
|
-
draft_logits = rollback_logits(
|
|
381
|
-
draft_logits, num_draft_tokens - accepted
|
|
382
|
-
)
|
|
383
|
-
logits = rollback_logits(logits, num_draft_tokens - accepted)
|
|
384
|
-
|
|
385
|
-
# sample the next token according to the modified distribution of shape [1, n_vocab]
|
|
386
|
-
modified_dist = token_logits - draft_token_logits
|
|
387
|
-
modified_dist = torch.where(
|
|
388
|
-
modified_dist > 0, modified_dist, torch.zeros_like(modified_dist)
|
|
389
|
-
)
|
|
390
|
-
normalized = normalize_logits(
|
|
391
|
-
logits_processor,
|
|
392
|
-
output_ids,
|
|
393
|
-
modified_dist.unsqueeze(1), # [1, 1, n_vocab]
|
|
394
|
-
)
|
|
395
|
-
next_token = sample(
|
|
396
|
-
normalized[0, -1, :],
|
|
397
|
-
0, # must be 0, since the dist is quiet unified, higher temperature results in garbled text
|
|
398
|
-
top_p,
|
|
399
|
-
)
|
|
400
|
-
output_ids.append(next_token)
|
|
401
|
-
if logger.getEffectiveLevel() <= logging.DEBUG:
|
|
402
|
-
logger.debug(f"Generated: '{tokenizer.decode([next_token])}'")
|
|
403
|
-
if next_token in stop_token_ids:
|
|
404
|
-
stopped = True
|
|
405
|
-
break
|
|
406
|
-
|
|
407
|
-
if accepted == num_draft_tokens:
|
|
408
|
-
if logger.getEffectiveLevel() <= logging.DEBUG:
|
|
409
|
-
logger.debug(
|
|
410
|
-
f"Accepted ({accepted}/{num_draft_tokens}): '{tokenizer.decode(output_ids[-num_draft_tokens:])}'"
|
|
411
|
-
)
|
|
412
|
-
next_token = sample(
|
|
413
|
-
logits[0, -1, :],
|
|
414
|
-
temperature,
|
|
415
|
-
top_p,
|
|
416
|
-
)
|
|
417
|
-
output_ids.append(next_token)
|
|
418
|
-
if logger.getEffectiveLevel() <= logging.DEBUG:
|
|
419
|
-
logger.debug(f"Generated: '{tokenizer.decode([next_token])}'")
|
|
420
|
-
if next_token in stop_token_ids:
|
|
421
|
-
stopped = True
|
|
422
|
-
|
|
423
|
-
total_seconds_on_accepting += time.time() - start
|
|
424
|
-
|
|
425
|
-
if (
|
|
426
|
-
accepted > 0 # more than 2 tokens has been generated, flush.
|
|
427
|
-
or len(output_ids) >= max_new_tokens
|
|
428
|
-
or stopped
|
|
429
|
-
):
|
|
430
|
-
output = tokenizer.decode(
|
|
431
|
-
output_ids if echo else output_ids[num_prompt_tokens:],
|
|
432
|
-
spaces_between_special_tokens=False,
|
|
433
|
-
clean_up_tokenization_spaces=True,
|
|
434
|
-
)
|
|
435
|
-
rfind_start = len(prompt) if echo else 0
|
|
436
|
-
|
|
437
|
-
partially_stopped = False
|
|
438
|
-
if stop_str:
|
|
439
|
-
if isinstance(stop_str, str):
|
|
440
|
-
pos = output.rfind(stop_str, rfind_start)
|
|
441
|
-
if pos != -1:
|
|
442
|
-
output = output[:pos]
|
|
443
|
-
stopped = True
|
|
444
|
-
else:
|
|
445
|
-
partially_stopped = is_partial_stop(output, stop_str)
|
|
446
|
-
elif isinstance(stop_str, Iterable):
|
|
447
|
-
for each_stop in stop_str:
|
|
448
|
-
pos = output.rfind(each_stop, rfind_start)
|
|
449
|
-
if pos != -1:
|
|
450
|
-
output = output[:pos]
|
|
451
|
-
stopped = True
|
|
452
|
-
break
|
|
453
|
-
else:
|
|
454
|
-
partially_stopped = is_partial_stop(output, each_stop)
|
|
455
|
-
if partially_stopped:
|
|
456
|
-
break
|
|
457
|
-
else:
|
|
458
|
-
raise ValueError(f"Invalid stop field type {type(stop_str)}")
|
|
459
|
-
|
|
460
|
-
if stream:
|
|
461
|
-
# return the delta.
|
|
462
|
-
output_length = len(output)
|
|
463
|
-
output = output[last_output_length:]
|
|
464
|
-
last_output_length = output_length
|
|
465
|
-
|
|
466
|
-
# prevent yielding partial stop sequence.
|
|
467
|
-
if not partially_stopped:
|
|
468
|
-
completion_choice = CompletionChoice(
|
|
469
|
-
text=output, index=0, logprobs=None, finish_reason=None
|
|
470
|
-
)
|
|
471
|
-
completion_chunk = CompletionChunk(
|
|
472
|
-
id=request_id,
|
|
473
|
-
object="text_completion",
|
|
474
|
-
created=int(time.time()),
|
|
475
|
-
model=model_uid,
|
|
476
|
-
choices=[completion_choice],
|
|
477
|
-
)
|
|
478
|
-
completion_usage = CompletionUsage(
|
|
479
|
-
prompt_tokens=num_prompt_tokens,
|
|
480
|
-
completion_tokens=len(output_ids) - num_prompt_tokens,
|
|
481
|
-
total_tokens=len(output_ids),
|
|
482
|
-
)
|
|
483
|
-
|
|
484
|
-
yield completion_chunk, completion_usage
|
|
485
|
-
if stopped:
|
|
486
|
-
break
|
|
487
|
-
else:
|
|
488
|
-
finish_reason = "length"
|
|
489
|
-
|
|
490
|
-
logger.info(
|
|
491
|
-
f"In total, {total_num_accepted_tokens}/{total_num_draft_tokens} draft tokens are "
|
|
492
|
-
f"accepted, acceptance rate: {total_num_accepted_tokens / total_num_draft_tokens:.2f}"
|
|
493
|
-
)
|
|
494
|
-
total_seconds = (
|
|
495
|
-
total_seconds_on_drafting + total_seconds_on_eval + total_seconds_on_accepting
|
|
496
|
-
)
|
|
497
|
-
logger.info(
|
|
498
|
-
f"In total, {total_seconds_on_drafting:.2f}s, {total_seconds_on_eval:.2f}s and "
|
|
499
|
-
f"{total_seconds_on_accepting:.2f}s are spent on drafting, eval, and accepting "
|
|
500
|
-
f"respectively. Average generation speed: {(len(output_ids) - num_prompt_tokens) / total_seconds:.2f} tokens/s."
|
|
501
|
-
)
|
|
502
|
-
|
|
503
|
-
if stream:
|
|
504
|
-
completion_choice = CompletionChoice(
|
|
505
|
-
text="", index=0, logprobs=None, finish_reason=finish_reason
|
|
506
|
-
)
|
|
507
|
-
else:
|
|
508
|
-
completion_choice = CompletionChoice(
|
|
509
|
-
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
510
|
-
)
|
|
511
|
-
|
|
512
|
-
completion_chunk = CompletionChunk(
|
|
513
|
-
id=request_id,
|
|
514
|
-
object="text_completion",
|
|
515
|
-
created=int(time.time()),
|
|
516
|
-
model=model_uid,
|
|
517
|
-
choices=[completion_choice],
|
|
518
|
-
)
|
|
519
|
-
completion_usage = CompletionUsage(
|
|
520
|
-
prompt_tokens=num_prompt_tokens,
|
|
521
|
-
completion_tokens=len(output_ids) - num_prompt_tokens,
|
|
522
|
-
total_tokens=len(output_ids),
|
|
523
|
-
)
|
|
524
|
-
|
|
525
|
-
yield completion_chunk, completion_usage
|
|
526
|
-
|
|
527
|
-
# clean up.
|
|
528
|
-
del kv_cache
|
|
529
|
-
del draft_kv_cache
|
|
530
|
-
gc.collect()
|
|
531
|
-
empty_cache()
|