sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. sglang/srt/_custom_ops.py +29 -1
  2. sglang/srt/configs/model_config.py +1 -1
  3. sglang/srt/conversation.py +1 -1
  4. sglang/srt/disaggregation/common/conn.py +34 -6
  5. sglang/srt/disaggregation/mini_lb.py +3 -2
  6. sglang/srt/disaggregation/mooncake/conn.py +49 -20
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  8. sglang/srt/disaggregation/nixl/conn.py +17 -13
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  10. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  11. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  12. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  13. sglang/srt/distributed/parallel_state.py +70 -15
  14. sglang/srt/entrypoints/engine.py +2 -8
  15. sglang/srt/entrypoints/http_server.py +20 -32
  16. sglang/srt/entrypoints/openai/protocol.py +3 -3
  17. sglang/srt/entrypoints/openai/serving_chat.py +27 -4
  18. sglang/srt/function_call/base_format_detector.py +74 -12
  19. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  20. sglang/srt/function_call/ebnf_composer.py +95 -63
  21. sglang/srt/function_call/function_call_parser.py +4 -4
  22. sglang/srt/function_call/kimik2_detector.py +41 -16
  23. sglang/srt/function_call/llama32_detector.py +6 -3
  24. sglang/srt/function_call/mistral_detector.py +11 -3
  25. sglang/srt/function_call/pythonic_detector.py +16 -14
  26. sglang/srt/function_call/qwen25_detector.py +12 -3
  27. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +10 -9
  28. sglang/srt/layers/activation.py +11 -3
  29. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  30. sglang/srt/layers/communicator.py +12 -12
  31. sglang/srt/layers/dp_attention.py +72 -24
  32. sglang/srt/layers/logits_processor.py +34 -24
  33. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  35. sglang/srt/layers/moe/topk.py +5 -13
  36. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  37. sglang/srt/layers/quantization/modelopt_quant.py +8 -4
  38. sglang/srt/layers/quantization/utils.py +0 -9
  39. sglang/srt/layers/radix_attention.py +5 -3
  40. sglang/srt/lora/lora_manager.py +133 -169
  41. sglang/srt/lora/lora_registry.py +124 -0
  42. sglang/srt/lora/mem_pool.py +2 -2
  43. sglang/srt/managers/cache_controller.py +53 -6
  44. sglang/srt/managers/io_struct.py +19 -1
  45. sglang/srt/managers/schedule_batch.py +13 -3
  46. sglang/srt/managers/scheduler.py +13 -25
  47. sglang/srt/managers/tokenizer_manager.py +28 -25
  48. sglang/srt/managers/tp_worker.py +2 -4
  49. sglang/srt/mem_cache/allocator.py +67 -7
  50. sglang/srt/mem_cache/hicache_storage.py +17 -1
  51. sglang/srt/mem_cache/hiradix_cache.py +30 -16
  52. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  53. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  54. sglang/srt/model_executor/forward_batch_info.py +201 -29
  55. sglang/srt/model_executor/model_runner.py +41 -23
  56. sglang/srt/models/deepseek_v2.py +1 -2
  57. sglang/srt/models/mllama4.py +10 -3
  58. sglang/srt/models/qwen2_moe.py +0 -4
  59. sglang/srt/models/qwen3_moe.py +1 -6
  60. sglang/srt/reasoning_parser.py +46 -4
  61. sglang/srt/sampling/sampling_batch_info.py +6 -5
  62. sglang/srt/server_args.py +76 -55
  63. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  64. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  65. sglang/srt/speculative/eagle_utils.py +51 -23
  66. sglang/srt/speculative/eagle_worker.py +59 -44
  67. sglang/srt/two_batch_overlap.py +9 -5
  68. sglang/srt/utils.py +17 -68
  69. sglang/test/test_activation.py +50 -1
  70. sglang/version.py +1 -1
  71. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +5 -5
  72. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +75 -72
  73. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  75. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -18,16 +18,21 @@ logger = logging.getLogger(__name__)
18
18
 
19
19
 
