sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/constrained/fsm_cache.py +11 -2
  11. sglang/srt/constrained/jump_forward.py +1 -0
  12. sglang/srt/conversation.py +50 -1
  13. sglang/srt/hf_transformers_utils.py +22 -23
  14. sglang/srt/layers/activation.py +100 -1
  15. sglang/srt/layers/decode_attention.py +338 -50
  16. sglang/srt/layers/fused_moe/layer.py +2 -2
  17. sglang/srt/layers/logits_processor.py +56 -19
  18. sglang/srt/layers/radix_attention.py +3 -4
  19. sglang/srt/layers/sampler.py +101 -0
  20. sglang/srt/managers/controller_multi.py +2 -8
  21. sglang/srt/managers/controller_single.py +7 -10
  22. sglang/srt/managers/detokenizer_manager.py +20 -9
  23. sglang/srt/managers/io_struct.py +44 -11
  24. sglang/srt/managers/policy_scheduler.py +5 -2
  25. sglang/srt/managers/schedule_batch.py +46 -166
  26. sglang/srt/managers/tokenizer_manager.py +192 -83
  27. sglang/srt/managers/tp_worker.py +118 -24
  28. sglang/srt/mem_cache/memory_pool.py +82 -8
  29. sglang/srt/mm_utils.py +79 -7
  30. sglang/srt/model_executor/cuda_graph_runner.py +32 -8
  31. sglang/srt/model_executor/forward_batch_info.py +51 -26
  32. sglang/srt/model_executor/model_runner.py +201 -58
  33. sglang/srt/models/gemma2.py +10 -6
  34. sglang/srt/models/gpt_bigcode.py +1 -1
  35. sglang/srt/models/grok.py +11 -1
  36. sglang/srt/models/llama_embedding.py +4 -0
  37. sglang/srt/models/llava.py +176 -59
  38. sglang/srt/models/qwen2.py +9 -3
  39. sglang/srt/openai_api/adapter.py +200 -39
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_batch_info.py +136 -0
  42. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
  43. sglang/srt/server.py +92 -57
  44. sglang/srt/server_args.py +43 -15
  45. sglang/srt/utils.py +26 -16
  46. sglang/test/runners.py +22 -30
  47. sglang/test/simple_eval_common.py +9 -10
  48. sglang/test/simple_eval_gpqa.py +2 -1
  49. sglang/test/simple_eval_humaneval.py +2 -2
  50. sglang/test/simple_eval_math.py +2 -1
  51. sglang/test/simple_eval_mmlu.py +2 -1
  52. sglang/test/test_activation.py +55 -0
  53. sglang/test/test_utils.py +36 -53
  54. sglang/version.py +1 -1
  55. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
  56. sglang-0.2.14.post1.dist-info/RECORD +114 -0
  57. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
  58. sglang/launch_server_llavavid.py +0 -29
  59. sglang-0.2.13.dist-info/RECORD +0 -112
  60. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
  61. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
@@ -30,14 +30,19 @@ from transformers import (
30
30
  PreTrainedTokenizer,
31
31
  PreTrainedTokenizerFast,
32
32
  )
33
- from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
34
33
 
35
- from sglang.srt.utils import is_multimodal_model
34
+ try:
35
+ from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
36
+
37
+ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
38
+ ChatGLMConfig.model_type: ChatGLMConfig,
39
+ DbrxConfig.model_type: DbrxConfig,
40
+ }
41
+ except ImportError:
42
+ # We want this file to run without vllm dependency
43
+ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {}
36
44
 
37
- _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
38
- ChatGLMConfig.model_type: ChatGLMConfig,
39
- DbrxConfig.model_type: DbrxConfig,
40
- }
45
+ from sglang.srt.utils import is_multimodal_model
41
46
 
42
47
 
43
48
  def download_from_hf(model_path: str):
