sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,451 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ from dataclasses import dataclass
16
+ from enum import Enum, auto
17
+ from typing import Dict, Optional, Tuple
18
+
19
+ import torch.distributed
20
+
21
+ from sglang.srt.distributed import (
22
+ get_tensor_model_parallel_world_size,
23
+ tensor_model_parallel_all_reduce,
24
+ )
25
+ from sglang.srt.layers.dp_attention import (
26
+ attn_tp_all_gather,
27
+ attn_tp_reduce_scatter,
28
+ dp_gather_partial,
29
+ dp_scatter,
30
+ get_attention_tp_rank,
31
+ get_attention_tp_size,
32
+ get_local_attention_dp_size,
33
+ )
34
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
35
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
36
+
37
+
38
+ class ScatterMode(Enum):
39
+ SCATTERED = auto()
40
+ TP_ATTN_FULL = auto()
41
+ FULL = auto()
42
+
43
+
44
+ @dataclass
45
+ class _LayerModeComputationContext:
46
+ num_layers: int
47
+ layer_id: int
48
+ is_layer_sparse: bool
49
+ is_previous_layer_sparse: Optional[bool]
50
+
51
+ def previous_layer(self):
52
+ assert self.is_previous_layer_sparse is not None
53
+ return _LayerModeComputationContext(
54
+ layer_id=self.layer_id - 1,
55
+ is_layer_sparse=self.is_previous_layer_sparse,
56
+ is_previous_layer_sparse=None,
57
+ num_layers=self.num_layers,
58
+ )
59
+
60
+
61
+ @dataclass
62
+ class LayerScatterModes:
63
+ layer_input_mode: ScatterMode
64
+ attn_mode: ScatterMode
65
+ # Can be further split into e.g. mlp_input_mode and mlp_output_mode if needed
66
+ mlp_mode: ScatterMode
67
+ middle_residual_mode: ScatterMode
68
+ layer_output_mode: ScatterMode
69
+
70
+ @classmethod
71
+ def init_new(cls, **kwargs):
72
+ context = _LayerModeComputationContext(**kwargs)
73
+ return cls(
74
+ layer_input_mode=cls._compute_layer_input_mode(context),
75
+ attn_mode=ScatterMode.TP_ATTN_FULL,
76
+ mlp_mode=cls._compute_mlp_mode(context),
77
+ middle_residual_mode=cls._compute_middle_residual_mode(context),
78
+ layer_output_mode=cls._compute_layer_output_mode(context),
79
+ )
80
+
81
+ @classmethod
82
+ def _compute_layer_input_mode(cls, context: _LayerModeComputationContext):
83
+ if context.layer_id == 0:
84
+ return ScatterMode.TP_ATTN_FULL
85
+ return cls._compute_layer_output_mode(context.previous_layer())
86
+
87
+ @classmethod
88
+ def _compute_mlp_mode(cls, context: _LayerModeComputationContext):
89
+ if context.is_layer_sparse:
90
+ return (
91
+ ScatterMode.SCATTERED
92
+ if global_server_args_dict["enable_deepep_moe"]
93
+ else ScatterMode.FULL
94
+ )
95
+ else:
96
+ return (
97
+ ScatterMode.SCATTERED
98
+ if enable_moe_dense_fully_dp()
99
+ else ScatterMode.FULL
100
+ )
101
+
102
+ @classmethod
103
+ def _compute_middle_residual_mode(cls, context: _LayerModeComputationContext):
104
+ mlp_mode = cls._compute_mlp_mode(context)
105
+ if mlp_mode == ScatterMode.SCATTERED:
106
+ return ScatterMode.SCATTERED
107
+ if mlp_mode == ScatterMode.FULL:
108
+ return ScatterMode.TP_ATTN_FULL
109
+ raise NotImplementedError
110
+
111
+ @classmethod
112
+ def _compute_layer_output_mode(cls, context: _LayerModeComputationContext):
113
+ mlp_mode = cls._compute_mlp_mode(context)
114
+ if context.layer_id == context.num_layers - 1:
115
+ return ScatterMode.TP_ATTN_FULL
116
+ if mlp_mode == ScatterMode.SCATTERED:
117
+ return ScatterMode.SCATTERED
118
+ if mlp_mode == ScatterMode.FULL:
119
+ return ScatterMode.TP_ATTN_FULL
120
+ raise NotImplementedError
121
+
122
+
123
+ def enable_moe_dense_fully_dp():
124
+ return global_server_args_dict["moe_dense_tp_size"] == 1
125
+
126
+
127
+ class LayerCommunicator:
128
+ def __init__(
129
+ self,
130
+ layer_scatter_modes: LayerScatterModes,
131
+ input_layernorm: torch.nn.Module,
132
+ post_attention_layernorm: torch.nn.Module,
133
+ ):
134
+ self.layer_scatter_modes = layer_scatter_modes
135
+ self.input_layernorm = input_layernorm
136
+ self.post_attention_layernorm = post_attention_layernorm
137
+
138
+ self.attn_tp_rank = get_attention_tp_rank()
139
+ self.attn_tp_size = get_attention_tp_size()
140
+ self.local_attn_dp_size = get_local_attention_dp_size()
141
+ self.tp_size = get_tensor_model_parallel_world_size()
142
+ self.process_group_sizes = {
143
+ ScatterMode.SCATTERED: 1,
144
+ ScatterMode.TP_ATTN_FULL: self.attn_tp_size,
145
+ ScatterMode.FULL: self.tp_size,
146
+ }
147
+
148
+ def prepare_attn(
149
+ self,
150
+ hidden_states: torch.Tensor,
151
+ residual: torch.Tensor,
152
+ forward_batch: ForwardBatch,
153
+ ):
154
+ if hidden_states.shape[0] == 0:
155
+ residual = hidden_states
156
+ else:
157
+ if residual is None:
158
+ residual = hidden_states
159
+ hidden_states = self.input_layernorm(hidden_states)
160
+ else:
161
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
162
+
163
+ hidden_states = _communicate_simple(
164
+ hidden_states=hidden_states,
165
+ forward_batch=forward_batch,
166
+ input_mode=self.layer_scatter_modes.layer_input_mode,
167
+ output_mode=self.layer_scatter_modes.attn_mode,
168
+ context=self._compute_context(forward_batch),
169
+ )
170
+
171
+ return hidden_states, residual
172
+
173
+ def prepare_mlp(
174
+ self,
175
+ hidden_states: torch.Tensor,
176
+ residual: torch.Tensor,
177
+ forward_batch: ForwardBatch,
178
+ ):
179
+ return _communicate_with_all_reduce_and_layer_norm(
180
+ hidden_states=hidden_states,
181
+ residual=residual,
182
+ forward_batch=forward_batch,
183
+ hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
184
+ residual_input_mode=self.layer_scatter_modes.layer_input_mode,
185
+ hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
186
+ residual_output_mode=self.layer_scatter_modes.middle_residual_mode,
187
+ layernorm=self.post_attention_layernorm,
188
+ context=self._compute_context(forward_batch),
189
+ )
190
+
191
+ def postprocess_layer(
192
+ self,
193
+ hidden_states: torch.Tensor,
194
+ residual: torch.Tensor,
195
+ forward_batch: ForwardBatch,
196
+ ):
197
+ return _communicate_summable_tensor_pair(
198
+ hidden_states=hidden_states,
199
+ residual=residual,
200
+ forward_batch=forward_batch,
201
+ hidden_states_input_mode=self.layer_scatter_modes.mlp_mode,
202
+ residual_input_mode=self.layer_scatter_modes.middle_residual_mode,
203
+ output_mode=self.layer_scatter_modes.layer_output_mode,
204
+ context=self._compute_context(forward_batch),
205
+ )
206
+
207
+ def _compute_context(self, forward_batch: ForwardBatch):
208
+ return _Context(
209
+ num_tokens_of_mode=_compute_num_tokens_of_mode(
210
+ forward_batch,
211
+ attn_tp_rank=self.attn_tp_rank,
212
+ attn_tp_size=self.attn_tp_size,
213
+ ),
214
+ process_group_sizes=self.process_group_sizes,
215
+ attn_tp_rank=self.attn_tp_rank,
216
+ attn_tp_size=self.attn_tp_size,
217
+ local_attn_dp_size=self.local_attn_dp_size,
218
+ tp_size=self.tp_size,
219
+ )
220
+
221
+
222
+ def _compute_num_tokens_of_mode(
223
+ forward_batch: ForwardBatch, attn_tp_rank: int, attn_tp_size: int
224
+ ):
225
+ tp_attn_full_num_tokens = forward_batch.input_ids.shape[0]
226
+ return {
227
+ ScatterMode.SCATTERED: _torch_tensor_split_len(
228
+ tp_attn_full_num_tokens, attn_tp_size, attn_tp_rank
229
+ ),
230
+ ScatterMode.TP_ATTN_FULL: tp_attn_full_num_tokens,
231
+ ScatterMode.FULL: (
232
+ forward_batch.gathered_buffer.shape[0]
233
+ if global_server_args_dict["enable_dp_attention"]
234
+ else forward_batch.input_ids.shape[0]
235
+ ),
236
+ }
237
+
238
+
239
+ def _torch_tensor_split_len(tensor_len: int, n: int, output_index: int):
240
+ if output_index < int(tensor_len % n):
241
+ return int(tensor_len / n) + 1
242
+ else:
243
+ return int(tensor_len / n)
244
+
245
+
246
+ @dataclass
247
+ class _Context:
248
+ num_tokens_of_mode: Dict["ScatterMode", int]
249
+ process_group_sizes: Dict["ScatterMode", int]
250
+ attn_tp_rank: int
251
+ attn_tp_size: int
252
+ local_attn_dp_size: int
253
+ tp_size: int
254
+
255
+ def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"):
256
+ return self.process_group_sizes[a] == self.process_group_sizes[b]
257
+
258
+ def check_shape(self, x: torch.Tensor, mode: ScatterMode):
259
+ if x is None:
260
+ return
261
+
262
+ actual_num_tokens = x.shape[0]
263
+ expect_num_tokens = self.num_tokens_of_mode[mode]
264
+ assert (
265
+ actual_num_tokens == expect_num_tokens
266
+ ), f"{actual_num_tokens=} {expect_num_tokens=} {mode=} {x.shape=} {self.num_tokens_of_mode=} {self.process_group_sizes=}"
267
+ return x
268
+
269
+ def check_shapes(
270
+ self, xs: Tuple[torch.Tensor, ...], modes: Tuple[ScatterMode, ...]
271
+ ) -> Tuple[torch.Tensor, ...]:
272
+ return tuple(
273
+ [self.check_shape(x, mode) for x, mode in zip(xs, modes, strict=True)]
274
+ )
275
+
276
+
277
+ def _communicate_simple(
278
+ hidden_states: torch.Tensor,
279
+ forward_batch: ForwardBatch,
280
+ input_mode: ScatterMode,
281
+ output_mode: ScatterMode,
282
+ context: _Context,
283
+ ) -> torch.Tensor:
284
+ def _inner():
285
+ nonlocal hidden_states
286
+
287
+ if context.is_same_group_size(input_mode, output_mode):
288
+ return hidden_states
289
+
290
+ if (input_mode == ScatterMode.SCATTERED) and (
291
+ output_mode == ScatterMode.TP_ATTN_FULL
292
+ ):
293
+ hidden_states, local_hidden_states = (
294
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
295
+ hidden_states,
296
+ )
297
+ attn_tp_all_gather(
298
+ list(hidden_states.tensor_split(context.attn_tp_size)),
299
+ local_hidden_states,
300
+ )
301
+ return hidden_states
302
+
303
+ raise NotImplementedError(f"{input_mode=} {output_mode=}")
304
+
305
+ context.check_shape(hidden_states, input_mode)
306
+ return context.check_shape(_inner(), output_mode)
307
+
308
+
309
+ def _communicate_with_all_reduce_and_layer_norm(
310
+ hidden_states: torch.Tensor,
311
+ residual: torch.Tensor,
312
+ hidden_states_input_mode: ScatterMode,
313
+ residual_input_mode: ScatterMode,
314
+ hidden_states_output_mode: ScatterMode,
315
+ residual_output_mode: ScatterMode,
316
+ forward_batch: ForwardBatch,
317
+ layernorm: torch.nn.Module,
318
+ context: _Context,
319
+ ):
320
+ """Besides communication, needs to
321
+ 1. All reduce in tp_attn_group on hidden_states
322
+ 2. Apply layer norm
323
+ """
324
+
325
+ def _inner():
326
+ nonlocal hidden_states, residual
327
+
328
+ if (
329
+ context.is_same_group_size(
330
+ hidden_states_input_mode, hidden_states_output_mode
331
+ )
332
+ and context.is_same_group_size(residual_input_mode, residual_output_mode)
333
+ and context.attn_tp_size == 1
334
+ ):
335
+ # TODO move these `if shape != 0` into LayerNorm itself
336
+ if hidden_states.shape[0] != 0:
337
+ hidden_states, residual = layernorm(hidden_states, residual)
338
+ return hidden_states, residual
339
+
340
+ if (
341
+ (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
342
+ and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
343
+ and (hidden_states_output_mode == ScatterMode.FULL)
344
+ and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
345
+ ):
346
+ if context.local_attn_dp_size != 1:
347
+ if context.attn_tp_rank == 0:
348
+ hidden_states += residual
349
+ hidden_states, local_hidden_states = (
350
+ forward_batch.gathered_buffer,
351
+ hidden_states,
352
+ )
353
+ dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
354
+ dp_scatter(residual, hidden_states, forward_batch)
355
+ if hidden_states.shape[0] != 0:
356
+ hidden_states = layernorm(hidden_states)
357
+ else:
358
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
359
+ hidden_states, residual = layernorm(hidden_states, residual)
360
+ return hidden_states, residual
361
+
362
+ if (
363
+ (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
364
+ and (
365
+ residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
366
+ )
367
+ and (hidden_states_output_mode == ScatterMode.SCATTERED)
368
+ and (residual_output_mode == ScatterMode.SCATTERED)
369
+ ):
370
+ tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
371
+ hidden_states = tensor_list[context.attn_tp_rank]
372
+ attn_tp_reduce_scatter(hidden_states, tensor_list)
373
+ if residual_input_mode == ScatterMode.TP_ATTN_FULL:
374
+ residual = residual.tensor_split(context.attn_tp_size)[
375
+ context.attn_tp_rank
376
+ ]
377
+ if hidden_states.shape[0] != 0:
378
+ hidden_states, residual = layernorm(hidden_states, residual)
379
+ return hidden_states, residual
380
+
381
+ raise NotImplementedError(
382
+ f"{hidden_states_input_mode=} {residual_input_mode=} {residual_output_mode=} {residual_output_mode=}"
383
+ )
384
+
385
+ context.check_shapes(
386
+ (hidden_states, residual), (hidden_states_input_mode, residual_input_mode)
387
+ )
388
+ return context.check_shapes(
389
+ _inner(), (hidden_states_output_mode, residual_output_mode)
390
+ )
391
+
392
+
393
+ def _communicate_summable_tensor_pair(
394
+ hidden_states: torch.Tensor,
395
+ residual: torch.Tensor,
396
+ forward_batch: ForwardBatch,
397
+ hidden_states_input_mode: ScatterMode,
398
+ residual_input_mode: ScatterMode,
399
+ output_mode: ScatterMode,
400
+ context: _Context,
401
+ ):
402
+ """It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""
403
+
404
+ def _inner():
405
+ nonlocal hidden_states, residual
406
+
407
+ if context.is_same_group_size(
408
+ hidden_states_input_mode, output_mode
409
+ ) and context.is_same_group_size(residual_input_mode, output_mode):
410
+ return hidden_states, residual
411
+
412
+ if (
413
+ (hidden_states_input_mode == ScatterMode.FULL)
414
+ and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
415
+ and (output_mode == ScatterMode.TP_ATTN_FULL)
416
+ ):
417
+ # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
418
+ # important: forward batch.gathered_buffer is used both after scatter and after gather.
419
+ # be careful about this!
420
+ hidden_states, global_hidden_states = (
421
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
422
+ hidden_states,
423
+ )
424
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
425
+ return hidden_states, residual
426
+
427
+ if (
428
+ (hidden_states_input_mode == ScatterMode.SCATTERED)
429
+ and (residual_input_mode == ScatterMode.SCATTERED)
430
+ and (output_mode == ScatterMode.TP_ATTN_FULL)
431
+ ):
432
+ hidden_states += residual
433
+ residual = None
434
+ hidden_states, local_hidden_states = (
435
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
436
+ hidden_states,
437
+ )
438
+ attn_tp_all_gather(
439
+ list(hidden_states.tensor_split(context.attn_tp_size)),
440
+ local_hidden_states,
441
+ )
442
+ return hidden_states, residual
443
+
444
+ raise NotImplementedError(
445
+ f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
446
+ )
447
+
448
+ context.check_shapes(
449
+ (hidden_states, residual), (hidden_states_input_mode, residual_input_mode)
450
+ )
451
+ return context.check_shapes(_inner(), (output_mode, output_mode))
@@ -142,16 +142,6 @@ def get_local_attention_dp_size():
142
142
  return _LOCAL_ATTN_DP_SIZE
143
143
 
144
144
 
145
- def get_local_attention_dp_rank():
146
- assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
147
- return _LOCAL_ATTN_DP_RANK
148
-
149
-
150
- def get_local_attention_dp_size():
151
- assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
152
- return _LOCAL_ATTN_DP_SIZE
153
-
154
-
155
145
  @contextmanager
156
146
  def disable_dp_size():
157
147
  """Patch the tp group temporarily until this function ends.
@@ -0,0 +1,207 @@
1
+ """Cutlass MoE kernel."""
2
+
3
+ import functools
4
+ import json
5
+ import logging
6
+ import os
7
+ from typing import Any, Callable, Dict, List, Optional, Tuple
8
+
9
+ import torch
10
+
11
+ from sglang.srt.utils import is_cuda
12
+
13
+ _is_cuda = is_cuda()
14
+ if _is_cuda:
15
+ import sgl_kernel
16
+ from sgl_kernel import (
17
+ fp8_blockwise_scaled_grouped_mm,
18
+ prepare_moe_input,
19
+ silu_and_mul,
20
+ )
21
+
22
+
23
+ def cutlass_fused_experts(
24
+ a: torch.Tensor,
25
+ w1_q: torch.Tensor,
26
+ w2_q: torch.Tensor,
27
+ w1_scale: torch.Tensor,
28
+ w2_scale: torch.Tensor,
29
+ topk_weights: torch.Tensor,
30
+ topk_ids: torch.Tensor,
31
+ a1_strides: torch.Tensor,
32
+ c1_strides: torch.Tensor,
33
+ a2_strides: torch.Tensor,
34
+ c2_strides: torch.Tensor,
35
+ workspace: torch.Tensor,
36
+ a_ptrs: torch.Tensor,
37
+ b_ptrs: torch.Tensor,
38
+ out_ptrs: torch.Tensor,
39
+ a_scales_ptrs: torch.Tensor,
40
+ b_scales_ptrs: torch.Tensor,
41
+ expert_offsets: torch.Tensor,
42
+ problem_sizes1: torch.Tensor,
43
+ problem_sizes2: torch.Tensor,
44
+ use_fp8_blockscale: bool = True,
45
+ ) -> torch.Tensor:
46
+ """Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations.
47
+
48
+ This function implements a Mixture of Experts (MoE) layer with a SwiGLU/SiLU
49
+ activation, leveraging custom kernels likely derived from CUTLASS principles
50
+ for grouped matrix multiplication (`fp8_blockwise_scaled_grouped_mm`) and
51
+ data preparation (`prepare_moe_input`, `silu_and_mul`).
52
+
53
+ It handles per-token routing, quantizes input activations to FP8 with
54
+ per-token scales, performs the expert computations using FP8 GEMMs with
55
+ pre-quantized FP8 weights (per-block scales), applies the SiLU activation,
56
+ and combines the results weighted by the router scores.
57
+
58
+ Args:
59
+ a (torch.Tensor): Input activations. Shape: `(m, k)`, where `m` is the total
60
+ number of tokens and `k` is the hidden size. Expected dtype: `torch.half`
61
+ or `torch.bfloat16`.
62
+ w1_q (torch.Tensor): Pre-quantized FP8 weight tensor for the first GEMM
63
+ (up-projection part of SwiGLU). Expected shape: `(E, k, n*2)`, where
64
+ `E` is the number of experts, `k` is the hidden size, and `n*2` is the
65
+ intermediate size (`I`). Expected dtype: `torch.float8_e4m3fn`.
66
+ Note: This shape implies weights are stored as (num_experts, hidden_size, intermediate_size).
67
+ w2_q (torch.Tensor): Pre-quantized FP8 weight tensor for the second GEMM
68
+ (down-projection). Expected shape: `(E, n, k)`, where `n` is half the
69
+ intermediate size (`I // 2`). Expected dtype: `torch.float8_e4m3fn`.
70
+ Note: This shape implies weights are stored as (num_experts, intermediate_size // 2, hidden_size).
71
+ w1_scale (torch.Tensor): Scales corresponding to `w1_q` (per-block scales).
72
+ Shape: `(E, num_blocks_n, num_blocks_k)`. Dtype: `torch.float32`.
73
+ w2_scale (torch.Tensor): Scales corresponding to `w2_q` (per-block scales).
74
+ Shape: `(E, num_blocks_k, num_blocks_n)`. Dtype: `torch.float32`.
75
+ topk_weights (torch.Tensor): Router weights for the selected top-k experts
76
+ for each token. Shape: `(m, topk)`. Dtype should ideally match `a`.
77
+ topk_ids (torch.Tensor): Indices of the selected top-k experts for each token.
78
+ Shape: `(m, topk)`. Dtype: `torch.int32`.
79
+ a1_strides (torch.Tensor): Stride information for the first GEMM's 'a' input.
80
+ Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
81
+ Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification
82
+ as it's passed as both a_stride and b_stride in the first call.
83
+ c1_strides (torch.Tensor): Stride information for the first GEMM's 'c' output.
84
+ Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
85
+ a2_strides (torch.Tensor): Stride information for the second GEMM's 'a' input.
86
+ Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
87
+ Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification
88
+ as it's passed as both a_stride and b_stride in the second call.
89
+ c2_strides (torch.Tensor): Stride information for the second GEMM's 'c' output.
90
+ Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
91
+ workspace (torch.Tensor): Reusable workspace for the underlying kernel.
92
+ a_ptrs (torch.Tensor): Pointers container for calculating offsets of the input activations for each expert.
93
+ b_ptrs (torch.Tensor): Pointers container for calculating offsets of the input weights for each expert.
94
+ out_ptrs (torch.Tensor): Pointers container for calculating offsets of the output activations for each expert.
95
+ a_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
96
+ b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
97
+ use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with
98
+ block scaling. Currently, only `True` is supported. Defaults to `True`.
99
+
100
+ Returns:
101
+ torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`.
102
+
103
+ Raises:
104
+ AssertionError: If input shapes, dtypes, or flags are inconsistent or unsupported.
105
+ NotImplementedError: If CUDA is not available or `sgl_kernel` is not properly installed.
106
+ """
107
+ assert use_fp8_blockscale, "Only support fp8 blockscale for now"
108
+ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
109
+ assert w1_q.dtype == torch.float8_e4m3fn
110
+ assert w2_q.dtype == torch.float8_e4m3fn
111
+ assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
112
+ assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
113
+ assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
114
+ assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
115
+ assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
116
+ assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
117
+ assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
118
+
119
+ if is_cuda:
120
+ from sglang.srt.layers.quantization.fp8_kernel import (
121
+ sglang_per_token_group_quant_fp8,
122
+ )
123
+
124
+ out_dtype = a.dtype
125
+ num_experts = w1_q.size(0)
126
+ m = a.size(0)
127
+ k = w1_q.size(1)
128
+ n = w2_q.size(1)
129
+
130
+ topk = topk_ids.size(1)
131
+
132
+ a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
133
+ device = a_q.device
134
+
135
+ a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
136
+ c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
137
+
138
+ prepare_moe_input(
139
+ topk_ids,
140
+ expert_offsets,
141
+ problem_sizes1,
142
+ problem_sizes2,
143
+ a_map,
144
+ c_map,
145
+ num_experts,
146
+ n,
147
+ k,
148
+ )
149
+
150
+ rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
151
+ rep_a1_scales = a1_scale[a_map]
152
+
153
+ c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
154
+ c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
155
+
156
+ a_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)
157
+ w_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)
158
+
159
+ fp8_blockwise_scaled_grouped_mm(
160
+ c1,
161
+ a_ptrs,
162
+ b_ptrs,
163
+ out_ptrs,
164
+ a_scales_ptrs,
165
+ b_scales_ptrs,
166
+ rep_a_q,
167
+ w1_q,
168
+ rep_a1_scales,
169
+ w1_scale,
170
+ a1_strides,
171
+ a1_strides,
172
+ c1_strides,
173
+ a_sf_layout,
174
+ w_sf_layout,
175
+ problem_sizes1,
176
+ expert_offsets[:-1],
177
+ workspace,
178
+ )
179
+
180
+ intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
181
+ silu_and_mul(c1, intermediate)
182
+
183
+ intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
184
+
185
+ fp8_blockwise_scaled_grouped_mm(
186
+ c2,
187
+ a_ptrs,
188
+ b_ptrs,
189
+ out_ptrs,
190
+ a_scales_ptrs,
191
+ b_scales_ptrs,
192
+ intemediate_q,
193
+ w2_q,
194
+ a2_scale,
195
+ w2_scale,
196
+ a2_strides,
197
+ a2_strides,
198
+ c2_strides,
199
+ a_sf_layout,
200
+ w_sf_layout,
201
+ problem_sizes2,
202
+ expert_offsets[:-1],
203
+ workspace,
204
+ )
205
+ return (
206
+ c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
207
+ ).sum(dim=1)