20
20
  class KimiK2Detector(BaseFormatDetector):
21
+ """
22
+ Detector for Kimi K2 model function call format.
23
+
24
+ Format Structure:
25
+ ```
26
+ <|tool_calls_section_begin|>
27
+ <|tool_call_begin|>functions.{func_name}:{index} <|tool_call_argument_begin|>{json_args}<|tool_call_end|>
28
+ <|tool_calls_section_end|>
29
+ ```
30
+
31
+ Reference: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md
32
+ """
21
33
 
22
34
  def __init__(self):
23
35
  super().__init__()
24
- self._buffer = ""
25
- self.current_tool_name_sent: bool = False
26
- self.prev_tool_call_arr: list[dict] = []
27
- self.current_tool_id: int = -1
28
- self.streamed_args_for_tool: list[str] = (
29
- []
30
- ) # map what has been streamed for each tool so far to a list
31
36
 
32
37
  self.bot_token: str = "<|tool_calls_section_begin|>"
33
38
  self.eot_token: str = "<|tool_calls_section_end|>"
@@ -114,11 +119,7 @@ class KimiK2Detector(BaseFormatDetector):
114
119
  return StreamingParseResult(normal_text=new_text)
115
120
 
116
121
  if not hasattr(self, "_tool_indices"):
117
- self._tool_indices = {
118
- tool.function.name: i
119
- for i, tool in enumerate(tools)
120
- if tool.function and tool.function.name
121
- }
122
+ self._tool_indices = self._get_tool_indices(tools)
122
123
 
123
124
  calls: list[ToolCallItem] = []
124
125
  try:
@@ -150,7 +151,7 @@ class KimiK2Detector(BaseFormatDetector):
150
151
  )
151
152
  )
152
153
  self.current_tool_name_sent = True
153
- # Store the tool call info for adapter.py
154
+ # Store the tool call info for serving layer completions endpoint
154
155
  self.prev_tool_call_arr[self.current_tool_id] = {
155
156
  "name": function_name,
156
157
  "arguments": {},
@@ -214,7 +215,31 @@ class KimiK2Detector(BaseFormatDetector):
214
215
  return StreamingParseResult(normal_text=current_text)
215
216
 
216
217
  def structure_info(self) -> _GetInfoFunc:
217
- raise NotImplementedError()
218
+ """Return function that creates StructureInfo for guided generation."""
219
+
220
+ def get_info(name: str) -> StructureInfo:
221
+ return StructureInfo(
222
+ begin=f"<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:0 <|tool_call_argument_begin|>",
223
+ end="<|tool_call_end|><|tool_calls_section_end|>",
224
+ trigger="<|tool_calls_section_begin|>",
225
+ )
226
+
227
+ return get_info
218
228
 
219
- def build_ebnf(self, tools: List[Tool]):
220
- raise NotImplementedError()
229
+ def build_ebnf(self, tools: List[Tool]) -> str:
230
+ """
231
+ Build EBNF grammar for KimiK2 tool call format.
232
+
233
+ NOTE: The call_rule_fmt uses [0-9]+ for the function index to allow the grammar
234
+ to accept any numeric index (0, 1, 2, etc.) for proper sequential indexing in
235
+ multiple function call scenarios, while still maintaining the correct KimiK2
236
+ format structure for constrained generation.
237
+ """
238
+ return EBNFComposer.build_ebnf(
239
+ tools,
240
+ sequence_start_token=self.bot_token,
241
+ sequence_end_token=self.eot_token,
242
+ tool_call_separator="",
243
+ call_rule_fmt='"<|tool_call_begin|>functions.{name}:" [0-9]+ " <|tool_call_argument_begin|>" {arguments_rule} "<|tool_call_end|>"',
244
+ function_format="json",
245
+ )
@@ -16,9 +16,12 @@ logger = logging.getLogger(__name__)
16
16
 
17
17
  class Llama32Detector(BaseFormatDetector):
18
18
  """
19
- Detector for Llama 3.2 models.
20
- Assumes function call format:
21
- <|python_tag|>{"name":"xxx", "arguments":{...}}
19
+ Detector for Llama 3.2 models with json tool call format.
20
+
21
+ Format Structure:
22
+ ```
23
+ <python_tag>{"name":"xxx", "arguments":{...}}
24
+ ```
22
25
  """
23
26
 
24
27
  def __init__(self):
@@ -17,9 +17,17 @@ logger = logging.getLogger(__name__)
17
17
 
18
18
  class MistralDetector(BaseFormatDetector):
19
19
  """
20
- Detector for Mistral models.
21
- Assumes function call format:
22
- [TOOL_CALLS] [{"name":"func1", "arguments":{...}}, {"name":"func2", "arguments":{...}}]
20
+ Detector for Mistral model function call format.
21
+
22
+ The Mistral format uses a simple bracket-delimited structure with JSON arrays
23
+ containing function call objects.
24
+
25
+ Format Structure:
26
+ ```
27
+ [TOOL_CALLS] [{"name": "function_name", "arguments": {json_args}}, ...]
28
+ ```
29
+
30
+ Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3?chat_template=default
23
31
  """
24
32
 
25
33
  def __init__(self):
@@ -8,7 +8,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool
8
8
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
9
9
  from sglang.srt.function_call.core_types import (
10
10
  StreamingParseResult,
11
- StructureInfo,
12
11
  ToolCallItem,
13
12
  _GetInfoFunc,
14
13
  )
@@ -19,10 +18,17 @@ logger = logging.getLogger(__name__)
19
18
 
20
19
  class PythonicDetector(BaseFormatDetector):
21
20
  """
22
- Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
23
- Assumes function call format:
24
- [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
25
- Arguments are Python literals (not JSON).
21
+ Detector for Llama-4 models with Pythonic tool call format.
22
+
23
+ The Pythonic format uses Python function call syntax within square brackets,
24
+ with arguments as Python literals rather than JSON.
25
+
26
+ Format Structure:
27
+ ```
28
+ [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
29
+ ```
30
+
31
+ Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct?chat_template=default
26
32
  """
27
33
 
28
34
  def __init__(self):
@@ -75,11 +81,7 @@ class PythonicDetector(BaseFormatDetector):
75
81
  return StreamingParseResult(normal_text=normal_text, calls=[])
76
82
 
77
83
  calls = []
78
- tool_indices = {
79
- tool.function.name: i
80
- for i, tool in enumerate(tools)
81
- if tool.function.name
82
- }
84
+ tool_indices = self._get_tool_indices(tools)
83
85
  for call_index, call in enumerate(parsed.elts):
84
86
  if not isinstance(call.func, ast.Name):
85
87
  continue
@@ -213,11 +215,11 @@ class PythonicDetector(BaseFormatDetector):
213
215
  else:
214
216
  raise ValueError("Tool call arguments must be literals")
215
217
 
216
- def structure_info(self) -> _GetInfoFunc:
217
- def info(name: str):
218
- return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(")
218
+ def supports_structural_tag(self) -> bool:
219
+ return False
219
220
 
220
- return info
221
+ def structure_info(self) -> _GetInfoFunc:
222
+ raise NotImplementedError
221
223
 
222
224
  def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
223
225
  return EBNFComposer.build_ebnf(
@@ -17,9 +17,18 @@ logger = logging.getLogger(__name__)
17
17
 
18
18
  class Qwen25Detector(BaseFormatDetector):
19
19
  """
20
- Detector for Qwen 2.5 models.
21
- Assumes function call format:
22
- <tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
20
+ Detector for Qwen 2.5 and Qwen 3 model function call format.
21
+
22
+ Format Structure:
23
+ ```
24
+ <tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
25
+ ```
26
+
27
+ Key Components:
28
+ - Tool Call Tags: `<tool_call>` and `</tool_call>` wrap each individual call
29
+ - Function Call Object: JSON object with "name" and "arguments" fields
30
+
31
+ Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default
23
32
  """
24
33
 
25
34
  def __init__(self):
@@ -9,7 +9,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool
9
9
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
10
10
  from sglang.srt.function_call.core_types import (
11
11
  StreamingParseResult,
12
- StructureInfo,
13
12
  ToolCallItem,
14
13
  _GetInfoFunc,
15
14
  )
@@ -29,7 +28,7 @@ def _safe_val(raw: str) -> Any:
29
28
  return raw
30
29
 
31
30
 
32
- class Qwen3XMLDetector(BaseFormatDetector):
31
+ class Qwen3CoderDetector(BaseFormatDetector):
33
32
  """
34
33
  Detector for Qwen 3 models.
35
34
  Assumes function call format:
@@ -127,24 +126,26 @@ class Qwen3XMLDetector(BaseFormatDetector):
127
126
  params[pname] = _safe_val(pval)
128
127
  raw = {"name": fname, "arguments": params}
129
128
  try:
129
+ # TODO: fix idx in function call, the index for a function
130
+ # call will always be -1 in parse_base_json
130
131
  res.extend(self.parse_base_json(raw, tools))
131
132
  except Exception:
132
133
  logger.warning("invalid tool call for %s dropped", fname)
133
134
  return res
134
135
 
136
+ def supports_structural_tag(self) -> bool:
137
+ return False
138
+
135
139
  def structure_info(self) -> _GetInfoFunc:
136
- return lambda n: StructureInfo(
137
- begin=f"{self.tool_call_start_token}\n<function={n}>",
138
- end=f"</function>\n{self.tool_call_end_token}",
139
- trigger=self.tool_call_start_token,
140
- )
140
+ raise NotImplementedError
141
141
 
142
- # TODO: fake ebnf for xml + outlines backend
143
142
  def build_ebnf(self, tools: List[Tool]):
144
143
  return EBNFComposer.build_ebnf(
145
144
  tools,
146
145
  individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
147
146
  individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
148
147
  tool_call_separator="\\n",
149
- function_format="json",
148
+ function_format="xml",
149
+ call_rule_fmt='"<function={name}>\\n" {arguments_rule} "\\n</function>"',
150
+ key_value_rule_fmt='"<parameter={key}>\\n" {valrule} "\\n</parameter>"',
150
151
  )
@@ -33,6 +33,7 @@ from sglang.srt.utils import (
33
33
  cpu_has_amx_support,
34
34
  is_cpu,
35
35
  is_cuda,
36
+ is_hip,
36
37
  is_npu,
37
38
  set_weight_attrs,
38
39
  )
@@ -42,9 +43,12 @@ _is_cuda = is_cuda()
42
43
  _is_npu = is_npu()
43
44
  _is_cpu_amx_available = cpu_has_amx_support()
44
45
  _is_cpu = is_cpu()
46
+ _is_hip = is_hip()
45
47
 
46
48
  if _is_cuda:
47
49
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
50
+ elif _is_hip:
51
+ from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
48
52
 
49
53
  if is_npu():
50
54
  import torch_npu
@@ -126,9 +130,13 @@ class QuickGELU(CustomOp):
126
130
  return x * torch.sigmoid(1.702 * x)
127
131
 
128
132
  def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
129
- # TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel
130
133
  return self.forward_native(x)
131
134
 
135
+ def forward_hip(self, x: torch.Tensor) -> torch.Tensor:
136
+ out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
137
+ gelu_quick(x, out)
138
+ return out
139
+
132
140
 
133
141
  class ScaledActivation(nn.Module):
134
142
  """An activation function with post-scale parameters.
@@ -222,8 +230,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
222
230
  return nn.Identity()
223
231
 
224
232
 
225
- if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
233
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
226
234
  logger.info(
227
- "sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
235
+ "sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
228
236
  )
229
237
  from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
@@ -65,7 +65,9 @@ class AttentionBackend(ABC):
65
65
  **kwargs,
66
66
  ):
67
67
  """Run forward on an attention layer."""
68
- if forward_batch.forward_mode.is_decode():
68
+ if forward_batch.forward_mode.is_idle():
69
+ return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
70
+ elif forward_batch.forward_mode.is_decode():
69
71
  return self.forward_decode(
70
72
  q,
71
73
  k,
@@ -24,8 +24,8 @@ from sglang.srt.distributed import (
24
24
  tensor_model_parallel_all_reduce,
25
25
  )
26
26
  from sglang.srt.layers.dp_attention import (
27
- attn_tp_all_gather,
28
- attn_tp_reduce_scatter,
27
+ attn_tp_all_gather_into_tensor,
28
+ attn_tp_reduce_scatter_tensor,
29
29
  dp_gather_partial,
30
30
  dp_scatter,
31
31
  get_attention_dp_size,
@@ -309,8 +309,8 @@ class CommunicateSimpleFn:
309
309
  forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
310
310
  hidden_states,
311
311
  )
312
- attn_tp_all_gather(
313
- list(hidden_states.tensor_split(context.attn_tp_size)),
312
+ attn_tp_all_gather_into_tensor(
313
+ hidden_states,
314
314
  local_hidden_states,
315
315
  )
316
316
  return hidden_states
@@ -400,9 +400,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
400
400
  ].clone(),
401
401
  residual,
402
402
  )
403
- attn_tp_all_gather(
404
- list(residual.tensor_split(context.attn_tp_size)), local_residual
405
- )
403
+ attn_tp_all_gather_into_tensor(residual, local_residual)
406
404
  if context.attn_dp_size != 1:
407
405
  if context.attn_tp_rank == 0:
408
406
  hidden_states += residual
@@ -442,9 +440,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
442
440
  *,
443
441
  residual_input_mode,
444
442
  ):
445
- tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
446
- hidden_states = tensor_list[context.attn_tp_rank]
447
- attn_tp_reduce_scatter(hidden_states, tensor_list)
443
+ input_hidden_states = hidden_states
444
+ hidden_states = hidden_states.tensor_split(context.attn_tp_size)[
445
+ context.attn_tp_rank
446
+ ]
447
+ attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states)
448
448
  if residual_input_mode == ScatterMode.TP_ATTN_FULL:
449
449
  residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
450
450
  if hidden_states.shape[0] != 0:
@@ -547,8 +547,8 @@ class CommunicateSummableTensorPairFn:
547
547
  forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
548
548
  hidden_states,
549
549
  )
550
- attn_tp_all_gather(
551
- list(hidden_states.tensor_split(context.attn_tp_size)),
550
+ attn_tp_all_gather_into_tensor(
551
+ hidden_states,
552
552
  local_hidden_states,
553
553
  )
554
554
  return hidden_states, residual
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  import functools
4
4
  import logging
5
5
  from contextlib import contextmanager
6
- from typing import TYPE_CHECKING, List
6
+ from enum import IntEnum, auto
7
+ from typing import TYPE_CHECKING, List, Tuple
7
8
 
8
9
  import torch
9
10
  import triton
@@ -30,6 +31,34 @@ _LOCAL_ATTN_DP_SIZE = None
30
31
  _LOCAL_ATTN_DP_RANK = None
31
32
 
32
33
 
34
+ class DPPaddingMode(IntEnum):
35
+
36
+ # Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
37
+ MAX_LEN = auto()
38
+ # Padding tokens to sum length and then gather tokens using `all_reduce`
39
+ SUM_LEN = auto()
40
+
41
+ def is_max_len(self):
42
+ return self == DPPaddingMode.MAX_LEN
43
+
44
+ def is_sum_len(self):
45
+ return self == DPPaddingMode.SUM_LEN
46
+
47
+ @classmethod
48
+ def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
49
+ # we choose the mode that minimizes the communication cost
50
+ max_len = max(global_num_tokens)
51
+ sum_len = sum(global_num_tokens)
52
+ if sum_len * 2 > max_len * get_attention_dp_size():
53
+ return cls.MAX_LEN
54
+ else:
55
+ return cls.SUM_LEN
56
+
57
+ @classmethod
58
+ def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
59
+ return cls.MAX_LEN
60
+
61
+
33
62
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
34
63
  if not enable_dp_attention:
35
64
  return tp_rank, tp_size, 0
@@ -162,7 +191,7 @@ def disable_dp_size():
162
191
  _ATTN_DP_SIZE = old_dp_size
163
192
 
164
193
 
165
- def get_dp_local_info(forward_batch: ForwardBatch):
194
+ def get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch.Tensor]:
166
195
  # `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