@@ -137,18 +142,6 @@ def get_tokenizer(
137
142
  raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
138
143
  kwargs["use_fast"] = False
139
144
 
140
- if (
141
- "llama" in tokenizer_name.lower()
142
- and kwargs.get("use_fast", True)
143
- and tokenizer_name != _FAST_LLAMA_TOKENIZER
144
- ):
145
- pass
146
- # warnings.warn(
147
- # "For some LLaMA V1 models, initializing the fast tokenizer may "
148
- # "take a long time. To reduce the initialization time, consider "
149
- # f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
150
- # "tokenizer."
151
- # )
152
145
  try:
153
146
  tokenizer = AutoTokenizer.from_pretrained(
154
147
  tokenizer_name,
@@ -229,6 +222,8 @@ class TiktokenTokenizer:
229
222
  }
230
223
  assert tok_dict["word_split"] == "V1"
231
224
 
225
+ default_allowed_special = None
226
+
232
227
  kwargs = {
233
228
  "name": name,
234
229
  "pat_str": tok_dict.get("pat_str", PAT_STR_B),
@@ -242,14 +237,18 @@ class TiktokenTokenizer:
242
237
  for bytes_list in tok_dict["default_allowed_special"]
243
238
  ]
244
239
  )
245
- else:
246
- default_allowed_special = None
247
240
  if "vocab_size" in tok_dict:
248
241
  kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
249
242
 
243
+ PAD = "<|pad|>"
244
+ EOS = "<|eos|>"
245
+ SEP = "<|separator|>"
246
+
247
+ DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
248
+
250
249
  tokenizer = tiktoken.Encoding(**kwargs)
251
250
  tokenizer._default_allowed_special = default_allowed_special or set()
252
- tokenizer._default_allowed_special |= {"<|separator|>"}
251
+ tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
253
252
 
254
253
  def encode_patched(
255
254
  self,
@@ -266,14 +265,14 @@ class TiktokenTokenizer:
266
265
  self,
267
266
  text,
268
267
  allowed_special=allowed_special,
269
- disallowed_special=disallowed_special,
268
+ disallowed_special=(),
270
269
  )
271
270
 
272
271
  tokenizer.encode = functools.partial(encode_patched, tokenizer)
273
272
 
274
273
  # Convert to HF interface
275
274
  self.tokenizer = tokenizer
276
- self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
275
+ self.eos_token_id = tokenizer._special_tokens[EOS]
277
276
  self.vocab_size = tokenizer.n_vocab
278
277
  self.chat_template = Template(
279
278
  "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
@@ -13,10 +13,20 @@ limitations under the License.
13
13
 
14
14
  """Fused operators for activation layers."""
15
15
 
16
+ from typing import Optional
17
+
16
18
  import torch
19
+ import torch.nn as nn
17
20
  import torch.nn.functional as F
18
- from flashinfer.activation import silu_and_mul
21
+ from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul
22
+ from vllm.distributed import (
23
+ divide,
24
+ get_tensor_model_parallel_rank,
25
+ get_tensor_model_parallel_world_size,
26
+ )
19
27
  from vllm.model_executor.custom_op import CustomOp
28
+ from vllm.model_executor.layers.quantization import QuantizationConfig
29
+ from vllm.model_executor.utils import set_weight_attrs
20
30
 
21
31
 
22
32
  class SiluAndMul(CustomOp):
@@ -30,3 +40,92 @@ class SiluAndMul(CustomOp):
30
40
  out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
31
41
  silu_and_mul(x, out)
32
42
  return out
43
+
44
+
45
+ class GeluAndMul(CustomOp):
46
+ def __init__(self, **kwargs):
47
+ super().__init__()
48
+
49
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
50
+ d = x.shape[-1] // 2
51
+ return F.gelu(x[..., :d], approximate="tanh") * x[..., d:]
52
+
53
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
54
+ d = x.shape[-1] // 2
55
+ output_shape = x.shape[:-1] + (d,)
56
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
57
+ gelu_tanh_and_mul(x, out)
58
+ return out
59
+
60
+
61
+ class ScaledActivation(nn.Module):
62
+ """An activation function with post-scale parameters.
63
+
64
+ This is used for some quantization methods like AWQ.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ act_module: nn.Module,
70
+ intermediate_size: int,
71
+ input_is_parallel: bool = True,
72
+ params_dtype: Optional[torch.dtype] = None,
73
+ ):
74
+ super().__init__()
75
+ self.act = act_module
76
+ self.input_is_parallel = input_is_parallel
77
+ if input_is_parallel:
78
+ tp_size = get_tensor_model_parallel_world_size()
79
+ intermediate_size_per_partition = divide(intermediate_size, tp_size)
80
+ else:
81
+ intermediate_size_per_partition = intermediate_size
82
+ if params_dtype is None:
83
+ params_dtype = torch.get_default_dtype()
84
+ self.scales = nn.Parameter(
85
+ torch.empty(intermediate_size_per_partition, dtype=params_dtype)
86
+ )
87
+ set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ return self.act(x) / self.scales
91
+
92
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
93
+ param_data = param.data
94
+ if self.input_is_parallel:
95
+ tp_rank = get_tensor_model_parallel_rank()
96
+ shard_size = param_data.shape[0]
97
+ start_idx = tp_rank * shard_size
98
+ loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
99
+ assert param_data.shape == loaded_weight.shape
100
+ param_data.copy_(loaded_weight)
101
+
102
+
103
+ _ACTIVATION_REGISTRY = {
104
+ "gelu": nn.GELU(),
105
+ "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
106
+ }
107
+
108
+
109
+ def get_act_fn(
110
+ act_fn_name: str,
111
+ quant_config: Optional[QuantizationConfig] = None,
112
+ intermediate_size: Optional[int] = None,
113
+ input_is_parallel: bool = True,
114
+ params_dtype: Optional[torch.dtype] = None,
115
+ ) -> nn.Module:
116
+ """Get an activation function by name."""
117
+ act_fn_name = act_fn_name.lower()
118
+ if act_fn_name not in _ACTIVATION_REGISTRY:
119
+ raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
120
+
121
+ act_fn = _ACTIVATION_REGISTRY[act_fn_name]
122
+ if quant_config is not None and act_fn_name in quant_config.get_scaled_act_names():
123
+ if intermediate_size is None:
124
+ raise ValueError(
125
+ "intermediate_size must be specified for scaled "
126
+ "activation functions."
127
+ )
128
+ return ScaledActivation(
129
+ act_fn, intermediate_size, input_is_parallel, params_dtype
130
+ )
131
+ return act_fn
@@ -26,7 +26,7 @@ import triton.language as tl
26
26
 
27
27
  from sglang.srt.managers.schedule_batch import global_server_args_dict
28
28
 
29
- if global_server_args_dict.get("attention_reduce_in_fp32", False):
29
+ if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
30
30
  REDUCE_TRITON_TYPE = tl.float32
31
31
  REDUCE_TORCH_TYPE = torch.float32
32
32
  else:
@@ -58,7 +58,6 @@ def _fwd_kernel_stage1(
58
58
  att_stride_h,
59
59
  kv_group_num: tl.constexpr,
60
60
  BLOCK_DMODEL: tl.constexpr,
61
- BLOCK_DPE: tl.constexpr,
62
61
  BLOCK_N: tl.constexpr,
63
62
  logit_cap: tl.constexpr,
64
63
  ):
@@ -78,10 +77,6 @@ def _fwd_kernel_stage1(
78
77
 
79
78
  off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
80
79
 
81
- if BLOCK_DPE > 0:
82
- offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
83
- off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe
84
-
85
80
  offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
86
81
 
87
82
  block_stard_index = start_n * BLOCK_N
@@ -106,19 +101,6 @@ def _fwd_kernel_stage1(
106
101
  other=0.0,
107
102
  ).to(REDUCE_TRITON_TYPE)
108
103
  att_value = tl.sum(q[None, :] * k, 1)
109
- if BLOCK_DPE > 0:
110
- qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE)
111
- offs_buf_kpe = (
112
- k_loc[:, None] * stride_buf_kbs
113
- + cur_kv_head * stride_buf_kh
114
- + offs_dpe[None, :]
115
- )
116
- kpe = tl.load(
117
- K_Buffer + offs_buf_kpe,
118
- mask=offs_n_new[:, None] < cur_batch_end_index,
119
- other=0.0,
120
- ).to(REDUCE_TRITON_TYPE)
121
- att_value += tl.sum(qpe[None, :] * kpe, 1)
122
104
  att_value *= sm_scale
123
105
 
124
106
  if logit_cap > 0:
@@ -214,14 +196,7 @@ def _decode_att_m_fwd(
214
196
  # shape constraints
215
197
  Lq, Lk = q.shape[-1], k_buffer.shape[-1]
216
198
  assert Lq == Lk
217
- assert Lk in {16, 32, 64, 128, 256, 576}
218
-
219
- if Lk == 576:
220
- BLOCK_DMODEL = 512
221
- BLOCK_DPE = 64
222
- else:
223
- BLOCK_DMODEL = Lk
224
- BLOCK_DPE = 0
199
+ assert Lk in {16, 32, 64, 128, 256}
225
200
 
226
201
  batch, head_num = B_req_idx.shape[0], q.shape[1]
227
202
 
@@ -249,8 +224,7 @@ def _decode_att_m_fwd(
249
224
  k_buffer.stride(1),
250
225
  att_out.stride(0),
251
226
  kv_group_num=kv_group_num,
252
- BLOCK_DMODEL=BLOCK_DMODEL,
253
- BLOCK_DPE=BLOCK_DPE,
227
+ BLOCK_DMODEL=Lk,
254
228
  BLOCK_N=BLOCK,
255
229
  logit_cap=logit_cap,
256
230
  num_warps=num_warps,
@@ -296,6 +270,293 @@ def _decode_softmax_reducev_fwd(
296
270
  )
297
271
 
298
272
 
273
+ @triton.jit
274
+ def _fwd_grouped_kernel_stage1(
275
+ Q,
276
+ K_Buffer,
277
+ sm_scale,
278
+ Req_to_tokens,
279
+ B_req_idx,
280
+ B_Start_Loc,
281
+ B_Seqlen,
282
+ Att_Out,
283
+ stride_req_to_tokens_b,
284
+ stride_qbs,
285
+ stride_qh,
286
+ stride_buf_kbs,
287
+ stride_buf_kh,
288
+ att_stride_h,
289
+ kv_group_num: tl.constexpr,
290
+ q_head_num: tl.constexpr,
291
+ BLOCK_DMODEL: tl.constexpr,
292
+ BLOCK_DPE: tl.constexpr,
293
+ BLOCK_N: tl.constexpr,
294
+ BLOCK_H: tl.constexpr,
295
+ logit_cap: tl.constexpr,
296
+ ):
297
+ cur_batch = tl.program_id(0)
298
+ cur_kv_head = tl.program_id(1)
299
+ start_n = tl.program_id(2)
300
+
301
+ cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
302
+ mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
303
+ mask_h = mask_h & (cur_head < q_head_num)
304
+
305
+ offs_d = tl.arange(0, BLOCK_DMODEL)
306
+ cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
307
+ cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
308
+ cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
309
+
310
+ cur_batch_start_index = 0
311
+ cur_batch_end_index = cur_batch_seq_len
312
+
313
+ offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
314
+
315
+ if BLOCK_DPE > 0:
316
+ offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
317
+ off_qpe = (
318
+ cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
319
+ )
320
+
321
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
322
+
323
+ block_stard_index = start_n * BLOCK_N
324
+ block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
325
+
326
+ for start_mark in range(0, block_mask, 1):
327
+ q = tl.load(Q + offs_q + start_mark, mask=mask_h[:, None]).to(
328
+ REDUCE_TRITON_TYPE
329
+ )
330
+ offs_n_new = cur_batch_start_index + offs_n
331
+ k_loc = tl.load(
332
+ Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
333
+ mask=offs_n_new < cur_batch_end_index,
334
+ other=0,
335
+ )
336
+ offs_buf_k = (
337
+ k_loc[None, :] * stride_buf_kbs
338
+ + cur_kv_head * stride_buf_kh
339
+ + offs_d[:, None]
340
+ )
341
+ k = tl.load(
342
+ K_Buffer + offs_buf_k,
343
+ mask=offs_n_new[None, :] < cur_batch_end_index,
344
+ other=0.0,
345
+ ).to(REDUCE_TRITON_TYPE)
346
+ qk = tl.dot(q, k)
347
+ if BLOCK_DPE > 0:
348
+ qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
349
+ REDUCE_TRITON_TYPE
350
+ )
351
+ offs_buf_kpe = (
352
+ k_loc[None, :] * stride_buf_kbs
353
+ + cur_kv_head * stride_buf_kh
354
+ + offs_dpe[:, None]
355
+ )
356
+ kpe = tl.load(
357
+ K_Buffer + offs_buf_kpe,
358
+ mask=offs_n_new[None, :] < cur_batch_end_index,
359
+ other=0.0,
360
+ ).to(REDUCE_TRITON_TYPE)
361
+ qk += tl.dot(qpe, kpe)
362
+ qk *= sm_scale
363
+
364
+ if logit_cap > 0:
365
+ qk = logit_cap * tanh(qk / logit_cap)
366
+
367
+ offs_o = cur_head[:, None] * att_stride_h + (
368
+ cur_batch_in_all_start_index + offs_n[None, :]
369
+ )
370
+
371
+ tl.store(
372
+ Att_Out + offs_o,
373
+ qk,
374
+ mask=mask_h[:, None] & (offs_n_new[None, :] < cur_batch_end_index),
375
+ )
376
+
377
+
378
+ @triton.jit
379
+ def _fwd_grouped_kernel_stage2(
380
+ Logics,
381
+ V_Buffer,
382
+ Out,
383
+ Req_to_tokens,
384
+ B_req_idx,
385
+ B_Start_Loc,
386
+ B_Seqlen,
387
+ stride_logic_h,
388
+ stride_buf_vbs,
389
+ stride_buf_vh,
390
+ stride_obs,
391
+ stride_oh,
392
+ stride_req_to_token_b,
393
+ kv_group_num: tl.constexpr,
394
+ q_head_num: tl.constexpr,
395
+ BLOCK_DMODEL: tl.constexpr,
396
+ BLOCK_N: tl.constexpr,
397
+ BLOCK_H: tl.constexpr,
398
+ ):
399
+ cur_batch = tl.program_id(0)
400
+ cur_kv_head = tl.program_id(1)
401
+
402
+ cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
403
+ mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
404
+ mask_h = mask_h & (cur_head < q_head_num)
405
+
406
+ cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
407
+ cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
408
+ cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
409
+
410
+ offs_n = tl.arange(0, BLOCK_N)
411
+ offs_d = tl.arange(0, BLOCK_DMODEL)
412
+
413
+ offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :]
414
+ v_ptrs = V_Buffer + offs_buf_v
415
+
416
+ e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
417
+ e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
418
+ acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32)
419
+
420
+ for start_n in range(0, cur_batch_seq_len, BLOCK_N):
421
+ start_n = tl.multiple_of(start_n, BLOCK_N)
422
+ v_index = tl.load(
423
+ Req_to_tokens
424
+ + cur_batch_req_idx * stride_req_to_token_b
425
+ + (start_n + offs_n),
426
+ mask=(start_n + offs_n) < cur_batch_seq_len,
427
+ other=0,
428
+ )
429
+
430
+ offs_qk = cur_head[:, None] * stride_logic_h + (
431
+ cur_batch_start_loc + start_n + offs_n[None, :]
432
+ )
433
+
434
+ qk = tl.load(
435
+ Logics + offs_qk,
436
+ mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
437
+ other=float("-inf"),
438
+ )
439
+
440
+ n_e_max = tl.maximum(tl.max(qk, 1), e_max)
441
+ old_scale = tl.exp(e_max - n_e_max)
442
+ p = tl.exp(qk - n_e_max[:, None])
443
+ e_sum = e_sum * old_scale + tl.sum(p, 1)
444
+ v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs)
445
+ p = p.to(v.dtype)
446
+ acc = acc * old_scale[:, None] + tl.dot(p, v)
447
+ e_max = n_e_max
448
+
449
+ acc = acc / e_sum[:, None]
450
+ off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :]
451
+ out_ptrs = Out + off_o
452
+ tl.store(out_ptrs, acc, mask=mask_h[:, None])
453
+
454
+
455
+ def _decode_grouped_att_m_fwd(
456
+ q,
457
+ k_buffer,
458
+ att_out,
459
+ Req_to_tokens,
460
+ B_req_idx,
461
+ B_Start_Loc,
462
+ B_Seqlen,
463
+ max_len_in_batch,
464
+ sm_scale,
465
+ logit_cap,
466
+ ):
467
+ BLOCK = 32
468
+ # shape constraints
469
+ Lq, Lk = q.shape[-1], k_buffer.shape[-1]
470
+ assert Lq == Lk
471
+ assert Lk in {16, 32, 64, 128, 256, 576}
472
+
473
+ if Lk == 576:
474
+ BLOCK_DMODEL = 512
475
+ BLOCK_DPE = 64
476
+ else:
477
+ BLOCK_DMODEL = Lk
478
+ BLOCK_DPE = 0
479
+
480
+ batch, head_num = B_req_idx.shape[0], q.shape[1]
481
+ kv_group_num = q.shape[1] // k_buffer.shape[1]
482
+
483
+ BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
484
+ grid = (
485
+ batch,
486
+ triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
487
+ triton.cdiv(max_len_in_batch, BLOCK),
488
+ )
489
+
490
+ num_warps = 4
491
+
492
+ _fwd_grouped_kernel_stage1[grid](
493
+ q,
494
+ k_buffer,
495
+ sm_scale,
496
+ Req_to_tokens,
497
+ B_req_idx,
498
+ B_Start_Loc,
499
+ B_Seqlen,
500
+ att_out,
501
+ Req_to_tokens.stride(0),
502
+ q.stride(0),
503
+ q.stride(1),
504
+ k_buffer.stride(0),
505
+ k_buffer.stride(1),
506
+ att_out.stride(0),
507
+ kv_group_num=kv_group_num,
508
+ q_head_num=head_num,
509
+ BLOCK_DMODEL=BLOCK_DMODEL,
510
+ BLOCK_DPE=BLOCK_DPE,
511
+ BLOCK_N=BLOCK,
512
+ BLOCK_H=BLOCK_H,
513
+ logit_cap=logit_cap,
514
+ num_warps=num_warps,
515
+ num_stages=1,
516
+ )
517
+
518
+
519
+ def _decode_grouped_softmax_reducev_fwd(
520
+ logics,
521
+ v_buffer,
522
+ o,
523
+ req_to_tokens,
524
+ b_req_idx,
525
+ b_start_loc,
526
+ b_seq_len,
527
+ ):
528
+ BLOCK = 128
529
+ batch, head_num = b_seq_len.shape[0], logics.shape[0]
530
+ kv_group_num = logics.shape[0] // v_buffer.shape[1]
531
+ BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
532
+ grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
533
+
534
+ num_warps = 8
535
+
536
+ _fwd_grouped_kernel_stage2[grid](
537
+ logics,
538
+ v_buffer,
539
+ o,
540
+ req_to_tokens,
541
+ b_req_idx,
542
+ b_start_loc,
543
+ b_seq_len,
544
+ logics.stride(0),
545
+ v_buffer.stride(0),
546
+ v_buffer.stride(1),
547
+ o.stride(0),
548
+ o.stride(1),
549
+ req_to_tokens.stride(0),
550
+ kv_group_num=kv_group_num,
551
+ q_head_num=head_num,
552
+ BLOCK_DMODEL=v_buffer.shape[-1],
553
+ BLOCK_N=BLOCK,
554
+ BLOCK_H=BLOCK_H,
555
+ num_warps=num_warps,
556
+ num_stages=1,
557
+ )
558
+
559
+
299
560
  def decode_attention_fwd(
300
561
  q,
301
562
  k_buffer,
@@ -316,24 +577,51 @@ def decode_attention_fwd(
316
577
  (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
317
578
  )
318
579
 
319
- _decode_att_m_fwd(
320
- q,
321
- k_buffer,
322
- att_m,
323
- req_to_token,
324
- b_req_idx,
325
- b_start_loc,
326
- b_seq_len,
327
- max_len_in_batch,
328
- sm_scale,
329
- logit_cap,
330
- )
331
- _decode_softmax_reducev_fwd(
332
- att_m,
333
- v_buffer,
334
- o,
335
- req_to_token,
336
- b_req_idx,
337
- b_start_loc,
338
- b_seq_len,
339
- )
580
+ kv_group_num = q.shape[1] // v_buffer.shape[1]
581
+
582
+ if kv_group_num == 1:
583
+ # MHA
584
+ _decode_att_m_fwd(
585
+ q,
586
+ k_buffer,
587
+ att_m,
588
+ req_to_token,
589
+ b_req_idx,
590
+ b_start_loc,
591
+ b_seq_len,
592
+ max_len_in_batch,
593
+ sm_scale,
594
+ logit_cap,
595
+ )
596
+ _decode_softmax_reducev_fwd(
597
+ att_m,
598
+ v_buffer,
599
+ o,
600
+ req_to_token,
601
+ b_req_idx,
602
+ b_start_loc,
603
+ b_seq_len,
604
+ )
605
+ else:
606
+ # GQA/MQA/MLA
607
+ _decode_grouped_att_m_fwd(
608
+ q,
609
+ k_buffer,
610
+ att_m,
611
+ req_to_token,
612
+ b_req_idx,
613
+ b_start_loc,
614
+ b_seq_len,
615
+ max_len_in_batch,
616
+ sm_scale,
617
+ logit_cap,
618
+ )
619
+ _decode_grouped_softmax_reducev_fwd(
620
+ att_m,
621
+ v_buffer,
622
+ o,
623
+ req_to_token,
624
+ b_req_idx,
625
+ b_start_loc,
626
+ b_seq_len,
627
+ )
@@ -239,7 +239,7 @@ class FusedMoE(torch.nn.Module):
239
239
  weight_name: str,
240
240
  shard_id: int,
241
241
  expert_id: int,
242
- pre_sharded: bool,
242
+ use_presharded_weights: bool = False,
243
243
  ):
244
244
  param_data = param.data
245
245
 
@@ -273,7 +273,7 @@ class FusedMoE(torch.nn.Module):
273
273
  else:
274
274
  tp_rank = get_tensor_model_parallel_rank()
275
275
  shard_size = self.intermediate_size_per_partition
276
- if pre_sharded:
276
+ if use_presharded_weights:
277
277
  shard = slice(None)
278
278
  else:
279
279
  shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)