sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +13 -6
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +2 -4
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +40 -35
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +110 -74
- sglang/srt/managers/tokenizer_manager.py +24 -15
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +60 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +118 -141
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +6 -8
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +8 -43
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/{llama2.py → llama.py} +48 -26
- sglang/srt/models/llama_classification.py +14 -40
- sglang/srt/models/llama_embedding.py +7 -6
- sglang/srt/models/llava.py +38 -16
- sglang/srt/models/llavavid.py +7 -8
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mistral.py +2 -3
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +67 -58
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +70 -0
- sglang/test/test_utils.py +89 -1
- sglang/utils.py +38 -4
- sglang/version.py +1 -1
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.2.15.dist-info/RECORD +0 -118
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
sglang/srt/layers/sampler.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import dataclasses
|
2
2
|
import logging
|
3
|
-
from typing import Union
|
3
|
+
from typing import Tuple, Union
|
4
4
|
|
5
5
|
import torch
|
6
6
|
from flashinfer.sampling import (
|
@@ -9,6 +9,7 @@ from flashinfer.sampling import (
|
|
9
9
|
top_k_top_p_sampling_from_probs,
|
10
10
|
top_p_renorm_prob,
|
11
11
|
)
|
12
|
+
from torch.library import custom_op as torch_custom_op
|
12
13
|
from vllm.model_executor.custom_op import CustomOp
|
13
14
|
|
14
15
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
@@ -30,43 +31,18 @@ class SampleOutput:
|
|
30
31
|
class Sampler(CustomOp):
|
31
32
|
def __init__(self):
|
32
33
|
super().__init__()
|
34
|
+
# FIXME: torch.multinomial has too many bugs
|
35
|
+
self.forward_native = self.forward_cuda
|
36
|
+
self.is_torch_compile = False
|
33
37
|
|
34
|
-
def
|
35
|
-
# min-token, presence, frequency
|
36
|
-
if sampling_info.linear_penalties is not None:
|
37
|
-
logits += sampling_info.linear_penalties
|
38
|
-
|
39
|
-
# repetition
|
40
|
-
if sampling_info.scaling_penalties is not None:
|
41
|
-
logits = torch.where(
|
42
|
-
logits > 0,
|
43
|
-
logits / sampling_info.scaling_penalties,
|
44
|
-
logits * sampling_info.scaling_penalties,
|
45
|
-
)
|
46
|
-
|
47
|
-
return logits
|
48
|
-
|
49
|
-
def _get_probs(
|
50
|
-
self,
|
51
|
-
logits: torch.Tensor,
|
52
|
-
sampling_info: SamplingBatchInfo,
|
53
|
-
is_torch_compile: bool = False,
|
54
|
-
):
|
38
|
+
def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
55
39
|
# Post process logits
|
56
40
|
logits = logits.contiguous()
|
57
41
|
logits.div_(sampling_info.temperatures)
|
58
|
-
if is_torch_compile:
|
42
|
+
if self.is_torch_compile:
|
59
43
|
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
60
44
|
logits.add_(0)
|
61
45
|
|
62
|
-
if sampling_info.logit_bias is not None:
|
63
|
-
logits.add_(sampling_info.logit_bias)
|
64
|
-
|
65
|
-
if sampling_info.vocab_mask is not None:
|
66
|
-
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
67
|
-
|
68
|
-
logits = self._apply_penalties(logits, sampling_info)
|
69
|
-
|
70
46
|
return torch.softmax(logits, dim=-1)
|
71
47
|
|
72
48
|
def forward_cuda(
|
@@ -79,7 +55,7 @@ class Sampler(CustomOp):
|
|
79
55
|
|
80
56
|
probs = self._get_probs(logits, sampling_info)
|
81
57
|
|
82
|
-
if
|
58
|
+
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
83
59
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
84
60
|
uniform_samples = torch.rand(
|
85
61
|
(max_top_k_round, batch_size), device=probs.device
|
@@ -91,14 +67,18 @@ class Sampler(CustomOp):
|
|
91
67
|
probs, uniform_samples, sampling_info.min_ps
|
92
68
|
)
|
93
69
|
else:
|
94
|
-
batch_next_token_ids, success =
|
70
|
+
batch_next_token_ids, success = flashinfer_top_k_top_p(
|
95
71
|
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
|
96
72
|
)
|
97
|
-
|
73
|
+
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
98
74
|
# Here we provide a slower fallback implementation.
|
99
75
|
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
100
76
|
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
101
77
|
)
|
78
|
+
else:
|
79
|
+
raise ValueError(
|
80
|
+
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
81
|
+
)
|
102
82
|
|
103
83
|
return SampleOutput(success, probs, batch_next_token_ids)
|
104
84
|
|
@@ -110,7 +90,7 @@ class Sampler(CustomOp):
|
|
110
90
|
if isinstance(logits, LogitsProcessorOutput):
|
111
91
|
logits = logits.next_token_logits
|
112
92
|
|
113
|
-
probs = self._get_probs(logits, sampling_info
|
93
|
+
probs = self._get_probs(logits, sampling_info)
|
114
94
|
|
115
95
|
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
116
96
|
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
@@ -119,6 +99,31 @@ class Sampler(CustomOp):
|
|
119
99
|
return SampleOutput(success, probs, batch_next_token_ids)
|
120
100
|
|
121
101
|
|
102
|
+
@torch_custom_op("my_lib::flashinfer_top_k_top_p", mutates_args={})
|
103
|
+
def flashinfer_top_k_top_p(
|
104
|
+
probs: torch.Tensor,
|
105
|
+
uniform_samples: torch.Tensor,
|
106
|
+
top_ks: torch.Tensor,
|
107
|
+
top_ps: torch.Tensor,
|
108
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
109
|
+
# NOTE: we do not use min_p neither in CUDA nor in torch.compile
|
110
|
+
return top_k_top_p_sampling_from_probs(probs, uniform_samples, top_ks, top_ps)
|
111
|
+
|
112
|
+
|
113
|
+
@flashinfer_top_k_top_p.register_fake
|
114
|
+
def _(
|
115
|
+
probs: torch.Tensor,
|
116
|
+
uniform_samples: torch.Tensor,
|
117
|
+
top_ks: torch.Tensor,
|
118
|
+
top_ps: torch.Tensor,
|
119
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
120
|
+
bs = probs.shape[0]
|
121
|
+
return (
|
122
|
+
torch.ones(bs, dtype=torch.bool, device=probs.device),
|
123
|
+
torch.zeros(bs, dtype=torch.int32, device=probs.device),
|
124
|
+
)
|
125
|
+
|
126
|
+
|
122
127
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
123
128
|
probs: torch.Tensor,
|
124
129
|
top_ks: torch.Tensor,
|
@@ -0,0 +1,75 @@
|
|
1
|
+
"""
|
2
|
+
Common utilities for torchao.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Dict, Set
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
|
10
|
+
def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
11
|
+
"""Quantize a Tensor with torchao quantization specified by torchao_config
|
12
|
+
|
13
|
+
Args:
|
14
|
+
`param`: weight parameter of the linear module
|
15
|
+
`torchao_config`: type of quantization and their arguments we want to use to
|
16
|
+
quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size
|
17
|
+
128
|
18
|
+
"""
|
19
|
+
# Lazy import to suppress some warnings
|
20
|
+
from torchao.quantization import (
|
21
|
+
int4_weight_only,
|
22
|
+
int8_dynamic_activation_int8_weight,
|
23
|
+
int8_weight_only,
|
24
|
+
quantize_,
|
25
|
+
)
|
26
|
+
|
27
|
+
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
|
28
|
+
dummy_linear.weight = param
|
29
|
+
if "int8wo" in torchao_config:
|
30
|
+
quantize_(dummy_linear, int8_weight_only())
|
31
|
+
elif "int8dq" in torchao_config:
|
32
|
+
quantize_(dummy_linear, int8_dynamic_activation_int8_weight())
|
33
|
+
elif "int4wo" in torchao_config:
|
34
|
+
group_size = int(torchao_config.split("-")[-1])
|
35
|
+
assert group_size in [
|
36
|
+
32,
|
37
|
+
64,
|
38
|
+
128,
|
39
|
+
256,
|
40
|
+
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
|
41
|
+
quantize_(dummy_linear, int4_weight_only(group_size=group_size))
|
42
|
+
elif "fp8wo" in torchao_config:
|
43
|
+
from torchao.quantization import float8_weight_only
|
44
|
+
|
45
|
+
# this requires newer hardware
|
46
|
+
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
47
|
+
quantize_(dummy_linear, float8_weight_only())
|
48
|
+
return dummy_linear.weight
|
49
|
+
|
50
|
+
|
51
|
+
def apply_torchao_config_(
|
52
|
+
self: torch.nn.Module,
|
53
|
+
params_dict: Dict[str, torch.Tensor],
|
54
|
+
param_suffixes: Set[str],
|
55
|
+
) -> None:
|
56
|
+
"""A util function used for quantizing the weight parameters after they are loaded if
|
57
|
+
self.torchao_config is specified
|
58
|
+
|
59
|
+
Args:
|
60
|
+
`self`: the model we want to quantize
|
61
|
+
`params_dict`: dictionary mapping from param_name to the parameter Tensor
|
62
|
+
`param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
None, the `params_dict` is modified inplace and the weights of `self` model are quantized
|
66
|
+
"""
|
67
|
+
if self.torchao_config:
|
68
|
+
for param_suffix in param_suffixes:
|
69
|
+
for name in params_dict:
|
70
|
+
param = params_dict[name]
|
71
|
+
if param_suffix in name and param.ndim == 2:
|
72
|
+
params_dict[name] = torchao_quantize_param_data(
|
73
|
+
param, self.torchao_config
|
74
|
+
)
|
75
|
+
self.load_state_dict(params_dict, assign=True)
|
@@ -15,24 +15,15 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""
|
17
17
|
Memory-efficient attention for decoding.
|
18
|
+
It supports page size = 1.
|
18
19
|
"""
|
19
20
|
|
20
21
|
# Adapted from
|
21
22
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
22
23
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
|
23
|
-
import torch
|
24
24
|
import triton
|
25
25
|
import triton.language as tl
|
26
26
|
|
27
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
28
|
-
|
29
|
-
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
30
|
-
REDUCE_TRITON_TYPE = tl.float32
|
31
|
-
REDUCE_TORCH_TYPE = torch.float32
|
32
|
-
else:
|
33
|
-
REDUCE_TRITON_TYPE = tl.float16
|
34
|
-
REDUCE_TORCH_TYPE = torch.float16
|
35
|
-
|
36
27
|
|
37
28
|
@triton.jit
|
38
29
|
def tanh(x):
|
@@ -60,11 +51,13 @@ def _fwd_kernel_stage1(
|
|
60
51
|
BLOCK_DMODEL: tl.constexpr,
|
61
52
|
BLOCK_N: tl.constexpr,
|
62
53
|
logit_cap: tl.constexpr,
|
54
|
+
Lk: tl.constexpr,
|
63
55
|
):
|
64
56
|
cur_batch = tl.program_id(0)
|
65
57
|
cur_head = tl.program_id(1)
|
66
58
|
start_n = tl.program_id(2)
|
67
59
|
|
60
|
+
reduce_dtype = Att_Out.dtype.element_ty
|
68
61
|
cur_kv_head = cur_head // kv_group_num
|
69
62
|
|
70
63
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
@@ -83,7 +76,7 @@ def _fwd_kernel_stage1(
|
|
83
76
|
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
84
77
|
|
85
78
|
for start_mark in range(0, block_mask, 1):
|
86
|
-
q = tl.load(Q + off_q + start_mark).to(
|
79
|
+
q = tl.load(Q + off_q + start_mark).to(reduce_dtype)
|
87
80
|
offs_n_new = cur_batch_start_index + offs_n
|
88
81
|
k_loc = tl.load(
|
89
82
|
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
|
@@ -97,9 +90,9 @@ def _fwd_kernel_stage1(
|
|
97
90
|
)
|
98
91
|
k = tl.load(
|
99
92
|
K_Buffer + offs_buf_k,
|
100
|
-
mask=offs_n_new[:, None] < cur_batch_end_index,
|
93
|
+
mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk),
|
101
94
|
other=0.0,
|
102
|
-
).to(
|
95
|
+
).to(reduce_dtype)
|
103
96
|
att_value = tl.sum(q[None, :] * k, 1)
|
104
97
|
att_value *= sm_scale
|
105
98
|
|
@@ -112,7 +105,7 @@ def _fwd_kernel_stage1(
|
|
112
105
|
|
113
106
|
@triton.jit
|
114
107
|
def _fwd_kernel_stage2(
|
115
|
-
|
108
|
+
logits,
|
116
109
|
V_Buffer,
|
117
110
|
Out,
|
118
111
|
Req_to_tokens,
|
@@ -128,6 +121,7 @@ def _fwd_kernel_stage2(
|
|
128
121
|
kv_group_num: tl.constexpr,
|
129
122
|
BLOCK_DMODEL: tl.constexpr,
|
130
123
|
BLOCK_N: tl.constexpr,
|
124
|
+
Lv: tl.constexpr,
|
131
125
|
):
|
132
126
|
cur_batch = tl.program_id(0)
|
133
127
|
cur_head = tl.program_id(1)
|
@@ -159,7 +153,7 @@ def _fwd_kernel_stage2(
|
|
159
153
|
)
|
160
154
|
|
161
155
|
qk = tl.load(
|
162
|
-
|
156
|
+
logits
|
163
157
|
+ cur_head * stride_logic_h
|
164
158
|
+ (cur_batch_start_loc + start_n + offs_n),
|
165
159
|
mask=start_n + offs_n < cur_batch_seq_len,
|
@@ -170,14 +164,16 @@ def _fwd_kernel_stage2(
|
|
170
164
|
old_scale = tl.exp(e_max - n_e_max)
|
171
165
|
p = tl.exp(qk - n_e_max)
|
172
166
|
e_sum = e_sum * old_scale + tl.sum(p, 0)
|
173
|
-
v = tl.load(
|
167
|
+
v = tl.load(
|
168
|
+
v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
|
169
|
+
)
|
174
170
|
acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
|
175
171
|
e_max = n_e_max
|
176
172
|
|
177
173
|
acc = acc / e_sum
|
178
174
|
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d
|
179
175
|
out_ptrs = Out + off_o
|
180
|
-
tl.store(out_ptrs, acc)
|
176
|
+
tl.store(out_ptrs, acc, mask=(offs_d < Lv))
|
181
177
|
|
182
178
|
|
183
179
|
def _decode_att_m_fwd(
|
@@ -193,10 +189,7 @@ def _decode_att_m_fwd(
|
|
193
189
|
logit_cap,
|
194
190
|
):
|
195
191
|
BLOCK = 32
|
196
|
-
|
197
|
-
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
198
|
-
assert Lq == Lk
|
199
|
-
assert Lk in {16, 32, 64, 128, 256}
|
192
|
+
Lk = k_buffer.shape[-1]
|
200
193
|
|
201
194
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
202
195
|
|
@@ -208,6 +201,8 @@ def _decode_att_m_fwd(
|
|
208
201
|
else:
|
209
202
|
num_warps = 2
|
210
203
|
|
204
|
+
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
205
|
+
|
211
206
|
_fwd_kernel_stage1[grid](
|
212
207
|
q,
|
213
208
|
k_buffer,
|
@@ -224,16 +219,17 @@ def _decode_att_m_fwd(
|
|
224
219
|
k_buffer.stride(1),
|
225
220
|
att_out.stride(0),
|
226
221
|
kv_group_num=kv_group_num,
|
227
|
-
BLOCK_DMODEL=
|
222
|
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
228
223
|
BLOCK_N=BLOCK,
|
229
224
|
logit_cap=logit_cap,
|
230
225
|
num_warps=num_warps,
|
231
226
|
num_stages=1,
|
227
|
+
Lk=Lk,
|
232
228
|
)
|
233
229
|
|
234
230
|
|
235
231
|
def _decode_softmax_reducev_fwd(
|
236
|
-
|
232
|
+
logits,
|
237
233
|
v_buffer,
|
238
234
|
o,
|
239
235
|
req_to_tokens,
|
@@ -242,31 +238,35 @@ def _decode_softmax_reducev_fwd(
|
|
242
238
|
b_seq_len,
|
243
239
|
):
|
244
240
|
BLOCK = 64
|
245
|
-
batch, head = b_seq_len.shape[0],
|
241
|
+
batch, head = b_seq_len.shape[0], logits.shape[0]
|
246
242
|
grid = (batch, head, 1)
|
247
|
-
kv_group_num =
|
243
|
+
kv_group_num = logits.shape[0] // v_buffer.shape[1]
|
248
244
|
|
249
245
|
num_warps = 1
|
250
246
|
|
247
|
+
Lv = v_buffer.shape[-1]
|
248
|
+
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
249
|
+
|
251
250
|
_fwd_kernel_stage2[grid](
|
252
|
-
|
251
|
+
logits,
|
253
252
|
v_buffer,
|
254
253
|
o,
|
255
254
|
req_to_tokens,
|
256
255
|
b_req_idx,
|
257
256
|
b_start_loc,
|
258
257
|
b_seq_len,
|
259
|
-
|
258
|
+
logits.stride(0),
|
260
259
|
v_buffer.stride(0),
|
261
260
|
v_buffer.stride(1),
|
262
261
|
o.stride(0),
|
263
262
|
o.stride(1),
|
264
263
|
req_to_tokens.stride(0),
|
265
264
|
kv_group_num=kv_group_num,
|
266
|
-
BLOCK_DMODEL=
|
265
|
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
267
266
|
BLOCK_N=BLOCK,
|
268
267
|
num_warps=num_warps,
|
269
268
|
num_stages=3,
|
269
|
+
Lv=Lv,
|
270
270
|
)
|
271
271
|
|
272
272
|
|
@@ -293,11 +293,13 @@ def _fwd_grouped_kernel_stage1(
|
|
293
293
|
BLOCK_N: tl.constexpr,
|
294
294
|
BLOCK_H: tl.constexpr,
|
295
295
|
logit_cap: tl.constexpr,
|
296
|
+
Lk: tl.constexpr,
|
296
297
|
):
|
297
298
|
cur_batch = tl.program_id(0)
|
298
299
|
cur_kv_head = tl.program_id(1)
|
299
300
|
start_n = tl.program_id(2)
|
300
301
|
|
302
|
+
reduce_dtype = Att_Out.dtype.element_ty
|
301
303
|
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
|
302
304
|
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
|
303
305
|
mask_h = mask_h & (cur_head < q_head_num)
|
@@ -324,9 +326,9 @@ def _fwd_grouped_kernel_stage1(
|
|
324
326
|
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
325
327
|
|
326
328
|
for start_mark in range(0, block_mask, 1):
|
327
|
-
q = tl.load(
|
328
|
-
|
329
|
-
)
|
329
|
+
q = tl.load(
|
330
|
+
Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk)
|
331
|
+
).to(reduce_dtype)
|
330
332
|
offs_n_new = cur_batch_start_index + offs_n
|
331
333
|
k_loc = tl.load(
|
332
334
|
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
|
@@ -340,13 +342,13 @@ def _fwd_grouped_kernel_stage1(
|
|
340
342
|
)
|
341
343
|
k = tl.load(
|
342
344
|
K_Buffer + offs_buf_k,
|
343
|
-
mask=offs_n_new[None, :] < cur_batch_end_index,
|
345
|
+
mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk),
|
344
346
|
other=0.0,
|
345
|
-
).to(
|
347
|
+
).to(reduce_dtype)
|
346
348
|
qk = tl.dot(q, k)
|
347
349
|
if BLOCK_DPE > 0:
|
348
350
|
qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
|
349
|
-
|
351
|
+
reduce_dtype
|
350
352
|
)
|
351
353
|
offs_buf_kpe = (
|
352
354
|
k_loc[None, :] * stride_buf_kbs
|
@@ -357,7 +359,7 @@ def _fwd_grouped_kernel_stage1(
|
|
357
359
|
K_Buffer + offs_buf_kpe,
|
358
360
|
mask=offs_n_new[None, :] < cur_batch_end_index,
|
359
361
|
other=0.0,
|
360
|
-
).to(
|
362
|
+
).to(reduce_dtype)
|
361
363
|
qk += tl.dot(qpe, kpe)
|
362
364
|
qk *= sm_scale
|
363
365
|
|
@@ -377,7 +379,7 @@ def _fwd_grouped_kernel_stage1(
|
|
377
379
|
|
378
380
|
@triton.jit
|
379
381
|
def _fwd_grouped_kernel_stage2(
|
380
|
-
|
382
|
+
logits,
|
381
383
|
V_Buffer,
|
382
384
|
Out,
|
383
385
|
Req_to_tokens,
|
@@ -395,6 +397,7 @@ def _fwd_grouped_kernel_stage2(
|
|
395
397
|
BLOCK_DMODEL: tl.constexpr,
|
396
398
|
BLOCK_N: tl.constexpr,
|
397
399
|
BLOCK_H: tl.constexpr,
|
400
|
+
Lv: tl.constexpr,
|
398
401
|
):
|
399
402
|
cur_batch = tl.program_id(0)
|
400
403
|
cur_kv_head = tl.program_id(1)
|
@@ -432,7 +435,7 @@ def _fwd_grouped_kernel_stage2(
|
|
432
435
|
)
|
433
436
|
|
434
437
|
qk = tl.load(
|
435
|
-
|
438
|
+
logits + offs_qk,
|
436
439
|
mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
|
437
440
|
other=float("-inf"),
|
438
441
|
)
|
@@ -441,7 +444,9 @@ def _fwd_grouped_kernel_stage2(
|
|
441
444
|
old_scale = tl.exp(e_max - n_e_max)
|
442
445
|
p = tl.exp(qk - n_e_max[:, None])
|
443
446
|
e_sum = e_sum * old_scale + tl.sum(p, 1)
|
444
|
-
v = tl.load(
|
447
|
+
v = tl.load(
|
448
|
+
v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
|
449
|
+
)
|
445
450
|
p = p.to(v.dtype)
|
446
451
|
acc = acc * old_scale[:, None] + tl.dot(p, v)
|
447
452
|
e_max = n_e_max
|
@@ -449,7 +454,7 @@ def _fwd_grouped_kernel_stage2(
|
|
449
454
|
acc = acc / e_sum[:, None]
|
450
455
|
off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :]
|
451
456
|
out_ptrs = Out + off_o
|
452
|
-
tl.store(out_ptrs, acc, mask=mask_h[:, None])
|
457
|
+
tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv))
|
453
458
|
|
454
459
|
|
455
460
|
def _decode_grouped_att_m_fwd(
|
@@ -464,17 +469,17 @@ def _decode_grouped_att_m_fwd(
|
|
464
469
|
sm_scale,
|
465
470
|
logit_cap,
|
466
471
|
):
|
467
|
-
BLOCK =
|
468
|
-
|
469
|
-
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
470
|
-
assert Lq == Lk
|
471
|
-
assert Lk in {16, 32, 64, 128, 256, 576}
|
472
|
+
BLOCK = 64
|
473
|
+
Lk = k_buffer.shape[-1]
|
472
474
|
|
473
475
|
if Lk == 576:
|
474
476
|
BLOCK_DMODEL = 512
|
475
477
|
BLOCK_DPE = 64
|
478
|
+
elif Lk == 288:
|
479
|
+
BLOCK_DMODEL = 256
|
480
|
+
BLOCK_DPE = 32
|
476
481
|
else:
|
477
|
-
BLOCK_DMODEL = Lk
|
482
|
+
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
478
483
|
BLOCK_DPE = 0
|
479
484
|
|
480
485
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
@@ -513,11 +518,12 @@ def _decode_grouped_att_m_fwd(
|
|
513
518
|
logit_cap=logit_cap,
|
514
519
|
num_warps=num_warps,
|
515
520
|
num_stages=1,
|
521
|
+
Lk=Lk,
|
516
522
|
)
|
517
523
|
|
518
524
|
|
519
525
|
def _decode_grouped_softmax_reducev_fwd(
|
520
|
-
|
526
|
+
logits,
|
521
527
|
v_buffer,
|
522
528
|
o,
|
523
529
|
req_to_tokens,
|
@@ -526,22 +532,25 @@ def _decode_grouped_softmax_reducev_fwd(
|
|
526
532
|
b_seq_len,
|
527
533
|
):
|
528
534
|
BLOCK = 128
|
529
|
-
batch, head_num = b_seq_len.shape[0],
|
530
|
-
kv_group_num =
|
535
|
+
batch, head_num = b_seq_len.shape[0], logits.shape[0]
|
536
|
+
kv_group_num = logits.shape[0] // v_buffer.shape[1]
|
531
537
|
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
|
532
538
|
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
|
533
539
|
|
534
540
|
num_warps = 8
|
535
541
|
|
542
|
+
Lv = v_buffer.shape[-1]
|
543
|
+
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
544
|
+
|
536
545
|
_fwd_grouped_kernel_stage2[grid](
|
537
|
-
|
546
|
+
logits,
|
538
547
|
v_buffer,
|
539
548
|
o,
|
540
549
|
req_to_tokens,
|
541
550
|
b_req_idx,
|
542
551
|
b_start_loc,
|
543
552
|
b_seq_len,
|
544
|
-
|
553
|
+
logits.stride(0),
|
545
554
|
v_buffer.stride(0),
|
546
555
|
v_buffer.stride(1),
|
547
556
|
o.stride(0),
|
@@ -549,9 +558,10 @@ def _decode_grouped_softmax_reducev_fwd(
|
|
549
558
|
req_to_tokens.stride(0),
|
550
559
|
kv_group_num=kv_group_num,
|
551
560
|
q_head_num=head_num,
|
552
|
-
BLOCK_DMODEL=
|
561
|
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
553
562
|
BLOCK_N=BLOCK,
|
554
563
|
BLOCK_H=BLOCK_H,
|
564
|
+
Lv=Lv,
|
555
565
|
num_warps=num_warps,
|
556
566
|
num_stages=1,
|
557
567
|
)
|
@@ -566,17 +576,11 @@ def decode_attention_fwd(
|
|
566
576
|
b_req_idx,
|
567
577
|
b_start_loc,
|
568
578
|
b_seq_len,
|
579
|
+
attn_logits,
|
569
580
|
max_len_in_batch,
|
570
|
-
total_num_tokens,
|
571
581
|
sm_scale,
|
572
|
-
logit_cap
|
573
|
-
att_m=None,
|
582
|
+
logit_cap=0.0,
|
574
583
|
):
|
575
|
-
if att_m is None:
|
576
|
-
att_m = torch.empty(
|
577
|
-
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
578
|
-
)
|
579
|
-
|
580
584
|
kv_group_num = q.shape[1] // v_buffer.shape[1]
|
581
585
|
|
582
586
|
if kv_group_num == 1:
|
@@ -584,7 +588,7 @@ def decode_attention_fwd(
|
|
584
588
|
_decode_att_m_fwd(
|
585
589
|
q,
|
586
590
|
k_buffer,
|
587
|
-
|
591
|
+
attn_logits,
|
588
592
|
req_to_token,
|
589
593
|
b_req_idx,
|
590
594
|
b_start_loc,
|
@@ -594,7 +598,7 @@ def decode_attention_fwd(
|
|
594
598
|
logit_cap,
|
595
599
|
)
|
596
600
|
_decode_softmax_reducev_fwd(
|
597
|
-
|
601
|
+
attn_logits,
|
598
602
|
v_buffer,
|
599
603
|
o,
|
600
604
|
req_to_token,
|
@@ -607,7 +611,7 @@ def decode_attention_fwd(
|
|
607
611
|
_decode_grouped_att_m_fwd(
|
608
612
|
q,
|
609
613
|
k_buffer,
|
610
|
-
|
614
|
+
attn_logits,
|
611
615
|
req_to_token,
|
612
616
|
b_req_idx,
|
613
617
|
b_start_loc,
|
@@ -617,7 +621,7 @@ def decode_attention_fwd(
|
|
617
621
|
logit_cap,
|
618
622
|
)
|
619
623
|
_decode_grouped_softmax_reducev_fwd(
|
620
|
-
|
624
|
+
attn_logits,
|
621
625
|
v_buffer,
|
622
626
|
o,
|
623
627
|
req_to_token,
|