sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -24,8 +24,10 @@ if TYPE_CHECKING:
24
24
  _ATTN_TP_GROUP = None
25
25
  _ATTN_TP_RANK = None
26
26
  _ATTN_TP_SIZE = None
27
- _DP_RANK = None
28
- _DP_SIZE = None
27
+ _ATTN_DP_RANK = None
28
+ _ATTN_DP_SIZE = None
29
+ _LOCAL_ATTN_DP_SIZE = None
30
+ _LOCAL_ATTN_DP_RANK = None
29
31
 
30
32
 
31
33
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
@@ -33,9 +35,27 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
33
35
  return tp_rank, tp_size, 0
34
36
 
35
37
  attn_tp_size = tp_size // dp_size
36
- dp_rank = tp_rank // attn_tp_size
38
+ attn_dp_rank = tp_rank // attn_tp_size
37
39
  attn_tp_rank = tp_rank % attn_tp_size
38
- return attn_tp_rank, attn_tp_size, dp_rank
40
+
41
+ return attn_tp_rank, attn_tp_size, attn_dp_rank
42
+
43
+
44
+ def compute_dp_attention_local_info(
45
+ enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
46
+ ):
47
+ if not enable_dp_attention:
48
+ return tp_rank, tp_size, 0
49
+
50
+ local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size
51
+ local_tp_rank = tp_rank % local_tp_size
52
+ local_dp_size = max(1, dp_size // (tp_size // local_tp_size))
53
+
54
+ local_attn_tp_size = local_tp_size // local_dp_size
55
+ local_attn_dp_rank = local_tp_rank // local_attn_tp_size
56
+ local_attn_tp_rank = local_tp_rank % local_attn_tp_size
57
+
58
+ return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank
39
59
 
40
60
 
41
61
  def initialize_dp_attention(
@@ -43,22 +63,32 @@ def initialize_dp_attention(
43
63
  tp_rank: int,
44
64
  tp_size: int,
45
65
  dp_size: int,
66
+ moe_dense_tp_size: int,
46
67
  pp_size: int,
47
68
  ):
48
- global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
69
+ global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
70
+ global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
49
71
 
50
72
  from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
51
73
 
52
- _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
74
+ _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
53
75
  enable_dp_attention, tp_rank, tp_size, dp_size
54
76
  )
77
+ _, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
78
+ enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
79
+ )
55
80
 
56
81
  if enable_dp_attention:
57
82
  local_rank = tp_rank % (tp_size // dp_size)
58
- _DP_SIZE = dp_size
83
+ _ATTN_DP_SIZE = dp_size
84
+ if moe_dense_tp_size is None:
85
+ _LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
86
+ else:
87
+ _LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
59
88
  else:
60
89
  local_rank = tp_rank
61
- _DP_SIZE = 1
90
+ _ATTN_DP_SIZE = 1
91
+ _LOCAL_ATTN_DP_SIZE = 1
62
92
 
63
93
  tp_group = get_tp_group()
64
94
  _ATTN_TP_GROUP = GroupCoordinator(
@@ -93,13 +123,33 @@ def get_attention_tp_size():
93
123
 
94
124
 
95
125
  def get_attention_dp_rank():
96
- assert _DP_RANK is not None, "dp attention not initialized!"
97
- return _DP_RANK
126
+ assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
127
+ return _ATTN_DP_RANK
98
128
 
99
129
 
100
130
  def get_attention_dp_size():
101
- assert _DP_SIZE is not None, "dp attention not initialized!"
102
- return _DP_SIZE
131
+ assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
132
+ return _ATTN_DP_SIZE
133
+
134
+
135
+ def get_local_attention_dp_rank():
136
+ assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
137
+ return _LOCAL_ATTN_DP_RANK
138
+
139
+
140
+ def get_local_attention_dp_size():
141
+ assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
142
+ return _LOCAL_ATTN_DP_SIZE
143
+
144
+
145
+ def get_local_attention_dp_rank():
146
+ assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
147
+ return _LOCAL_ATTN_DP_RANK
148
+
149
+
150
+ def get_local_attention_dp_size():
151
+ assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
152
+ return _LOCAL_ATTN_DP_SIZE
103
153
 
104
154
 
105
155
  @contextmanager
@@ -112,19 +162,19 @@ def disable_dp_size():
112
162
  Args:
113
163
  tp_group (GroupCoordinator): the tp group coordinator
114
164
  """
115
- global _DP_SIZE
116
- assert _DP_SIZE is not None, "dp attention not initialized!"
165
+ global _ATTN_DP_SIZE
166
+ assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
117
167
 
118
- old_dp_size = _DP_SIZE
119
- _DP_SIZE = 1
168
+ old_dp_size = _ATTN_DP_SIZE
169
+ _ATTN_DP_SIZE = 1
120
170
  try:
121
171
  yield
122
172
  finally:
123
- _DP_SIZE = old_dp_size
173
+ _ATTN_DP_SIZE = old_dp_size
124
174
 
125
175
 
126
176
  def get_dp_local_info(forward_batch: ForwardBatch):
127
- dp_rank = get_attention_dp_rank()
177
+ dp_rank = get_local_attention_dp_rank()
128
178
 
129
179
  if forward_batch.dp_local_start_pos is None:
130
180
  cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
@@ -201,7 +251,7 @@ def _dp_gather(
201
251
  global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
202
252
  )
203
253
 
204
- # Input IDs are in int 32. We should use inplace_all_reduce for local case becaues of custom all reduce.
254
+ # Input IDs are in int 32. We should use inplace_all_reduce for local case because of custom all reduce.
205
255
  NUM_GPUS_PER_NODE = 8
206
256
  if (
207
257
  not local_tokens.dtype.is_floating_point
@@ -252,12 +302,12 @@ def dp_scatter(
252
302
  )
253
303
 
254
304
 
255
- def tp_reduce_scatter(
305
+ def attn_tp_reduce_scatter(
256
306
  output: torch.Tensor,
257
307
  input_list: List[torch.Tensor],
258
308
  ):
259
309
  return get_attention_tp_group().reduce_scatter(output, input_list)
260
310
 
261
311
 
262
- def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
312
+ def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
263
313
  return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
@@ -76,7 +76,7 @@ class RMSNorm(CustomOp):
76
76
  residual: Optional[torch.Tensor] = None,
77
77
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
78
78
  if not x.is_contiguous():
79
- # NOTE: Romove this if aiter kernel supports discontinuous input
79
+ # NOTE: Remove this if aiter kernel supports discontinuous input
80
80
  x = x.contiguous()
81
81
  if residual is not None:
82
82
  fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
@@ -23,15 +23,17 @@ import triton.language as tl
23
23
  from torch import nn
24
24
 
25
25
  from sglang.srt.distributed import (
26
- get_tensor_model_parallel_rank,
27
26
  get_tensor_model_parallel_world_size,
28
27
  tensor_model_parallel_all_gather,
29
28
  )
30
29
  from sglang.srt.layers.dp_attention import (
30
+ attn_tp_all_gather,
31
31
  dp_gather_replicate,
32
32
  dp_scatter,
33
- get_attention_dp_rank,
34
33
  get_attention_dp_size,
34
+ get_attention_tp_size,
35
+ get_local_attention_dp_rank,
36
+ get_local_attention_dp_size,
35
37
  )
36
38
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
37
39
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -45,6 +47,18 @@ from sglang.srt.utils import dump_to_file
45
47
  logger = logging.getLogger(__name__)
46
48
 
47
49
 
50
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
51
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
52
+ from sglang.srt.model_executor.forward_batch_info import (
53
+ CaptureHiddenMode,
54
+ ForwardBatch,
55
+ ForwardMode,
56
+ )
57
+ from sglang.srt.utils import dump_to_file
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+
48
62
  @dataclasses.dataclass
49
63
  class LogitsProcessorOutput:
50
64
  ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
@@ -169,7 +183,7 @@ class LogitsMetadata:
169
183
  return
170
184
 
171
185
  cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
172
- dp_rank = get_attention_dp_rank()
186
+ dp_rank = get_local_attention_dp_rank()
173
187
  if dp_rank == 0:
174
188
  dp_local_start_pos = torch.zeros_like(
175
189
  self.global_num_tokens_for_logprob_gpu[0]
@@ -198,12 +212,20 @@ class LogitsProcessor(nn.Module):
198
212
  super().__init__()
199
213
  self.config = config
200
214
  self.logit_scale = logit_scale
201
- self.do_tensor_parallel_all_gather = (
202
- not skip_all_gather and get_tensor_model_parallel_world_size() > 1
203
- )
204
- self.do_tensor_parallel_all_gather_dp_attn = (
205
- self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
206
- )
215
+ self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
216
+ if self.use_attn_tp_group:
217
+ self.attn_tp_size = get_attention_tp_size()
218
+ self.do_tensor_parallel_all_gather = (
219
+ not skip_all_gather and self.attn_tp_size > 1
220
+ )
221
+ self.do_tensor_parallel_all_gather_dp_attn = False
222
+ else:
223
+ self.do_tensor_parallel_all_gather = (
224
+ not skip_all_gather and get_tensor_model_parallel_world_size() > 1
225
+ )
226
+ self.do_tensor_parallel_all_gather_dp_attn = (
227
+ self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
228
+ )
207
229
  self.final_logit_softcapping = getattr(
208
230
  self.config, "final_logit_softcapping", None
209
231
  )
@@ -315,7 +337,8 @@ class LogitsProcessor(nn.Module):
315
337
 
316
338
  if self.debug_tensor_dump_output_folder:
317
339
  assert (
318
- not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
340
+ not self.do_tensor_parallel_all_gather
341
+ or get_local_attention_dp_size() == 1
319
342
  ), "dp attention + sharded lm_head doesn't support full logits"
320
343
  full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
321
344
  dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
@@ -442,7 +465,19 @@ class LogitsProcessor(nn.Module):
442
465
  logits.mul_(self.logit_scale)
443
466
 
444
467
  if self.do_tensor_parallel_all_gather:
445
- logits = tensor_model_parallel_all_gather(logits)
468
+ if self.use_attn_tp_group:
469
+ global_logits = torch.empty(
470
+ (self.config.vocab_size, logits.shape[0]),
471
+ device=logits.device,
472
+ dtype=logits.dtype,
473
+ )
474
+ global_logits = global_logits.T
475
+ attn_tp_all_gather(
476
+ list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
477
+ )
478
+ logits = global_logits
479
+ else:
480
+ logits = tensor_model_parallel_all_gather(logits)
446
481
 
447
482
  if self.do_tensor_parallel_all_gather_dp_attn:
448
483
  logits, global_logits = (
@@ -5,16 +5,23 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
- from sglang.srt.distributed import get_tensor_model_parallel_rank
9
8
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
10
9
  from sglang.srt.utils import is_cuda
11
10
 
11
+ logger = logging.getLogger(__name__)
12
+
12
13
  _is_cuda = is_cuda()
13
14
  if _is_cuda:
14
15
  from sglang.srt.layers.quantization.fp8_kernel import (
15
- sglang_per_token_group_quant_fp8,
16
+ sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
16
17
  )
17
- logger = logging.getLogger(__name__)
18
+
19
+ try:
20
+ from deep_gemm import ceil_div
21
+ except ImportError:
22
+ logger.error(f"Failed to import ceil_div from deep_gemm.")
23
+
24
+ import triton.language as tl
18
25
 
19
26
 
20
27
  @triton.jit
@@ -109,7 +116,7 @@ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
109
116
  seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
110
117
  src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64)
111
118
 
112
- # Find offet
119
+ # Find offset
113
120
  expert_ids = torch.arange(
114
121
  num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
115
122
  )
@@ -654,10 +661,7 @@ def grouped_gemm_triton(
654
661
  if block_shape is not None:
655
662
  assert len(block_shape) == 2
656
663
  block_n, block_k = block_shape[0], block_shape[1]
657
- if _is_cuda:
658
- a, scale_a = sglang_per_token_group_quant_fp8(a, block_k)
659
- else:
660
- a, scale_a = per_token_group_quant_fp8(a, block_k)
664
+ a, scale_a = per_token_group_quant_fp8(a, block_k)
661
665
 
662
666
  assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
663
667
  assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
@@ -707,3 +711,334 @@ def grouped_gemm_triton(
707
711
  **config,
708
712
  )
709
713
  return c
714
+
715
+
716
+ @triton.jit
717
+ def _fwd_kernel_ep_scatter_1(
718
+ num_recv_tokens_per_expert,
719
+ expert_start_loc,
720
+ m_indices,
721
+ num_experts: tl.constexpr,
722
+ BLOCK_E: tl.constexpr,
723
+ BLOCK_EXPERT_NUM: tl.constexpr,
724
+ ):
725
+ cur_expert = tl.program_id(0)
726
+
727
+ offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
728
+ tokens_per_expert = tl.load(
729
+ num_recv_tokens_per_expert + offset_cumsum,
730
+ mask=offset_cumsum < num_experts,
731
+ other=0,
732
+ )
733
+ cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
734
+ tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
735
+
736
+ cur_expert_start = tl.load(expert_start_loc + cur_expert)
737
+ cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
738
+
739
+ m_indices_start_ptr = m_indices + cur_expert_start
740
+ off_expert = tl.arange(0, BLOCK_E)
741
+
742
+ for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
743
+ tl.store(
744
+ m_indices_start_ptr + start_m + off_expert,
745
+ cur_expert,
746
+ )
747
+
748
+
749
+ @triton.jit
750
+ def _fwd_kernel_ep_scatter_2(
751
+ total_token_num,
752
+ expert_start_loc,
753
+ recv_x,
754
+ recv_x_stride0,
755
+ recv_x_stride1,
756
+ recv_x_scale,
757
+ recv_x_scale_stride0,
758
+ recv_x_scale_stride1,
759
+ recv_topk,
760
+ recv_topk_stride0,
761
+ recv_topk_stride1,
762
+ output_tensor,
763
+ output_tensor_stride0,
764
+ output_tensor_stride1,
765
+ output_tensor_scale,
766
+ output_tensor_scale_stride0,
767
+ output_tensor_scale_stride1,
768
+ output_index,
769
+ output_index_stride0,
770
+ output_index_stride1,
771
+ topk_num: tl.constexpr,
772
+ HIDDEN_SIZE: tl.constexpr,
773
+ HIDDEN_SIZE_PAD: tl.constexpr,
774
+ SCALE_HIDDEN_SIZE: tl.constexpr,
775
+ SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
776
+ ):
777
+ start_token_id = tl.program_id(0)
778
+ grid_num = tl.num_programs(0)
779
+
780
+ offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
781
+ mask = offset_in < HIDDEN_SIZE
782
+
783
+ offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
784
+ mask_s = offset_in_s < SCALE_HIDDEN_SIZE
785
+
786
+ for token_id in range(start_token_id, total_token_num, grid_num):
787
+ to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
788
+ to_copy_s = tl.load(
789
+ recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
790
+ )
791
+
792
+ for topk_index in tl.range(0, topk_num, 1, num_stages=4):
793
+ expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
794
+ if expert_id >= 0:
795
+ dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
796
+ tl.store(
797
+ output_index + token_id * output_index_stride0 + topk_index,
798
+ dest_token_index,
799
+ )
800
+ output_tensor_ptr = (
801
+ output_tensor + dest_token_index * output_tensor_stride0
802
+ )
803
+ output_tensor_scale_ptr = (
804
+ output_tensor_scale + dest_token_index * output_tensor_scale_stride0
805
+ )
806
+ tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
807
+ tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
808
+
809
+
810
+ # copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
811
+ @torch.no_grad()
812
+ def ep_scatter(
813
+ recv_x: torch.Tensor,
814
+ recv_x_scale: torch.Tensor,
815
+ recv_topk: torch.Tensor,
816
+ num_recv_tokens_per_expert: torch.Tensor,
817
+ expert_start_loc: torch.Tensor,
818
+ output_tensor: torch.Tensor,
819
+ output_tensor_scale: torch.Tensor,
820
+ m_indices: torch.Tensor,
821
+ output_index: torch.Tensor,
822
+ ):
823
+ BLOCK_E = 128 # token num of per expert is aligned to 128
824
+ BLOCK_D = 128 # block size of quantization
825
+ num_warps = 8
826
+ num_experts = num_recv_tokens_per_expert.shape[0]
827
+ hidden_size = recv_x.shape[1]
828
+ # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
829
+ grid = num_experts
830
+
831
+ assert m_indices.shape[0] % BLOCK_E == 0
832
+
833
+ _fwd_kernel_ep_scatter_1[(grid,)](
834
+ num_recv_tokens_per_expert,
835
+ expert_start_loc,
836
+ m_indices,
837
+ num_experts=num_experts,
838
+ num_warps=num_warps,
839
+ BLOCK_E=BLOCK_E,
840
+ BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
841
+ )
842
+
843
+ grid = min(recv_topk.shape[0], 1024 * 8)
844
+
845
+ _fwd_kernel_ep_scatter_2[(grid,)](
846
+ recv_topk.shape[0],
847
+ expert_start_loc,
848
+ recv_x,
849
+ recv_x.stride(0),
850
+ recv_x.stride(1),
851
+ recv_x_scale,
852
+ recv_x_scale.stride(0),
853
+ recv_x_scale.stride(1),
854
+ recv_topk,
855
+ recv_topk.stride(0),
856
+ recv_topk.stride(1),
857
+ output_tensor,
858
+ output_tensor.stride(0),
859
+ output_tensor.stride(1),
860
+ output_tensor_scale,
861
+ output_tensor_scale.stride(0),
862
+ output_tensor_scale.stride(1),
863
+ output_index,
864
+ output_index.stride(0),
865
+ output_index.stride(1),
866
+ topk_num=recv_topk.shape[1],
867
+ num_warps=num_warps,
868
+ HIDDEN_SIZE=hidden_size,
869
+ HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
870
+ SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
871
+ SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
872
+ )
873
+ return
874
+
875
+
876
+ @triton.jit
877
+ def _fwd_kernel_ep_gather(
878
+ total_token_num,
879
+ input_tensor,
880
+ input_tensor_stride0,
881
+ input_tensor_stride1,
882
+ recv_topk_ids,
883
+ recv_topk_ids_stride0,
884
+ recv_topk_ids_stride1,
885
+ recv_topk_weight,
886
+ recv_topk_weight_stride0,
887
+ recv_topk_weight_stride1,
888
+ input_index,
889
+ input_index_stride0,
890
+ input_index_stride1,
891
+ output_tensor,
892
+ output_tensor_stride0,
893
+ output_tensor_stride1,
894
+ topk_num: tl.constexpr,
895
+ BLOCK_D: tl.constexpr,
896
+ ):
897
+ cur_block = tl.program_id(0)
898
+ start_cur_token = tl.program_id(1)
899
+ grid_num = tl.num_programs(1)
900
+
901
+ for cur_token in range(start_cur_token, total_token_num, grid_num):
902
+ off_d = tl.arange(0, BLOCK_D)
903
+ accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
904
+ for topk_index in range(0, topk_num):
905
+ expert_id = tl.load(
906
+ recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
907
+ )
908
+ if expert_id >= 0:
909
+ source_token_index = tl.load(
910
+ input_index + cur_token * input_index_stride0 + topk_index
911
+ )
912
+ acc_weight = tl.load(
913
+ recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
914
+ )
915
+ tmp = tl.load(
916
+ input_tensor
917
+ + source_token_index * input_tensor_stride0
918
+ + cur_block * BLOCK_D
919
+ + off_d
920
+ )
921
+ accumulator += tmp.to(tl.float32) * acc_weight
922
+
923
+ tl.store(
924
+ output_tensor
925
+ + cur_token * output_tensor_stride0
926
+ + cur_block * BLOCK_D
927
+ + off_d,
928
+ accumulator.to(output_tensor.dtype.element_ty),
929
+ )
930
+
931
+
932
+ @torch.no_grad()
933
+ def ep_gather(
934
+ input_tensor: torch.Tensor,
935
+ recv_topk_ids: torch.Tensor,
936
+ recv_topk_weight: torch.Tensor,
937
+ input_index: torch.Tensor,
938
+ output_tensor: torch.Tensor,
939
+ ):
940
+ BLOCK_D = 1024 # block size of quantization
941
+ num_warps = 2
942
+ num_tokens = output_tensor.shape[0]
943
+ hidden_size = input_tensor.shape[1]
944
+ assert hidden_size % BLOCK_D == 0
945
+ grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
946
+ _fwd_kernel_ep_gather[grid](
947
+ num_tokens,
948
+ input_tensor,
949
+ input_tensor.stride(0),
950
+ input_tensor.stride(1),
951
+ recv_topk_ids,
952
+ recv_topk_ids.stride(0),
953
+ recv_topk_ids.stride(1),
954
+ recv_topk_weight,
955
+ recv_topk_weight.stride(0),
956
+ recv_topk_weight.stride(1),
957
+ input_index,
958
+ input_index.stride(0),
959
+ input_index.stride(1),
960
+ output_tensor,
961
+ output_tensor.stride(0),
962
+ output_tensor.stride(1),
963
+ topk_num=recv_topk_ids.shape[1],
964
+ num_warps=num_warps,
965
+ BLOCK_D=BLOCK_D,
966
+ )
967
+ return
968
+
969
+
970
+ # copy from
971
+ # https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58
972
+ def get_tma_aligned_size(x: int, element_size: int) -> int:
973
+ """
974
+ Global memory address of TMA must be 16-byte aligned.
975
+ Since we use column-major layout for the LHS scaling tensor,
976
+ the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
977
+
978
+ Arguments:
979
+ x: original M-axis shape of the LHS scaling tensor.
980
+ element_size: element size of the LHS scaling tensor.
981
+
982
+ Returns:
983
+ M-axis shape of the LHS scaling tensor after padding.
984
+ """
985
+ tma_alignment_bytes = 16
986
+ assert tma_alignment_bytes % element_size == 0
987
+ alignment = tma_alignment_bytes // element_size
988
+ return ceil_div(x, alignment) * alignment
989
+
990
+
991
+ @triton.jit
992
+ def _tma_align_input_scale_kernel(
993
+ input_scale_ptr,
994
+ output_ptr,
995
+ m,
996
+ k_div_block_size,
997
+ input_scale_stride_m,
998
+ input_scale_stride_k,
999
+ output_stride_m,
1000
+ output_stride_k,
1001
+ BLOCK_SIZE_K: tl.constexpr,
1002
+ ):
1003
+ pid_m = tl.program_id(axis=0)
1004
+ grid_m = tl.num_programs(0)
1005
+ k_offsets = tl.arange(0, BLOCK_SIZE_K)
1006
+
1007
+ for m_base in range(pid_m, m, grid_m):
1008
+ input_offset = (
1009
+ input_scale_ptr
1010
+ + m_base * input_scale_stride_m
1011
+ + k_offsets * input_scale_stride_k
1012
+ )
1013
+ input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size)
1014
+
1015
+ output_offset = (
1016
+ output_ptr + k_offsets * output_stride_k + m_base * output_stride_m
1017
+ )
1018
+ tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size)
1019
+
1020
+
1021
+ # copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py
1022
+ def tma_align_input_scale(input_scale: torch.Tensor):
1023
+ assert input_scale.dim() == 2
1024
+ m, k_div_block_size = input_scale.shape
1025
+ padd_m = get_tma_aligned_size(m, input_scale.element_size())
1026
+ output = torch.empty(
1027
+ (k_div_block_size, padd_m), dtype=input_scale.dtype, device=input_scale.device
1028
+ )
1029
+
1030
+ grid_m = min(m, 8192)
1031
+ BLOCK_SIZE_K = triton.next_power_of_2(k_div_block_size)
1032
+
1033
+ _tma_align_input_scale_kernel[(grid_m,)](
1034
+ input_scale_ptr=input_scale,
1035
+ output_ptr=output,
1036
+ m=m,
1037
+ k_div_block_size=k_div_block_size,
1038
+ input_scale_stride_m=input_scale.stride(0),
1039
+ input_scale_stride_k=input_scale.stride(1),
1040
+ output_stride_m=output.stride(1), # Note: these are swapped
1041
+ output_stride_k=output.stride(0), # for column-major
1042
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
1043
+ )
1044
+ return output.t()[:m]