sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +1 -0
  3. sglang/bench_serving.py +9 -1
  4. sglang/check_env.py +140 -48
  5. sglang/lang/backend/runtime_endpoint.py +1 -0
  6. sglang/lang/chat_template.py +32 -0
  7. sglang/llama3_eval.py +316 -0
  8. sglang/srt/aio_rwlock.py +100 -0
  9. sglang/srt/configs/model_config.py +8 -1
  10. sglang/srt/constrained/xgrammar_backend.py +4 -1
  11. sglang/srt/layers/attention/flashinfer_backend.py +51 -5
  12. sglang/srt/layers/attention/triton_backend.py +16 -25
  13. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  14. sglang/srt/layers/linear.py +20 -2
  15. sglang/srt/layers/logits_processor.py +133 -95
  16. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
  17. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  18. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  19. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
  20. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
  21. sglang/srt/layers/moe/topk.py +191 -0
  22. sglang/srt/layers/quantization/__init__.py +5 -50
  23. sglang/srt/layers/quantization/fp8.py +221 -36
  24. sglang/srt/layers/quantization/fp8_kernel.py +278 -0
  25. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  26. sglang/srt/layers/radix_attention.py +8 -1
  27. sglang/srt/layers/sampler.py +27 -5
  28. sglang/srt/layers/torchao_utils.py +31 -0
  29. sglang/srt/managers/detokenizer_manager.py +37 -17
  30. sglang/srt/managers/io_struct.py +39 -10
  31. sglang/srt/managers/schedule_batch.py +54 -34
  32. sglang/srt/managers/schedule_policy.py +64 -5
  33. sglang/srt/managers/scheduler.py +171 -136
  34. sglang/srt/managers/tokenizer_manager.py +184 -133
  35. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  36. sglang/srt/mem_cache/chunk_cache.py +2 -2
  37. sglang/srt/mem_cache/memory_pool.py +15 -8
  38. sglang/srt/mem_cache/radix_cache.py +12 -2
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -11
  40. sglang/srt/model_executor/model_runner.py +28 -14
  41. sglang/srt/model_parallel.py +66 -5
  42. sglang/srt/models/dbrx.py +1 -1
  43. sglang/srt/models/deepseek.py +1 -1
  44. sglang/srt/models/deepseek_v2.py +67 -18
  45. sglang/srt/models/gemma2.py +34 -0
  46. sglang/srt/models/gemma2_reward.py +0 -1
  47. sglang/srt/models/granite.py +517 -0
  48. sglang/srt/models/grok.py +73 -9
  49. sglang/srt/models/llama.py +22 -0
  50. sglang/srt/models/llama_classification.py +11 -23
  51. sglang/srt/models/llama_reward.py +0 -2
  52. sglang/srt/models/llava.py +37 -14
  53. sglang/srt/models/mixtral.py +2 -2
  54. sglang/srt/models/olmoe.py +1 -1
  55. sglang/srt/models/qwen2.py +20 -0
  56. sglang/srt/models/qwen2_moe.py +1 -1
  57. sglang/srt/models/xverse_moe.py +1 -1
  58. sglang/srt/openai_api/adapter.py +8 -0
  59. sglang/srt/openai_api/protocol.py +9 -4
  60. sglang/srt/server.py +2 -1
  61. sglang/srt/server_args.py +19 -9
  62. sglang/srt/utils.py +40 -54
  63. sglang/test/test_block_fp8.py +341 -0
  64. sglang/test/test_utils.py +3 -2
  65. sglang/utils.py +10 -3
  66. sglang/version.py +1 -1
  67. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
  68. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
  69. sglang/srt/layers/fused_moe_patch.py +0 -133
  70. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  71. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  72. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
  73. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -6,16 +6,20 @@ import functools
6
6
  import json
7
7
  import logging
8
8
  import os
