sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +49 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +2 -8
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +27 -4
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -4
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +10 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/topk.py +5 -13
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/modelopt_quant.py +8 -4
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +53 -6
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/schedule_batch.py +13 -3
- sglang/srt/managers/scheduler.py +13 -25
- sglang/srt/managers/tokenizer_manager.py +28 -25
- sglang/srt/managers/tp_worker.py +2 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +30 -16
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +41 -23
- sglang/srt/models/deepseek_v2.py +1 -2
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +0 -4
- sglang/srt/models/qwen3_moe.py +1 -6
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +76 -55
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +17 -68
- sglang/test/test_activation.py +50 -1
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +5 -5
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +75 -72
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -18,16 +18,21 @@ logger = logging.getLogger(__name__)
|
|
18
18
|
|
19
19
|
|
20
20
|
class KimiK2Detector(BaseFormatDetector):
|
21
|
+
"""
|
22
|
+
Detector for Kimi K2 model function call format.
|
23
|
+
|
24
|
+
Format Structure:
|
25
|
+
```
|
26
|
+
<|tool_calls_section_begin|>
|
27
|
+
<|tool_call_begin|>functions.{func_name}:{index} <|tool_call_argument_begin|>{json_args}<|tool_call_end|>
|
28
|
+
<|tool_calls_section_end|>
|
29
|
+
```
|
30
|
+
|
31
|
+
Reference: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md
|
32
|
+
"""
|
21
33
|
|
22
34
|
def __init__(self):
|
23
35
|
super().__init__()
|
24
|
-
self._buffer = ""
|
25
|
-
self.current_tool_name_sent: bool = False
|
26
|
-
self.prev_tool_call_arr: list[dict] = []
|
27
|
-
self.current_tool_id: int = -1
|
28
|
-
self.streamed_args_for_tool: list[str] = (
|
29
|
-
[]
|
30
|
-
) # map what has been streamed for each tool so far to a list
|
31
36
|
|
32
37
|
self.bot_token: str = "<|tool_calls_section_begin|>"
|
33
38
|
self.eot_token: str = "<|tool_calls_section_end|>"
|
@@ -114,11 +119,7 @@ class KimiK2Detector(BaseFormatDetector):
|
|
114
119
|
return StreamingParseResult(normal_text=new_text)
|
115
120
|
|
116
121
|
if not hasattr(self, "_tool_indices"):
|
117
|
-
self._tool_indices =
|
118
|
-
tool.function.name: i
|
119
|
-
for i, tool in enumerate(tools)
|
120
|
-
if tool.function and tool.function.name
|
121
|
-
}
|
122
|
+
self._tool_indices = self._get_tool_indices(tools)
|
122
123
|
|
123
124
|
calls: list[ToolCallItem] = []
|
124
125
|
try:
|
@@ -150,7 +151,7 @@ class KimiK2Detector(BaseFormatDetector):
|
|
150
151
|
)
|
151
152
|
)
|
152
153
|
self.current_tool_name_sent = True
|
153
|
-
# Store the tool call info for
|
154
|
+
# Store the tool call info for serving layer completions endpoint
|
154
155
|
self.prev_tool_call_arr[self.current_tool_id] = {
|
155
156
|
"name": function_name,
|
156
157
|
"arguments": {},
|
@@ -214,7 +215,31 @@ class KimiK2Detector(BaseFormatDetector):
|
|
214
215
|
return StreamingParseResult(normal_text=current_text)
|
215
216
|
|
216
217
|
def structure_info(self) -> _GetInfoFunc:
|
217
|
-
|
218
|
+
"""Return function that creates StructureInfo for guided generation."""
|
219
|
+
|
220
|
+
def get_info(name: str) -> StructureInfo:
|
221
|
+
return StructureInfo(
|
222
|
+
begin=f"<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:0 <|tool_call_argument_begin|>",
|
223
|
+
end="<|tool_call_end|><|tool_calls_section_end|>",
|
224
|
+
trigger="<|tool_calls_section_begin|>",
|
225
|
+
)
|
226
|
+
|
227
|
+
return get_info
|
218
228
|
|
219
|
-
def build_ebnf(self, tools: List[Tool]):
|
220
|
-
|
229
|
+
def build_ebnf(self, tools: List[Tool]) -> str:
|
230
|
+
"""
|
231
|
+
Build EBNF grammar for KimiK2 tool call format.
|
232
|
+
|
233
|
+
NOTE: The call_rule_fmt uses [0-9]+ for the function index to allow the grammar
|
234
|
+
to accept any numeric index (0, 1, 2, etc.) for proper sequential indexing in
|
235
|
+
multiple function call scenarios, while still maintaining the correct KimiK2
|
236
|
+
format structure for constrained generation.
|
237
|
+
"""
|
238
|
+
return EBNFComposer.build_ebnf(
|
239
|
+
tools,
|
240
|
+
sequence_start_token=self.bot_token,
|
241
|
+
sequence_end_token=self.eot_token,
|
242
|
+
tool_call_separator="",
|
243
|
+
call_rule_fmt='"<|tool_call_begin|>functions.{name}:" [0-9]+ " <|tool_call_argument_begin|>" {arguments_rule} "<|tool_call_end|>"',
|
244
|
+
function_format="json",
|
245
|
+
)
|
@@ -16,9 +16,12 @@ logger = logging.getLogger(__name__)
|
|
16
16
|
|
17
17
|
class Llama32Detector(BaseFormatDetector):
|
18
18
|
"""
|
19
|
-
Detector for Llama 3.2 models.
|
20
|
-
|
21
|
-
|
19
|
+
Detector for Llama 3.2 models with json tool call format.
|
20
|
+
|
21
|
+
Format Structure:
|
22
|
+
```
|
23
|
+
<python_tag>{"name":"xxx", "arguments":{...}}
|
24
|
+
```
|
22
25
|
"""
|
23
26
|
|
24
27
|
def __init__(self):
|
@@ -17,9 +17,17 @@ logger = logging.getLogger(__name__)
|
|
17
17
|
|
18
18
|
class MistralDetector(BaseFormatDetector):
|
19
19
|
"""
|
20
|
-
Detector for Mistral
|
21
|
-
|
22
|
-
|
20
|
+
Detector for Mistral model function call format.
|
21
|
+
|
22
|
+
The Mistral format uses a simple bracket-delimited structure with JSON arrays
|
23
|
+
containing function call objects.
|
24
|
+
|
25
|
+
Format Structure:
|
26
|
+
```
|
27
|
+
[TOOL_CALLS] [{"name": "function_name", "arguments": {json_args}}, ...]
|
28
|
+
```
|
29
|
+
|
30
|
+
Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3?chat_template=default
|
23
31
|
"""
|
24
32
|
|
25
33
|
def __init__(self):
|
@@ -8,7 +8,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool
|
|
8
8
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
9
9
|
from sglang.srt.function_call.core_types import (
|
10
10
|
StreamingParseResult,
|
11
|
-
StructureInfo,
|
12
11
|
ToolCallItem,
|
13
12
|
_GetInfoFunc,
|
14
13
|
)
|
@@ -19,10 +18,17 @@ logger = logging.getLogger(__name__)
|
|
19
18
|
|
20
19
|
class PythonicDetector(BaseFormatDetector):
|
21
20
|
"""
|
22
|
-
Detector for Llama-
|
23
|
-
|
24
|
-
|
25
|
-
|
21
|
+
Detector for Llama-4 models with Pythonic tool call format.
|
22
|
+
|
23
|
+
The Pythonic format uses Python function call syntax within square brackets,
|
24
|
+
with arguments as Python literals rather than JSON.
|
25
|
+
|
26
|
+
Format Structure:
|
27
|
+
```
|
28
|
+
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
29
|
+
```
|
30
|
+
|
31
|
+
Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct?chat_template=default
|
26
32
|
"""
|
27
33
|
|
28
34
|
def __init__(self):
|
@@ -75,11 +81,7 @@ class PythonicDetector(BaseFormatDetector):
|
|
75
81
|
return StreamingParseResult(normal_text=normal_text, calls=[])
|
76
82
|
|
77
83
|
calls = []
|
78
|
-
tool_indices =
|
79
|
-
tool.function.name: i
|
80
|
-
for i, tool in enumerate(tools)
|
81
|
-
if tool.function.name
|
82
|
-
}
|
84
|
+
tool_indices = self._get_tool_indices(tools)
|
83
85
|
for call_index, call in enumerate(parsed.elts):
|
84
86
|
if not isinstance(call.func, ast.Name):
|
85
87
|
continue
|
@@ -213,11 +215,11 @@ class PythonicDetector(BaseFormatDetector):
|
|
213
215
|
else:
|
214
216
|
raise ValueError("Tool call arguments must be literals")
|
215
217
|
|
216
|
-
def
|
217
|
-
|
218
|
-
return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(")
|
218
|
+
def supports_structural_tag(self) -> bool:
|
219
|
+
return False
|
219
220
|
|
220
|
-
|
221
|
+
def structure_info(self) -> _GetInfoFunc:
|
222
|
+
raise NotImplementedError
|
221
223
|
|
222
224
|
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
|
223
225
|
return EBNFComposer.build_ebnf(
|
@@ -17,9 +17,18 @@ logger = logging.getLogger(__name__)
|
|
17
17
|
|
18
18
|
class Qwen25Detector(BaseFormatDetector):
|
19
19
|
"""
|
20
|
-
Detector for Qwen 2.5
|
21
|
-
|
22
|
-
|
20
|
+
Detector for Qwen 2.5 and Qwen 3 model function call format.
|
21
|
+
|
22
|
+
Format Structure:
|
23
|
+
```
|
24
|
+
<tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
|
25
|
+
```
|
26
|
+
|
27
|
+
Key Components:
|
28
|
+
- Tool Call Tags: `<tool_call>` and `</tool_call>` wrap each individual call
|
29
|
+
- Function Call Object: JSON object with "name" and "arguments" fields
|
30
|
+
|
31
|
+
Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default
|
23
32
|
"""
|
24
33
|
|
25
34
|
def __init__(self):
|
@@ -9,7 +9,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool
|
|
9
9
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
10
10
|
from sglang.srt.function_call.core_types import (
|
11
11
|
StreamingParseResult,
|
12
|
-
StructureInfo,
|
13
12
|
ToolCallItem,
|
14
13
|
_GetInfoFunc,
|
15
14
|
)
|
@@ -29,7 +28,7 @@ def _safe_val(raw: str) -> Any:
|
|
29
28
|
return raw
|
30
29
|
|
31
30
|
|
32
|
-
class
|
31
|
+
class Qwen3CoderDetector(BaseFormatDetector):
|
33
32
|
"""
|
34
33
|
Detector for Qwen 3 models.
|
35
34
|
Assumes function call format:
|
@@ -127,24 +126,26 @@ class Qwen3XMLDetector(BaseFormatDetector):
|
|
127
126
|
params[pname] = _safe_val(pval)
|
128
127
|
raw = {"name": fname, "arguments": params}
|
129
128
|
try:
|
129
|
+
# TODO: fix idx in function call, the index for a function
|
130
|
+
# call will always be -1 in parse_base_json
|
130
131
|
res.extend(self.parse_base_json(raw, tools))
|
131
132
|
except Exception:
|
132
133
|
logger.warning("invalid tool call for %s dropped", fname)
|
133
134
|
return res
|
134
135
|
|
136
|
+
def supports_structural_tag(self) -> bool:
|
137
|
+
return False
|
138
|
+
|
135
139
|
def structure_info(self) -> _GetInfoFunc:
|
136
|
-
|
137
|
-
begin=f"{self.tool_call_start_token}\n<function={n}>",
|
138
|
-
end=f"</function>\n{self.tool_call_end_token}",
|
139
|
-
trigger=self.tool_call_start_token,
|
140
|
-
)
|
140
|
+
raise NotImplementedError
|
141
141
|
|
142
|
-
# TODO: fake ebnf for xml + outlines backend
|
143
142
|
def build_ebnf(self, tools: List[Tool]):
|
144
143
|
return EBNFComposer.build_ebnf(
|
145
144
|
tools,
|
146
145
|
individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
|
147
146
|
individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
|
148
147
|
tool_call_separator="\\n",
|
149
|
-
function_format="
|
148
|
+
function_format="xml",
|
149
|
+
call_rule_fmt='"<function={name}>\\n" {arguments_rule} "\\n</function>"',
|
150
|
+
key_value_rule_fmt='"<parameter={key}>\\n" {valrule} "\\n</parameter>"',
|
150
151
|
)
|
sglang/srt/layers/activation.py
CHANGED
@@ -33,6 +33,7 @@ from sglang.srt.utils import (
|
|
33
33
|
cpu_has_amx_support,
|
34
34
|
is_cpu,
|
35
35
|
is_cuda,
|
36
|
+
is_hip,
|
36
37
|
is_npu,
|
37
38
|
set_weight_attrs,
|
38
39
|
)
|
@@ -42,9 +43,12 @@ _is_cuda = is_cuda()
|
|
42
43
|
_is_npu = is_npu()
|
43
44
|
_is_cpu_amx_available = cpu_has_amx_support()
|
44
45
|
_is_cpu = is_cpu()
|
46
|
+
_is_hip = is_hip()
|
45
47
|
|
46
48
|
if _is_cuda:
|
47
49
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
50
|
+
elif _is_hip:
|
51
|
+
from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
|
48
52
|
|
49
53
|
if is_npu():
|
50
54
|
import torch_npu
|
@@ -126,9 +130,13 @@ class QuickGELU(CustomOp):
|
|
126
130
|
return x * torch.sigmoid(1.702 * x)
|
127
131
|
|
128
132
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
129
|
-
# TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel
|
130
133
|
return self.forward_native(x)
|
131
134
|
|
135
|
+
def forward_hip(self, x: torch.Tensor) -> torch.Tensor:
|
136
|
+
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
|
137
|
+
gelu_quick(x, out)
|
138
|
+
return out
|
139
|
+
|
132
140
|
|
133
141
|
class ScaledActivation(nn.Module):
|
134
142
|
"""An activation function with post-scale parameters.
|
@@ -222,8 +230,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
|
222
230
|
return nn.Identity()
|
223
231
|
|
224
232
|
|
225
|
-
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
233
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
226
234
|
logger.info(
|
227
|
-
"sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
235
|
+
"sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
228
236
|
)
|
229
237
|
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
@@ -65,7 +65,9 @@ class AttentionBackend(ABC):
|
|
65
65
|
**kwargs,
|
66
66
|
):
|
67
67
|
"""Run forward on an attention layer."""
|
68
|
-
if forward_batch.forward_mode.
|
68
|
+
if forward_batch.forward_mode.is_idle():
|
69
|
+
return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
|
70
|
+
elif forward_batch.forward_mode.is_decode():
|
69
71
|
return self.forward_decode(
|
70
72
|
q,
|
71
73
|
k,
|
@@ -24,8 +24,8 @@ from sglang.srt.distributed import (
|
|
24
24
|
tensor_model_parallel_all_reduce,
|
25
25
|
)
|
26
26
|
from sglang.srt.layers.dp_attention import (
|
27
|
-
|
28
|
-
|
27
|
+
attn_tp_all_gather_into_tensor,
|
28
|
+
attn_tp_reduce_scatter_tensor,
|
29
29
|
dp_gather_partial,
|
30
30
|
dp_scatter,
|
31
31
|
get_attention_dp_size,
|
@@ -309,8 +309,8 @@ class CommunicateSimpleFn:
|
|
309
309
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
310
310
|
hidden_states,
|
311
311
|
)
|
312
|
-
|
313
|
-
|
312
|
+
attn_tp_all_gather_into_tensor(
|
313
|
+
hidden_states,
|
314
314
|
local_hidden_states,
|
315
315
|
)
|
316
316
|
return hidden_states
|
@@ -400,9 +400,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
400
400
|
].clone(),
|
401
401
|
residual,
|
402
402
|
)
|
403
|
-
|
404
|
-
list(residual.tensor_split(context.attn_tp_size)), local_residual
|
405
|
-
)
|
403
|
+
attn_tp_all_gather_into_tensor(residual, local_residual)
|
406
404
|
if context.attn_dp_size != 1:
|
407
405
|
if context.attn_tp_rank == 0:
|
408
406
|
hidden_states += residual
|
@@ -442,9 +440,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
442
440
|
*,
|
443
441
|
residual_input_mode,
|
444
442
|
):
|
445
|
-
|
446
|
-
hidden_states =
|
447
|
-
|
443
|
+
input_hidden_states = hidden_states
|
444
|
+
hidden_states = hidden_states.tensor_split(context.attn_tp_size)[
|
445
|
+
context.attn_tp_rank
|
446
|
+
]
|
447
|
+
attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states)
|
448
448
|
if residual_input_mode == ScatterMode.TP_ATTN_FULL:
|
449
449
|
residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
|
450
450
|
if hidden_states.shape[0] != 0:
|
@@ -547,8 +547,8 @@ class CommunicateSummableTensorPairFn:
|
|
547
547
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
548
548
|
hidden_states,
|
549
549
|
)
|
550
|
-
|
551
|
-
|
550
|
+
attn_tp_all_gather_into_tensor(
|
551
|
+
hidden_states,
|
552
552
|
local_hidden_states,
|
553
553
|
)
|
554
554
|
return hidden_states, residual
|
@@ -3,7 +3,8 @@ from __future__ import annotations
|
|
3
3
|
import functools
|
4
4
|
import logging
|
5
5
|
from contextlib import contextmanager
|
6
|
-
from
|
6
|
+
from enum import IntEnum, auto
|
7
|
+
from typing import TYPE_CHECKING, List, Tuple
|
7
8
|
|
8
9
|
import torch
|
9
10
|
import triton
|
@@ -30,6 +31,34 @@ _LOCAL_ATTN_DP_SIZE = None
|
|
30
31
|
_LOCAL_ATTN_DP_RANK = None
|
31
32
|
|
32
33
|
|
34
|
+
class DPPaddingMode(IntEnum):
|
35
|
+
|
36
|
+
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
|
37
|
+
MAX_LEN = auto()
|
38
|
+
# Padding tokens to sum length and then gather tokens using `all_reduce`
|
39
|
+
SUM_LEN = auto()
|
40
|
+
|
41
|
+
def is_max_len(self):
|
42
|
+
return self == DPPaddingMode.MAX_LEN
|
43
|
+
|
44
|
+
def is_sum_len(self):
|
45
|
+
return self == DPPaddingMode.SUM_LEN
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
|
49
|
+
# we choose the mode that minimizes the communication cost
|
50
|
+
max_len = max(global_num_tokens)
|
51
|
+
sum_len = sum(global_num_tokens)
|
52
|
+
if sum_len * 2 > max_len * get_attention_dp_size():
|
53
|
+
return cls.MAX_LEN
|
54
|
+
else:
|
55
|
+
return cls.SUM_LEN
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
|
59
|
+
return cls.MAX_LEN
|
60
|
+
|
61
|
+
|
33
62
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
34
63
|
if not enable_dp_attention:
|
35
64
|
return tp_rank, tp_size, 0
|
@@ -162,7 +191,7 @@ def disable_dp_size():
|
|
162
191
|
_ATTN_DP_SIZE = old_dp_size
|
163
192
|
|
164
193
|
|
165
|
-
def get_dp_local_info(forward_batch: ForwardBatch):
|
194
|
+
def get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
166
195
|
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
|
167
196
|
dp_rank = get_attention_dp_rank()
|
168
197
|
|
@@ -221,7 +250,7 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
|
|
221
250
|
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
|
222
251
|
|
223
252
|
|
224
|
-
def
|
253
|
+
def _dp_gather_via_all_reduce(
|
225
254
|
global_tokens: torch.Tensor,
|
226
255
|
local_tokens: torch.Tensor,
|
227
256
|
forward_batch: ForwardBatch,
|
@@ -238,13 +267,6 @@ def _dp_gather(
|
|
238
267
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
239
268
|
), "aliasing between global_tokens and local_tokens not allowed"
|
240
269
|
|
241
|
-
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
|
242
|
-
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
|
243
|
-
# actual size of the accepted tokens.
|
244
|
-
if forward_batch.forward_mode.is_draft_extend():
|
245
|
-
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
246
|
-
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
247
|
-
|
248
270
|
memcpy_triton(
|
249
271
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
250
272
|
)
|
@@ -263,6 +285,38 @@ def _dp_gather(
|
|
263
285
|
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
|
264
286
|
|
265
287
|
|
288
|
+
def _dp_gather_via_all_gather(
|
289
|
+
global_tokens: torch.Tensor,
|
290
|
+
local_tokens: torch.Tensor,
|
291
|
+
forward_batch: ForwardBatch,
|
292
|
+
is_partial: bool,
|
293
|
+
):
|
294
|
+
if not is_partial:
|
295
|
+
if get_attention_tp_rank() != 0:
|
296
|
+
local_tokens.fill_(0)
|
297
|
+
scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[
|
298
|
+
get_attention_tp_rank()
|
299
|
+
]
|
300
|
+
get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens)
|
301
|
+
get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens)
|
302
|
+
|
303
|
+
|
304
|
+
def _dp_gather(
|
305
|
+
global_tokens: torch.Tensor,
|
306
|
+
local_tokens: torch.Tensor,
|
307
|
+
forward_batch: ForwardBatch,
|
308
|
+
is_partial: bool,
|
309
|
+
):
|
310
|
+
if forward_batch.dp_padding_mode.is_max_len():
|
311
|
+
_dp_gather_via_all_gather(
|
312
|
+
global_tokens, local_tokens, forward_batch, is_partial
|
313
|
+
)
|
314
|
+
else:
|
315
|
+
_dp_gather_via_all_reduce(
|
316
|
+
global_tokens, local_tokens, forward_batch, is_partial
|
317
|
+
)
|
318
|
+
|
319
|
+
|
266
320
|
def dp_gather_partial(
|
267
321
|
global_tokens: torch.Tensor,
|
268
322
|
local_tokens: torch.Tensor,
|
@@ -296,24 +350,18 @@ def dp_scatter(
|
|
296
350
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
297
351
|
), "aliasing between local_tokens and global_tokens not allowed"
|
298
352
|
|
299
|
-
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
|
300
|
-
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
|
301
|
-
# actual size of the accepted tokens.
|
302
|
-
if forward_batch.forward_mode.is_draft_extend():
|
303
|
-
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
304
|
-
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
305
|
-
|
306
353
|
memcpy_triton(
|
307
354
|
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
308
355
|
)
|
309
356
|
|
310
357
|
|
311
|
-
def
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
358
|
+
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
359
|
+
return get_attention_tp_group().reduce_scatter_tensor(output, input)
|
360
|
+
|
361
|
+
|
362
|
+
def attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor):
|
363
|
+
return get_attention_tp_group().all_gather_into_tensor(output, input)
|
316
364
|
|
317
365
|
|
318
|
-
def attn_tp_all_gather(output_list: List[torch.Tensor],
|
319
|
-
return get_attention_tp_group().all_gather(
|
366
|
+
def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor):
|
367
|
+
return get_attention_tp_group().all_gather(input, output_tensor_list=output_list)
|
@@ -27,7 +27,9 @@ from sglang.srt.distributed import (
|
|
27
27
|
tensor_model_parallel_all_gather,
|
28
28
|
)
|
29
29
|
from sglang.srt.layers.dp_attention import (
|
30
|
+
DPPaddingMode,
|
30
31
|
attn_tp_all_gather,
|
32
|
+
attn_tp_all_gather_into_tensor,
|
31
33
|
dp_gather_replicate,
|
32
34
|
dp_scatter,
|
33
35
|
get_attention_dp_rank,
|
@@ -111,7 +113,8 @@ class LogitsMetadata:
|
|
111
113
|
# Number of tokens to sample per DP rank
|
112
114
|
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
|
113
115
|
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
114
|
-
|
116
|
+
# The gather mode for DP attention
|
117
|
+
dp_padding_mode: Optional[DPPaddingMode] = None
|
115
118
|
# for padding
|
116
119
|
padded_static_len: int = -1
|
117
120
|
|
@@ -163,12 +166,12 @@ class LogitsMetadata:
|
|
163
166
|
forward_batch_gathered_buffer=forward_batch.gathered_buffer,
|
164
167
|
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
|
165
168
|
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
|
169
|
+
dp_padding_mode=DPPaddingMode.SUM_LEN,
|
166
170
|
)
|
167
171
|
|
168
|
-
def compute_dp_attention_metadata(self
|
169
|
-
|
170
|
-
|
171
|
-
return
|
172
|
+
def compute_dp_attention_metadata(self):
|
173
|
+
# TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
|
174
|
+
# we may use a smaller buffer in draft extend.
|
172
175
|
|
173
176
|
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
174
177
|
dp_rank = get_attention_dp_rank()
|
@@ -179,18 +182,9 @@ class LogitsMetadata:
|
|
179
182
|
else:
|
180
183
|
dp_local_start_pos = cumtokens[dp_rank - 1]
|
181
184
|
dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
|
182
|
-
gathered_buffer = torch.zeros(
|
183
|
-
(
|
184
|
-
sum(self.global_num_tokens_for_logprob_cpu),
|
185
|
-
hidden_states.shape[1],
|
186
|
-
),
|
187
|
-
dtype=hidden_states.dtype,
|
188
|
-
device=hidden_states.device,
|
189
|
-
)
|
190
185
|
|
191
186
|
self.dp_local_start_pos = dp_local_start_pos
|
192
187
|
self.dp_local_num_tokens = dp_local_num_tokens
|
193
|
-
self.gathered_buffer = gathered_buffer
|
194
188
|
|
195
189
|
|
196
190
|
class LogitsProcessor(nn.Module):
|
@@ -434,7 +428,7 @@ class LogitsProcessor(nn.Module):
|
|
434
428
|
guarantee the given hidden_states follow this constraint.
|
435
429
|
"""
|
436
430
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
437
|
-
logits_metadata.compute_dp_attention_metadata(
|
431
|
+
logits_metadata.compute_dp_attention_metadata()
|
438
432
|
hidden_states, local_hidden_states = (
|
439
433
|
torch.empty_like(logits_metadata.gathered_buffer),
|
440
434
|
hidden_states,
|
@@ -463,15 +457,31 @@ class LogitsProcessor(nn.Module):
|
|
463
457
|
|
464
458
|
if self.do_tensor_parallel_all_gather:
|
465
459
|
if self.use_attn_tp_group:
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
460
|
+
if self.config.vocab_size % self.attn_tp_size == 0:
|
461
|
+
global_logits = torch.empty(
|
462
|
+
(
|
463
|
+
self.attn_tp_size,
|
464
|
+
logits.shape[0],
|
465
|
+
self.config.vocab_size // self.attn_tp_size,
|
466
|
+
),
|
467
|
+
device=logits.device,
|
468
|
+
dtype=logits.dtype,
|
469
|
+
)
|
470
|
+
attn_tp_all_gather_into_tensor(global_logits, logits)
|
471
|
+
global_logits = global_logits.permute(1, 0, 2).reshape(
|
472
|
+
logits.shape[0], self.config.vocab_size
|
473
|
+
)
|
474
|
+
else:
|
475
|
+
global_logits = torch.empty(
|
476
|
+
(self.config.vocab_size, logits.shape[0]),
|
477
|
+
device=logits.device,
|
478
|
+
dtype=logits.dtype,
|
479
|
+
)
|
480
|
+
global_logits = global_logits.T
|
481
|
+
attn_tp_all_gather(
|
482
|
+
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)),
|
483
|
+
logits,
|
484
|
+
)
|
475
485
|
logits = global_logits
|
476
486
|
else:
|
477
487
|
logits = tensor_model_parallel_all_gather(logits)
|