sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/srt/configs/model_config.py +35 -0
  3. sglang/srt/conversation.py +9 -5
  4. sglang/srt/disaggregation/base/conn.py +5 -2
  5. sglang/srt/disaggregation/decode.py +6 -1
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  8. sglang/srt/disaggregation/prefill.py +2 -0
  9. sglang/srt/distributed/parallel_state.py +11 -9
  10. sglang/srt/entrypoints/context.py +244 -0
  11. sglang/srt/entrypoints/engine.py +4 -3
  12. sglang/srt/entrypoints/harmony_utils.py +370 -0
  13. sglang/srt/entrypoints/http_server.py +71 -0
  14. sglang/srt/entrypoints/openai/protocol.py +227 -1
  15. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  16. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  18. sglang/srt/entrypoints/tool.py +87 -0
  19. sglang/srt/eplb/expert_location.py +5 -1
  20. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  21. sglang/srt/hf_transformers_utils.py +30 -3
  22. sglang/srt/jinja_template_utils.py +8 -1
  23. sglang/srt/layers/attention/aiter_backend.py +5 -8
  24. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  25. sglang/srt/layers/attention/triton_backend.py +85 -14
  26. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  28. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  29. sglang/srt/layers/attention/vision.py +13 -5
  30. sglang/srt/layers/communicator.py +21 -4
  31. sglang/srt/layers/dp_attention.py +12 -0
  32. sglang/srt/layers/linear.py +2 -7
  33. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  34. sglang/srt/layers/moe/ep_moe/layer.py +77 -73
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
  37. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  38. sglang/srt/layers/moe/topk.py +12 -3
  39. sglang/srt/layers/moe/utils.py +16 -0
  40. sglang/srt/layers/quantization/__init__.py +22 -0
  41. sglang/srt/layers/quantization/fp4.py +557 -0
  42. sglang/srt/layers/quantization/fp8.py +3 -6
  43. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  44. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  45. sglang/srt/layers/quantization/mxfp4.py +651 -0
  46. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  47. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  48. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  49. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  50. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  51. sglang/srt/layers/quantization/quark/utils.py +107 -0
  52. sglang/srt/layers/quantization/unquant.py +60 -6
  53. sglang/srt/layers/quantization/w4afp8.py +1 -1
  54. sglang/srt/layers/rotary_embedding.py +225 -1
  55. sglang/srt/layers/utils.py +9 -0
  56. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  57. sglang/srt/lora/lora_manager.py +70 -14
  58. sglang/srt/lora/lora_registry.py +3 -2
  59. sglang/srt/lora/mem_pool.py +43 -5
  60. sglang/srt/managers/cache_controller.py +55 -30
  61. sglang/srt/managers/detokenizer_manager.py +1 -1
  62. sglang/srt/managers/io_struct.py +15 -3
  63. sglang/srt/managers/mm_utils.py +5 -11
  64. sglang/srt/managers/schedule_batch.py +28 -7
  65. sglang/srt/managers/scheduler.py +26 -12
  66. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  67. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  68. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  69. sglang/srt/managers/template_manager.py +35 -1
  70. sglang/srt/managers/tokenizer_manager.py +24 -6
  71. sglang/srt/managers/tp_worker.py +3 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  73. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  74. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  75. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  76. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  77. sglang/srt/model_executor/cuda_graph_runner.py +7 -6
  78. sglang/srt/model_executor/forward_batch_info.py +35 -14
  79. sglang/srt/model_executor/model_runner.py +19 -2
  80. sglang/srt/model_loader/weight_utils.py +10 -0
  81. sglang/srt/models/bailing_moe.py +425 -0
  82. sglang/srt/models/deepseek_v2.py +72 -33
  83. sglang/srt/models/ernie4.py +426 -0
  84. sglang/srt/models/ernie4_eagle.py +203 -0
  85. sglang/srt/models/gemma3n_mm.py +39 -0
  86. sglang/srt/models/glm4_moe.py +24 -12
  87. sglang/srt/models/gpt_oss.py +1134 -0
  88. sglang/srt/models/qwen2.py +6 -0
  89. sglang/srt/models/qwen2_moe.py +6 -0
  90. sglang/srt/models/qwen3_moe.py +32 -6
  91. sglang/srt/models/step3_vl.py +9 -0
  92. sglang/srt/models/transformers.py +2 -5
  93. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  94. sglang/srt/reasoning_parser.py +18 -39
  95. sglang/srt/server_args.py +142 -7
  96. sglang/srt/two_batch_overlap.py +157 -5
  97. sglang/srt/utils.py +38 -2
  98. sglang/test/runners.py +2 -2
  99. sglang/test/test_utils.py +1 -1
  100. sglang/version.py +1 -1
  101. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
  102. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
  103. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  104. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  105. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,332 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Support attention backend for TRTLLM MHA kernels from flashinfer.