9
- from typing import Any, Callable, Dict, Optional, Tuple
9
+ from typing import Any, Callable, Dict, List, Optional, Tuple
10
10
 
11
11
  import torch
12
12
  import triton
13
13
  import triton.language as tl
14
+ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
14
15
  from vllm import _custom_ops as ops
15
16
 
17
+ from sglang.srt.layers.moe.topk import select_experts
18
+ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
16
19
  from sglang.srt.utils import direct_register_custom_op, get_device_name
17
20
 
18
21
  logger = logging.getLogger(__name__)
22
+ padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
19
23
 
20
24
 
21
25
  @triton.jit
@@ -46,8 +50,14 @@ def fused_moe_kernel(
46
50
  stride_bn,
47
51
  stride_cm,
48
52
  stride_cn,
53
+ stride_asm,
54
+ stride_ask,
49
55
  stride_bse,
56
+ stride_bsk,
50
57
  stride_bsn,
58
+ # Block size for block-wise quantization
59
+ group_n: tl.constexpr,
60
+ group_k: tl.constexpr,
51
61
  # Meta-parameters
52
62
  BLOCK_SIZE_M: tl.constexpr,
53
63
  BLOCK_SIZE_N: tl.constexpr,
@@ -58,6 +68,7 @@ def fused_moe_kernel(
58
68
  compute_type: tl.constexpr,
59
69
  use_fp8_w8a8: tl.constexpr,
60
70
  use_int8_w8a16: tl.constexpr,
71
+ even_Ks: tl.constexpr,
61
72
  ):
62
73
  """
63
74
  Implements the fused computation for a Mixture of Experts (MOE) using
@@ -130,8 +141,15 @@ def fused_moe_kernel(
130
141
  b_scale = tl.load(b_scale_ptrs)
131
142
 
132
143
  if use_fp8_w8a8:
133
- a_scale = tl.load(a_scale_ptr)
134
- b_scale = tl.load(b_scale_ptr + off_experts)
144
+ if group_k > 0 and group_n > 0:
145
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
146
+ offs_bsn = offs_bn // group_n
147
+ b_scale_ptrs = (
148
+ b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
149
+ )
150
+ else:
151
+ a_scale = tl.load(a_scale_ptr)
152
+ b_scale = tl.load(b_scale_ptr + off_experts)
135
153
 
136
154
  # -----------------------------------------------------------
137
155
  # Iterate to compute a block of the C matrix.
@@ -143,17 +161,36 @@ def fused_moe_kernel(
143
161
  for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
144
162
  # Load the next block of A and B, generate a mask by checking the
145
163
  # K dimension.
146
- a = tl.load(
147
- a_ptrs,
148
- mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
149
- other=0.0,
150
- )
151
- b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
164
+ if even_Ks:
165
+ a = tl.load(
166
+ a_ptrs,
167
+ mask=token_mask[:, None],
168
+ other=0.0,
169
+ )
170
+ b = tl.load(b_ptrs)
171
+ else:
172
+ a = tl.load(
173
+ a_ptrs,
174
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
175
+ other=0.0,
176
+ )
177
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
178
+
152
179
  # We accumulate along the K dimension.
153
180
  if use_int8_w8a16:
154
181
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
155
182
  elif use_fp8_w8a8:
156
- accumulator = tl.dot(a, b, acc=accumulator)
183
+ if group_k > 0 and group_n > 0:
184
+ k_start = k * BLOCK_SIZE_K
185
+ offs_ks = k_start // group_k
186
+ a_scale = tl.load(
187
+ a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
188
+ )
189
+ b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
190
+
191
+ accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
192
+ else:
193
+ accumulator = tl.dot(a, b, acc=accumulator)
157
194
  else:
158
195
  accumulator += tl.dot(a, b)
159
196
  # Advance the ptrs to the next K block.
@@ -166,7 +203,10 @@ def fused_moe_kernel(
166
203
  if use_int8_w8a16:
167
204
  accumulator = (accumulator * b_scale).to(compute_type)
168
205
  elif use_fp8_w8a8:
169
- accumulator = (accumulator * a_scale * b_scale).to(compute_type)
206
+ if group_k > 0 and group_n > 0:
207
+ accumulator = accumulator.to(compute_type)
208
+ else:
209
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
170
210
  else:
171
211
  accumulator = accumulator.to(compute_type)
172
212
  # -----------------------------------------------------------
@@ -227,9 +267,25 @@ def moe_align_block_size(
227
267
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
228
268
  )
229
269
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
230
- ops.moe_align_block_size(
231
- topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
232
- )
270
+ # FIXME(zhyncs)
271
+ if num_experts >= 256:
272
+ sgl_moe_align_block_size(
273
+ topk_ids,
274
+ num_experts,
275
+ block_size,
276
+ sorted_ids,
277
+ expert_ids,
278
+ num_tokens_post_pad,
279
+ )
280
+ else:
281
+ ops.moe_align_block_size(
282
+ topk_ids,
283
+ num_experts,
284
+ block_size,
285
+ sorted_ids,
286
+ expert_ids,
287
+ num_tokens_post_pad,
288
+ )
233
289
  return sorted_ids, expert_ids, num_tokens_post_pad
234
290
 
235
291
 
@@ -250,13 +306,24 @@ def invoke_fused_moe_kernel(
250
306
  compute_type: tl.dtype,
251
307
  use_fp8_w8a8: bool,
252
308
  use_int8_w8a16: bool,
309
+ block_shape: Optional[List[int]] = None,
253
310
  ) -> None:
254
311
  assert topk_weights.stride(1) == 1
255
312
  assert sorted_token_ids.stride(0) == 1
256
313
 
314
+ padded_size = 0
257
315
  if use_fp8_w8a8:
258
- A, A_scale = ops.scaled_fp8_quant(A, A_scale)
316
+ padded_size = padding_size
259
317
  assert B_scale is not None
318
+ if block_shape is None:
319
+ A, A_scale = ops.scaled_fp8_quant(A, A_scale)
320
+ else:
321
+ assert len(block_shape) == 2
322
+ block_n, block_k = block_shape[0], block_shape[1]
323
+ A, A_scale = per_token_group_quant_fp8(A, block_k)
324
+ assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
325
+ assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
326
+ assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
260
327
  elif use_int8_w8a16:
261
328
  assert B_scale is not None
262
329
  else:
@@ -268,6 +335,12 @@ def invoke_fused_moe_kernel(
268
335
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
269
336
  )
270
337
 
338
+ K = B.shape[2] - padded_size
339
+ if K % config["BLOCK_SIZE_K"] == 0:
340
+ even_Ks = True
341
+ else:
342
+ even_Ks = False
343
+
271
344
  fused_moe_kernel[grid](
272
345
  A,
273
346
  B,
@@ -279,7 +352,7 @@ def invoke_fused_moe_kernel(
279
352
  expert_ids,
280
353
  num_tokens_post_padded,
281
354
  B.shape[1],
282
- B.shape[2],
355
+ B.shape[2] - padded_size,
283
356
  sorted_token_ids.shape[0],
284
357
  topk_ids.numel(),
285
358
  A.stride(0),
@@ -289,13 +362,19 @@ def invoke_fused_moe_kernel(
289
362
  B.stride(1),
290
363
  C.stride(1),
291
364
  C.stride(2),
292
- B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
293
- B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
365
+ A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
366
+ A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
367
+ B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
368
+ B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
369
+ B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
370
+ 0 if block_shape is None else block_shape[0],
371
+ 0 if block_shape is None else block_shape[1],
294
372
  MUL_ROUTED_WEIGHT=mul_routed_weight,
295
373
  top_k=top_k,
296
374
  compute_type=compute_type,
297
375
  use_fp8_w8a8=use_fp8_w8a8,
298
376
  use_int8_w8a16=use_int8_w8a16,
377
+ even_Ks=even_Ks,
299
378
  **config,
300
379
  )
301
380
 
@@ -351,20 +430,39 @@ def get_default_config(
351
430
  dtype: Optional[str],
352
431
  is_marlin: bool,
353
432
  ) -> Dict[str, int]:
354
- config = {
355
- "BLOCK_SIZE_M": 64,
356
- "BLOCK_SIZE_N": 64,
357
- "BLOCK_SIZE_K": 32,
358
- "GROUP_SIZE_M": 8,
359
- }
360
- # A heuristic: fused marlin works faster with this config for small M
361
- if M <= E or (is_marlin and M <= 32):
433
+ if dtype == "fp8_w8a8":
362
434
  config = {
363
- "BLOCK_SIZE_M": 16,
364
- "BLOCK_SIZE_N": 32,
365
- "BLOCK_SIZE_K": 64,
366
- "GROUP_SIZE_M": 1,
435
+ "BLOCK_SIZE_M": 128,
436
+ "BLOCK_SIZE_N": 256,
437
+ "BLOCK_SIZE_K": 128,
438
+ "GROUP_SIZE_M": 32,
439
+ "num_warps": 8,
440
+ "num_stages": 4,
367
441
  }
442
+ if M <= E:
443
+ config = {
444
+ "BLOCK_SIZE_M": 64,
445
+ "BLOCK_SIZE_N": 128,
446
+ "BLOCK_SIZE_K": 128,
447
+ "GROUP_SIZE_M": 1,
448
+ "num_warps": 4,
449
+ "num_stages": 4,
450
+ }
451
+ else:
452
+ config = {
453
+ "BLOCK_SIZE_M": 64,
454
+ "BLOCK_SIZE_N": 64,
455
+ "BLOCK_SIZE_K": 32,
456
+ "GROUP_SIZE_M": 8,
457
+ }
458
+ # A heuristic: fused marlin works faster with this config for small M
459
+ if M <= E or (is_marlin and M <= 32):
460
+ config = {
461
+ "BLOCK_SIZE_M": 16,
462
+ "BLOCK_SIZE_N": 32,
463
+ "BLOCK_SIZE_K": 64,
464
+ "GROUP_SIZE_M": 1,
465
+ }
368
466
  return config
369
467
 
370
468
 
@@ -375,8 +473,9 @@ def try_get_optimal_moe_config(
375
473
  dtype: Optional[str],
376
474
  M: int,
377
475
  is_marlin: bool = False,
476
+ block_shape: Optional[List[int]] = None,
378
477
  ):
379
- from sglang.srt.layers.fused_moe_triton import get_config
478
+ from sglang.srt.layers.moe.fused_moe_triton import get_config
380
479
 
381
480
  override_config = get_config()
382
481
  if override_config:
@@ -393,77 +492,16 @@ def try_get_optimal_moe_config(
393
492
  else:
394
493
  # Else use the default config
395
494
  config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
495
+ # TODO(HandH1998): Optimize the configs of block-wise quant.
496
+ # NOTE(HandH1998): For block-wise quant,
497
+ # BLOCK_K must be divisable by block_shape[1]
498
+ # BLOCK_N and BLOCK_M has no requirements
499
+ if block_shape is not None:
500
+ config["BLOCK_SIZE_N"] = block_shape[0]
501
+ config["BLOCK_SIZE_K"] = block_shape[1]
396
502
  return config
397
503
 
398
504
 
399
- def fused_topk(
400
- hidden_states: torch.Tensor,
401
- gating_output: torch.Tensor,
402
- topk: int,
403
- renormalize: bool,
404
- ):
405
- assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
406
-
407
- M, _ = hidden_states.shape
408
-
409
- topk_weights = torch.empty(
410
- M, topk, dtype=torch.float32, device=hidden_states.device
411
- )
412
- topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
413
- token_expert_indicies = torch.empty(
414
- M, topk, dtype=torch.int32, device=hidden_states.device
415
- )
416
-
417
- ops.topk_softmax(
418
- topk_weights,
419
- topk_ids,
420
- token_expert_indicies,
421
- gating_output.float(), # TODO(woosuk): Optimize this.
422
- )
423
- del token_expert_indicies # Not used. Will be used in the future.
424
-
425
- if renormalize:
426
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
427
-
428
- return topk_weights, topk_ids
429
-
430
-
431
- # This is used by the Deepseek-V2 model
432
- def grouped_topk(
433
- hidden_states: torch.Tensor,
434
- gating_output: torch.Tensor,
435
- topk: int,
436
- renormalize: bool,
437
- num_expert_group: int = 0,
438
- topk_group: int = 0,
439
- ):
440
-
441
- assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
442
-
443
- scores = torch.softmax(gating_output, dim=-1)
444
- num_token = scores.shape[0]
445
- group_scores = (
446
- scores.view(num_token, num_expert_group, -1).max(dim=-1).values
447
- ) # [n, n_group]
448
- group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
449
- 1
450
- ] # [n, top_k_group]
451
- group_mask = torch.zeros_like(group_scores) # [n, n_group]
452
- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
453
- score_mask = (
454
- group_mask.unsqueeze(-1)
455
- .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
456
- .reshape(num_token, -1)
457
- ) # [n, e]
458
- tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
459
- topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
460
-
461
- if renormalize:
462
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
463
-
464
- return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
465
-
466
-
467
505
  def get_config_dtype_str(
468
506
  dtype: torch.dtype,
469
507
  use_int8_w8a16: Optional[bool] = False,
@@ -492,6 +530,7 @@ def inplace_fused_experts(
492
530
  w2_scale: Optional[torch.Tensor] = None,
493
531
  a1_scale: Optional[torch.Tensor] = None,
494
532
  a2_scale: Optional[torch.Tensor] = None,
533
+ block_shape: Optional[List[int]] = None,
495
534
  ) -> None:
496
535
  fused_experts_impl(
497
536
  hidden_states,
@@ -506,6 +545,7 @@ def inplace_fused_experts(
506
545
  w2_scale,
507
546
  a1_scale,
508
547
  a2_scale,
548
+ block_shape,
509
549
  )
510
550
 
511
551
 
@@ -521,6 +561,7 @@ def inplace_fused_experts_fake(
521
561
  w2_scale: Optional[torch.Tensor] = None,
522
562
  a1_scale: Optional[torch.Tensor] = None,
523
563
  a2_scale: Optional[torch.Tensor] = None,
564
+ block_shape: Optional[List[int]] = None,
524
565
  ) -> None:
525
566
  pass
526
567
 
@@ -545,6 +586,7 @@ def outplace_fused_experts(
545
586
  w2_scale: Optional[torch.Tensor] = None,
546
587
  a1_scale: Optional[torch.Tensor] = None,
547
588
  a2_scale: Optional[torch.Tensor] = None,
589
+ block_shape: Optional[List[int]] = None,
548
590
  ) -> torch.Tensor:
549
591
  return fused_experts_impl(
550
592
  hidden_states,
@@ -559,6 +601,7 @@ def outplace_fused_experts(
559
601
  w2_scale,
560
602
  a1_scale,
561
603
  a2_scale,
604
+ block_shape,
562
605
  )
563
606
 
564
607
 
@@ -574,6 +617,7 @@ def outplace_fused_experts_fake(
574
617
  w2_scale: Optional[torch.Tensor] = None,
575
618
  a1_scale: Optional[torch.Tensor] = None,
576
619
  a2_scale: Optional[torch.Tensor] = None,
620
+ block_shape: Optional[List[int]] = None,
577
621
  ) -> torch.Tensor:
578
622
  return torch.empty_like(hidden_states)
579
623
 
@@ -599,6 +643,7 @@ def fused_experts(
599
643
  w2_scale: Optional[torch.Tensor] = None,
600
644
  a1_scale: Optional[torch.Tensor] = None,
601
645
  a2_scale: Optional[torch.Tensor] = None,
646
+ block_shape: Optional[List[int]] = None,
602
647
  ):
603
648
  if inplace:
604
649
  torch.ops.sglang.inplace_fused_experts(
@@ -613,6 +658,7 @@ def fused_experts(
613
658
  w2_scale,
614
659
  a1_scale,
615
660
  a2_scale,
661
+ block_shape,
616
662
  )
617
663
  return hidden_states
618
664
  else:
@@ -628,6 +674,7 @@ def fused_experts(
628
674
  w2_scale,
629
675
  a1_scale,
630
676
  a2_scale,
677
+ block_shape,
631
678
  )
632
679
 
633
680
 
@@ -644,9 +691,14 @@ def fused_experts_impl(
644
691
  w2_scale: Optional[torch.Tensor] = None,
645
692
  a1_scale: Optional[torch.Tensor] = None,
646
693
  a2_scale: Optional[torch.Tensor] = None,
694
+ block_shape: Optional[List[int]] = None,
647
695
  ):
696
+ padded_size = padding_size
697
+ if not use_fp8_w8a8:
698
+ padded_size = 0
699
+
648
700
  # Check constraints.
649
- assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
701
+ assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
650
702
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
651
703
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
652
704
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -668,9 +720,10 @@ def fused_experts_impl(
668
720
  get_config_func = functools.partial(
669
721
  try_get_optimal_moe_config,
670
722
  w1.shape,
671
- w2.shape,
723
+ (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
672
724
  topk_ids.shape[1],
673
725
  config_dtype,
726
+ block_shape=block_shape,
674
727
  )
675
728
 
676
729
  config = get_config_func(M)
@@ -743,6 +796,7 @@ def fused_experts_impl(
743
796
  compute_type=compute_type,
744
797
  use_fp8_w8a8=use_fp8_w8a8,
745
798
  use_int8_w8a16=use_int8_w8a16,
799
+ block_shape=block_shape,
746
800
  )
747
801
 
748
802
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
@@ -764,6 +818,7 @@ def fused_experts_impl(
764
818
  compute_type=compute_type,
765
819
  use_fp8_w8a8=use_fp8_w8a8,
766
820
  use_int8_w8a16=use_int8_w8a16,
821
+ block_shape=block_shape,
767
822
  )
768
823
 
769
824
  torch.sum(
@@ -792,6 +847,7 @@ def fused_moe(
792
847
  w2_scale: Optional[torch.Tensor] = None,
793
848
  a1_scale: Optional[torch.Tensor] = None,
794
849
  a2_scale: Optional[torch.Tensor] = None,
850
+ block_shape: Optional[List[int]] = None,
795
851
  ) -> torch.Tensor:
796
852
  """
797
853
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -819,6 +875,12 @@ def fused_moe(
819
875
  w1.
820
876
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
821
877
  w2.
878
+ - a1_scale (Optional[torch.Tensor]): Optional scale to be used for
879
+ a1.
880
+ - a2_scale (Optional[torch.Tensor]): Optional scale to be used for
881
+ a2.
882
+ - block_shape: (Optional[List[int]]): Optional block size for block-wise
883
+ quantization.
822
884
 
823
885
  Returns:
824
886
  - torch.Tensor: The output tensor after applying the MoE layer.
@@ -826,24 +888,16 @@ def fused_moe(
826
888
  # Check constraints.
827
889
  assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
828
890
 
829
- if use_grouped_topk:
830
- assert num_expert_group is not None and topk_group is not None
831
- topk_weights, topk_ids = grouped_topk(
832
- hidden_states,
833
- gating_output,
834
- topk,
835
- renormalize,
836
- num_expert_group,
837
- topk_group,
838
- )
839
- elif custom_routing_function is None:
840
- topk_weights, topk_ids = fused_topk(
841
- hidden_states, gating_output, topk, renormalize
842
- )
843
- else:
844
- topk_weights, topk_ids = custom_routing_function(
845
- hidden_states, gating_output, topk, renormalize
846
- )
891
+ topk_weights, topk_ids = select_experts(
892
+ hidden_states=hidden_states,
893
+ router_logits=gating_output,
894
+ use_grouped_topk=use_grouped_topk,
895
+ top_k=topk,
896
+ renormalize=renormalize,
897
+ topk_group=topk_group,
898
+ num_expert_group=num_expert_group,
899
+ custom_routing_function=custom_routing_function,
900
+ )
847
901
 
848
902
  return fused_experts(
849
903
  hidden_states,
@@ -858,4 +912,5 @@ def fused_moe(
858
912
  w2_scale=w2_scale,
859
913
  a1_scale=a1_scale,
860
914
  a2_scale=a2_scale,
915
+ block_shape=block_shape,
861
916
  )
@@ -13,14 +13,15 @@ from vllm.distributed import (
13
13
  from vllm.model_executor.custom_op import CustomOp
14
14
 
15
15
  from sglang.srt.layers.custom_op_util import register_custom_op
16
+ from sglang.srt.layers.moe.topk import select_experts
16
17
  from sglang.srt.layers.quantization.base_config import (
17
18
  QuantizationConfig,
18
19
  QuantizeMethodBase,
19
20
  )
20
21
  from sglang.srt.utils import set_weight_attrs
21
22
 
22
- if torch.cuda.is_available() or torch.hip.is_available():
23
- from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
23
+ if torch.cuda.is_available():
24
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
24
25
  else:
25
26
  fused_experts = None # type: ignore
26
27
 
@@ -33,6 +34,7 @@ class FusedMoeWeightScaleSupported(Enum):
33
34
  TENSOR = "tensor"
34
35
  CHANNEL = "channel"
35
36
  GROUP = "group"
37
+ BLOCK = "block"
36
38
 
37
39
 
38
40
  class FusedMoEMethodBase(QuantizeMethodBase):
@@ -106,6 +108,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
106
108
  topk_group: Optional[int] = None,
107
109
  num_expert_group: Optional[int] = None,
108
110
  custom_routing_function: Optional[Callable] = None,
111
+ correction_bias: Optional[torch.Tensor] = None,
109
112
  ) -> torch.Tensor:
110
113
  return self.forward(
111
114
  x=x,
@@ -117,6 +120,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
117
120
  topk_group=topk_group,
118
121
  num_expert_group=num_expert_group,
119
122
  custom_routing_function=custom_routing_function,
123
+ correction_bias=correction_bias,
120
124
  )
121
125
 
122
126
  def forward_cuda(
@@ -130,8 +134,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
130
134
  topk_group: Optional[int] = None,
131
135
  num_expert_group: Optional[int] = None,
132
136
  custom_routing_function: Optional[Callable] = None,
137
+ correction_bias: Optional[torch.Tensor] = None,
133
138
  ) -> torch.Tensor:
134
- topk_weights, topk_ids = FusedMoE.select_experts(
139
+ topk_weights, topk_ids = select_experts(
135
140
  hidden_states=x,
136
141
  router_logits=router_logits,
137
142
  use_grouped_topk=use_grouped_topk,
@@ -140,6 +145,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
140
145
  topk_group=topk_group,
141
146
  num_expert_group=num_expert_group,
142
147
  custom_routing_function=custom_routing_function,
148
+ correction_bias=correction_bias,
143
149
  )
144
150
 
145
151
  return fused_experts(
@@ -197,6 +203,7 @@ class FusedMoE(torch.nn.Module):
197
203
  tp_size: Optional[int] = None,
198
204
  prefix: str = "",
199
205
  custom_routing_function: Optional[Callable] = None,
206
+ correction_bias: Optional[torch.Tensor] = None,
200
207
  ):
201
208
  super().__init__()
202
209
 
@@ -208,6 +215,7 @@ class FusedMoE(torch.nn.Module):
208
215
  )
209
216
  self.top_k = top_k
210
217
  self.num_experts = num_experts
218
+ assert intermediate_size % self.tp_size == 0
211
219
  self.intermediate_size_per_partition = intermediate_size // self.tp_size
212
220
  self.reduce_results = reduce_results
213
221
  self.renormalize = renormalize
@@ -217,6 +225,7 @@ class FusedMoE(torch.nn.Module):
217
225
  self.num_expert_group = num_expert_group
218
226
  self.topk_group = topk_group
219
227
  self.custom_routing_function = custom_routing_function
228
+ self.correction_bias = correction_bias
220
229
 
221
230
  if quant_config is None:
222
231
  self.quant_method: Optional[QuantizeMethodBase] = (
@@ -463,7 +472,10 @@ class FusedMoE(torch.nn.Module):
463
472
  expert_data=expert_data,
464
473
  tp_rank=tp_rank,
465
474
  )
466
- elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
475
+ elif quant_method in [
476
+ FusedMoeWeightScaleSupported.GROUP.value,
477
+ FusedMoeWeightScaleSupported.BLOCK.value,
478
+ ]:
467
479
  self._load_model_weight_or_group_weight_scale(
468
480
  shard_id=shard_id,
469
481
  shard_dim=shard_dim,
@@ -503,51 +515,6 @@ class FusedMoE(torch.nn.Module):
503
515
  )
504
516
  return
505
517
 
506
- @staticmethod
507
- def select_experts(
508
- hidden_states: torch.Tensor,
509
- router_logits: torch.Tensor,
510
- top_k: int,
511
- use_grouped_topk: bool,
512
- renormalize: bool,
513
- topk_group: Optional[int] = None,
514
- num_expert_group: Optional[int] = None,
515
- custom_routing_function: Optional[Callable] = None,
516
- ):
517
- from sglang.srt.layers.fused_moe_triton.fused_moe import (
518
- fused_topk,
519
- grouped_topk,
520
- )
521
-
522
- # DeekSeekv2 uses grouped_top_k
523
- if use_grouped_topk:
524
- assert topk_group is not None
525
- assert num_expert_group is not None
526
- topk_weights, topk_ids = grouped_topk(
527
- hidden_states=hidden_states,
528
- gating_output=router_logits,
529
- topk=top_k,
530
- renormalize=renormalize,
531
- num_expert_group=num_expert_group,
532
- topk_group=topk_group,
533
- )
534
- elif custom_routing_function is None:
535
- topk_weights, topk_ids = fused_topk(
536
- hidden_states=hidden_states,
537
- gating_output=router_logits,
538
- topk=top_k,
539
- renormalize=renormalize,
540
- )
541
- else:
542
- topk_weights, topk_ids = custom_routing_function(
543
- hidden_states=hidden_states,
544
- gating_output=router_logits,
545
- topk=top_k,
546
- renormalize=renormalize,
547
- )
548
-
549
- return topk_weights, topk_ids
550
-
551
518
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
552
519
  assert self.quant_method is not None
553
520
 
@@ -562,6 +529,7 @@ class FusedMoE(torch.nn.Module):
562
529
  topk_group=self.topk_group,
563
530
  num_expert_group=self.num_expert_group,
564
531
  custom_routing_function=self.custom_routing_function,
532
+ correction_bias=self.correction_bias,
565
533
  )
566
534
 
567
535
  if self.reduce_results and self.tp_size > 1: