sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +17 -8
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +5 -17
- sglang/lang/backend/runtime_endpoint.py +5 -2
- sglang/lang/interpreter.py +1 -4
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +33 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +1 -3
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/fused_moe/layer.py +27 -7
- sglang/srt/layers/layernorm.py +12 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +38 -122
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +259 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +105 -71
- sglang/srt/managers/tokenizer_manager.py +17 -8
- sglang/srt/managers/tp_worker.py +188 -121
- sglang/srt/model_executor/cuda_graph_runner.py +69 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +123 -154
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +1 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/exaone.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama.py +51 -5
- sglang/srt/models/llama_classification.py +1 -20
- sglang/srt/models/llava.py +30 -5
- sglang/srt/models/llavavid.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +669 -0
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/olmoe.py +415 -0
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +46 -80
- sglang/srt/server.py +30 -15
- sglang/srt/server_args.py +163 -28
- sglang/srt/utils.py +19 -51
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +7 -5
- sglang/test/test_utils.py +85 -2
- sglang/utils.py +32 -37
- sglang/version.py +1 -1
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
- sglang-0.3.1.post1.dist-info/RECORD +130 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
- sglang-0.3.0.dist-info/RECORD +0 -118
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,235 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
|
6
|
+
@triton.jit
|
7
|
+
def create_flashinfer_kv_indices_triton(
|
8
|
+
req_to_token_ptr, # [max_batch, max_context_len]
|
9
|
+
req_pool_indices_ptr,
|
10
|
+
page_kernel_lens_ptr,
|
11
|
+
kv_indptr,
|
12
|
+
kv_start_idx,
|
13
|
+
kv_indices_ptr,
|
14
|
+
max_context_len: tl.constexpr,
|
15
|
+
):
|
16
|
+
BLOCK_SIZE: tl.constexpr = 512
|
17
|
+
pid = tl.program_id(axis=0)
|
18
|
+
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
19
|
+
kv_indices_offset = tl.load(kv_indptr + pid)
|
20
|
+
|
21
|
+
kv_start = 0
|
22
|
+
kv_end = 0
|
23
|
+
if kv_start_idx:
|
24
|
+
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
25
|
+
kv_end = kv_start
|
26
|
+
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
27
|
+
|
28
|
+
req_to_token_ptr += req_pool_index * max_context_len
|
29
|
+
kv_indices_ptr += kv_indices_offset
|
30
|
+
|
31
|
+
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
|
32
|
+
st_offset = tl.arange(0, BLOCK_SIZE)
|
33
|
+
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
34
|
+
for _ in range(num_loop):
|
35
|
+
mask = ld_offset < kv_end
|
36
|
+
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
|
37
|
+
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
|
38
|
+
ld_offset += BLOCK_SIZE
|
39
|
+
st_offset += BLOCK_SIZE
|
40
|
+
|
41
|
+
|
42
|
+
class FlashinferUpdater:
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
forward_mode,
|
46
|
+
model_runner,
|
47
|
+
req_pool_indices,
|
48
|
+
seq_lens,
|
49
|
+
prefix_lens,
|
50
|
+
decode_wrapper=None,
|
51
|
+
use_ragged=False,
|
52
|
+
):
|
53
|
+
self.forward_mode = forward_mode
|
54
|
+
self.model_runner = model_runner
|
55
|
+
self.req_pool_indices = req_pool_indices
|
56
|
+
self.seq_lens = seq_lens
|
57
|
+
self.prefix_lens = prefix_lens
|
58
|
+
self.use_ragged = use_ragged
|
59
|
+
|
60
|
+
self.num_qo_heads = (
|
61
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
62
|
+
)
|
63
|
+
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
64
|
+
model_runner.tp_size
|
65
|
+
)
|
66
|
+
self.head_dim = model_runner.model_config.head_dim
|
67
|
+
self.batch_size = len(req_pool_indices)
|
68
|
+
|
69
|
+
self.decode_wrapper = (
|
70
|
+
decode_wrapper or self.model_runner.attn_backend.decode_wrapper
|
71
|
+
)
|
72
|
+
self.prefill_wrapper_ragged = (
|
73
|
+
self.model_runner.attn_backend.prefill_wrapper_ragged
|
74
|
+
)
|
75
|
+
self.prefill_wrapper_paged = (
|
76
|
+
self.model_runner.attn_backend.prefill_wrapper_paged
|
77
|
+
)
|
78
|
+
|
79
|
+
self.kv_last_page_len = torch.ones(
|
80
|
+
(self.batch_size,), dtype=torch.int32, device="cuda"
|
81
|
+
)
|
82
|
+
|
83
|
+
def _init_indices_no_sliding_window(self):
|
84
|
+
if self.use_ragged:
|
85
|
+
paged_kernel_lens = self.prefix_lens
|
86
|
+
else:
|
87
|
+
paged_kernel_lens = self.seq_lens
|
88
|
+
|
89
|
+
self.kv_indptr = torch.zeros(
|
90
|
+
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
91
|
+
)
|
92
|
+
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
93
|
+
self.kv_indices = torch.empty(
|
94
|
+
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
|
95
|
+
)
|
96
|
+
|
97
|
+
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
98
|
+
self.model_runner.req_to_token_pool.req_to_token,
|
99
|
+
self.req_pool_indices,
|
100
|
+
paged_kernel_lens,
|
101
|
+
self.kv_indptr,
|
102
|
+
None,
|
103
|
+
self.kv_indices,
|
104
|
+
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
105
|
+
)
|
106
|
+
|
107
|
+
def _init_indices_sliding_window(self, wrapper_id):
|
108
|
+
if wrapper_id == 0:
|
109
|
+
# window attention use paged only
|
110
|
+
if self.forward_mode.is_decode():
|
111
|
+
paged_kernel_lens = torch.minimum(
|
112
|
+
self.seq_lens,
|
113
|
+
torch.tensor(self.model_runner.sliding_window_size + 1),
|
114
|
+
)
|
115
|
+
else:
|
116
|
+
paged_kernel_lens = torch.minimum(
|
117
|
+
self.seq_lens,
|
118
|
+
torch.tensor(self.model_runner.sliding_window_size)
|
119
|
+
+ self.seq_lens
|
120
|
+
- self.prefix_lens,
|
121
|
+
)
|
122
|
+
else:
|
123
|
+
# full attention
|
124
|
+
paged_kernel_lens = self.seq_lens
|
125
|
+
|
126
|
+
kv_start_idx = self.seq_lens - paged_kernel_lens
|
127
|
+
self.kv_indptr = torch.zeros(
|
128
|
+
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
129
|
+
)
|
130
|
+
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
131
|
+
self.kv_indices = torch.empty(
|
132
|
+
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
|
133
|
+
)
|
134
|
+
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
135
|
+
self.model_runner.req_to_token_pool.req_to_token,
|
136
|
+
self.req_pool_indices,
|
137
|
+
paged_kernel_lens,
|
138
|
+
self.kv_indptr,
|
139
|
+
kv_start_idx,
|
140
|
+
self.kv_indices,
|
141
|
+
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
142
|
+
)
|
143
|
+
|
144
|
+
def _update_decode_indices(self, decode_wrapper):
|
145
|
+
decode_wrapper.end_forward()
|
146
|
+
decode_wrapper.begin_forward(
|
147
|
+
self.kv_indptr,
|
148
|
+
self.kv_indices,
|
149
|
+
self.kv_last_page_len,
|
150
|
+
self.num_qo_heads,
|
151
|
+
self.num_kv_heads,
|
152
|
+
self.head_dim,
|
153
|
+
1,
|
154
|
+
data_type=self.model_runner.kv_cache_dtype,
|
155
|
+
q_data_type=self.model_runner.dtype,
|
156
|
+
)
|
157
|
+
|
158
|
+
def _update_extend_indices(self, ragged_wrapper, paged_wrapper):
|
159
|
+
# extend part
|
160
|
+
qo_indptr = torch.zeros(
|
161
|
+
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
162
|
+
)
|
163
|
+
qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0)
|
164
|
+
|
165
|
+
if self.use_ragged:
|
166
|
+
ragged_wrapper.end_forward()
|
167
|
+
ragged_wrapper.begin_forward(
|
168
|
+
qo_indptr,
|
169
|
+
qo_indptr,
|
170
|
+
self.num_qo_heads,
|
171
|
+
self.num_kv_heads,
|
172
|
+
self.head_dim,
|
173
|
+
)
|
174
|
+
|
175
|
+
# cached part
|
176
|
+
paged_wrapper.end_forward()
|
177
|
+
paged_wrapper.begin_forward(
|
178
|
+
qo_indptr,
|
179
|
+
self.kv_indptr,
|
180
|
+
self.kv_indices,
|
181
|
+
self.kv_last_page_len,
|
182
|
+
self.num_qo_heads,
|
183
|
+
self.num_kv_heads,
|
184
|
+
self.head_dim,
|
185
|
+
1,
|
186
|
+
)
|
187
|
+
|
188
|
+
def update_indices_no_sliding_window(self):
|
189
|
+
self._init_indices_no_sliding_window()
|
190
|
+
|
191
|
+
if self.forward_mode.is_decode():
|
192
|
+
self._update_decode_indices(self.decode_wrapper)
|
193
|
+
else:
|
194
|
+
self._update_extend_indices(
|
195
|
+
self.prefill_wrapper_ragged,
|
196
|
+
self.prefill_wrapper_paged,
|
197
|
+
)
|
198
|
+
|
199
|
+
def update_indices_sliding_window(self):
|
200
|
+
assert self.use_ragged is False
|
201
|
+
|
202
|
+
for wrapper_id in range(2):
|
203
|
+
self._init_indices_sliding_window(wrapper_id)
|
204
|
+
if self.forward_mode.is_decode():
|
205
|
+
self._update_decode_indices(self.decode_wrapper[wrapper_id])
|
206
|
+
else:
|
207
|
+
self._update_extend_indices(
|
208
|
+
None,
|
209
|
+
self.prefill_wrapper_paged[wrapper_id],
|
210
|
+
)
|
211
|
+
|
212
|
+
|
213
|
+
def update_flashinfer_indices(
|
214
|
+
forward_mode,
|
215
|
+
model_runner,
|
216
|
+
req_pool_indices,
|
217
|
+
seq_lens,
|
218
|
+
prefix_lens,
|
219
|
+
decode_wrapper=None,
|
220
|
+
use_ragged=False,
|
221
|
+
):
|
222
|
+
updater = FlashinferUpdater(
|
223
|
+
forward_mode,
|
224
|
+
model_runner,
|
225
|
+
req_pool_indices,
|
226
|
+
seq_lens,
|
227
|
+
prefix_lens,
|
228
|
+
decode_wrapper,
|
229
|
+
use_ragged,
|
230
|
+
)
|
231
|
+
|
232
|
+
if model_runner.sliding_window_size is None:
|
233
|
+
updater.update_indices_no_sliding_window()
|
234
|
+
else:
|
235
|
+
updater.update_indices_sliding_window()
|
@@ -18,6 +18,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|
18
18
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
19
19
|
from vllm.model_executor.utils import set_weight_attrs
|
20
20
|
|
21
|
+
from sglang.srt.utils import is_hip
|
22
|
+
|
21
23
|
logger = init_logger(__name__)
|
22
24
|
|
23
25
|
|
@@ -381,6 +383,7 @@ from torch.nn import Module
|
|
381
383
|
from vllm import _custom_ops as ops
|
382
384
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
383
385
|
all_close_1d,
|
386
|
+
normalize_e4m3fn_to_e4m3fnuz,
|
384
387
|
per_tensor_dequantize,
|
385
388
|
)
|
386
389
|
from vllm.utils import print_warning_once
|
@@ -479,14 +482,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
479
482
|
|
480
483
|
def process_weights_after_loading(self, layer: Module) -> None:
|
481
484
|
|
482
|
-
# If checkpoint is fp16, quantize in place.
|
485
|
+
# If checkpoint is fp16 or bfloat16, quantize in place.
|
483
486
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
484
|
-
|
485
|
-
|
486
|
-
)
|
487
|
-
w2_weight = torch.empty_like(
|
488
|
-
layer.w2_weight.data, dtype=torch.float8_e4m3fn
|
489
|
-
)
|
487
|
+
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
488
|
+
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
489
|
+
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
490
|
+
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
490
491
|
|
491
492
|
# Re-initialize w13_scale because we directly quantize
|
492
493
|
# merged w13 weights and generate a single scaling factor.
|
@@ -534,6 +535,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
534
535
|
layer.a2_scale.max(), requires_grad=False
|
535
536
|
)
|
536
537
|
|
538
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
539
|
+
if is_hip():
|
540
|
+
# Normalize the weights and scales
|
541
|
+
w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz(
|
542
|
+
layer.w13_weight, layer.w13_scale, layer.a13_scale
|
543
|
+
)
|
544
|
+
w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz(
|
545
|
+
layer.w2_weight, layer.w2_scale, layer.a2_scale
|
546
|
+
)
|
547
|
+
# Reset the parameters
|
548
|
+
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
549
|
+
layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False)
|
550
|
+
if a13_scale is not None:
|
551
|
+
layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False)
|
552
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
553
|
+
layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False)
|
554
|
+
if a2_scale is not None:
|
555
|
+
layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False)
|
556
|
+
|
537
557
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
538
558
|
# We take the max then dequant and requant each expert.
|
539
559
|
assert layer.w13_scale is not None
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -15,6 +15,7 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""Fused operators for normalization layers."""
|
17
17
|
|
18
|
+
import logging
|
18
19
|
from typing import Optional, Tuple, Union
|
19
20
|
|
20
21
|
import torch
|
@@ -27,6 +28,10 @@ from flashinfer.norm import (
|
|
27
28
|
)
|
28
29
|
from vllm.model_executor.custom_op import CustomOp
|
29
30
|
|
31
|
+
from sglang.srt.utils import is_hip
|
32
|
+
|
33
|
+
logger = logging.getLogger(__name__)
|
34
|
+
|
30
35
|
|
31
36
|
class RMSNorm(CustomOp):
|
32
37
|
def __init__(
|
@@ -109,3 +114,10 @@ class GemmaRMSNorm(CustomOp):
|
|
109
114
|
return x, residual
|
110
115
|
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
111
116
|
return out
|
117
|
+
|
118
|
+
|
119
|
+
if is_hip():
|
120
|
+
logger.info(
|
121
|
+
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
|
122
|
+
)
|
123
|
+
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
@@ -37,7 +37,7 @@ class LogitsProcessorOutput:
|
|
37
37
|
|
38
38
|
# The normlaized logprobs of prompts. shape: [#seq]
|
39
39
|
normalized_prompt_logprobs: torch.Tensor
|
40
|
-
# The logprobs of input tokens.
|
40
|
+
# The logprobs of input tokens. shape: [#token, vocab_size]
|
41
41
|
input_token_logprobs: torch.Tensor
|
42
42
|
|
43
43
|
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
@@ -49,25 +49,39 @@ class LogitsProcessorOutput:
|
|
49
49
|
@dataclasses.dataclass
|
50
50
|
class LogitsMetadata:
|
51
51
|
forward_mode: ForwardMode
|
52
|
+
top_logprobs_nums: Optional[List[int]]
|
53
|
+
|
52
54
|
return_logprob: bool = False
|
55
|
+
return_top_logprob: bool = False
|
53
56
|
|
54
57
|
extend_seq_lens: Optional[torch.Tensor] = None
|
55
|
-
|
56
|
-
top_logprobs_nums: Optional[List[int]] = None
|
58
|
+
extend_seq_lens_cpu: Optional[List[int]] = None
|
57
59
|
|
58
|
-
|
59
|
-
|
60
|
+
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
61
|
+
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
60
62
|
|
61
63
|
@classmethod
|
62
64
|
def from_input_metadata(cls, input_metadata: InputMetadata):
|
65
|
+
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
66
|
+
if input_metadata.forward_mode.is_extend():
|
67
|
+
extend_logprob_pruned_lens_cpu = [
|
68
|
+
extend_len - start_len
|
69
|
+
for extend_len, start_len in zip(
|
70
|
+
input_metadata.extend_seq_lens,
|
71
|
+
input_metadata.extend_logprob_start_lens_cpu,
|
72
|
+
)
|
73
|
+
]
|
74
|
+
else:
|
75
|
+
extend_logprob_pruned_lens_cpu = None
|
63
76
|
return cls(
|
64
77
|
forward_mode=input_metadata.forward_mode,
|
65
|
-
extend_seq_lens=input_metadata.extend_seq_lens,
|
66
|
-
extend_start_loc=input_metadata.extend_start_loc,
|
67
|
-
return_logprob=input_metadata.return_logprob,
|
68
78
|
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
79
|
+
return_logprob=input_metadata.return_logprob,
|
80
|
+
return_top_logprob=return_top_logprob,
|
81
|
+
extend_seq_lens=input_metadata.extend_seq_lens,
|
69
82
|
extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
|
70
|
-
|
83
|
+
extend_logprob_start_lens_cpu=input_metadata.extend_logprob_start_lens_cpu,
|
84
|
+
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
71
85
|
)
|
72
86
|
|
73
87
|
|
@@ -82,57 +96,49 @@ class LogitsProcessor(nn.Module):
|
|
82
96
|
def _get_normalized_prompt_logprobs(
|
83
97
|
self,
|
84
98
|
input_token_logprobs: torch.Tensor,
|
85
|
-
cum_start_len0: torch.Tensor,
|
86
|
-
cum_start_len1: torch.Tensor,
|
87
99
|
logits_metadata: LogitsMetadata,
|
88
100
|
):
|
89
101
|
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
102
|
+
pruned_lens = torch.tensor(
|
103
|
+
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
|
104
|
+
)
|
90
105
|
|
91
|
-
start =
|
92
|
-
|
93
|
-
|
94
|
-
|
106
|
+
start = torch.zeros_like(pruned_lens)
|
107
|
+
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
|
108
|
+
end = torch.clamp(
|
109
|
+
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
|
110
|
+
)
|
95
111
|
sum_logp = (
|
96
112
|
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
97
113
|
)
|
98
|
-
normalized_prompt_logprobs = sum_logp / (
|
99
|
-
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
100
|
-
)
|
101
|
-
|
114
|
+
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
|
102
115
|
return normalized_prompt_logprobs
|
103
116
|
|
104
117
|
@staticmethod
|
105
118
|
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
106
|
-
|
119
|
+
max_k = max(logits_metadata.top_logprobs_nums)
|
120
|
+
ret = all_logprobs.topk(max_k, dim=1)
|
121
|
+
values = ret.values.tolist()
|
122
|
+
indices = ret.indices.tolist()
|
123
|
+
|
124
|
+
if logits_metadata.forward_mode.is_decode():
|
107
125
|
output_top_logprobs = []
|
108
|
-
max_k = max(logits_metadata.top_logprobs_nums)
|
109
|
-
ret = all_logprobs.topk(max_k, dim=1)
|
110
|
-
values = ret.values.tolist()
|
111
|
-
indices = ret.indices.tolist()
|
112
126
|
for i, k in enumerate(logits_metadata.top_logprobs_nums):
|
113
127
|
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
|
114
128
|
return None, output_top_logprobs
|
115
129
|
else:
|
116
|
-
# TODO: vectorize the code below
|
117
130
|
input_top_logprobs, output_top_logprobs = [], []
|
118
|
-
pt = 0
|
119
|
-
extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu
|
120
131
|
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
start_len = logits_metadata.logprob_start_lens_cpu[i]
|
128
|
-
pruned_len = extend_seq_len - start_len
|
129
|
-
|
130
|
-
if extend_seq_len == 0:
|
132
|
+
pt = 0
|
133
|
+
for k, pruned_len in zip(
|
134
|
+
logits_metadata.top_logprobs_nums,
|
135
|
+
logits_metadata.extend_logprob_pruned_lens_cpu,
|
136
|
+
):
|
137
|
+
if pruned_len <= 0:
|
131
138
|
input_top_logprobs.append([])
|
132
139
|
output_top_logprobs.append([])
|
133
140
|
continue
|
134
141
|
|
135
|
-
k = logits_metadata.top_logprobs_nums[i]
|
136
142
|
input_top_logprobs.append(
|
137
143
|
[
|
138
144
|
list(zip(values[pt + j][:k], indices[pt + j][:k]))
|
@@ -163,14 +169,11 @@ class LogitsProcessor(nn.Module):
|
|
163
169
|
assert isinstance(logits_metadata, LogitsMetadata)
|
164
170
|
|
165
171
|
# Get the last hidden states and last logits for the next token prediction
|
166
|
-
if logits_metadata.forward_mode
|
172
|
+
if logits_metadata.forward_mode.is_decode():
|
167
173
|
last_index = None
|
168
174
|
last_hidden = hidden_states
|
169
175
|
else:
|
170
|
-
last_index = (
|
171
|
-
torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
172
|
-
- 1
|
173
|
-
)
|
176
|
+
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
174
177
|
last_hidden = hidden_states[last_index]
|
175
178
|
|
176
179
|
last_logits = torch.matmul(last_hidden, weight.T)
|
@@ -194,21 +197,15 @@ class LogitsProcessor(nn.Module):
|
|
194
197
|
output_top_logprobs=None,
|
195
198
|
)
|
196
199
|
else:
|
197
|
-
|
198
|
-
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
199
|
-
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
200
|
+
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
200
201
|
|
201
|
-
|
202
|
-
return_top_logprob
|
203
|
-
x > 0 for x in logits_metadata.top_logprobs_nums
|
204
|
-
)
|
205
|
-
if return_top_logprob:
|
202
|
+
if logits_metadata.forward_mode.is_decode():
|
203
|
+
if logits_metadata.return_top_logprob:
|
206
204
|
output_top_logprobs = self.get_top_logprobs(
|
207
205
|
last_logprobs, logits_metadata
|
208
206
|
)[1]
|
209
207
|
else:
|
210
208
|
output_top_logprobs = None
|
211
|
-
|
212
209
|
return LogitsProcessorOutput(
|
213
210
|
next_token_logits=last_logits,
|
214
211
|
next_token_logprobs=last_logprobs,
|
@@ -218,22 +215,18 @@ class LogitsProcessor(nn.Module):
|
|
218
215
|
output_top_logprobs=output_top_logprobs,
|
219
216
|
)
|
220
217
|
else:
|
218
|
+
# Slice the requested tokens to compute logprob
|
221
219
|
pt, states, pruned_input_ids = 0, [], []
|
222
|
-
for
|
223
|
-
|
220
|
+
for start_len, extend_len in zip(
|
221
|
+
logits_metadata.extend_logprob_start_lens_cpu,
|
222
|
+
logits_metadata.extend_seq_lens_cpu,
|
223
|
+
):
|
224
224
|
states.append(hidden_states[pt + start_len : pt + extend_len])
|
225
225
|
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
226
226
|
pt += extend_len
|
227
227
|
|
228
|
+
# Compute the logits and logprobs for all required tokens
|
228
229
|
states = torch.cat(states, dim=0)
|
229
|
-
pruned_input_ids = torch.cat(pruned_input_ids, dim=0)
|
230
|
-
|
231
|
-
cum_start_len1 = torch.tensor(
|
232
|
-
logits_metadata.logprob_start_lens_cpu, device="cuda"
|
233
|
-
).cumsum(0)
|
234
|
-
cum_start_len0 = torch.zeros_like(cum_start_len1)
|
235
|
-
cum_start_len0[1:] = cum_start_len1[:-1]
|
236
|
-
|
237
230
|
all_logits = torch.matmul(states, weight.T)
|
238
231
|
if self.do_tensor_parallel_all_gather:
|
239
232
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
@@ -249,35 +242,29 @@ class LogitsProcessor(nn.Module):
|
|
249
242
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
250
243
|
|
251
244
|
# Get the logprob of top-k tokens
|
252
|
-
return_top_logprob
|
253
|
-
x > 0 for x in logits_metadata.top_logprobs_nums
|
254
|
-
)
|
255
|
-
if return_top_logprob:
|
245
|
+
if logits_metadata.return_top_logprob:
|
256
246
|
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
|
257
247
|
all_logprobs, logits_metadata
|
258
248
|
)
|
259
249
|
else:
|
260
250
|
input_top_logprobs = output_top_logprobs = None
|
261
251
|
|
262
|
-
|
263
|
-
|
264
|
-
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
265
|
-
# Note that we pad a zero at the end of each sequence for easy computation.
|
252
|
+
# Compute the normalized logprobs for the requested tokens.
|
253
|
+
# Note that we pad a zero at the end for easy batching.
|
266
254
|
input_token_logprobs = all_logprobs[
|
267
255
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
268
|
-
torch.cat(
|
256
|
+
torch.cat(
|
257
|
+
[
|
258
|
+
torch.cat(pruned_input_ids)[1:],
|
259
|
+
torch.tensor([0], device="cuda"),
|
260
|
+
]
|
261
|
+
),
|
269
262
|
]
|
270
|
-
|
271
263
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
272
264
|
input_token_logprobs,
|
273
|
-
cum_start_len0,
|
274
|
-
cum_start_len1,
|
275
265
|
logits_metadata,
|
276
266
|
)
|
277
267
|
|
278
|
-
# Remove the last token logprob for the prefill tokens.
|
279
|
-
input_token_logprobs = input_token_logprobs[:-1]
|
280
|
-
|
281
268
|
return LogitsProcessorOutput(
|
282
269
|
next_token_logits=last_logits,
|
283
270
|
next_token_logprobs=last_logprobs,
|