5
+ The kernel supports sm100 only, with sliding window and attention sink features.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import TYPE_CHECKING, Optional
10
+
11
+ import torch
12
+
13
+ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
14
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
15
+ from sglang.srt.utils import is_flashinfer_available
16
+
17
+ if is_flashinfer_available():
18
+ import flashinfer
19
+
20
+ if TYPE_CHECKING:
21
+ from sglang.srt.layers.radix_attention import RadixAttention
22
+ from sglang.srt.model_executor.model_runner import ModelRunner
23
+ from sglang.srt.speculative.spec_info import SpecInfo
24
+
25
+ # Constants
26
+ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
27
+
28
+ # Reuse this workspace buffer across all TRTLLM MHA wrappers
29
+ global_workspace_buffer = None
30
+
31
+
32
+ @dataclass
33
+ class TRTLLMMHAMetadata:
34
+ # Sequence lengths for the forward batch
35
+ cache_seqlens_int32: torch.Tensor = None
36
+ # Maximum sequence length for query
37
+ max_seq_len_q: int = 1
38
+ # Maximum sequence length for key
39
+ max_seq_len_k: int = 0
40
+ # Cumulative sequence lengths for `query
41
+ cu_seqlens_q: torch.Tensor = None
42
+ # Cumulative sequence lengths for key
43
+ cu_seqlens_k: torch.Tensor = None
44
+ # Page table, the index of KV Cache Tables/Blocks
45
+ page_table: torch.Tensor = None
46
+
47
+
48
+ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
49
+ """TRTLLM MHA attention kernel from flashinfer."""
50
+
51
+ def __init__(
52
+ self,
53
+ model_runner: ModelRunner,
54
+ skip_prefill: bool = False,
55
+ kv_indptr_buf: Optional[torch.Tensor] = None,
56
+ q_indptr_decode_buf: Optional[torch.Tensor] = None,
57
+ ):
58
+ super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
59
+
60
+ config = model_runner.model_config
61
+
62
+ # MHA-specific dimensions
63
+ self.max_context_len = model_runner.model_config.context_len
64
+ self.hidden_size = config.hidden_size
65
+
66
+ # Runtime parameters
67
+ self.data_type = model_runner.kv_cache_dtype
68
+ self.q_data_type = model_runner.dtype
69
+ self.page_size = model_runner.page_size
70
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
71
+ self.device = model_runner.device
72
+
73
+ # Workspace allocation
74
+ self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
75
+ # Allocate buffers
76
+ global global_workspace_buffer
77
+ if global_workspace_buffer is None:
78
+ global_workspace_buffer = torch.empty(
79
+ self.workspace_size,
80
+ dtype=torch.uint8,
81
+ device=model_runner.device,
82
+ )
83
+ self.workspace_buffer = global_workspace_buffer
84
+
85
+ # CUDA graph state
86
+ self.decode_cuda_graph_metadata = {}
87
+
88
+ # Forward metadata
89
+ self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
90
+
91
+ def init_cuda_graph_state(
92
+ self,
93
+ max_bs: int,
94
+ max_num_tokens: int,
95
+ kv_indices_buf: Optional[torch.Tensor] = None,
96
+ ):
97
+ """Initialize CUDA graph state for TRTLLM MHA."""
98
+ self.decode_cuda_graph_metadata = {
99
+ "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
100
+ "page_table": torch.zeros(
101
+ max_bs,
102
+ (self.max_context_len + self.page_size - 1) // self.page_size,
103
+ dtype=torch.int32,
104
+ device=self.device,
105
+ ),
106
+ "strided_indices": torch.arange(
107
+ 0, self.max_context_len, self.page_size, device=self.device
108
+ ),
109
+ }
110
+
111
+ def init_forward_metadata_capture_cuda_graph(
112
+ self,
113
+ bs: int,
114
+ num_tokens: int,
115
+ req_pool_indices: torch.Tensor,
116
+ seq_lens: torch.Tensor,
117
+ encoder_lens: Optional[torch.Tensor],
118
+ forward_mode: ForwardMode,
119
+ spec_info: Optional[SpecInfo],
120
+ ):
121
+ """Initialize metadata for CUDA graph capture."""
122
+ metadata = TRTLLMMHAMetadata()
123
+
124
+ # Get sequence information
125
+ metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
126
+
127
+ # Precompute maximum sequence length
128
+ metadata.max_seq_len_k = self.max_context_len
129
+
130
+ # Precompute page table
131
+ metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
132
+ self.decode_cuda_graph_metadata[bs] = metadata
133
+ self.forward_metadata = metadata
134
+
135
+ def init_forward_metadata_replay_cuda_graph(
136
+ self,
137
+ bs: int,
138
+ req_pool_indices: torch.Tensor,
139
+ seq_lens: torch.Tensor,
140
+ seq_lens_sum: int,
141
+ encoder_lens: Optional[torch.Tensor],
142
+ forward_mode: ForwardMode,
143
+ spec_info: Optional[SpecInfo],
144
+ seq_lens_cpu: Optional[torch.Tensor],
145
+ ):
146
+ """Replay CUDA graph with new inputs."""
147
+ seq_lens = seq_lens[:bs]
148
+ seq_lens_cpu = seq_lens_cpu[:bs]
149
+ req_pool_indices = req_pool_indices[:bs]
150
+ device = seq_lens.device
151
+ metadata = None
152
+
153
+ # Normal Decode
154
+ metadata = self.decode_cuda_graph_metadata[bs]
155
+ max_len = seq_lens_cpu.max().item()
156
+ max_seq_pages = (max_len + self.page_size - 1) // self.page_size
157
+ metadata.max_seq_len_k = self.max_context_len
158
+
159
+ metadata.cache_seqlens_int32.copy_(seq_lens)
160
+ page_indices = self.req_to_token[
161
+ req_pool_indices[:, None],
162
+ self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][None, :],
163
+ ]
164
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
165
+ self.forward_metadata = metadata
166
+
167
+ def get_cuda_graph_seq_len_fill_value(self) -> int:
168
+ """Get the fill value for sequence lengths in CUDA graph."""
169
+ return 1
170
+
171
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
172
+ """Initialize the metadata for a forward pass."""
173
+
174
+ metadata = TRTLLMMHAMetadata()
175
+ seqlens_in_batch = forward_batch.seq_lens
176
+ batch_size = forward_batch.batch_size
177
+ device = seqlens_in_batch.device
178
+
179
+ if forward_batch.forward_mode.is_decode_or_idle():
180
+ # Normal Decode
181
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
182
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
183
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
184
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
185
+ ]
186
+ else:
187
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
188
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
189
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
190
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
191
+ )
192
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
193
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
194
+ ]
195
+
196
+ if any(forward_batch.extend_prefix_lens_cpu):
197
+ extend_seq_lens = forward_batch.extend_seq_lens
198
+ metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
199
+ metadata.cu_seqlens_q = torch.nn.functional.pad(
200
+ torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
201
+ )
202
+ else:
203
+ metadata.max_seq_len_q = metadata.max_seq_len_k
204
+ metadata.cu_seqlens_q = metadata.cu_seqlens_k
205
+
206
+ # Convert the page table to a strided format
207
+ if self.page_size > 1:
208
+ self.strided_indices = torch.arange(
209
+ 0, metadata.page_table.shape[1], self.page_size, device=self.device
210
+ )
211
+ metadata.page_table = (
212
+ metadata.page_table[:, self.strided_indices] // self.page_size
213
+ )
214
+
215
+ self.forward_metadata = metadata
216
+
217
+ def forward_decode(
218
+ self,
219
+ q: torch.Tensor,
220
+ k: torch.Tensor,
221
+ v: torch.Tensor,
222
+ layer: RadixAttention,
223
+ forward_batch: ForwardBatch,
224
+ save_kv_cache: bool = True,
225
+ **kwargs,
226
+ ) -> torch.Tensor:
227
+ """Run forward for decode using TRTLLM MHA kernel."""
228
+ cache_loc = forward_batch.out_cache_loc
229
+ if save_kv_cache and k is not None:
230
+ forward_batch.token_to_kv_pool.set_kv_buffer(
231
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
232
+ )
233
+
234
+ q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
235
+ k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
236
+ # shape conversion:
237
+ # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
238
+ k_cache = k_cache.view(
239
+ -1, self.page_size, layer.tp_k_head_num, layer.head_dim
240
+ ).permute(0, 2, 1, 3)
241
+ v_cache = v_cache.view(
242
+ -1, self.page_size, layer.tp_v_head_num, layer.head_dim
243
+ ).permute(0, 2, 1, 3)
244
+ kv_cache = (k_cache, v_cache)
245
+
246
+ # TODO: add support for quantization
247
+ q_scale = 1.0
248
+ k_scale = (
249
+ layer.k_scale_float
250
+ if getattr(layer, "k_scale_float", None) is not None
251
+ else 1.0
252
+ )
253
+ bmm1_scale = q_scale * k_scale * layer.scaling
254
+ bmm2_scale = 1.0
255
+ # sink: additional value per head in the denominator of the softmax.
256
+ attention_sink = kwargs.get("sinks", None)
257
+
258
+ # Call TRT-LLM kernel
259
+ # raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
260
+ o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
261
+ query=q,
262
+ kv_cache=kv_cache,
263
+ workspace_buffer=self.workspace_buffer,
264
+ block_tables=self.forward_metadata.page_table,
265
+ seq_lens=self.forward_metadata.cache_seqlens_int32,
266
+ max_seq_len=self.forward_metadata.max_seq_len_k,
267
+ bmm1_scale=bmm1_scale,
268
+ bmm2_scale=bmm2_scale,
269
+ window_left=layer.sliding_window_size,
270
+ # TODO: add attention_sink operation or nvfp4 scale factor if needed
271
+ sinks=attention_sink,
272
+ )
273
+
274
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
275
+
276
+ def forward_extend(
277
+ self,
278
+ q: torch.Tensor,
279
+ k: torch.Tensor,
280
+ v: torch.Tensor,
281
+ layer: RadixAttention,
282
+ forward_batch: ForwardBatch,
283
+ save_kv_cache=True,
284
+ **kwargs,
285
+ ):
286
+ cache_loc = forward_batch.out_cache_loc
287
+ if save_kv_cache and k is not None:
288
+ forward_batch.token_to_kv_pool.set_kv_buffer(
289
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
290
+ )
291
+ q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
292
+ # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
293
+ k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
294
+ k_cache = k_cache.view(
295
+ -1, self.page_size, layer.tp_k_head_num, layer.head_dim
296
+ ).permute(0, 2, 1, 3)
297
+ v_cache = v_cache.view(
298
+ -1, self.page_size, layer.tp_v_head_num, layer.head_dim
299
+ ).permute(0, 2, 1, 3)
300
+ kv_cache = (k_cache, v_cache)
301
+
302
+ # sink: additional value per head in the denominator of the softmax.
303
+ attention_sink = kwargs.get("sinks", None)
304
+ # TODO: add support for quantization
305
+ q_scale = 1.0
306
+ k_scale = (
307
+ layer.k_scale_float
308
+ if getattr(layer, "k_scale_float", None) is not None
309
+ else 1.0
310
+ )
311
+ bmm1_scale = q_scale * k_scale * layer.scaling
312
+ bmm2_scale = 1.0
313
+
314
+ o = flashinfer.prefill.trtllm_batch_context_with_kv_cache(
315
+ query=q,
316
+ kv_cache=kv_cache,
317
+ workspace_buffer=self.workspace_buffer,
318
+ block_tables=self.forward_metadata.page_table,
319
+ seq_lens=self.forward_metadata.cache_seqlens_int32,
320
+ max_q_len=self.forward_metadata.max_seq_len_q,
321
+ max_kv_len=self.forward_metadata.max_seq_len_k,
322
+ bmm1_scale=bmm1_scale,
323
+ bmm2_scale=bmm2_scale,
324
+ batch_size=forward_batch.batch_size,
325
+ cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,
326
+ cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
327
+ window_left=layer.sliding_window_size,
328
+ # TODO: add attention_sink operation or nvfp4 scale factor if needed
329
+ sinks=attention_sink,
330
+ )
331
+
332
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -11,6 +11,7 @@ import torch.nn as nn
11
11
  import torch.nn.functional as F
12
12
  from einops import rearrange
13
13
 
14
+ from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
14
15
  from sglang.srt.utils import is_cuda, print_info_once
15
16
 
16
17
  _is_cuda = is_cuda()
@@ -365,19 +366,20 @@ class VisionAttention(nn.Module):
365
366
  **kwargs,
366
367
  ):
367
368
  super().__init__()
368
- world_size = parallel_state.get_tensor_model_parallel_world_size()
369
- self.tp_size = world_size
370
- self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
369
+ attn_tp_rank = get_attention_tp_rank()
370
+ attn_tp_size = get_attention_tp_size()
371
+ self.tp_size = attn_tp_size
372
+ self.tp_rank = attn_tp_rank
371
373
  self.dropout = dropout
372
374
  self.head_size = embed_dim // num_heads
373
375
  self.hidden_size_per_attention_head = dist_utils.divide(
374
376
  projection_size, num_heads
375
377
  )
376
378
  self.num_attention_heads_per_partition = dist_utils.divide(
377
- num_dummy_heads + num_heads, world_size
379
+ num_dummy_heads + num_heads, self.tp_size
378
380
  )
379
381
  self.num_attention_kv_heads_per_partition = dist_utils.divide(
380
- num_dummy_heads + num_heads, world_size
382
+ num_dummy_heads + num_heads, self.tp_size
381
383
  )
382
384
 
383
385
  self.q_size = self.num_attention_heads_per_partition * self.head_size
@@ -427,6 +429,8 @@ class VisionAttention(nn.Module):
427
429
  total_num_kv_heads=num_dummy_heads + num_heads,
428
430
  bias=qkv_bias,
429
431
  quant_config=quant_config,
432
+ tp_rank=self.tp_rank,
433
+ tp_size=self.tp_size,
430
434
  prefix=add_prefix("qkv_proj", prefix),
431
435
  )
432
436
  else:
@@ -435,6 +439,8 @@ class VisionAttention(nn.Module):
435
439
  output_size=3 * self.dummy_dim,
436
440
  bias=qkv_bias,
437
441
  quant_config=quant_config,
442
+ tp_rank=self.tp_rank,
443
+ tp_size=self.tp_size,
438
444
  prefix=add_prefix("qkv_proj", prefix),
439
445
  )
440
446
  self.proj = RowParallelLinear(
@@ -442,6 +448,8 @@ class VisionAttention(nn.Module):
442
448
  output_size=embed_dim,
443
449
  bias=proj_bias,
444
450
  quant_config=quant_config,
451
+ tp_rank=self.tp_rank,
452
+ tp_size=self.tp_size,
445
453
  prefix=add_prefix("proj", prefix),
446
454
  )