167
196
  dp_rank = get_attention_dp_rank()
168
197
 
@@ -221,7 +250,7 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
221
250
  memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
222
251
 
223
252
 
224
- def _dp_gather(
253
+ def _dp_gather_via_all_reduce(
225
254
  global_tokens: torch.Tensor,
226
255
  local_tokens: torch.Tensor,
227
256
  forward_batch: ForwardBatch,
@@ -238,13 +267,6 @@ def _dp_gather(
238
267
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
239
268
  ), "aliasing between global_tokens and local_tokens not allowed"
240
269
 
241
- # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
242
- # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
243
- # actual size of the accepted tokens.
244
- if forward_batch.forward_mode.is_draft_extend():
245
- shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
246
- local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
247
-
248
270
  memcpy_triton(
249
271
  global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
250
272
  )
@@ -263,6 +285,38 @@ def _dp_gather(
263
285
  global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
264
286
 
265
287
 
288
+ def _dp_gather_via_all_gather(
289
+ global_tokens: torch.Tensor,
290
+ local_tokens: torch.Tensor,
291
+ forward_batch: ForwardBatch,
292
+ is_partial: bool,
293
+ ):
294
+ if not is_partial:
295
+ if get_attention_tp_rank() != 0:
296
+ local_tokens.fill_(0)
297
+ scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[
298
+ get_attention_tp_rank()
299
+ ]
300
+ get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens)
301
+ get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens)
302
+
303
+
304
+ def _dp_gather(
305
+ global_tokens: torch.Tensor,
306
+ local_tokens: torch.Tensor,
307
+ forward_batch: ForwardBatch,
308
+ is_partial: bool,
309
+ ):
310
+ if forward_batch.dp_padding_mode.is_max_len():
311
+ _dp_gather_via_all_gather(
312
+ global_tokens, local_tokens, forward_batch, is_partial
313
+ )
314
+ else:
315
+ _dp_gather_via_all_reduce(
316
+ global_tokens, local_tokens, forward_batch, is_partial
317
+ )
318
+
319
+
266
320
  def dp_gather_partial(
267
321
  global_tokens: torch.Tensor,
268
322
  local_tokens: torch.Tensor,
@@ -296,24 +350,18 @@ def dp_scatter(
296
350
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
297
351
  ), "aliasing between local_tokens and global_tokens not allowed"
298
352
 
299
- # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
300
- # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
301
- # actual size of the accepted tokens.
302
- if forward_batch.forward_mode.is_draft_extend():
303
- shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
304
- local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
305
-
306
353
  memcpy_triton(
307
354
  local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
308
355
  )
309
356
 
310
357
 
311
- def attn_tp_reduce_scatter(
312
- output: torch.Tensor,
313
- input_list: List[torch.Tensor],
314
- ):
315
- return get_attention_tp_group().reduce_scatter(output, input_list)
358
+ def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
359
+ return get_attention_tp_group().reduce_scatter_tensor(output, input)
360
+
361
+
362
+ def attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor):
363
+ return get_attention_tp_group().all_gather_into_tensor(output, input)
316
364
 
317
365
 
318
- def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
319
- return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
366
+ def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor):
367
+ return get_attention_tp_group().all_gather(input, output_tensor_list=output_list)
@@ -27,7 +27,9 @@ from sglang.srt.distributed import (
27
27
  tensor_model_parallel_all_gather,
28
28
  )
