sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/lang/interpreter.py +40 -1
  5. sglang/lang/ir.py +27 -0
  6. sglang/math_utils.py +8 -0
  7. sglang/srt/configs/model_config.py +6 -0
  8. sglang/srt/conversation.py +6 -0
  9. sglang/srt/disaggregation/base/__init__.py +1 -1
  10. sglang/srt/disaggregation/base/conn.py +25 -11
  11. sglang/srt/disaggregation/common/__init__.py +5 -1
  12. sglang/srt/disaggregation/common/utils.py +42 -0
  13. sglang/srt/disaggregation/decode.py +196 -51
  14. sglang/srt/disaggregation/fake/__init__.py +1 -1
  15. sglang/srt/disaggregation/fake/conn.py +15 -9
  16. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +18 -13
  18. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  19. sglang/srt/disaggregation/nixl/conn.py +17 -12
  20. sglang/srt/disaggregation/prefill.py +128 -43
  21. sglang/srt/disaggregation/utils.py +127 -123
  22. sglang/srt/entrypoints/engine.py +15 -1
  23. sglang/srt/entrypoints/http_server.py +13 -2
  24. sglang/srt/eplb_simulator/__init__.py +1 -0
  25. sglang/srt/eplb_simulator/reader.py +51 -0
  26. sglang/srt/layers/activation.py +19 -0
  27. sglang/srt/layers/attention/aiter_backend.py +15 -2
  28. sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
  29. sglang/srt/layers/attention/flashattention_backend.py +53 -64
  30. sglang/srt/layers/attention/flashinfer_backend.py +1 -2
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
  32. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  33. sglang/srt/layers/attention/triton_backend.py +119 -119
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  35. sglang/srt/layers/attention/vision.py +51 -24
  36. sglang/srt/layers/communicator.py +23 -5
  37. sglang/srt/layers/linear.py +0 -4
  38. sglang/srt/layers/logits_processor.py +0 -12
  39. sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
  40. sglang/srt/layers/moe/ep_moe/layer.py +42 -32
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
  43. sglang/srt/layers/moe/topk.py +16 -8
  44. sglang/srt/layers/pooler.py +56 -0
  45. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  46. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  47. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  49. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  50. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  51. sglang/srt/layers/radix_attention.py +2 -3
  52. sglang/srt/lora/lora_manager.py +79 -34
  53. sglang/srt/lora/mem_pool.py +4 -5
  54. sglang/srt/managers/cache_controller.py +2 -1
  55. sglang/srt/managers/io_struct.py +28 -4
  56. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  57. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  58. sglang/srt/managers/schedule_batch.py +39 -6
  59. sglang/srt/managers/scheduler.py +73 -17
  60. sglang/srt/managers/tokenizer_manager.py +29 -2
  61. sglang/srt/mem_cache/chunk_cache.py +1 -0
  62. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  63. sglang/srt/mem_cache/memory_pool.py +111 -407
  64. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  65. sglang/srt/mem_cache/radix_cache.py +36 -12
  66. sglang/srt/model_executor/cuda_graph_runner.py +122 -55
  67. sglang/srt/model_executor/forward_batch_info.py +14 -5
  68. sglang/srt/model_executor/model_runner.py +6 -6
  69. sglang/srt/model_loader/loader.py +8 -1
  70. sglang/srt/models/bert.py +113 -13
  71. sglang/srt/models/deepseek_v2.py +113 -155
  72. sglang/srt/models/internvl.py +46 -102
  73. sglang/srt/models/roberta.py +117 -9
  74. sglang/srt/models/vila.py +305 -0
  75. sglang/srt/openai_api/adapter.py +162 -4
  76. sglang/srt/openai_api/protocol.py +37 -1
  77. sglang/srt/sampling/sampling_batch_info.py +24 -0
  78. sglang/srt/sampling/sampling_params.py +2 -0
  79. sglang/srt/server_args.py +318 -233
  80. sglang/srt/speculative/build_eagle_tree.py +1 -1
  81. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
  82. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
  83. sglang/srt/speculative/eagle_utils.py +389 -109
  84. sglang/srt/speculative/eagle_worker.py +134 -43
  85. sglang/srt/two_batch_overlap.py +4 -2
  86. sglang/srt/utils.py +58 -0
  87. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  88. sglang/test/runners.py +38 -3
  89. sglang/test/test_block_fp8.py +1 -0
  90. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  91. sglang/test/test_block_fp8_ep.py +1 -0
  92. sglang/test/test_utils.py +3 -1
  93. sglang/utils.py +9 -0
  94. sglang/version.py +1 -1
  95. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
  96. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
  97. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -31,11 +31,6 @@ _is_hip = is_hip()