447
455
 
@@ -27,6 +27,7 @@ from sglang.srt.layers.dp_attention import (
27
27
  attn_tp_all_gather_into_tensor,
28
28
  attn_tp_reduce_scatter_tensor,
29
29
  dp_gather_partial,
30
+ dp_reduce_scatter_tensor,
30
31
  dp_scatter,
31
32
  get_attention_dp_size,
32
33
  get_attention_tp_rank,
@@ -149,10 +150,13 @@ class LayerCommunicator:
149
150
  layer_scatter_modes: LayerScatterModes,
150
151
  input_layernorm: torch.nn.Module,
151
152
  post_attention_layernorm: torch.nn.Module,
153
+ # Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
154
+ allow_reduce_scatter: bool = False,
152
155
  ):
153
156
  self.layer_scatter_modes = layer_scatter_modes
154
157
  self.input_layernorm = input_layernorm
155
158
  self.post_attention_layernorm = post_attention_layernorm
159
+ self.allow_reduce_scatter = allow_reduce_scatter
156
160
 
157
161
  self._context = CommunicateContext.init_new()
158
162
  self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
@@ -239,6 +243,15 @@ class LayerCommunicator:
239
243
  residual=residual,
240
244
  forward_batch=forward_batch,
241
245
  context=self._context,
