sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,451 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from enum import Enum, auto
|
17
|
+
from typing import Dict, Optional, Tuple
|
18
|
+
|
19
|
+
import torch.distributed
|
20
|
+
|
21
|
+
from sglang.srt.distributed import (
|
22
|
+
get_tensor_model_parallel_world_size,
|
23
|
+
tensor_model_parallel_all_reduce,
|
24
|
+
)
|
25
|
+
from sglang.srt.layers.dp_attention import (
|
26
|
+
attn_tp_all_gather,
|
27
|
+
attn_tp_reduce_scatter,
|
28
|
+
dp_gather_partial,
|
29
|
+
dp_scatter,
|
30
|
+
get_attention_tp_rank,
|
31
|
+
get_attention_tp_size,
|
32
|
+
get_local_attention_dp_size,
|
33
|
+
)
|
34
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
35
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
36
|
+
|
37
|
+
|
38
|
+
class ScatterMode(Enum):
|
39
|
+
SCATTERED = auto()
|
40
|
+
TP_ATTN_FULL = auto()
|
41
|
+
FULL = auto()
|
42
|
+
|
43
|
+
|
44
|
+
@dataclass
|
45
|
+
class _LayerModeComputationContext:
|
46
|
+
num_layers: int
|
47
|
+
layer_id: int
|
48
|
+
is_layer_sparse: bool
|
49
|
+
is_previous_layer_sparse: Optional[bool]
|
50
|
+
|
51
|
+
def previous_layer(self):
|
52
|
+
assert self.is_previous_layer_sparse is not None
|
53
|
+
return _LayerModeComputationContext(
|
54
|
+
layer_id=self.layer_id - 1,
|
55
|
+
is_layer_sparse=self.is_previous_layer_sparse,
|
56
|
+
is_previous_layer_sparse=None,
|
57
|
+
num_layers=self.num_layers,
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
@dataclass
|
62
|
+
class LayerScatterModes:
|
63
|
+
layer_input_mode: ScatterMode
|
64
|
+
attn_mode: ScatterMode
|
65
|
+
# Can be further split into e.g. mlp_input_mode and mlp_output_mode if needed
|
66
|
+
mlp_mode: ScatterMode
|
67
|
+
middle_residual_mode: ScatterMode
|
68
|
+
layer_output_mode: ScatterMode
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
def init_new(cls, **kwargs):
|
72
|
+
context = _LayerModeComputationContext(**kwargs)
|
73
|
+
return cls(
|
74
|
+
layer_input_mode=cls._compute_layer_input_mode(context),
|
75
|
+
attn_mode=ScatterMode.TP_ATTN_FULL,
|
76
|
+
mlp_mode=cls._compute_mlp_mode(context),
|
77
|
+
middle_residual_mode=cls._compute_middle_residual_mode(context),
|
78
|
+
layer_output_mode=cls._compute_layer_output_mode(context),
|
79
|
+
)
|
80
|
+
|
81
|
+
@classmethod
|
82
|
+
def _compute_layer_input_mode(cls, context: _LayerModeComputationContext):
|
83
|
+
if context.layer_id == 0:
|
84
|
+
return ScatterMode.TP_ATTN_FULL
|
85
|
+
return cls._compute_layer_output_mode(context.previous_layer())
|
86
|
+
|
87
|
+
@classmethod
|
88
|
+
def _compute_mlp_mode(cls, context: _LayerModeComputationContext):
|
89
|
+
if context.is_layer_sparse:
|
90
|
+
return (
|
91
|
+
ScatterMode.SCATTERED
|
92
|
+
if global_server_args_dict["enable_deepep_moe"]
|
93
|
+
else ScatterMode.FULL
|
94
|
+
)
|
95
|
+
else:
|
96
|
+
return (
|
97
|
+
ScatterMode.SCATTERED
|
98
|
+
if enable_moe_dense_fully_dp()
|
99
|
+
else ScatterMode.FULL
|
100
|
+
)
|
101
|
+
|
102
|
+
@classmethod
|
103
|
+
def _compute_middle_residual_mode(cls, context: _LayerModeComputationContext):
|
104
|
+
mlp_mode = cls._compute_mlp_mode(context)
|
105
|
+
if mlp_mode == ScatterMode.SCATTERED:
|
106
|
+
return ScatterMode.SCATTERED
|
107
|
+
if mlp_mode == ScatterMode.FULL:
|
108
|
+
return ScatterMode.TP_ATTN_FULL
|
109
|
+
raise NotImplementedError
|
110
|
+
|
111
|
+
@classmethod
|
112
|
+
def _compute_layer_output_mode(cls, context: _LayerModeComputationContext):
|
113
|
+
mlp_mode = cls._compute_mlp_mode(context)
|
114
|
+
if context.layer_id == context.num_layers - 1:
|
115
|
+
return ScatterMode.TP_ATTN_FULL
|
116
|
+
if mlp_mode == ScatterMode.SCATTERED:
|
117
|
+
return ScatterMode.SCATTERED
|
118
|
+
if mlp_mode == ScatterMode.FULL:
|
119
|
+
return ScatterMode.TP_ATTN_FULL
|
120
|
+
raise NotImplementedError
|
121
|
+
|
122
|
+
|
123
|
+
def enable_moe_dense_fully_dp():
|
124
|
+
return global_server_args_dict["moe_dense_tp_size"] == 1
|
125
|
+
|
126
|
+
|
127
|
+
class LayerCommunicator:
|
128
|
+
def __init__(
|
129
|
+
self,
|
130
|
+
layer_scatter_modes: LayerScatterModes,
|
131
|
+
input_layernorm: torch.nn.Module,
|
132
|
+
post_attention_layernorm: torch.nn.Module,
|
133
|
+
):
|
134
|
+
self.layer_scatter_modes = layer_scatter_modes
|
135
|
+
self.input_layernorm = input_layernorm
|
136
|
+
self.post_attention_layernorm = post_attention_layernorm
|
137
|
+
|
138
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
139
|
+
self.attn_tp_size = get_attention_tp_size()
|
140
|
+
self.local_attn_dp_size = get_local_attention_dp_size()
|
141
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
142
|
+
self.process_group_sizes = {
|
143
|
+
ScatterMode.SCATTERED: 1,
|
144
|
+
ScatterMode.TP_ATTN_FULL: self.attn_tp_size,
|
145
|
+
ScatterMode.FULL: self.tp_size,
|
146
|
+
}
|
147
|
+
|
148
|
+
def prepare_attn(
|
149
|
+
self,
|
150
|
+
hidden_states: torch.Tensor,
|
151
|
+
residual: torch.Tensor,
|
152
|
+
forward_batch: ForwardBatch,
|
153
|
+
):
|
154
|
+
if hidden_states.shape[0] == 0:
|
155
|
+
residual = hidden_states
|
156
|
+
else:
|
157
|
+
if residual is None:
|
158
|
+
residual = hidden_states
|
159
|
+
hidden_states = self.input_layernorm(hidden_states)
|
160
|
+
else:
|
161
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
162
|
+
|
163
|
+
hidden_states = _communicate_simple(
|
164
|
+
hidden_states=hidden_states,
|
165
|
+
forward_batch=forward_batch,
|
166
|
+
input_mode=self.layer_scatter_modes.layer_input_mode,
|
167
|
+
output_mode=self.layer_scatter_modes.attn_mode,
|
168
|
+
context=self._compute_context(forward_batch),
|
169
|
+
)
|
170
|
+
|
171
|
+
return hidden_states, residual
|
172
|
+
|
173
|
+
def prepare_mlp(
|
174
|
+
self,
|
175
|
+
hidden_states: torch.Tensor,
|
176
|
+
residual: torch.Tensor,
|
177
|
+
forward_batch: ForwardBatch,
|
178
|
+
):
|
179
|
+
return _communicate_with_all_reduce_and_layer_norm(
|
180
|
+
hidden_states=hidden_states,
|
181
|
+
residual=residual,
|
182
|
+
forward_batch=forward_batch,
|
183
|
+
hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
|
184
|
+
residual_input_mode=self.layer_scatter_modes.layer_input_mode,
|
185
|
+
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
|
186
|
+
residual_output_mode=self.layer_scatter_modes.middle_residual_mode,
|
187
|
+
layernorm=self.post_attention_layernorm,
|
188
|
+
context=self._compute_context(forward_batch),
|
189
|
+
)
|
190
|
+
|
191
|
+
def postprocess_layer(
|
192
|
+
self,
|
193
|
+
hidden_states: torch.Tensor,
|
194
|
+
residual: torch.Tensor,
|
195
|
+
forward_batch: ForwardBatch,
|
196
|
+
):
|
197
|
+
return _communicate_summable_tensor_pair(
|
198
|
+
hidden_states=hidden_states,
|
199
|
+
residual=residual,
|
200
|
+
forward_batch=forward_batch,
|
201
|
+
hidden_states_input_mode=self.layer_scatter_modes.mlp_mode,
|
202
|
+
residual_input_mode=self.layer_scatter_modes.middle_residual_mode,
|
203
|
+
output_mode=self.layer_scatter_modes.layer_output_mode,
|
204
|
+
context=self._compute_context(forward_batch),
|
205
|
+
)
|
206
|
+
|
207
|
+
def _compute_context(self, forward_batch: ForwardBatch):
|
208
|
+
return _Context(
|
209
|
+
num_tokens_of_mode=_compute_num_tokens_of_mode(
|
210
|
+
forward_batch,
|
211
|
+
attn_tp_rank=self.attn_tp_rank,
|
212
|
+
attn_tp_size=self.attn_tp_size,
|
213
|
+
),
|
214
|
+
process_group_sizes=self.process_group_sizes,
|
215
|
+
attn_tp_rank=self.attn_tp_rank,
|
216
|
+
attn_tp_size=self.attn_tp_size,
|
217
|
+
local_attn_dp_size=self.local_attn_dp_size,
|
218
|
+
tp_size=self.tp_size,
|
219
|
+
)
|
220
|
+
|
221
|
+
|
222
|
+
def _compute_num_tokens_of_mode(
|
223
|
+
forward_batch: ForwardBatch, attn_tp_rank: int, attn_tp_size: int
|
224
|
+
):
|
225
|
+
tp_attn_full_num_tokens = forward_batch.input_ids.shape[0]
|
226
|
+
return {
|
227
|
+
ScatterMode.SCATTERED: _torch_tensor_split_len(
|
228
|
+
tp_attn_full_num_tokens, attn_tp_size, attn_tp_rank
|
229
|
+
),
|
230
|
+
ScatterMode.TP_ATTN_FULL: tp_attn_full_num_tokens,
|
231
|
+
ScatterMode.FULL: (
|
232
|
+
forward_batch.gathered_buffer.shape[0]
|
233
|
+
if global_server_args_dict["enable_dp_attention"]
|
234
|
+
else forward_batch.input_ids.shape[0]
|
235
|
+
),
|
236
|
+
}
|
237
|
+
|
238
|
+
|
239
|
+
def _torch_tensor_split_len(tensor_len: int, n: int, output_index: int):
|
240
|
+
if output_index < int(tensor_len % n):
|
241
|
+
return int(tensor_len / n) + 1
|
242
|
+
else:
|
243
|
+
return int(tensor_len / n)
|
244
|
+
|
245
|
+
|
246
|
+
@dataclass
|
247
|
+
class _Context:
|
248
|
+
num_tokens_of_mode: Dict["ScatterMode", int]
|
249
|
+
process_group_sizes: Dict["ScatterMode", int]
|
250
|
+
attn_tp_rank: int
|
251
|
+
attn_tp_size: int
|
252
|
+
local_attn_dp_size: int
|
253
|
+
tp_size: int
|
254
|
+
|
255
|
+
def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"):
|
256
|
+
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
257
|
+
|
258
|
+
def check_shape(self, x: torch.Tensor, mode: ScatterMode):
|
259
|
+
if x is None:
|
260
|
+
return
|
261
|
+
|
262
|
+
actual_num_tokens = x.shape[0]
|
263
|
+
expect_num_tokens = self.num_tokens_of_mode[mode]
|
264
|
+
assert (
|
265
|
+
actual_num_tokens == expect_num_tokens
|
266
|
+
), f"{actual_num_tokens=} {expect_num_tokens=} {mode=} {x.shape=} {self.num_tokens_of_mode=} {self.process_group_sizes=}"
|
267
|
+
return x
|
268
|
+
|
269
|
+
def check_shapes(
|
270
|
+
self, xs: Tuple[torch.Tensor, ...], modes: Tuple[ScatterMode, ...]
|
271
|
+
) -> Tuple[torch.Tensor, ...]:
|
272
|
+
return tuple(
|
273
|
+
[self.check_shape(x, mode) for x, mode in zip(xs, modes, strict=True)]
|
274
|
+
)
|
275
|
+
|
276
|
+
|
277
|
+
def _communicate_simple(
|
278
|
+
hidden_states: torch.Tensor,
|
279
|
+
forward_batch: ForwardBatch,
|
280
|
+
input_mode: ScatterMode,
|
281
|
+
output_mode: ScatterMode,
|
282
|
+
context: _Context,
|
283
|
+
) -> torch.Tensor:
|
284
|
+
def _inner():
|
285
|
+
nonlocal hidden_states
|
286
|
+
|
287
|
+
if context.is_same_group_size(input_mode, output_mode):
|
288
|
+
return hidden_states
|
289
|
+
|
290
|
+
if (input_mode == ScatterMode.SCATTERED) and (
|
291
|
+
output_mode == ScatterMode.TP_ATTN_FULL
|
292
|
+
):
|
293
|
+
hidden_states, local_hidden_states = (
|
294
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
295
|
+
hidden_states,
|
296
|
+
)
|
297
|
+
attn_tp_all_gather(
|
298
|
+
list(hidden_states.tensor_split(context.attn_tp_size)),
|
299
|
+
local_hidden_states,
|
300
|
+
)
|
301
|
+
return hidden_states
|
302
|
+
|
303
|
+
raise NotImplementedError(f"{input_mode=} {output_mode=}")
|
304
|
+
|
305
|
+
context.check_shape(hidden_states, input_mode)
|
306
|
+
return context.check_shape(_inner(), output_mode)
|
307
|
+
|
308
|
+
|
309
|
+
def _communicate_with_all_reduce_and_layer_norm(
|
310
|
+
hidden_states: torch.Tensor,
|
311
|
+
residual: torch.Tensor,
|
312
|
+
hidden_states_input_mode: ScatterMode,
|
313
|
+
residual_input_mode: ScatterMode,
|
314
|
+
hidden_states_output_mode: ScatterMode,
|
315
|
+
residual_output_mode: ScatterMode,
|
316
|
+
forward_batch: ForwardBatch,
|
317
|
+
layernorm: torch.nn.Module,
|
318
|
+
context: _Context,
|
319
|
+
):
|
320
|
+
"""Besides communication, needs to
|
321
|
+
1. All reduce in tp_attn_group on hidden_states
|
322
|
+
2. Apply layer norm
|
323
|
+
"""
|
324
|
+
|
325
|
+
def _inner():
|
326
|
+
nonlocal hidden_states, residual
|
327
|
+
|
328
|
+
if (
|
329
|
+
context.is_same_group_size(
|
330
|
+
hidden_states_input_mode, hidden_states_output_mode
|
331
|
+
)
|
332
|
+
and context.is_same_group_size(residual_input_mode, residual_output_mode)
|
333
|
+
and context.attn_tp_size == 1
|
334
|
+
):
|
335
|
+
# TODO move these `if shape != 0` into LayerNorm itself
|
336
|
+
if hidden_states.shape[0] != 0:
|
337
|
+
hidden_states, residual = layernorm(hidden_states, residual)
|
338
|
+
return hidden_states, residual
|
339
|
+
|
340
|
+
if (
|
341
|
+
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
342
|
+
and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
|
343
|
+
and (hidden_states_output_mode == ScatterMode.FULL)
|
344
|
+
and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
|
345
|
+
):
|
346
|
+
if context.local_attn_dp_size != 1:
|
347
|
+
if context.attn_tp_rank == 0:
|
348
|
+
hidden_states += residual
|
349
|
+
hidden_states, local_hidden_states = (
|
350
|
+
forward_batch.gathered_buffer,
|
351
|
+
hidden_states,
|
352
|
+
)
|
353
|
+
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
354
|
+
dp_scatter(residual, hidden_states, forward_batch)
|
355
|
+
if hidden_states.shape[0] != 0:
|
356
|
+
hidden_states = layernorm(hidden_states)
|
357
|
+
else:
|
358
|
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
359
|
+
hidden_states, residual = layernorm(hidden_states, residual)
|
360
|
+
return hidden_states, residual
|
361
|
+
|
362
|
+
if (
|
363
|
+
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
364
|
+
and (
|
365
|
+
residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
|
366
|
+
)
|
367
|
+
and (hidden_states_output_mode == ScatterMode.SCATTERED)
|
368
|
+
and (residual_output_mode == ScatterMode.SCATTERED)
|
369
|
+
):
|
370
|
+
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
|
371
|
+
hidden_states = tensor_list[context.attn_tp_rank]
|
372
|
+
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
373
|
+
if residual_input_mode == ScatterMode.TP_ATTN_FULL:
|
374
|
+
residual = residual.tensor_split(context.attn_tp_size)[
|
375
|
+
context.attn_tp_rank
|
376
|
+
]
|
377
|
+
if hidden_states.shape[0] != 0:
|
378
|
+
hidden_states, residual = layernorm(hidden_states, residual)
|
379
|
+
return hidden_states, residual
|
380
|
+
|
381
|
+
raise NotImplementedError(
|
382
|
+
f"{hidden_states_input_mode=} {residual_input_mode=} {residual_output_mode=} {residual_output_mode=}"
|
383
|
+
)
|
384
|
+
|
385
|
+
context.check_shapes(
|
386
|
+
(hidden_states, residual), (hidden_states_input_mode, residual_input_mode)
|
387
|
+
)
|
388
|
+
return context.check_shapes(
|
389
|
+
_inner(), (hidden_states_output_mode, residual_output_mode)
|
390
|
+
)
|
391
|
+
|
392
|
+
|
393
|
+
def _communicate_summable_tensor_pair(
|
394
|
+
hidden_states: torch.Tensor,
|
395
|
+
residual: torch.Tensor,
|
396
|
+
forward_batch: ForwardBatch,
|
397
|
+
hidden_states_input_mode: ScatterMode,
|
398
|
+
residual_input_mode: ScatterMode,
|
399
|
+
output_mode: ScatterMode,
|
400
|
+
context: _Context,
|
401
|
+
):
|
402
|
+
"""It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""
|
403
|
+
|
404
|
+
def _inner():
|
405
|
+
nonlocal hidden_states, residual
|
406
|
+
|
407
|
+
if context.is_same_group_size(
|
408
|
+
hidden_states_input_mode, output_mode
|
409
|
+
) and context.is_same_group_size(residual_input_mode, output_mode):
|
410
|
+
return hidden_states, residual
|
411
|
+
|
412
|
+
if (
|
413
|
+
(hidden_states_input_mode == ScatterMode.FULL)
|
414
|
+
and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
|
415
|
+
and (output_mode == ScatterMode.TP_ATTN_FULL)
|
416
|
+
):
|
417
|
+
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
418
|
+
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
419
|
+
# be careful about this!
|
420
|
+
hidden_states, global_hidden_states = (
|
421
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
422
|
+
hidden_states,
|
423
|
+
)
|
424
|
+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
425
|
+
return hidden_states, residual
|
426
|
+
|
427
|
+
if (
|
428
|
+
(hidden_states_input_mode == ScatterMode.SCATTERED)
|
429
|
+
and (residual_input_mode == ScatterMode.SCATTERED)
|
430
|
+
and (output_mode == ScatterMode.TP_ATTN_FULL)
|
431
|
+
):
|
432
|
+
hidden_states += residual
|
433
|
+
residual = None
|
434
|
+
hidden_states, local_hidden_states = (
|
435
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
436
|
+
hidden_states,
|
437
|
+
)
|
438
|
+
attn_tp_all_gather(
|
439
|
+
list(hidden_states.tensor_split(context.attn_tp_size)),
|
440
|
+
local_hidden_states,
|
441
|
+
)
|
442
|
+
return hidden_states, residual
|
443
|
+
|
444
|
+
raise NotImplementedError(
|
445
|
+
f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
|
446
|
+
)
|
447
|
+
|
448
|
+
context.check_shapes(
|
449
|
+
(hidden_states, residual), (hidden_states_input_mode, residual_input_mode)
|
450
|
+
)
|
451
|
+
return context.check_shapes(_inner(), (output_mode, output_mode))
|
@@ -24,8 +24,10 @@ if TYPE_CHECKING:
|
|
24
24
|
_ATTN_TP_GROUP = None
|
25
25
|
_ATTN_TP_RANK = None
|
26
26
|
_ATTN_TP_SIZE = None
|
27
|
-
|
28
|
-
|
27
|
+
_ATTN_DP_RANK = None
|
28
|
+
_ATTN_DP_SIZE = None
|
29
|
+
_LOCAL_ATTN_DP_SIZE = None
|
30
|
+
_LOCAL_ATTN_DP_RANK = None
|
29
31
|
|
30
32
|
|
31
33
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
@@ -33,9 +35,27 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
|
|
33
35
|
return tp_rank, tp_size, 0
|
34
36
|
|
35
37
|
attn_tp_size = tp_size // dp_size
|
36
|
-
|
38
|
+
attn_dp_rank = tp_rank // attn_tp_size
|
37
39
|
attn_tp_rank = tp_rank % attn_tp_size
|
38
|
-
|
40
|
+
|
41
|
+
return attn_tp_rank, attn_tp_size, attn_dp_rank
|
42
|
+
|
43
|
+
|
44
|
+
def compute_dp_attention_local_info(
|
45
|
+
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
|
46
|
+
):
|
47
|
+
if not enable_dp_attention:
|
48
|
+
return tp_rank, tp_size, 0
|
49
|
+
|
50
|
+
local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size
|
51
|
+
local_tp_rank = tp_rank % local_tp_size
|
52
|
+
local_dp_size = max(1, dp_size // (tp_size // local_tp_size))
|
53
|
+
|
54
|
+
local_attn_tp_size = local_tp_size // local_dp_size
|
55
|
+
local_attn_dp_rank = local_tp_rank // local_attn_tp_size
|
56
|
+
local_attn_tp_rank = local_tp_rank % local_attn_tp_size
|
57
|
+
|
58
|
+
return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank
|
39
59
|
|
40
60
|
|
41
61
|
def initialize_dp_attention(
|
@@ -43,22 +63,32 @@ def initialize_dp_attention(
|
|
43
63
|
tp_rank: int,
|
44
64
|
tp_size: int,
|
45
65
|
dp_size: int,
|
66
|
+
moe_dense_tp_size: int,
|
46
67
|
pp_size: int,
|
47
68
|
):
|
48
|
-
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE,
|
69
|
+
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
|
70
|
+
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
|
49
71
|
|
50
72
|
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
51
73
|
|
52
|
-
_ATTN_TP_RANK, _ATTN_TP_SIZE,
|
74
|
+
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
|
53
75
|
enable_dp_attention, tp_rank, tp_size, dp_size
|
54
76
|
)
|
77
|
+
_, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
|
78
|
+
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
|
79
|
+
)
|
55
80
|
|
56
81
|
if enable_dp_attention:
|
57
82
|
local_rank = tp_rank % (tp_size // dp_size)
|
58
|
-
|
83
|
+
_ATTN_DP_SIZE = dp_size
|
84
|
+
if moe_dense_tp_size is None:
|
85
|
+
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
|
86
|
+
else:
|
87
|
+
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
|
59
88
|
else:
|
60
89
|
local_rank = tp_rank
|
61
|
-
|
90
|
+
_ATTN_DP_SIZE = 1
|
91
|
+
_LOCAL_ATTN_DP_SIZE = 1
|
62
92
|
|
63
93
|
tp_group = get_tp_group()
|
64
94
|
_ATTN_TP_GROUP = GroupCoordinator(
|
@@ -93,13 +123,23 @@ def get_attention_tp_size():
|
|
93
123
|
|
94
124
|
|
95
125
|
def get_attention_dp_rank():
|
96
|
-
assert
|
97
|
-
return
|
126
|
+
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
|
127
|
+
return _ATTN_DP_RANK
|
98
128
|
|
99
129
|
|
100
130
|
def get_attention_dp_size():
|
101
|
-
assert
|
102
|
-
return
|
131
|
+
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
132
|
+
return _ATTN_DP_SIZE
|
133
|
+
|
134
|
+
|
135
|
+
def get_local_attention_dp_rank():
|
136
|
+
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
137
|
+
return _LOCAL_ATTN_DP_RANK
|
138
|
+
|
139
|
+
|
140
|
+
def get_local_attention_dp_size():
|
141
|
+
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
142
|
+
return _LOCAL_ATTN_DP_SIZE
|
103
143
|
|
104
144
|
|
105
145
|
@contextmanager
|
@@ -112,19 +152,19 @@ def disable_dp_size():
|
|
112
152
|
Args:
|
113
153
|
tp_group (GroupCoordinator): the tp group coordinator
|
114
154
|
"""
|
115
|
-
global
|
116
|
-
assert
|
155
|
+
global _ATTN_DP_SIZE
|
156
|
+
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
117
157
|
|
118
|
-
old_dp_size =
|
119
|
-
|
158
|
+
old_dp_size = _ATTN_DP_SIZE
|
159
|
+
_ATTN_DP_SIZE = 1
|
120
160
|
try:
|
121
161
|
yield
|
122
162
|
finally:
|
123
|
-
|
163
|
+
_ATTN_DP_SIZE = old_dp_size
|
124
164
|
|
125
165
|
|
126
166
|
def get_dp_local_info(forward_batch: ForwardBatch):
|
127
|
-
dp_rank =
|
167
|
+
dp_rank = get_local_attention_dp_rank()
|
128
168
|
|
129
169
|
if forward_batch.dp_local_start_pos is None:
|
130
170
|
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
@@ -201,7 +241,7 @@ def _dp_gather(
|
|
201
241
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
202
242
|
)
|
203
243
|
|
204
|
-
# Input IDs are in int 32. We should use inplace_all_reduce for local case
|
244
|
+
# Input IDs are in int 32. We should use inplace_all_reduce for local case because of custom all reduce.
|
205
245
|
NUM_GPUS_PER_NODE = 8
|
206
246
|
if (
|
207
247
|
not local_tokens.dtype.is_floating_point
|
@@ -252,12 +292,12 @@ def dp_scatter(
|
|
252
292
|
)
|
253
293
|
|
254
294
|
|
255
|
-
def
|
295
|
+
def attn_tp_reduce_scatter(
|
256
296
|
output: torch.Tensor,
|
257
297
|
input_list: List[torch.Tensor],
|
258
298
|
):
|
259
299
|
return get_attention_tp_group().reduce_scatter(output, input_list)
|
260
300
|
|
261
301
|
|
262
|
-
def
|
302
|
+
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
263
303
|
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -76,7 +76,7 @@ class RMSNorm(CustomOp):
|
|
76
76
|
residual: Optional[torch.Tensor] = None,
|
77
77
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
78
78
|
if not x.is_contiguous():
|
79
|
-
# NOTE:
|
79
|
+
# NOTE: Remove this if aiter kernel supports discontinuous input
|
80
80
|
x = x.contiguous()
|
81
81
|
if residual is not None:
|
82
82
|
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|