29
29
  from sglang.srt.layers.dp_attention import (
30
+ DPPaddingMode,
30
31
  attn_tp_all_gather,
32
+ attn_tp_all_gather_into_tensor,
31
33
  dp_gather_replicate,
32
34
  dp_scatter,
33
35
  get_attention_dp_rank,
@@ -111,7 +113,8 @@ class LogitsMetadata:
111
113
  # Number of tokens to sample per DP rank
112
114
  global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
113
115
  global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
114
-
116
+ # The gather mode for DP attention
117
+ dp_padding_mode: Optional[DPPaddingMode] = None
115
118
  # for padding
116
119
  padded_static_len: int = -1
117
120
 
@@ -163,12 +166,12 @@ class LogitsMetadata:
163
166
  forward_batch_gathered_buffer=forward_batch.gathered_buffer,
164
167
  global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
165
168
  global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
169
+ dp_padding_mode=DPPaddingMode.SUM_LEN,
166
170
  )
167
171
 
168
- def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
169
- if self.global_num_tokens_for_logprob_cpu is None:
170
- # we are capturing cuda graph
171
- return
172
+ def compute_dp_attention_metadata(self):
173
+ # TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
174
+ # we may use a smaller buffer in draft extend.
172
175
 
173
176
  cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
174
177
  dp_rank = get_attention_dp_rank()
