sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +6 -0
- sglang/bench_latency.py +7 -3
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +9 -0
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/fsm_cache.py +11 -2
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +100 -1
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/fused_moe/layer.py +2 -2
- sglang/srt/layers/logits_processor.py +56 -19
- sglang/srt/layers/radix_attention.py +3 -4
- sglang/srt/layers/sampler.py +101 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +46 -166
- sglang/srt/managers/tokenizer_manager.py +192 -83
- sglang/srt/managers/tp_worker.py +118 -24
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +32 -8
- sglang/srt/model_executor/forward_batch_info.py +51 -26
- sglang/srt/model_executor/model_runner.py +201 -58
- sglang/srt/models/gemma2.py +10 -6
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +11 -1
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/qwen2.py +9 -3
- sglang/srt/openai_api/adapter.py +200 -39
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_batch_info.py +136 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
- sglang/srt/server.py +92 -57
- sglang/srt/server_args.py +43 -15
- sglang/srt/utils.py +26 -16
- sglang/test/runners.py +22 -30
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_utils.py +36 -53
- sglang/version.py +1 -1
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
- sglang-0.2.14.post1.dist-info/RECORD +114 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang-0.2.13.dist-info/RECORD +0 -112
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
@@ -30,14 +30,19 @@ from transformers import (
|
|
30
30
|
PreTrainedTokenizer,
|
31
31
|
PreTrainedTokenizerFast,
|
32
32
|
)
|
33
|
-
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
34
33
|
|
35
|
-
|
34
|
+
try:
|
35
|
+
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
36
|
+
|
37
|
+
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
38
|
+
ChatGLMConfig.model_type: ChatGLMConfig,
|
39
|
+
DbrxConfig.model_type: DbrxConfig,
|
40
|
+
}
|
41
|
+
except ImportError:
|
42
|
+
# We want this file to run without vllm dependency
|
43
|
+
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {}
|
36
44
|
|
37
|
-
|
38
|
-
ChatGLMConfig.model_type: ChatGLMConfig,
|
39
|
-
DbrxConfig.model_type: DbrxConfig,
|
40
|
-
}
|
45
|
+
from sglang.srt.utils import is_multimodal_model
|
41
46
|
|
42
47
|
|
43
48
|
def download_from_hf(model_path: str):
|
@@ -137,18 +142,6 @@ def get_tokenizer(
|
|
137
142
|
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
138
143
|
kwargs["use_fast"] = False
|
139
144
|
|
140
|
-
if (
|
141
|
-
"llama" in tokenizer_name.lower()
|
142
|
-
and kwargs.get("use_fast", True)
|
143
|
-
and tokenizer_name != _FAST_LLAMA_TOKENIZER
|
144
|
-
):
|
145
|
-
pass
|
146
|
-
# warnings.warn(
|
147
|
-
# "For some LLaMA V1 models, initializing the fast tokenizer may "
|
148
|
-
# "take a long time. To reduce the initialization time, consider "
|
149
|
-
# f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
150
|
-
# "tokenizer."
|
151
|
-
# )
|
152
145
|
try:
|
153
146
|
tokenizer = AutoTokenizer.from_pretrained(
|
154
147
|
tokenizer_name,
|
@@ -229,6 +222,8 @@ class TiktokenTokenizer:
|
|
229
222
|
}
|
230
223
|
assert tok_dict["word_split"] == "V1"
|
231
224
|
|
225
|
+
default_allowed_special = None
|
226
|
+
|
232
227
|
kwargs = {
|
233
228
|
"name": name,
|
234
229
|
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
|
@@ -242,14 +237,18 @@ class TiktokenTokenizer:
|
|
242
237
|
for bytes_list in tok_dict["default_allowed_special"]
|
243
238
|
]
|
244
239
|
)
|
245
|
-
else:
|
246
|
-
default_allowed_special = None
|
247
240
|
if "vocab_size" in tok_dict:
|
248
241
|
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
|
249
242
|
|
243
|
+
PAD = "<|pad|>"
|
244
|
+
EOS = "<|eos|>"
|
245
|
+
SEP = "<|separator|>"
|
246
|
+
|
247
|
+
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
|
248
|
+
|
250
249
|
tokenizer = tiktoken.Encoding(**kwargs)
|
251
250
|
tokenizer._default_allowed_special = default_allowed_special or set()
|
252
|
-
tokenizer.
|
251
|
+
tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
|
253
252
|
|
254
253
|
def encode_patched(
|
255
254
|
self,
|
@@ -266,14 +265,14 @@ class TiktokenTokenizer:
|
|
266
265
|
self,
|
267
266
|
text,
|
268
267
|
allowed_special=allowed_special,
|
269
|
-
disallowed_special=
|
268
|
+
disallowed_special=(),
|
270
269
|
)
|
271
270
|
|
272
271
|
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
273
272
|
|
274
273
|
# Convert to HF interface
|
275
274
|
self.tokenizer = tokenizer
|
276
|
-
self.eos_token_id = tokenizer._special_tokens[
|
275
|
+
self.eos_token_id = tokenizer._special_tokens[EOS]
|
277
276
|
self.vocab_size = tokenizer.n_vocab
|
278
277
|
self.chat_template = Template(
|
279
278
|
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
sglang/srt/layers/activation.py
CHANGED
@@ -13,10 +13,20 @@ limitations under the License.
|
|
13
13
|
|
14
14
|
"""Fused operators for activation layers."""
|
15
15
|
|
16
|
+
from typing import Optional
|
17
|
+
|
16
18
|
import torch
|
19
|
+
import torch.nn as nn
|
17
20
|
import torch.nn.functional as F
|
18
|
-
from flashinfer.activation import silu_and_mul
|
21
|
+
from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul
|
22
|
+
from vllm.distributed import (
|
23
|
+
divide,
|
24
|
+
get_tensor_model_parallel_rank,
|
25
|
+
get_tensor_model_parallel_world_size,
|
26
|
+
)
|
19
27
|
from vllm.model_executor.custom_op import CustomOp
|
28
|
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
29
|
+
from vllm.model_executor.utils import set_weight_attrs
|
20
30
|
|
21
31
|
|
22
32
|
class SiluAndMul(CustomOp):
|
@@ -30,3 +40,92 @@ class SiluAndMul(CustomOp):
|
|
30
40
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
31
41
|
silu_and_mul(x, out)
|
32
42
|
return out
|
43
|
+
|
44
|
+
|
45
|
+
class GeluAndMul(CustomOp):
|
46
|
+
def __init__(self, **kwargs):
|
47
|
+
super().__init__()
|
48
|
+
|
49
|
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
50
|
+
d = x.shape[-1] // 2
|
51
|
+
return F.gelu(x[..., :d], approximate="tanh") * x[..., d:]
|
52
|
+
|
53
|
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
54
|
+
d = x.shape[-1] // 2
|
55
|
+
output_shape = x.shape[:-1] + (d,)
|
56
|
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
57
|
+
gelu_tanh_and_mul(x, out)
|
58
|
+
return out
|
59
|
+
|
60
|
+
|
61
|
+
class ScaledActivation(nn.Module):
|
62
|
+
"""An activation function with post-scale parameters.
|
63
|
+
|
64
|
+
This is used for some quantization methods like AWQ.
|
65
|
+
"""
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
act_module: nn.Module,
|
70
|
+
intermediate_size: int,
|
71
|
+
input_is_parallel: bool = True,
|
72
|
+
params_dtype: Optional[torch.dtype] = None,
|
73
|
+
):
|
74
|
+
super().__init__()
|
75
|
+
self.act = act_module
|
76
|
+
self.input_is_parallel = input_is_parallel
|
77
|
+
if input_is_parallel:
|
78
|
+
tp_size = get_tensor_model_parallel_world_size()
|
79
|
+
intermediate_size_per_partition = divide(intermediate_size, tp_size)
|
80
|
+
else:
|
81
|
+
intermediate_size_per_partition = intermediate_size
|
82
|
+
if params_dtype is None:
|
83
|
+
params_dtype = torch.get_default_dtype()
|
84
|
+
self.scales = nn.Parameter(
|
85
|
+
torch.empty(intermediate_size_per_partition, dtype=params_dtype)
|
86
|
+
)
|
87
|
+
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
88
|
+
|
89
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
90
|
+
return self.act(x) / self.scales
|
91
|
+
|
92
|
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
93
|
+
param_data = param.data
|
94
|
+
if self.input_is_parallel:
|
95
|
+
tp_rank = get_tensor_model_parallel_rank()
|
96
|
+
shard_size = param_data.shape[0]
|
97
|
+
start_idx = tp_rank * shard_size
|
98
|
+
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
99
|
+
assert param_data.shape == loaded_weight.shape
|
100
|
+
param_data.copy_(loaded_weight)
|
101
|
+
|
102
|
+
|
103
|
+
_ACTIVATION_REGISTRY = {
|
104
|
+
"gelu": nn.GELU(),
|
105
|
+
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
106
|
+
}
|
107
|
+
|
108
|
+
|
109
|
+
def get_act_fn(
|
110
|
+
act_fn_name: str,
|
111
|
+
quant_config: Optional[QuantizationConfig] = None,
|
112
|
+
intermediate_size: Optional[int] = None,
|
113
|
+
input_is_parallel: bool = True,
|
114
|
+
params_dtype: Optional[torch.dtype] = None,
|
115
|
+
) -> nn.Module:
|
116
|
+
"""Get an activation function by name."""
|
117
|
+
act_fn_name = act_fn_name.lower()
|
118
|
+
if act_fn_name not in _ACTIVATION_REGISTRY:
|
119
|
+
raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
|
120
|
+
|
121
|
+
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
122
|
+
if quant_config is not None and act_fn_name in quant_config.get_scaled_act_names():
|
123
|
+
if intermediate_size is None:
|
124
|
+
raise ValueError(
|
125
|
+
"intermediate_size must be specified for scaled "
|
126
|
+
"activation functions."
|
127
|
+
)
|
128
|
+
return ScaledActivation(
|
129
|
+
act_fn, intermediate_size, input_is_parallel, params_dtype
|
130
|
+
)
|
131
|
+
return act_fn
|
@@ -26,7 +26,7 @@ import triton.language as tl
|
|
26
26
|
|
27
27
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
28
28
|
|
29
|
-
if global_server_args_dict.get("
|
29
|
+
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
30
30
|
REDUCE_TRITON_TYPE = tl.float32
|
31
31
|
REDUCE_TORCH_TYPE = torch.float32
|
32
32
|
else:
|
@@ -58,7 +58,6 @@ def _fwd_kernel_stage1(
|
|
58
58
|
att_stride_h,
|
59
59
|
kv_group_num: tl.constexpr,
|
60
60
|
BLOCK_DMODEL: tl.constexpr,
|
61
|
-
BLOCK_DPE: tl.constexpr,
|
62
61
|
BLOCK_N: tl.constexpr,
|
63
62
|
logit_cap: tl.constexpr,
|
64
63
|
):
|
@@ -78,10 +77,6 @@ def _fwd_kernel_stage1(
|
|
78
77
|
|
79
78
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
80
79
|
|
81
|
-
if BLOCK_DPE > 0:
|
82
|
-
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
83
|
-
off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe
|
84
|
-
|
85
80
|
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
86
81
|
|
87
82
|
block_stard_index = start_n * BLOCK_N
|
@@ -106,19 +101,6 @@ def _fwd_kernel_stage1(
|
|
106
101
|
other=0.0,
|
107
102
|
).to(REDUCE_TRITON_TYPE)
|
108
103
|
att_value = tl.sum(q[None, :] * k, 1)
|
109
|
-
if BLOCK_DPE > 0:
|
110
|
-
qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE)
|
111
|
-
offs_buf_kpe = (
|
112
|
-
k_loc[:, None] * stride_buf_kbs
|
113
|
-
+ cur_kv_head * stride_buf_kh
|
114
|
-
+ offs_dpe[None, :]
|
115
|
-
)
|
116
|
-
kpe = tl.load(
|
117
|
-
K_Buffer + offs_buf_kpe,
|
118
|
-
mask=offs_n_new[:, None] < cur_batch_end_index,
|
119
|
-
other=0.0,
|
120
|
-
).to(REDUCE_TRITON_TYPE)
|
121
|
-
att_value += tl.sum(qpe[None, :] * kpe, 1)
|
122
104
|
att_value *= sm_scale
|
123
105
|
|
124
106
|
if logit_cap > 0:
|
@@ -214,14 +196,7 @@ def _decode_att_m_fwd(
|
|
214
196
|
# shape constraints
|
215
197
|
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
216
198
|
assert Lq == Lk
|
217
|
-
assert Lk in {16, 32, 64, 128, 256
|
218
|
-
|
219
|
-
if Lk == 576:
|
220
|
-
BLOCK_DMODEL = 512
|
221
|
-
BLOCK_DPE = 64
|
222
|
-
else:
|
223
|
-
BLOCK_DMODEL = Lk
|
224
|
-
BLOCK_DPE = 0
|
199
|
+
assert Lk in {16, 32, 64, 128, 256}
|
225
200
|
|
226
201
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
227
202
|
|
@@ -249,8 +224,7 @@ def _decode_att_m_fwd(
|
|
249
224
|
k_buffer.stride(1),
|
250
225
|
att_out.stride(0),
|
251
226
|
kv_group_num=kv_group_num,
|
252
|
-
BLOCK_DMODEL=
|
253
|
-
BLOCK_DPE=BLOCK_DPE,
|
227
|
+
BLOCK_DMODEL=Lk,
|
254
228
|
BLOCK_N=BLOCK,
|
255
229
|
logit_cap=logit_cap,
|
256
230
|
num_warps=num_warps,
|
@@ -296,6 +270,293 @@ def _decode_softmax_reducev_fwd(
|
|
296
270
|
)
|
297
271
|
|
298
272
|
|
273
|
+
@triton.jit
|
274
|
+
def _fwd_grouped_kernel_stage1(
|
275
|
+
Q,
|
276
|
+
K_Buffer,
|
277
|
+
sm_scale,
|
278
|
+
Req_to_tokens,
|
279
|
+
B_req_idx,
|
280
|
+
B_Start_Loc,
|
281
|
+
B_Seqlen,
|
282
|
+
Att_Out,
|
283
|
+
stride_req_to_tokens_b,
|
284
|
+
stride_qbs,
|
285
|
+
stride_qh,
|
286
|
+
stride_buf_kbs,
|
287
|
+
stride_buf_kh,
|
288
|
+
att_stride_h,
|
289
|
+
kv_group_num: tl.constexpr,
|
290
|
+
q_head_num: tl.constexpr,
|
291
|
+
BLOCK_DMODEL: tl.constexpr,
|
292
|
+
BLOCK_DPE: tl.constexpr,
|
293
|
+
BLOCK_N: tl.constexpr,
|
294
|
+
BLOCK_H: tl.constexpr,
|
295
|
+
logit_cap: tl.constexpr,
|
296
|
+
):
|
297
|
+
cur_batch = tl.program_id(0)
|
298
|
+
cur_kv_head = tl.program_id(1)
|
299
|
+
start_n = tl.program_id(2)
|
300
|
+
|
301
|
+
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
|
302
|
+
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
|
303
|
+
mask_h = mask_h & (cur_head < q_head_num)
|
304
|
+
|
305
|
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
306
|
+
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
307
|
+
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
308
|
+
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
309
|
+
|
310
|
+
cur_batch_start_index = 0
|
311
|
+
cur_batch_end_index = cur_batch_seq_len
|
312
|
+
|
313
|
+
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
314
|
+
|
315
|
+
if BLOCK_DPE > 0:
|
316
|
+
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
317
|
+
off_qpe = (
|
318
|
+
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
|
319
|
+
)
|
320
|
+
|
321
|
+
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
322
|
+
|
323
|
+
block_stard_index = start_n * BLOCK_N
|
324
|
+
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
325
|
+
|
326
|
+
for start_mark in range(0, block_mask, 1):
|
327
|
+
q = tl.load(Q + offs_q + start_mark, mask=mask_h[:, None]).to(
|
328
|
+
REDUCE_TRITON_TYPE
|
329
|
+
)
|
330
|
+
offs_n_new = cur_batch_start_index + offs_n
|
331
|
+
k_loc = tl.load(
|
332
|
+
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
|
333
|
+
mask=offs_n_new < cur_batch_end_index,
|
334
|
+
other=0,
|
335
|
+
)
|
336
|
+
offs_buf_k = (
|
337
|
+
k_loc[None, :] * stride_buf_kbs
|
338
|
+
+ cur_kv_head * stride_buf_kh
|
339
|
+
+ offs_d[:, None]
|
340
|
+
)
|
341
|
+
k = tl.load(
|
342
|
+
K_Buffer + offs_buf_k,
|
343
|
+
mask=offs_n_new[None, :] < cur_batch_end_index,
|
344
|
+
other=0.0,
|
345
|
+
).to(REDUCE_TRITON_TYPE)
|
346
|
+
qk = tl.dot(q, k)
|
347
|
+
if BLOCK_DPE > 0:
|
348
|
+
qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
|
349
|
+
REDUCE_TRITON_TYPE
|
350
|
+
)
|
351
|
+
offs_buf_kpe = (
|
352
|
+
k_loc[None, :] * stride_buf_kbs
|
353
|
+
+ cur_kv_head * stride_buf_kh
|
354
|
+
+ offs_dpe[:, None]
|
355
|
+
)
|
356
|
+
kpe = tl.load(
|
357
|
+
K_Buffer + offs_buf_kpe,
|
358
|
+
mask=offs_n_new[None, :] < cur_batch_end_index,
|
359
|
+
other=0.0,
|
360
|
+
).to(REDUCE_TRITON_TYPE)
|
361
|
+
qk += tl.dot(qpe, kpe)
|
362
|
+
qk *= sm_scale
|
363
|
+
|
364
|
+
if logit_cap > 0:
|
365
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
366
|
+
|
367
|
+
offs_o = cur_head[:, None] * att_stride_h + (
|
368
|
+
cur_batch_in_all_start_index + offs_n[None, :]
|
369
|
+
)
|
370
|
+
|
371
|
+
tl.store(
|
372
|
+
Att_Out + offs_o,
|
373
|
+
qk,
|
374
|
+
mask=mask_h[:, None] & (offs_n_new[None, :] < cur_batch_end_index),
|
375
|
+
)
|
376
|
+
|
377
|
+
|
378
|
+
@triton.jit
|
379
|
+
def _fwd_grouped_kernel_stage2(
|
380
|
+
Logics,
|
381
|
+
V_Buffer,
|
382
|
+
Out,
|
383
|
+
Req_to_tokens,
|
384
|
+
B_req_idx,
|
385
|
+
B_Start_Loc,
|
386
|
+
B_Seqlen,
|
387
|
+
stride_logic_h,
|
388
|
+
stride_buf_vbs,
|
389
|
+
stride_buf_vh,
|
390
|
+
stride_obs,
|
391
|
+
stride_oh,
|
392
|
+
stride_req_to_token_b,
|
393
|
+
kv_group_num: tl.constexpr,
|
394
|
+
q_head_num: tl.constexpr,
|
395
|
+
BLOCK_DMODEL: tl.constexpr,
|
396
|
+
BLOCK_N: tl.constexpr,
|
397
|
+
BLOCK_H: tl.constexpr,
|
398
|
+
):
|
399
|
+
cur_batch = tl.program_id(0)
|
400
|
+
cur_kv_head = tl.program_id(1)
|
401
|
+
|
402
|
+
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
|
403
|
+
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
|
404
|
+
mask_h = mask_h & (cur_head < q_head_num)
|
405
|
+
|
406
|
+
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
407
|
+
cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
|
408
|
+
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
409
|
+
|
410
|
+
offs_n = tl.arange(0, BLOCK_N)
|
411
|
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
412
|
+
|
413
|
+
offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :]
|
414
|
+
v_ptrs = V_Buffer + offs_buf_v
|
415
|
+
|
416
|
+
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
|
417
|
+
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
|
418
|
+
acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32)
|
419
|
+
|
420
|
+
for start_n in range(0, cur_batch_seq_len, BLOCK_N):
|
421
|
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
422
|
+
v_index = tl.load(
|
423
|
+
Req_to_tokens
|
424
|
+
+ cur_batch_req_idx * stride_req_to_token_b
|
425
|
+
+ (start_n + offs_n),
|
426
|
+
mask=(start_n + offs_n) < cur_batch_seq_len,
|
427
|
+
other=0,
|
428
|
+
)
|
429
|
+
|
430
|
+
offs_qk = cur_head[:, None] * stride_logic_h + (
|
431
|
+
cur_batch_start_loc + start_n + offs_n[None, :]
|
432
|
+
)
|
433
|
+
|
434
|
+
qk = tl.load(
|
435
|
+
Logics + offs_qk,
|
436
|
+
mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
|
437
|
+
other=float("-inf"),
|
438
|
+
)
|
439
|
+
|
440
|
+
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
441
|
+
old_scale = tl.exp(e_max - n_e_max)
|
442
|
+
p = tl.exp(qk - n_e_max[:, None])
|
443
|
+
e_sum = e_sum * old_scale + tl.sum(p, 1)
|
444
|
+
v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs)
|
445
|
+
p = p.to(v.dtype)
|
446
|
+
acc = acc * old_scale[:, None] + tl.dot(p, v)
|
447
|
+
e_max = n_e_max
|
448
|
+
|
449
|
+
acc = acc / e_sum[:, None]
|
450
|
+
off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :]
|
451
|
+
out_ptrs = Out + off_o
|
452
|
+
tl.store(out_ptrs, acc, mask=mask_h[:, None])
|
453
|
+
|
454
|
+
|
455
|
+
def _decode_grouped_att_m_fwd(
|
456
|
+
q,
|
457
|
+
k_buffer,
|
458
|
+
att_out,
|
459
|
+
Req_to_tokens,
|
460
|
+
B_req_idx,
|
461
|
+
B_Start_Loc,
|
462
|
+
B_Seqlen,
|
463
|
+
max_len_in_batch,
|
464
|
+
sm_scale,
|
465
|
+
logit_cap,
|
466
|
+
):
|
467
|
+
BLOCK = 32
|
468
|
+
# shape constraints
|
469
|
+
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
470
|
+
assert Lq == Lk
|
471
|
+
assert Lk in {16, 32, 64, 128, 256, 576}
|
472
|
+
|
473
|
+
if Lk == 576:
|
474
|
+
BLOCK_DMODEL = 512
|
475
|
+
BLOCK_DPE = 64
|
476
|
+
else:
|
477
|
+
BLOCK_DMODEL = Lk
|
478
|
+
BLOCK_DPE = 0
|
479
|
+
|
480
|
+
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
481
|
+
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
482
|
+
|
483
|
+
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
|
484
|
+
grid = (
|
485
|
+
batch,
|
486
|
+
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
487
|
+
triton.cdiv(max_len_in_batch, BLOCK),
|
488
|
+
)
|
489
|
+
|
490
|
+
num_warps = 4
|
491
|
+
|
492
|
+
_fwd_grouped_kernel_stage1[grid](
|
493
|
+
q,
|
494
|
+
k_buffer,
|
495
|
+
sm_scale,
|
496
|
+
Req_to_tokens,
|
497
|
+
B_req_idx,
|
498
|
+
B_Start_Loc,
|
499
|
+
B_Seqlen,
|
500
|
+
att_out,
|
501
|
+
Req_to_tokens.stride(0),
|
502
|
+
q.stride(0),
|
503
|
+
q.stride(1),
|
504
|
+
k_buffer.stride(0),
|
505
|
+
k_buffer.stride(1),
|
506
|
+
att_out.stride(0),
|
507
|
+
kv_group_num=kv_group_num,
|
508
|
+
q_head_num=head_num,
|
509
|
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
510
|
+
BLOCK_DPE=BLOCK_DPE,
|
511
|
+
BLOCK_N=BLOCK,
|
512
|
+
BLOCK_H=BLOCK_H,
|
513
|
+
logit_cap=logit_cap,
|
514
|
+
num_warps=num_warps,
|
515
|
+
num_stages=1,
|
516
|
+
)
|
517
|
+
|
518
|
+
|
519
|
+
def _decode_grouped_softmax_reducev_fwd(
|
520
|
+
logics,
|
521
|
+
v_buffer,
|
522
|
+
o,
|
523
|
+
req_to_tokens,
|
524
|
+
b_req_idx,
|
525
|
+
b_start_loc,
|
526
|
+
b_seq_len,
|
527
|
+
):
|
528
|
+
BLOCK = 128
|
529
|
+
batch, head_num = b_seq_len.shape[0], logics.shape[0]
|
530
|
+
kv_group_num = logics.shape[0] // v_buffer.shape[1]
|
531
|
+
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
|
532
|
+
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
|
533
|
+
|
534
|
+
num_warps = 8
|
535
|
+
|
536
|
+
_fwd_grouped_kernel_stage2[grid](
|
537
|
+
logics,
|
538
|
+
v_buffer,
|
539
|
+
o,
|
540
|
+
req_to_tokens,
|
541
|
+
b_req_idx,
|
542
|
+
b_start_loc,
|
543
|
+
b_seq_len,
|
544
|
+
logics.stride(0),
|
545
|
+
v_buffer.stride(0),
|
546
|
+
v_buffer.stride(1),
|
547
|
+
o.stride(0),
|
548
|
+
o.stride(1),
|
549
|
+
req_to_tokens.stride(0),
|
550
|
+
kv_group_num=kv_group_num,
|
551
|
+
q_head_num=head_num,
|
552
|
+
BLOCK_DMODEL=v_buffer.shape[-1],
|
553
|
+
BLOCK_N=BLOCK,
|
554
|
+
BLOCK_H=BLOCK_H,
|
555
|
+
num_warps=num_warps,
|
556
|
+
num_stages=1,
|
557
|
+
)
|
558
|
+
|
559
|
+
|
299
560
|
def decode_attention_fwd(
|
300
561
|
q,
|
301
562
|
k_buffer,
|
@@ -316,24 +577,51 @@ def decode_attention_fwd(
|
|
316
577
|
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
317
578
|
)
|
318
579
|
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
580
|
+
kv_group_num = q.shape[1] // v_buffer.shape[1]
|
581
|
+
|
582
|
+
if kv_group_num == 1:
|
583
|
+
# MHA
|
584
|
+
_decode_att_m_fwd(
|
585
|
+
q,
|
586
|
+
k_buffer,
|
587
|
+
att_m,
|
588
|
+
req_to_token,
|
589
|
+
b_req_idx,
|
590
|
+
b_start_loc,
|
591
|
+
b_seq_len,
|
592
|
+
max_len_in_batch,
|
593
|
+
sm_scale,
|
594
|
+
logit_cap,
|
595
|
+
)
|
596
|
+
_decode_softmax_reducev_fwd(
|
597
|
+
att_m,
|
598
|
+
v_buffer,
|
599
|
+
o,
|
600
|
+
req_to_token,
|
601
|
+
b_req_idx,
|
602
|
+
b_start_loc,
|
603
|
+
b_seq_len,
|
604
|
+
)
|
605
|
+
else:
|
606
|
+
# GQA/MQA/MLA
|
607
|
+
_decode_grouped_att_m_fwd(
|
608
|
+
q,
|
609
|
+
k_buffer,
|
610
|
+
att_m,
|
611
|
+
req_to_token,
|
612
|
+
b_req_idx,
|
613
|
+
b_start_loc,
|
614
|
+
b_seq_len,
|
615
|
+
max_len_in_batch,
|
616
|
+
sm_scale,
|
617
|
+
logit_cap,
|
618
|
+
)
|
619
|
+
_decode_grouped_softmax_reducev_fwd(
|
620
|
+
att_m,
|
621
|
+
v_buffer,
|
622
|
+
o,
|
623
|
+
req_to_token,
|
624
|
+
b_req_idx,
|
625
|
+
b_start_loc,
|
626
|
+
b_seq_len,
|
627
|
+
)
|
@@ -239,7 +239,7 @@ class FusedMoE(torch.nn.Module):
|
|
239
239
|
weight_name: str,
|
240
240
|
shard_id: int,
|
241
241
|
expert_id: int,
|
242
|
-
|
242
|
+
use_presharded_weights: bool = False,
|
243
243
|
):
|
244
244
|
param_data = param.data
|
245
245
|
|
@@ -273,7 +273,7 @@ class FusedMoE(torch.nn.Module):
|
|
273
273
|
else:
|
274
274
|
tp_rank = get_tensor_model_parallel_rank()
|
275
275
|
shard_size = self.intermediate_size_per_partition
|
276
|
-
if
|
276
|
+
if use_presharded_weights:
|
277
277
|
shard = slice(None)
|
278
278
|
else:
|
279
279
|
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|