31
31
 
32
32
  logger = logging.getLogger(__name__)
33
33
 
34
- # TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
35
- logger.warning(
36
- "The following error message 'operation scheduled before its operands' can be ignored."
37
- )
38
-
39
34
 
40
35
  _MIN_BLOCK_KV = 32
41
36
 
@@ -713,7 +708,7 @@ def decode_attention_fwd(
713
708
  num_kv_splits,
714
709
  max_kv_splits,
715
710
  sm_scale,
716
- logit_cap,
711
+ logit_cap=logit_cap,
717
712
  )
718
713
  else:
719
714
  # GQA/MQA/MLA
@@ -729,5 +724,5 @@ def decode_attention_fwd(
729
724
  num_kv_splits,
730
725
  max_kv_splits,
731
726
  sm_scale,
732
- logit_cap,
727
+ logit_cap=logit_cap,
733
728
  )
@@ -1,15 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import dataclasses
4
+ import functools
3
5
  import math
4
- from functools import lru_cache, wraps
5
- from typing import Optional, Tuple
6
+ from functools import lru_cache
7
+ from typing import Any, Optional, Tuple, Union
6
8
 
7
9
  import torch
8
10
  import torch.nn as nn
9
11
  import torch.nn.functional as F
10
12
  from einops import rearrange
11
13
 
12
- from sglang.srt.utils import is_cuda
14
+ from sglang.srt.utils import is_cuda, print_info_once
13
15
 
14
16
  _is_cuda = is_cuda()
15
17
 
@@ -29,29 +31,42 @@ from sglang.srt.layers.linear import (
29
31
  from sglang.srt.layers.quantization import QuantizationConfig
30
32
  from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
31
33
  from sglang.srt.managers.schedule_batch import global_server_args_dict
32
- from sglang.srt.utils import add_prefix, logger
34
+ from sglang.srt.utils import add_prefix
33
35
 
34
36
  ROTARY_EMBED_CLASSES = {
35
37
  "normal": apply_rotary_pos_emb,
36
38
  }
37
39
 
38
40
 
39
- def execute_once(func):
40
- has_run = None
41
+ @dataclasses.dataclass
42
+ class SingletonCache:
43
+ data: Any = None
41
44
 
42
- @wraps(func)
43
- def wrapper(*args, **kwargs):
44
- nonlocal has_run
45
- if not has_run:
46
- func(*args, **kwargs)
47
- has_run = True
45
+ def set_data(self, value: Any) -> None:
46
+ self.data = value
48
47
 
49
- return wrapper
48
+ def get_data(self) -> Optional[Any]:
49
+ return self.data
50
50
 
51
+ def empty(self) -> bool:
52
+ return self.get_data() is None
51
53
 
52
- @execute_once
53
- def info_once(message: str):
54
- logger.info(message)
54
+
55
+ # TODO: requires real seqlens from images
56
+ @functools.lru_cache(maxsize=128)
57
+ def _get_cu_seqlens_for_shape(batch_size: int, seqlen: int, device) -> torch.Tensor:
58
+ """
59
+ Generates cumulative sequence lengths (cu_seqlens) for a given batch_size, seqlen, and device.
60
+ Caches the result based on these parameters.
61
+ """
62
+ cu_seqlens = torch.arange(
63
+ 0,
64
+ (batch_size + 1) * seqlen,
65
+ step=seqlen,
66
+ dtype=torch.int32,
67
+ device=device,
68
+ )
69
+ return cu_seqlens
55
70
 
56
71
 
57
72
  class VisionSdpaAttention(nn.Module):
@@ -265,8 +280,9 @@ class VisionFlash3Attention(nn.Module):
265
280
  q: torch.Tensor,
266
281
  k: torch.Tensor,
267
282
  v: torch.Tensor,
268
- cu_seqlens: Optional[torch.Tensor],
269
- attention_mask: Optional[torch.Tensor] = None,
283
+ cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
284
+ bsz: int,
285
+ seq_len: int,
270
286
  **kwargs,
271
287
  ) -> torch.Tensor:
272
288
  r"""
@@ -275,7 +291,16 @@ class VisionFlash3Attention(nn.Module):
275
291
  Returns:
276
292
  [b * s, h, head_size]
277
293
  """
278
- cu_seqlens = cu_seqlens.to(dtype=torch.int32).cuda()
294
+ if cu_seqlens is None:
295
+ cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
296
+ elif isinstance(cu_seqlens, SingletonCache):
297
+ if cu_seqlens.empty():
298
+ cu_seqlens.set_data(
299
+ _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
300
+ )
301
+ cu_seqlens = cu_seqlens.get_data()
302
+
303
+ cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
279
304
  seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
280
305
  max_seqlen = seq_lens.max().item()
281
306
  output = flash_attn_varlen_func(
@@ -346,11 +371,11 @@ class VisionAttention(nn.Module):
346
371
  if global_server_args_dict["mm_attention_backend"] is None:
347
372
  if qkv_backend is None:
348
373
  qkv_backend = "sdpa"
349
- info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
374
+ print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
350
375
  else:
351
376
  qkv_backend = global_server_args_dict["mm_attention_backend"]
352
377
 
353
- info_once(f"Using {qkv_backend} as multimodal attention backend.")
378
+ print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
354
379
 
355
380
  self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
356
381
  head_dim=self.head_size,
@@ -423,15 +448,16 @@ class VisionAttention(nn.Module):
423
448
  # [s, b, embed_dim] --> [s, b, head * 3 * head_size]
424
449
  qkv, _ = self.qkv_proj(x)
425
450
 
426
- # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
451
+ # [s, b, head, head_dim_sum]
427
452
  new_x_shape = qkv.size()[:-1] + (
428
453
  head,
429
- 3 * self.hidden_size_per_attention_head,
454
+ self.q_size + 2 * self.kv_size,
430
455
  )
431
456
  qkv = qkv.view(*new_x_shape)
432
457
 
433
458
  # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
434
- q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
459
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
460
+
435
461
  # [s, b, head, head_size] --> [b, s, head, head_size]
436
462
  q, k, v = [
437
463
  rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
@@ -468,6 +494,7 @@ class VisionAttention(nn.Module):
468
494
  k=k,
469
495
  v=v,
470
496
  bsz=bsz,
497
+ seq_len=s,
471
498
  cu_seqlens=cu_seqlens,
472
499
  attention_mask=attention_mask,
473
500
  )
@@ -226,13 +226,13 @@ class LayerCommunicator:
226
226
 
227
227
  @dataclass
228
228
  class CommunicateContext:
229
- process_group_sizes: Dict["ScatterMode", int]
229
+ process_group_sizes: Dict[ScatterMode, int]
230
230
  attn_tp_rank: int
231
231
  attn_tp_size: int
232
232
  local_attn_dp_size: int
233
233
  tp_size: int
234
234
 
235
- def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"):
235
+ def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
236
236
  return self.process_group_sizes[a] == self.process_group_sizes[b]
237
237
 
238
238
  @classmethod
@@ -244,6 +244,7 @@ class CommunicateContext:
244
244
  process_group_sizes = {
245
245
  ScatterMode.SCATTERED: 1,
246
246
  ScatterMode.TP_ATTN_FULL: attn_tp_size,
247
+ # TODO: support --moe-dense-tp-size > 1
247
248
  ScatterMode.FULL: tp_size,
248
249
  }
249
250
  return cls(
@@ -323,11 +324,16 @@ class CommunicateWithAllReduceAndLayerNormFn:
323
324
 
324
325
  if (
325
326
  (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
326
- and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
327
+ and (
328
+ residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
329
+ )
327
330
  and (hidden_states_output_mode == ScatterMode.FULL)
328
331
  and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
329
332
  ):
330
- return CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states
333
+ return partial(
334
+ CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states_and_residual,
335
+ residual_input_mode=residual_input_mode,
336
+ )
331
337
 
332
338
  if (
333
339
  (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
@@ -360,13 +366,25 @@ class CommunicateWithAllReduceAndLayerNormFn:
360
366
  return hidden_states, residual
361
367
 
362
368
  @staticmethod
363
- def _gather_hidden_states(
369
+ def _gather_hidden_states_and_residual(
364
370
  hidden_states: torch.Tensor,
365
371
  residual: torch.Tensor,
366
372
  forward_batch: ForwardBatch,
367
373
  layernorm: torch.nn.Module,
368
374
  context: CommunicateContext,
375
+ *,
376
+ residual_input_mode,
369
377
  ):
378
+ if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
379
+ residual, local_residual = (
380
+ forward_batch.gathered_buffer[
381
+ : forward_batch.input_ids.shape[0]
382
+ ].clone(),
383
+ residual,
384
+ )
385
+ attn_tp_all_gather(
386
+ list(residual.tensor_split(context.attn_tp_size)), local_residual
387
+ )
370
388
  if context.local_attn_dp_size != 1:
371
389
  if context.attn_tp_rank == 0:
372
390
  hidden_states += residual
@@ -546,8 +546,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
546
546
  param.shard_id.append(loaded_shard_id)
547
547
  param.shard_id_map[loaded_shard_id] = len(param.data_container)
548
548
  param.data_container.append(loaded_weight)
549
- if len(param.data_container) == 2:
550
- self.qweight = param.materialize_nested()
551
549
  return
552
550
 
553
551
  param_data = param.data
@@ -961,8 +959,6 @@ class QKVParallelLinear(ColumnParallelLinear):
961
959
  param.shard_id.append(loaded_shard_id)
962
960
  param.shard_id_map[loaded_shard_id] = len(param.data_container)
963
961
  param.data_container.append(loaded_weight)
964
- if len(param.data_container) == 3:
965
- self.qweight = param.materialize_nested()
966
962
  return
967
963
 
968
964
  param_data = param.data
@@ -47,18 +47,6 @@ from sglang.srt.utils import dump_to_file
47
47
  logger = logging.getLogger(__name__)
48
48
 
49
49
 
50
- from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
51
- from sglang.srt.managers.schedule_batch import global_server_args_dict
52
- from sglang.srt.model_executor.forward_batch_info import (
53
- CaptureHiddenMode,
54
- ForwardBatch,
55
- ForwardMode,
56
- )
57
- from sglang.srt.utils import dump_to_file
58
-
59
- logger = logging.getLogger(__name__)
60
-
61
-
62
50
  @dataclasses.dataclass
63
51
  class LogitsProcessorOutput:
64
52
  ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
@@ -4,6 +4,7 @@ from typing import List, Optional
4
4
  import torch
5
5
  import triton
6
6
 
7
+ from sglang.math_utils import ceil_div
7
8
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
8
9
  from sglang.srt.utils import dispose_tensor, is_cuda
9
10
 
@@ -15,11 +16,6 @@ if _is_cuda:
15
16
  sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
16
17
  )
17
18
 
18
- try:
19
- from deep_gemm import ceil_div
20
- except ImportError:
21
- logger.error(f"Failed to import ceil_div from deep_gemm.")
22
-
23
19
  import triton.language as tl
24
20
 
25
21
 
@@ -278,6 +274,7 @@ def _silu_and_mul_post_quant_kernel(
278
274
  fp8_min,
279
275
  BLOCK_N: tl.constexpr,
280
276
  NUM_STAGE: tl.constexpr,
277
+ SCALE_UE8M0: tl.constexpr,
281
278
  ):
282
279
  expert_id = tl.program_id(2)
283
280
  token_id = tl.program_id(1)
@@ -319,6 +316,8 @@ def _silu_and_mul_post_quant_kernel(
319
316
  gate_up = up * gate
320
317
  _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
321
318
  output_s = _absmax / fp8_max
319
+ if SCALE_UE8M0:
320
+ output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
322
321
  output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
323
322
  output_ptr.dtype.element_ty
324
323
  )
@@ -339,6 +338,7 @@ def silu_and_mul_masked_post_quant_fwd(
339
338
  output_scale: torch.Tensor,
340
339
  quant_group_size: int,
341
340
  masked_m: torch.Tensor,
341
+ scale_ue8m0: bool = False,
342
342
  ):
343
343
  """
344
344
  input shape [expert_num, token_num_padded, hidden_dim]
@@ -395,6 +395,7 @@ def silu_and_mul_masked_post_quant_fwd(
395
395
  BLOCK_N=BLOCK_N,
396
396
  NUM_STAGE=NUM_STAGES,
397
397
  num_warps=num_warps,
398
+ SCALE_UE8M0=scale_ue8m0,
398
399
  )
399
400
  return
400
401
 
@@ -1,30 +1,11 @@
1
1
  import logging
2
2
  from typing import Callable, List, Optional, Tuple
3
3
 
4
+ import einops
4
5
  import torch
6
+ from sgl_kernel import silu_and_mul
5
7
  from torch.nn import Module
6
8
 
7
- from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
8
- from sglang.srt.managers.expert_location import get_global_expert_location_metadata
9
- from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
10
- from sglang.srt.managers.schedule_batch import global_server_args_dict
11
-
12
- try:
13
- from deep_gemm import (
14
- get_col_major_tma_aligned_tensor,
15
- m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
16
- m_grouped_gemm_fp8_fp8_bf16_nt_masked,
17
- )
18
- from sgl_kernel import silu_and_mul
19
-
20
- from sglang.srt.layers.quantization.fp8_kernel import (
21
- sglang_per_token_group_quant_fp8,
22
- )
23
-
24
- use_deep_gemm = True
25
- except ImportError:
26
- use_deep_gemm = False
27
-
28
9
  from sglang.srt.custom_op import CustomOp
29
10
  from sglang.srt.distributed import (
30
11
  get_tensor_model_parallel_rank,
@@ -45,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
45
26
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
46
27
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
47
28
  from sglang.srt.layers.moe.topk import select_experts
29
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
48
30
  from sglang.srt.layers.quantization.base_config import (
49
31
  QuantizationConfig,
50
32
  QuantizeMethodBase,
@@ -52,10 +34,20 @@ from sglang.srt.layers.quantization.base_config import (
52
34
  from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
53
35
  from sglang.srt.layers.quantization.fp8_kernel import (
54
36
  scaled_fp8_quant,
37
+ sglang_per_token_group_quant_fp8,
55
38
  sglang_per_token_quant_fp8,
56
39
  )
40
+ from sglang.srt.managers.expert_location import get_global_expert_location_metadata
41
+ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
42
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
57
43
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
58
- from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
44
+ from sglang.srt.utils import (
45
+ DeepEPMode,
46
+ dispose_tensor,
47
+ get_bool_env_var,
48
+ is_hip,
49
+ set_weight_attrs,
50
+ )
59
51
 
60
52
  _is_hip = is_hip()
61
53
 
@@ -680,7 +672,6 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
680
672
  params_dtype: torch.dtype,
681
673
  **extra_weight_attrs,
682
674
  ):
683
-
684
675
  if self.quant_config.is_checkpoint_fp8_serialized:
685
676
  params_dtype = torch.float8_e4m3fn
686
677
 
@@ -920,7 +911,9 @@ class DeepEPMoE(EPMoE):
920
911
  )
921
912
  self.deepep_mode = deepep_mode
922
913
  if self.deepep_mode.enable_low_latency():
923
- assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
914
+ assert (
915
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
916
+ ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
924
917
  self.w13_weight_fp8 = (
925
918
  self.w13_weight,
926
919
  (
@@ -948,7 +941,7 @@ class DeepEPMoE(EPMoE):
948
941
  ):
949
942
  resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
950
943
  if resolved_deepep_mode == DeepEPMode.normal:
951
- if _ENABLE_JIT_DEEPGEMM:
944
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
952
945
  return self.forward_deepgemm_contiguous(
953
946
  hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
954
947
  )
@@ -1145,7 +1138,7 @@ class DeepEPMoE(EPMoE):
1145
1138
  dtype=torch.bfloat16,
1146
1139
  )
1147
1140
  input_tensor[1] = tma_align_input_scale(input_tensor[1])
1148
- m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
1141
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
1149
1142
  input_tensor, self.w13_weight_fp8, gateup_output, m_indices
1150
1143
  )
1151
1144
  del input_tensor
@@ -1169,7 +1162,7 @@ class DeepEPMoE(EPMoE):
1169
1162
  )
1170
1163
  del down_input
1171
1164
  down_input_scale = tma_align_input_scale(down_input_scale)
1172
- m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
1165
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
1173
1166
  (down_input_fp8, down_input_scale),
1174
1167
  self.w2_weight_fp8,
1175
1168
  down_output,
@@ -1202,8 +1195,13 @@ class DeepEPMoE(EPMoE):
1202
1195
  gateup_output = torch.empty(
1203
1196
  (num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
1204
1197
  )
1205
- m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1206
- hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
1198
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1199
+ hidden_states_fp8,
1200
+ self.w13_weight_fp8,
1201
+ gateup_output,
1202
+ masked_m,
1203
+ expected_m,
1204
+ recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
1207
1205
  )
1208
1206
  dispose_tensor(hidden_states_fp8[0])
1209
1207
 
@@ -1233,6 +1231,7 @@ class DeepEPMoE(EPMoE):
1233
1231
  down_input_scale,
1234
1232
  scale_block_size,
1235
1233
  masked_m,
1234
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
1236
1235
  )
1237
1236
  del gateup_output
1238
1237
 
@@ -1240,13 +1239,24 @@ class DeepEPMoE(EPMoE):
1240
1239
  n = self.w2_weight.size(1)
1241
1240
  down_input_fp8 = (
1242
1241
  down_input,
1243
- get_col_major_tma_aligned_tensor(down_input_scale),
1242
+ (
1243
+ down_input_scale
1244
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
1245
+ else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
1246
+ down_input_scale
1247
+ )
1248
+ ),
1244
1249
  )
1245
1250
  down_output = torch.empty(
1246
1251
  (num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
1247
1252
  )
1248
- m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1249
- down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
1253
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1254
+ down_input_fp8,
1255
+ self.w2_weight_fp8,
1256
+ down_output,
1257
+ masked_m,
1258
+ expected_m,
1259
+ recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
1250
1260
  )
1251
1261
 
1252
1262
  return down_output
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from dataclasses import dataclass
3
3
 
4
- from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
4
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
5
5
  from sglang.srt.managers.expert_distribution import (
6
6
  get_global_expert_distribution_recorder,
7
7
  )
@@ -107,6 +107,8 @@ class DeepEPBuffer:
107
107
  num_rdma_bytes,
108
108
  low_latency_mode=deepep_mode.enable_low_latency(),
109
109
  num_qps_per_rank=num_qps_per_rank,
110
+ # TODO can be false when unneeded
111
+ allow_mnnvl=True,
110
112
  )
111
113
  return cls._buffer
112
114
 
@@ -234,14 +236,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
234
236
  topk_weights: torch.Tensor,
235
237
  ):
236
238
  topk_idx = topk_idx.to(torch.int64)
237
- if _ENABLE_JIT_DEEPGEMM:
239
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
238
240
  # TODO hard code 128 block quant,use fp8 communication
239
241
  hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
240
242
  previous_event = Buffer.capture() if self.async_finish else None
241
243
  return hidden_states, topk_idx, topk_weights, previous_event
242
244
 
243
245
  def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
244
- if _ENABLE_JIT_DEEPGEMM:
246
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
245
247
  (
246
248
  hidden_states,
247
249
  topk_idx,
@@ -343,7 +345,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
343
345
  previous_event=previous_event,
344
346
  async_finish=self.async_finish,
345
347
  allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
346
- expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
348
+ expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
347
349
  config=DeepEPConfig.get_instance().normal_dispatch_config,
348
350
  )
349
351
 
@@ -407,7 +409,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
407
409
  topk_idx: torch.Tensor,
408
410
  topk_weights: torch.Tensor,
409
411
  ):
410
- if _ENABLE_JIT_DEEPGEMM:
412
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
411
413
  output = hidden_states
412
414
  else:
413
415
  if hidden_states.shape[0] > 0:
@@ -540,38 +542,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
540
542
  topk_idx: torch.Tensor,
541
543
  use_fp8: bool = False,
542
544
  ):
543
- """
544
- # For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'.
545
- # Please make sure to change DeepEP code in internode_ll.cu dispatch / combine as below first and then reinstall.
546
- # More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782
547
-
548
- diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu
549
- index 76ae2e2..8ecd08f 100644
550
- --- a/csrc/kernels/internode_ll.cu
551
- +++ b/csrc/kernels/internode_ll.cu
552
- @@ -310,8 +310,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
553
- int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
554
- void* workspace, cudaStream_t stream, int phases) {
555
- constexpr int kNumMaxTopK = 9;
556
- - constexpr int kNumWarpsPerGroup = 10;
557
- - constexpr int kNumWarpGroups = 3;
558
- + constexpr int kNumWarpsPerGroup = 8;
559
- + constexpr int kNumWarpGroups = 4;
560
- EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
561
-
562
- const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
563
- @@ -501,8 +501,8 @@ void combine(void* combined_x,
564
- int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
565
- int num_topk, int num_experts, int rank, int num_ranks,
566
- void* workspace, cudaStream_t stream, int phases) {
567
- - constexpr int kNumWarpsPerGroup = 10;
568
- - constexpr int kNumWarpGroups = 3;
569
- + constexpr int kNumWarpsPerGroup = 8;
570
- + constexpr int kNumWarpGroups = 4;
571
- constexpr int kNumMaxTopk = 9;
572
-
573
- const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
574
- """
575
545
  buffer = self._get_buffer()
576
546
  packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
577
547
  buffer.low_latency_dispatch(
@@ -582,6 +552,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
582
552
  use_fp8=use_fp8,
583
553
  async_finish=not self.return_recv_hook,
584
554
  return_recv_hook=self.return_recv_hook,
555
+ round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
556
+ and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
557
+ use_ue8m0=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
558
+ and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
585
559
  )
586
560
  )
587
561
  return packed_recv_hidden, packed_recv_count, event, hook
@@ -12,6 +12,7 @@ import torch
12
12
  import triton
13
13
  import triton.language as tl
14
14
 
15
+ from sglang.math_utils import ceil_div
15
16
  from sglang.srt.layers.moe.topk import select_experts
16
17
  from sglang.srt.layers.quantization.fp8_kernel import (
17
18
  per_token_group_quant_fp8,
@@ -518,10 +519,6 @@ def fused_moe_kernel(
518
519
  tl.store(c_ptrs, accumulator, mask=c_mask)
519
520
 
520
521
 
521
- def ceil_div(a, b):
522
- return (a + b - 1) // b
523
-
524
-
525
522
  @triton.jit
526
523
  def moe_align_block_size_stage1(
527
524
  topk_ids_ptr,
@@ -249,6 +249,15 @@ def _mask_topk_ids_padded_region(
249
249
  topk_ids[indices >= num_token_non_padded, :] = -1
250
250
 
251
251
 
252
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
253
+ def _biased_grouped_topk_postprocess(
254
+ topk_ids, expert_location_dispatch_info, num_token_non_padded
255
+ ):
256
+ topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
257
+ _mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
258
+ return topk_ids
259
+
260
+
252
261
  def biased_grouped_topk(
253
262
  hidden_states: torch.Tensor,
254
263
  gating_output: torch.Tensor,
@@ -282,14 +291,13 @@ def biased_grouped_topk(
282
291
  num_fused_shared_experts,
283
292
  routed_scaling_factor,
284
293
  )
285
- # TODO merge into kernel for this branch
286
- topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
287
- # TODO will fuse this into kernel, thus use slow manual operation now
288
- if num_token_non_padded is None:
289
- return topk_weights, topk_ids
290
- torch.compile(
291
- _mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
292
- )(topk_ids, num_token_non_padded)
294
+ # TODO merge into kernel
295
+ if (expert_location_dispatch_info is not None) or (
296
+ num_token_non_padded is not None
297
+ ):
298
+ topk_ids = _biased_grouped_topk_postprocess(
299
+ topk_ids, expert_location_dispatch_info, num_token_non_padded
300
+ )
293
301
  return topk_weights, topk_ids
294
302
  else:
295
303
  biased_grouped_topk_fn = (