sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,451 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ from dataclasses import dataclass
16
+ from enum import Enum, auto
17
+ from typing import Dict, Optional, Tuple
18
+
19
+ import torch.distributed
20
+
21
+ from sglang.srt.distributed import (
22
+ get_tensor_model_parallel_world_size,
23
+ tensor_model_parallel_all_reduce,
24
+ )
25
+ from sglang.srt.layers.dp_attention import (
26
+ attn_tp_all_gather,
27
+ attn_tp_reduce_scatter,
28
+ dp_gather_partial,
29
+ dp_scatter,
30
+ get_attention_tp_rank,
31
+ get_attention_tp_size,
32
+ get_local_attention_dp_size,
33
+ )
34
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
35
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
36
+
37
+
38
+ class ScatterMode(Enum):
39
+ SCATTERED = auto()
40
+ TP_ATTN_FULL = auto()
41
+ FULL = auto()
42
+
43
+
44
+ @dataclass
45
+ class _LayerModeComputationContext:
46
+ num_layers: int
47
+ layer_id: int
48
+ is_layer_sparse: bool
49
+ is_previous_layer_sparse: Optional[bool]
50
+
51
+ def previous_layer(self):
52
+ assert self.is_previous_layer_sparse is not None
53
+ return _LayerModeComputationContext(
54
+ layer_id=self.layer_id - 1,
55
+ is_layer_sparse=self.is_previous_layer_sparse,
56
+ is_previous_layer_sparse=None,
57
+ num_layers=self.num_layers,
58
+ )
59
+
60
+
61
+ @dataclass
62
+ class LayerScatterModes:
63
+ layer_input_mode: ScatterMode
64
+ attn_mode: ScatterMode
65
+ # Can be further split into e.g. mlp_input_mode and mlp_output_mode if needed
66
+ mlp_mode: ScatterMode
67
+ middle_residual_mode: ScatterMode
68
+ layer_output_mode: ScatterMode
69
+
70
+ @classmethod
71
+ def init_new(cls, **kwargs):
72
+ context = _LayerModeComputationContext(**kwargs)
73
+ return cls(
74
+ layer_input_mode=cls._compute_layer_input_mode(context),
75
+ attn_mode=ScatterMode.TP_ATTN_FULL,
76
+ mlp_mode=cls._compute_mlp_mode(context),
77
+ middle_residual_mode=cls._compute_middle_residual_mode(context),
78
+ layer_output_mode=cls._compute_layer_output_mode(context),
79
+ )
80
+
81
+ @classmethod
82
+ def _compute_layer_input_mode(cls, context: _LayerModeComputationContext):
83
+ if context.layer_id == 0:
84
+ return ScatterMode.TP_ATTN_FULL
85
+ return cls._compute_layer_output_mode(context.previous_layer())
86
+
87
+ @classmethod
88
+ def _compute_mlp_mode(cls, context: _LayerModeComputationContext):
89
+ if context.is_layer_sparse:
90
+ return (
91
+ ScatterMode.SCATTERED
92
+ if global_server_args_dict["enable_deepep_moe"]
93
+ else ScatterMode.FULL
94
+ )
95
+ else:
96
+ return (
97
+ ScatterMode.SCATTERED
98
+ if enable_moe_dense_fully_dp()
99
+ else ScatterMode.FULL
100
+ )
101
+
102
+ @classmethod
103
+ def _compute_middle_residual_mode(cls, context: _LayerModeComputationContext):
104
+ mlp_mode = cls._compute_mlp_mode(context)
105
+ if mlp_mode == ScatterMode.SCATTERED:
106
+ return ScatterMode.SCATTERED
107
+ if mlp_mode == ScatterMode.FULL:
108
+ return ScatterMode.TP_ATTN_FULL
109
+ raise NotImplementedError
110
+
111
+ @classmethod
112
+ def _compute_layer_output_mode(cls, context: _LayerModeComputationContext):
113
+ mlp_mode = cls._compute_mlp_mode(context)
114
+ if context.layer_id == context.num_layers - 1:
115
+ return ScatterMode.TP_ATTN_FULL
116
+ if mlp_mode == ScatterMode.SCATTERED:
117
+ return ScatterMode.SCATTERED
118
+ if mlp_mode == ScatterMode.FULL:
119
+ return ScatterMode.TP_ATTN_FULL
120
+ raise NotImplementedError
121
+
122
+
123
+ def enable_moe_dense_fully_dp():
124
+ return global_server_args_dict["moe_dense_tp_size"] == 1
125
+
126
+
127
+ class LayerCommunicator:
128
+ def __init__(
129
+ self,
130
+ layer_scatter_modes: LayerScatterModes,
131
+ input_layernorm: torch.nn.Module,
132
+ post_attention_layernorm: torch.nn.Module,
133
+ ):
134
+ self.layer_scatter_modes = layer_scatter_modes
135
+ self.input_layernorm = input_layernorm
136
+ self.post_attention_layernorm = post_attention_layernorm
137
+
138
+ self.attn_tp_rank = get_attention_tp_rank()
139
+ self.attn_tp_size = get_attention_tp_size()
140
+ self.local_attn_dp_size = get_local_attention_dp_size()
141
+ self.tp_size = get_tensor_model_parallel_world_size()
142
+ self.process_group_sizes = {
143
+ ScatterMode.SCATTERED: 1,
144
+ ScatterMode.TP_ATTN_FULL: self.attn_tp_size,
145
+ ScatterMode.FULL: self.tp_size,
146
+ }
147
+
148
+ def prepare_attn(
149
+ self,
150
+ hidden_states: torch.Tensor,
151
+ residual: torch.Tensor,
152
+ forward_batch: ForwardBatch,
153
+ ):
154
+ if hidden_states.shape[0] == 0:
155
+ residual = hidden_states
156
+ else:
157
+ if residual is None:
158
+ residual = hidden_states
159
+ hidden_states = self.input_layernorm(hidden_states)
160
+ else:
161
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
162
+
163
+ hidden_states = _communicate_simple(
164
+ hidden_states=hidden_states,
165
+ forward_batch=forward_batch,
166
+ input_mode=self.layer_scatter_modes.layer_input_mode,
167
+ output_mode=self.layer_scatter_modes.attn_mode,
168
+ context=self._compute_context(forward_batch),
169
+ )
170
+
171
+ return hidden_states, residual
172
+
173
+ def prepare_mlp(
174
+ self,
175
+ hidden_states: torch.Tensor,
176
+ residual: torch.Tensor,
177
+ forward_batch: ForwardBatch,
178
+ ):
179
+ return _communicate_with_all_reduce_and_layer_norm(
180
+ hidden_states=hidden_states,
181
+ residual=residual,
182
+ forward_batch=forward_batch,
183
+ hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
184
+ residual_input_mode=self.layer_scatter_modes.layer_input_mode,
185
+ hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
186
+ residual_output_mode=self.layer_scatter_modes.middle_residual_mode,
187
+ layernorm=self.post_attention_layernorm,
188
+ context=self._compute_context(forward_batch),
189
+ )
190
+
191
+ def postprocess_layer(
192
+ self,
193
+ hidden_states: torch.Tensor,
194
+ residual: torch.Tensor,
195
+ forward_batch: ForwardBatch,
196
+ ):
197
+ return _communicate_summable_tensor_pair(
198
+ hidden_states=hidden_states,
199
+ residual=residual,
200
+ forward_batch=forward_batch,
201
+ hidden_states_input_mode=self.layer_scatter_modes.mlp_mode,
202
+ residual_input_mode=self.layer_scatter_modes.middle_residual_mode,
203
+ output_mode=self.layer_scatter_modes.layer_output_mode,
204
+ context=self._compute_context(forward_batch),
205
+ )
206
+
207
+ def _compute_context(self, forward_batch: ForwardBatch):
208
+ return _Context(
209
+ num_tokens_of_mode=_compute_num_tokens_of_mode(
210
+ forward_batch,
211
+ attn_tp_rank=self.attn_tp_rank,
212
+ attn_tp_size=self.attn_tp_size,
213
+ ),
214
+ process_group_sizes=self.process_group_sizes,
215
+ attn_tp_rank=self.attn_tp_rank,
216
+ attn_tp_size=self.attn_tp_size,
217
+ local_attn_dp_size=self.local_attn_dp_size,
218
+ tp_size=self.tp_size,
219
+ )
220
+
221
+
222
+ def _compute_num_tokens_of_mode(
223
+ forward_batch: ForwardBatch, attn_tp_rank: int, attn_tp_size: int
224
+ ):
225
+ tp_attn_full_num_tokens = forward_batch.input_ids.shape[0]
226
+ return {
227
+ ScatterMode.SCATTERED: _torch_tensor_split_len(
228
+ tp_attn_full_num_tokens, attn_tp_size, attn_tp_rank
229
+ ),
230
+ ScatterMode.TP_ATTN_FULL: tp_attn_full_num_tokens,
231
+ ScatterMode.FULL: (
232
+ forward_batch.gathered_buffer.shape[0]
233
+ if global_server_args_dict["enable_dp_attention"]
234
+ else forward_batch.input_ids.shape[0]
235
+ ),
236
+ }
237
+
238
+
239
+ def _torch_tensor_split_len(tensor_len: int, n: int, output_index: int):
240
+ if output_index < int(tensor_len % n):
241
+ return int(tensor_len / n) + 1
242
+ else:
243
+ return int(tensor_len / n)
244
+
245
+
246
+ @dataclass
247
+ class _Context:
248
+ num_tokens_of_mode: Dict["ScatterMode", int]
249
+ process_group_sizes: Dict["ScatterMode", int]
250
+ attn_tp_rank: int
251
+ attn_tp_size: int
252
+ local_attn_dp_size: int
253
+ tp_size: int
254
+
255
+ def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"):
256
+ return self.process_group_sizes[a] == self.process_group_sizes[b]
257
+
258
+ def check_shape(self, x: torch.Tensor, mode: ScatterMode):
259
+ if x is None:
260
+ return
261
+
262
+ actual_num_tokens = x.shape[0]
263
+ expect_num_tokens = self.num_tokens_of_mode[mode]
264
+ assert (
265
+ actual_num_tokens == expect_num_tokens
266
+ ), f"{actual_num_tokens=} {expect_num_tokens=} {mode=} {x.shape=} {self.num_tokens_of_mode=} {self.process_group_sizes=}"
267
+ return x
268
+
269
+ def check_shapes(
270
+ self, xs: Tuple[torch.Tensor, ...], modes: Tuple[ScatterMode, ...]
271
+ ) -> Tuple[torch.Tensor, ...]:
272
+ return tuple(
273
+ [self.check_shape(x, mode) for x, mode in zip(xs, modes, strict=True)]
274
+ )
275
+
276
+
277
+ def _communicate_simple(
278
+ hidden_states: torch.Tensor,
279
+ forward_batch: ForwardBatch,
280
+ input_mode: ScatterMode,
281
+ output_mode: ScatterMode,
282
+ context: _Context,
283
+ ) -> torch.Tensor:
284
+ def _inner():
285
+ nonlocal hidden_states
286
+
287
+ if context.is_same_group_size(input_mode, output_mode):
288
+ return hidden_states
289
+
290
+ if (input_mode == ScatterMode.SCATTERED) and (
291
+ output_mode == ScatterMode.TP_ATTN_FULL
292
+ ):
293
+ hidden_states, local_hidden_states = (
294
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
295
+ hidden_states,
296
+ )
297
+ attn_tp_all_gather(
298
+ list(hidden_states.tensor_split(context.attn_tp_size)),
299
+ local_hidden_states,
300
+ )
301
+ return hidden_states
302
+
303
+ raise NotImplementedError(f"{input_mode=} {output_mode=}")
304
+
305
+ context.check_shape(hidden_states, input_mode)
306
+ return context.check_shape(_inner(), output_mode)
307
+
308
+
309
+ def _communicate_with_all_reduce_and_layer_norm(
310
+ hidden_states: torch.Tensor,
311
+ residual: torch.Tensor,
312
+ hidden_states_input_mode: ScatterMode,
313
+ residual_input_mode: ScatterMode,
314
+ hidden_states_output_mode: ScatterMode,
315
+ residual_output_mode: ScatterMode,
316
+ forward_batch: ForwardBatch,
317
+ layernorm: torch.nn.Module,
318
+ context: _Context,
319
+ ):
320
+ """Besides communication, needs to
321
+ 1. All reduce in tp_attn_group on hidden_states
322
+ 2. Apply layer norm
323
+ """
324
+
325
+ def _inner():
326
+ nonlocal hidden_states, residual
327
+
328
+ if (
329
+ context.is_same_group_size(
330
+ hidden_states_input_mode, hidden_states_output_mode
331
+ )
332
+ and context.is_same_group_size(residual_input_mode, residual_output_mode)
333
+ and context.attn_tp_size == 1
334
+ ):
335
+ # TODO move these `if shape != 0` into LayerNorm itself
336
+ if hidden_states.shape[0] != 0:
337
+ hidden_states, residual = layernorm(hidden_states, residual)
338
+ return hidden_states, residual
339
+
340
+ if (
341
+ (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
342
+ and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
343
+ and (hidden_states_output_mode == ScatterMode.FULL)
344
+ and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
345
+ ):
346
+ if context.local_attn_dp_size != 1:
347
+ if context.attn_tp_rank == 0:
348
+ hidden_states += residual
349
+ hidden_states, local_hidden_states = (
350
+ forward_batch.gathered_buffer,
351
+ hidden_states,
352
+ )
353
+ dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
354
+ dp_scatter(residual, hidden_states, forward_batch)
355
+ if hidden_states.shape[0] != 0:
356
+ hidden_states = layernorm(hidden_states)
357
+ else:
358
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
359
+ hidden_states, residual = layernorm(hidden_states, residual)
360
+ return hidden_states, residual
361
+
362
+ if (
363
+ (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
364
+ and (
365
+ residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
366
+ )
367
+ and (hidden_states_output_mode == ScatterMode.SCATTERED)
368
+ and (residual_output_mode == ScatterMode.SCATTERED)
369
+ ):
370
+ tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
371
+ hidden_states = tensor_list[context.attn_tp_rank]
372
+ attn_tp_reduce_scatter(hidden_states, tensor_list)
373
+ if residual_input_mode == ScatterMode.TP_ATTN_FULL:
374
+ residual = residual.tensor_split(context.attn_tp_size)[
375
+ context.attn_tp_rank
376
+ ]
377
+ if hidden_states.shape[0] != 0:
378
+ hidden_states, residual = layernorm(hidden_states, residual)
379
+ return hidden_states, residual
380
+
381
+ raise NotImplementedError(
382
+ f"{hidden_states_input_mode=} {residual_input_mode=} {residual_output_mode=} {residual_output_mode=}"
383
+ )
384
+
385
+ context.check_shapes(
386
+ (hidden_states, residual), (hidden_states_input_mode, residual_input_mode)
387
+ )
388
+ return context.check_shapes(
389
+ _inner(), (hidden_states_output_mode, residual_output_mode)
390
+ )
391
+
392
+
393
+ def _communicate_summable_tensor_pair(
394
+ hidden_states: torch.Tensor,
395
+ residual: torch.Tensor,
396
+ forward_batch: ForwardBatch,
397
+ hidden_states_input_mode: ScatterMode,
398
+ residual_input_mode: ScatterMode,
399
+ output_mode: ScatterMode,
400
+ context: _Context,
401
+ ):
402
+ """It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""
403
+
404
+ def _inner():
405
+ nonlocal hidden_states, residual
406
+
407
+ if context.is_same_group_size(
408
+ hidden_states_input_mode, output_mode
409
+ ) and context.is_same_group_size(residual_input_mode, output_mode):
410
+ return hidden_states, residual
411
+
412
+ if (
413
+ (hidden_states_input_mode == ScatterMode.FULL)
414
+ and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
415
+ and (output_mode == ScatterMode.TP_ATTN_FULL)
416
+ ):
417
+ # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
418
+ # important: forward batch.gathered_buffer is used both after scatter and after gather.
419
+ # be careful about this!
420
+ hidden_states, global_hidden_states = (
421
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
422
+ hidden_states,
423
+ )
424
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
425
+ return hidden_states, residual
426
+
427
+ if (
428
+ (hidden_states_input_mode == ScatterMode.SCATTERED)
429
+ and (residual_input_mode == ScatterMode.SCATTERED)
430
+ and (output_mode == ScatterMode.TP_ATTN_FULL)
431
+ ):
432
+ hidden_states += residual
433
+ residual = None
434
+ hidden_states, local_hidden_states = (
435
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
436
+ hidden_states,
437
+ )
438
+ attn_tp_all_gather(
439
+ list(hidden_states.tensor_split(context.attn_tp_size)),
440
+ local_hidden_states,
441
+ )
442
+ return hidden_states, residual
443
+
444
+ raise NotImplementedError(
445
+ f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
446
+ )
447
+
448
+ context.check_shapes(
449
+ (hidden_states, residual), (hidden_states_input_mode, residual_input_mode)
450
+ )
451
+ return context.check_shapes(_inner(), (output_mode, output_mode))
@@ -24,8 +24,10 @@ if TYPE_CHECKING:
24
24
  _ATTN_TP_GROUP = None
25
25
  _ATTN_TP_RANK = None
26
26
  _ATTN_TP_SIZE = None
27
- _DP_RANK = None
28
- _DP_SIZE = None
27
+ _ATTN_DP_RANK = None
28
+ _ATTN_DP_SIZE = None
29
+ _LOCAL_ATTN_DP_SIZE = None
30
+ _LOCAL_ATTN_DP_RANK = None
29
31
 
30
32
 
31
33
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
@@ -33,9 +35,27 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
33
35
  return tp_rank, tp_size, 0
34
36
 
35
37
  attn_tp_size = tp_size // dp_size
36
- dp_rank = tp_rank // attn_tp_size
38
+ attn_dp_rank = tp_rank // attn_tp_size
37
39
  attn_tp_rank = tp_rank % attn_tp_size
38
- return attn_tp_rank, attn_tp_size, dp_rank
40
+
41
+ return attn_tp_rank, attn_tp_size, attn_dp_rank
42
+
43
+
44
+ def compute_dp_attention_local_info(
45
+ enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
46
+ ):
47
+ if not enable_dp_attention:
48
+ return tp_rank, tp_size, 0
49
+
50
+ local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size
51
+ local_tp_rank = tp_rank % local_tp_size
52
+ local_dp_size = max(1, dp_size // (tp_size // local_tp_size))
53
+
54
+ local_attn_tp_size = local_tp_size // local_dp_size
55
+ local_attn_dp_rank = local_tp_rank // local_attn_tp_size
56
+ local_attn_tp_rank = local_tp_rank % local_attn_tp_size
57
+
58
+ return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank
39
59
 
40
60
 
41
61
  def initialize_dp_attention(
@@ -43,22 +63,32 @@ def initialize_dp_attention(
43
63
  tp_rank: int,
44
64
  tp_size: int,
45
65
  dp_size: int,
66
+ moe_dense_tp_size: int,
46
67
  pp_size: int,
47
68
  ):
48
- global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
69
+ global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
70
+ global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
49
71
 
50
72
  from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
51
73
 
52
- _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
74
+ _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
53
75
  enable_dp_attention, tp_rank, tp_size, dp_size
54
76
  )
77
+ _, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
78
+ enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
79
+ )
55
80
 
56
81
  if enable_dp_attention:
57
82
  local_rank = tp_rank % (tp_size // dp_size)
58
- _DP_SIZE = dp_size
83
+ _ATTN_DP_SIZE = dp_size
84
+ if moe_dense_tp_size is None:
85
+ _LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
86
+ else:
87
+ _LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
59
88
  else:
60
89
  local_rank = tp_rank
61
- _DP_SIZE = 1
90
+ _ATTN_DP_SIZE = 1
91
+ _LOCAL_ATTN_DP_SIZE = 1
62
92
 
63
93
  tp_group = get_tp_group()
64
94
  _ATTN_TP_GROUP = GroupCoordinator(
@@ -93,13 +123,23 @@ def get_attention_tp_size():
93
123
 
94
124
 
95
125
  def get_attention_dp_rank():
96
- assert _DP_RANK is not None, "dp attention not initialized!"
97
- return _DP_RANK
126
+ assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
127
+ return _ATTN_DP_RANK
98
128
 
99
129
 
100
130
  def get_attention_dp_size():
101
- assert _DP_SIZE is not None, "dp attention not initialized!"
102
- return _DP_SIZE
131
+ assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
132
+ return _ATTN_DP_SIZE
133
+
134
+
135
+ def get_local_attention_dp_rank():
136
+ assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
137
+ return _LOCAL_ATTN_DP_RANK
138
+
139
+
140
+ def get_local_attention_dp_size():
141
+ assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
142
+ return _LOCAL_ATTN_DP_SIZE
103
143
 
104
144
 
105
145
  @contextmanager
@@ -112,19 +152,19 @@ def disable_dp_size():
112
152
  Args:
113
153
  tp_group (GroupCoordinator): the tp group coordinator
114
154
  """
115
- global _DP_SIZE
116
- assert _DP_SIZE is not None, "dp attention not initialized!"
155
+ global _ATTN_DP_SIZE
156
+ assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
117
157
 
118
- old_dp_size = _DP_SIZE
119
- _DP_SIZE = 1
158
+ old_dp_size = _ATTN_DP_SIZE
159
+ _ATTN_DP_SIZE = 1
120
160
  try:
121
161
  yield
122
162
  finally:
123
- _DP_SIZE = old_dp_size
163
+ _ATTN_DP_SIZE = old_dp_size
124
164
 
125
165
 
126
166
  def get_dp_local_info(forward_batch: ForwardBatch):
127
- dp_rank = get_attention_dp_rank()
167
+ dp_rank = get_local_attention_dp_rank()
128
168
 
129
169
  if forward_batch.dp_local_start_pos is None:
130
170
  cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
@@ -201,7 +241,7 @@ def _dp_gather(
201
241
  global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
202
242
  )
203
243
 
204
- # Input IDs are in int 32. We should use inplace_all_reduce for local case becaues of custom all reduce.
244
+ # Input IDs are in int 32. We should use inplace_all_reduce for local case because of custom all reduce.
205
245
  NUM_GPUS_PER_NODE = 8
206
246
  if (
207
247
  not local_tokens.dtype.is_floating_point
@@ -252,12 +292,12 @@ def dp_scatter(
252
292
  )
253
293
 
254
294
 
255
- def tp_reduce_scatter(
295
+ def attn_tp_reduce_scatter(
256
296
  output: torch.Tensor,
257
297
  input_list: List[torch.Tensor],
258
298
  ):
259
299
  return get_attention_tp_group().reduce_scatter(output, input_list)
260
300
 
261
301
 
262
- def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
302
+ def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
263
303
  return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
@@ -76,7 +76,7 @@ class RMSNorm(CustomOp):
76
76
  residual: Optional[torch.Tensor] = None,
77
77
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
78
78
  if not x.is_contiguous():
79
- # NOTE: Romove this if aiter kernel supports discontinuous input
79
+ # NOTE: Remove this if aiter kernel supports discontinuous input
80
80
  x = x.contiguous()
81
81
  if residual is not None:
82
82
  fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)