@@ -179,18 +182,9 @@ class LogitsMetadata:
179
182
  else:
180
183
  dp_local_start_pos = cumtokens[dp_rank - 1]
181
184
  dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
182
- gathered_buffer = torch.zeros(
183
- (
184
- sum(self.global_num_tokens_for_logprob_cpu),
185
- hidden_states.shape[1],
186
- ),
187
- dtype=hidden_states.dtype,
188
- device=hidden_states.device,
189
- )
190
185
 
191
186
  self.dp_local_start_pos = dp_local_start_pos
192
187
  self.dp_local_num_tokens = dp_local_num_tokens
193
- self.gathered_buffer = gathered_buffer
194
188
 
195
189
 
196
190
  class LogitsProcessor(nn.Module):
@@ -434,7 +428,7 @@ class LogitsProcessor(nn.Module):
434
428
  guarantee the given hidden_states follow this constraint.
435
429
  """
436
430
  if self.do_tensor_parallel_all_gather_dp_attn:
437
- logits_metadata.compute_dp_attention_metadata(hidden_states)
431
+ logits_metadata.compute_dp_attention_metadata()
438
432
  hidden_states, local_hidden_states = (
439
433
  torch.empty_like(logits_metadata.gathered_buffer),
440
434
  hidden_states,
@@ -463,15 +457,31 @@ class LogitsProcessor(nn.Module):
463
457
 
464
458
  if self.do_tensor_parallel_all_gather:
465
459
  if self.use_attn_tp_group:
466
- global_logits = torch.empty(
467
- (self.config.vocab_size, logits.shape[0]),
468
- device=logits.device,
469
- dtype=logits.dtype,
470
- )
471
- global_logits = global_logits.T
472
- attn_tp_all_gather(
473
- list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
474
- )
460
+ if self.config.vocab_size % self.attn_tp_size == 0:
461
+ global_logits = torch.empty(
462
+ (
463
+ self.attn_tp_size,
464
+ logits.shape[0],
465
+ self.config.vocab_size // self.attn_tp_size,
466
+ ),
467
+ device=logits.device,
468
+ dtype=logits.dtype,
469
+ )
470
+ attn_tp_all_gather_into_tensor(global_logits, logits)
471
+ global_logits = global_logits.permute(1, 0, 2).reshape(
472
+ logits.shape[0], self.config.vocab_size
473
+ )
474
+ else:
475
+ global_logits = torch.empty(
476
+ (self.config.vocab_size, logits.shape[0]),
477
+ device=logits.device,
478
+ dtype=logits.dtype,
479
+ )
480
+ global_logits = global_logits.T
481
+ attn_tp_all_gather(
482
+ list(global_logits.tensor_split(self.attn_tp_size, dim=-1)),
483
+ logits,
484
+ )
475
485
  logits = global_logits
476
486
  else:
477
487
  logits = tensor_model_parallel_all_gather(logits)