246
+ allow_reduce_scatter=self.allow_reduce_scatter,
247
+ )
248
+
249
+ def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
250
+ return (
251
+ self.allow_reduce_scatter
252
+ and self._communicate_summable_tensor_pair_fn
253
+ is CommunicateSummableTensorPairFn._scatter_hidden_states
254
+ and forward_batch.dp_padding_mode.is_max_len()
242
255
  )
243
256
 
244
257
 
@@ -524,6 +537,7 @@ class CommunicateSummableTensorPairFn:
524
537
  residual: torch.Tensor,
525
538
  forward_batch: ForwardBatch,
526
539
  context: CommunicateContext,
540
+ **kwargs,
527
541
  ):
528
542
  return hidden_states, residual
529
543
 
@@ -533,15 +547,17 @@ class CommunicateSummableTensorPairFn:
533
547
  residual: torch.Tensor,
534
548
  forward_batch: ForwardBatch,
535
549
  context: CommunicateContext,
550
+ allow_reduce_scatter: bool = False,
536
551
  ):
537
- # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
538
- # important: forward batch.gathered_buffer is used both after scatter and after gather.
539
- # be careful about this!
540
552
  hidden_states, global_hidden_states = (
541
553
  forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
542
554
  hidden_states,
543
555
  )
544
- dp_scatter(hidden_states, global_hidden_states, forward_batch)
556
+ if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
557
+ # When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
558
+ dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
559
+ else:
560
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
545
561
  return hidden_states, residual
546
562
 
547
563
  @staticmethod
@@ -550,6 +566,7 @@ class CommunicateSummableTensorPairFn:
550
566
  residual: torch.Tensor,
551
567
  forward_batch: ForwardBatch,
552
568
  context: CommunicateContext,
569
+ **kwargs,
553
570
  ):
554
571
  hidden_states += residual
555
572
  residual = None
@@ -12,6 +12,7 @@ import triton.language as tl
12
12
 
13
13
  from sglang.srt.distributed import (
14
14
  GroupCoordinator,
15
+ get_tensor_model_parallel_rank,
15
16
  get_tensor_model_parallel_world_size,
16
17
  get_tp_group,
17
18
  tensor_model_parallel_all_reduce,
@@ -355,6 +356,17 @@ def dp_scatter(
355
356
  )
356
357
 
357
358
 
359
+ def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
360
+ if get_tensor_model_parallel_world_size() == get_attention_dp_size():
361
+ get_tp_group().reduce_scatter_tensor(output, input)
362
+ else:
363
+ scattered_local_tokens = input.tensor_split(
364
+ get_tensor_model_parallel_world_size()
365
+ )[get_tensor_model_parallel_rank()]
366
+ get_tp_group().reduce_scatter_tensor(scattered_local_tokens, input)
367
+ get_attention_tp_group().all_gather_into_tensor(output, scattered_local_tokens)
368
+
369
+
358
370
  def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
359
371
  return get_attention_tp_group().reduce_scatter_tensor(output, input)
360
372
 
@@ -1191,11 +1191,6 @@ class RowParallelLinear(LinearBase):
1191
1191
  else self.weight_loader
1192
1192
  ),
1193
1193
  )
1194
- if not reduce_results and (bias and not skip_bias_add):
1195
- raise ValueError(
1196
- "When not reduce the results, adding bias to the "
1197
- "results can lead to incorrect results"
1198
- )
1199
1194
 
1200
1195
  if bias:
1201
1196
  self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
@@ -1282,7 +1277,7 @@ class RowParallelLinear(LinearBase):
1282
1277
  # It does not support additional parameters.
1283
1278
  param.load_row_parallel_weight(loaded_weight)
1284
1279
 
1285
- def forward(self, input_, can_fuse_mlp_allreduce=False):
1280
+ def forward(self, input_, skip_all_reduce=False):
1286
1281
  if self.input_is_parallel:
1287
1282
  input_parallel = input_
1288
1283
  else:
@@ -1299,7 +1294,7 @@ class RowParallelLinear(LinearBase):
1299
1294
  with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
1300
1295
  output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
1301
1296
  sm.tag(output_parallel)
1302
- if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
1297
+ if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
1303
1298
  output = tensor_model_parallel_all_reduce(output_parallel)
1304
1299
  else:
1305
1300
  output = output_parallel
@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
9
9
  import torch
10
10
 
11
11
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
12
+ from sglang.srt.layers.utils import is_sm100_supported
12
13
  from sglang.srt.utils import is_cuda
13
14
 
14
15
  _is_cuda = is_cuda()
@@ -123,6 +124,7 @@ def cutlass_fused_experts_fp8(
123
124
 
124
125
  if is_cuda:
125
126
  from sglang.srt.layers.quantization.fp8_kernel import (
127
+ per_token_group_quant_fp8_hopper_moe_mn_major,
126
128
  sglang_per_token_group_quant_fp8,
127
129
  )
128
130
 
@@ -133,9 +135,7 @@ def cutlass_fused_experts_fp8(
133
135
  n = w2_q.size(1)
134
136
 
135
137
  topk = topk_ids.size(1)
136
-
137
- a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
138
- device = a_q.device
138
+ device = a.device
139
139
 
140
140
  a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
141
141
  c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
@@ -152,8 +152,16 @@ def cutlass_fused_experts_fp8(
152
152
  k,
153
153
  )
154
154
 
155
- rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
156
- rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
155
+ if is_sm100_supported():
156
+ a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
157
+ rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
158
+ rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
159
+ else:
160
+ rep_a = shuffle_rows(a, a_map, (m * topk, k))
161
+ rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major(
162
+ rep_a, expert_offsets, problem_sizes1, 128
163
+ )
164
+ w1_scale = w1_scale.contiguous()
157
165
 
158
166
  c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
159
167
  c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
@@ -185,7 +193,13 @@ def cutlass_fused_experts_fp8(
185
193
  intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
186
194
  silu_and_mul(c1, intermediate)
187
195
 
188
- intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
196
+ if is_sm100_supported():
197
+ intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
198
+ else:
199
+ intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major(
200
+ intermediate, expert_offsets, problem_sizes2, 128
201
+ )
202
+ w2_scale = w2_scale.contiguous()
189
203
 
190
204
  fp8_blockwise_scaled_grouped_mm(
191
205
  c2,