sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/_custom_ops.py +29 -1
  3. sglang/srt/configs/internvl.py +3 -0
  4. sglang/srt/configs/model_config.py +5 -1
  5. sglang/srt/constrained/base_grammar_backend.py +10 -2
  6. sglang/srt/constrained/xgrammar_backend.py +7 -5
  7. sglang/srt/conversation.py +17 -2
  8. sglang/srt/debug_utils/__init__.py +0 -0
  9. sglang/srt/debug_utils/dump_comparator.py +131 -0
  10. sglang/srt/debug_utils/dumper.py +108 -0
  11. sglang/srt/debug_utils/text_comparator.py +172 -0
  12. sglang/srt/disaggregation/common/conn.py +34 -6
  13. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  14. sglang/srt/disaggregation/mini_lb.py +3 -2
  15. sglang/srt/disaggregation/mooncake/conn.py +65 -20
  16. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  17. sglang/srt/disaggregation/nixl/conn.py +17 -13
  18. sglang/srt/disaggregation/prefill.py +13 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  21. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  23. sglang/srt/distributed/parallel_state.py +70 -15
  24. sglang/srt/entrypoints/engine.py +5 -9
  25. sglang/srt/entrypoints/http_server.py +20 -32
  26. sglang/srt/entrypoints/openai/protocol.py +3 -3
  27. sglang/srt/entrypoints/openai/serving_chat.py +148 -72
  28. sglang/srt/function_call/base_format_detector.py +74 -12
  29. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  30. sglang/srt/function_call/ebnf_composer.py +105 -66
  31. sglang/srt/function_call/function_call_parser.py +6 -4
  32. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  33. sglang/srt/function_call/kimik2_detector.py +41 -16
  34. sglang/srt/function_call/llama32_detector.py +6 -3
  35. sglang/srt/function_call/mistral_detector.py +11 -3
  36. sglang/srt/function_call/pythonic_detector.py +16 -14
  37. sglang/srt/function_call/qwen25_detector.py +12 -3
  38. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
  39. sglang/srt/layers/activation.py +11 -3
  40. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  41. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  42. sglang/srt/layers/attention/vision.py +56 -8
  43. sglang/srt/layers/communicator.py +12 -12
  44. sglang/srt/layers/dp_attention.py +72 -24
  45. sglang/srt/layers/layernorm.py +26 -1
  46. sglang/srt/layers/logits_processor.py +46 -25
  47. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  48. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  51. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  52. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  53. sglang/srt/layers/moe/topk.py +88 -34
  54. sglang/srt/layers/multimodal.py +11 -8
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  56. sglang/srt/layers/quantization/fp8.py +25 -247
  57. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  58. sglang/srt/layers/quantization/modelopt_quant.py +33 -14
  59. sglang/srt/layers/quantization/unquant.py +24 -76
  60. sglang/srt/layers/quantization/utils.py +0 -9
  61. sglang/srt/layers/quantization/w4afp8.py +68 -17
  62. sglang/srt/layers/radix_attention.py +5 -3
  63. sglang/srt/lora/lora_manager.py +133 -169
  64. sglang/srt/lora/lora_registry.py +188 -0
  65. sglang/srt/lora/mem_pool.py +2 -2
  66. sglang/srt/managers/cache_controller.py +62 -13
  67. sglang/srt/managers/io_struct.py +19 -1
  68. sglang/srt/managers/mm_utils.py +154 -35
  69. sglang/srt/managers/multimodal_processor.py +3 -14
  70. sglang/srt/managers/schedule_batch.py +27 -11
  71. sglang/srt/managers/scheduler.py +48 -26
  72. sglang/srt/managers/tokenizer_manager.py +62 -28
  73. sglang/srt/managers/tp_worker.py +5 -4
  74. sglang/srt/mem_cache/allocator.py +67 -7
  75. sglang/srt/mem_cache/hicache_storage.py +17 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +35 -18
  77. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  78. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  79. sglang/srt/model_executor/forward_batch_info.py +201 -29
  80. sglang/srt/model_executor/model_runner.py +109 -37
  81. sglang/srt/models/deepseek_v2.py +63 -30
  82. sglang/srt/models/glm4_moe.py +1035 -0
  83. sglang/srt/models/glm4_moe_nextn.py +167 -0
  84. sglang/srt/models/interns1.py +328 -0
  85. sglang/srt/models/internvl.py +143 -47
  86. sglang/srt/models/llava.py +9 -5
  87. sglang/srt/models/minicpmo.py +4 -1
  88. sglang/srt/models/mllama4.py +10 -3
  89. sglang/srt/models/qwen2_moe.py +2 -6
  90. sglang/srt/models/qwen3_moe.py +6 -8
  91. sglang/srt/multimodal/processors/base_processor.py +20 -6
  92. sglang/srt/multimodal/processors/clip.py +2 -2
  93. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  94. sglang/srt/multimodal/processors/gemma3.py +2 -2
  95. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  96. sglang/srt/multimodal/processors/internvl.py +21 -8
  97. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  98. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  99. sglang/srt/multimodal/processors/llava.py +4 -4
  100. sglang/srt/multimodal/processors/minicpm.py +2 -3
  101. sglang/srt/multimodal/processors/mlama.py +2 -2
  102. sglang/srt/multimodal/processors/mllama4.py +18 -111
  103. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  104. sglang/srt/multimodal/processors/pixtral.py +2 -2
  105. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  106. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  107. sglang/srt/multimodal/processors/vila.py +3 -1
  108. sglang/srt/reasoning_parser.py +48 -5
  109. sglang/srt/sampling/sampling_batch_info.py +6 -5
  110. sglang/srt/server_args.py +132 -60
  111. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  112. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  113. sglang/srt/speculative/eagle_utils.py +51 -23
  114. sglang/srt/speculative/eagle_worker.py +59 -44
  115. sglang/srt/two_batch_overlap.py +9 -5
  116. sglang/srt/utils.py +113 -69
  117. sglang/srt/weight_sync/utils.py +119 -0
  118. sglang/test/runners.py +4 -0
  119. sglang/test/test_activation.py +50 -1
  120. sglang/test/test_utils.py +65 -5
  121. sglang/utils.py +19 -0
  122. sglang/version.py +1 -1
  123. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
  124. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
  125. sglang/srt/debug_utils.py +0 -74
  126. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  127. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  128. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool
8
8
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
9
9
  from sglang.srt.function_call.core_types import (
10
10
  StreamingParseResult,
11
- StructureInfo,
12
11
  ToolCallItem,
13
12
  _GetInfoFunc,
14
13
  )
@@ -19,10 +18,17 @@ logger = logging.getLogger(__name__)
19
18
 
20
19
  class PythonicDetector(BaseFormatDetector):
21
20
  """
22
- Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
23
- Assumes function call format:
24
- [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
25
- Arguments are Python literals (not JSON).
21
+ Detector for Llama-4 models with Pythonic tool call format.
22
+
23
+ The Pythonic format uses Python function call syntax within square brackets,
24
+ with arguments as Python literals rather than JSON.
25
+
26
+ Format Structure:
27
+ ```
28
+ [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
29
+ ```
30
+
31
+ Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct?chat_template=default
26
32
  """
27
33
 
28
34
  def __init__(self):
@@ -75,11 +81,7 @@ class PythonicDetector(BaseFormatDetector):
75
81
  return StreamingParseResult(normal_text=normal_text, calls=[])
76
82
 
77
83
  calls = []
78
- tool_indices = {
79
- tool.function.name: i
80
- for i, tool in enumerate(tools)
81
- if tool.function.name
82
- }
84
+ tool_indices = self._get_tool_indices(tools)
83
85
  for call_index, call in enumerate(parsed.elts):
84
86
  if not isinstance(call.func, ast.Name):
85
87
  continue
@@ -213,11 +215,11 @@ class PythonicDetector(BaseFormatDetector):
213
215
  else:
214
216
  raise ValueError("Tool call arguments must be literals")
215
217
 
216
- def structure_info(self) -> _GetInfoFunc:
217
- def info(name: str):
218
- return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(")
218
+ def supports_structural_tag(self) -> bool:
219
+ return False
219
220
 
220
- return info
221
+ def structure_info(self) -> _GetInfoFunc:
222
+ raise NotImplementedError
221
223
 
222
224
  def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
223
225
  return EBNFComposer.build_ebnf(
@@ -17,9 +17,18 @@ logger = logging.getLogger(__name__)
17
17
 
18
18
  class Qwen25Detector(BaseFormatDetector):
19
19
  """
20
- Detector for Qwen 2.5 models.
21
- Assumes function call format:
22
- <tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
20
+ Detector for Qwen 2.5 and Qwen 3 model function call format.
21
+
22
+ Format Structure:
23
+ ```
24
+ <tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
25
+ ```
26
+
27
+ Key Components:
28
+ - Tool Call Tags: `<tool_call>` and `</tool_call>` wrap each individual call
29
+ - Function Call Object: JSON object with "name" and "arguments" fields
30
+
31
+ Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default
23
32
  """
24
33
 
25
34
  def __init__(self):
@@ -9,7 +9,6 @@ from sglang.srt.entrypoints.openai.protocol import Tool
9
9
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
10
10
  from sglang.srt.function_call.core_types import (
11
11
  StreamingParseResult,
12
- StructureInfo,
13
12
  ToolCallItem,
14
13
  _GetInfoFunc,
15
14
  )
@@ -29,7 +28,7 @@ def _safe_val(raw: str) -> Any:
29
28
  return raw
30
29
 
31
30
 
32
- class Qwen3XMLDetector(BaseFormatDetector):
31
+ class Qwen3CoderDetector(BaseFormatDetector):
33
32
  """
34
33
  Detector for Qwen 3 models.
35
34
  Assumes function call format:
@@ -127,24 +126,27 @@ class Qwen3XMLDetector(BaseFormatDetector):
127
126
  params[pname] = _safe_val(pval)
128
127
  raw = {"name": fname, "arguments": params}
129
128
  try:
129
+ # TODO: fix idx in function call, the index for a function
130
+ # call will always be -1 in parse_base_json
130
131
  res.extend(self.parse_base_json(raw, tools))
131
132
  except Exception:
132
133
  logger.warning("invalid tool call for %s dropped", fname)
133
134
  return res
134
135
 
136
+ def supports_structural_tag(self) -> bool:
137
+ return False
138
+
135
139
  def structure_info(self) -> _GetInfoFunc:
136
- return lambda n: StructureInfo(
137
- begin=f"{self.tool_call_start_token}\n<function={n}>",
138
- end=f"</function>\n{self.tool_call_end_token}",
139
- trigger=self.tool_call_start_token,
140
- )
140
+ raise NotImplementedError
141
141
 
142
- # TODO: fake ebnf for xml + outlines backend
143
142
  def build_ebnf(self, tools: List[Tool]):
144
143
  return EBNFComposer.build_ebnf(
145
144
  tools,
146
145
  individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
147
146
  individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
148
147
  tool_call_separator="\\n",
149
- function_format="json",
148
+ function_format="xml",
149
+ call_rule_fmt='"<function={name}>\\n" {arguments_rule} "\\n</function>"',
150
+ key_value_rule_fmt='"<parameter={key}>\\n" {valrule} "\\n</parameter>"',
151
+ key_value_separator="\\n",
150
152
  )
@@ -33,6 +33,7 @@ from sglang.srt.utils import (
33
33
  cpu_has_amx_support,
34
34
  is_cpu,
35
35
  is_cuda,
36
+ is_hip,
36
37
  is_npu,
37
38
  set_weight_attrs,
38
39
  )
@@ -42,9 +43,12 @@ _is_cuda = is_cuda()
42
43
  _is_npu = is_npu()
43
44
  _is_cpu_amx_available = cpu_has_amx_support()
44
45
  _is_cpu = is_cpu()
46
+ _is_hip = is_hip()
45
47
 
46
48
  if _is_cuda:
47
49
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
50
+ elif _is_hip:
51
+ from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
48
52
 
49
53
  if is_npu():
50
54
  import torch_npu
@@ -126,9 +130,13 @@ class QuickGELU(CustomOp):
126
130
  return x * torch.sigmoid(1.702 * x)
127
131
 
128
132
  def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
129
- # TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel
130
133
  return self.forward_native(x)
131
134
 
135
+ def forward_hip(self, x: torch.Tensor) -> torch.Tensor:
136
+ out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
137
+ gelu_quick(x, out)
138
+ return out
139
+
132
140
 
133
141
  class ScaledActivation(nn.Module):
134
142
  """An activation function with post-scale parameters.
@@ -222,8 +230,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
222
230
  return nn.Identity()
223
231
 
224
232
 
225
- if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
233
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
226
234
  logger.info(
227
- "sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
235
+ "sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
228
236
  )
229
237
  from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
@@ -65,7 +65,9 @@ class AttentionBackend(ABC):
65
65
  **kwargs,
66
66
  ):
67
67
  """Run forward on an attention layer."""
68
- if forward_batch.forward_mode.is_decode():
68
+ if forward_batch.forward_mode.is_idle():
69
+ return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
70
+ elif forward_batch.forward_mode.is_decode():
69
71
  return self.forward_decode(
70
72
  q,
71
73
  k,
@@ -0,0 +1,100 @@
1
+ from typing import TYPE_CHECKING, Optional, Union
2
+
3
+ import torch
4
+
5
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
6
+ from sglang.srt.layers.radix_attention import RadixAttention
7
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
8
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
9
+
10
+
11
+ class HybridAttnBackend(AttentionBackend):
12
+ """Support different backends for prefill and decode."""
13
+
14
+ def __init__(
15
+ self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend
16
+ ):
17
+ self.prefill_backend = prefill_backend
18
+ self.decode_backend = decode_backend
19
+
20
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
21
+ if forward_batch.forward_mode.is_decode():
22
+ self.decode_backend.init_forward_metadata(forward_batch)
23
+ else:
24
+ self.prefill_backend.init_forward_metadata(forward_batch)
25
+
26
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
27
+ self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
28
+
29
+ def init_forward_metadata_capture_cuda_graph(
30
+ self,
31
+ bs: int,
32
+ num_tokens: int,
33
+ req_pool_indices: torch.Tensor,
34
+ seq_lens: torch.Tensor,
35
+ encoder_lens: Optional[torch.Tensor],
36
+ forward_mode: ForwardMode,
37
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
38
+ ):
39
+ self.decode_backend.init_forward_metadata_capture_cuda_graph(
40
+ bs,
41
+ num_tokens,
42
+ req_pool_indices,
43
+ seq_lens,
44
+ encoder_lens,
45
+ forward_mode,
46
+ spec_info,
47
+ )
48
+
49
+ def init_forward_metadata_replay_cuda_graph(
50
+ self,
51
+ bs: int,
52
+ req_pool_indices: torch.Tensor,
53
+ seq_lens: torch.Tensor,
54
+ seq_lens_sum: int,
55
+ encoder_lens: Optional[torch.Tensor],
56
+ forward_mode: ForwardMode,
57
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
58
+ seq_lens_cpu: Optional[torch.Tensor],
59
+ ):
60
+ self.decode_backend.init_forward_metadata_replay_cuda_graph(
61
+ bs,
62
+ req_pool_indices,
63
+ seq_lens,
64
+ seq_lens_sum,
65
+ encoder_lens,
66
+ forward_mode,
67
+ spec_info,
68
+ seq_lens_cpu,
69
+ )
70
+
71
+ def get_cuda_graph_seq_len_fill_value(self):
72
+ return self.decode_backend.get_cuda_graph_seq_len_fill_value()
73
+
74
+ def forward_decode(
75
+ self,
76
+ q: torch.Tensor,
77
+ k: torch.Tensor,
78
+ v: torch.Tensor,
79
+ layer: RadixAttention,
80
+ forward_batch: ForwardBatch,
81
+ save_kv_cache: bool = True,
82
+ **kwargs,
83
+ ):
84
+ return self.decode_backend.forward_decode(
85
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
86
+ )
87
+
88
+ def forward_extend(
89
+ self,
90
+ q: torch.Tensor,
91
+ k: torch.Tensor,
92
+ v: torch.Tensor,
93
+ layer: RadixAttention,
94
+ forward_batch: ForwardBatch,
95
+ save_kv_cache: bool = True,
96
+ **kwargs,
97
+ ):
98
+ return self.prefill_backend.forward_extend(
99
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
100
+ )
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import dataclasses
4
4
  import functools
5
5
  import math
6
- from functools import lru_cache
6
+ from functools import lru_cache, partial
7
7
  from typing import Any, Optional, Tuple, Union
8
8
 
9
9
  import torch
@@ -18,11 +18,16 @@ _is_cuda = is_cuda()
18
18
  if _is_cuda:
19
19
  from sgl_kernel.flash_attn import flash_attn_varlen_func
20
20
 
21
- from sglang.srt.distributed import parallel_state
21
+ from sglang.srt.distributed import (
22
+ parallel_state,
23
+ split_tensor_along_last_dim,
24
+ tensor_model_parallel_all_gather,
25
+ )
22
26
  from sglang.srt.distributed import utils as dist_utils
23
27
  from sglang.srt.layers.attention.triton_ops.prefill_attention import (
24
28
  context_attention_fwd,
25
29
  )
30
+ from sglang.srt.layers.layernorm import RMSNorm
26
31
  from sglang.srt.layers.linear import (
27
32
  ColumnParallelLinear,
28
33
  QKVParallelLinear,
@@ -349,25 +354,44 @@ class VisionAttention(nn.Module):
349
354
  flatten_batch: bool = False,
350
355
  prefix: str = "",
351
356
  proj_bias: bool = True,
357
+ num_dummy_heads: int = 0,
358
+ qkv_bias: bool = True,
359
+ qk_normalization: bool = False,
360
+ layer_norm_eps: float = 1e-06,
352
361
  **kwargs,
353
362
  ):
354
363
  super().__init__()
355
364
  world_size = parallel_state.get_tensor_model_parallel_world_size()
365
+ self.tp_size = world_size
366
+ self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
356
367
  self.dropout = dropout
357
368
  self.head_size = embed_dim // num_heads
358
369
  self.hidden_size_per_attention_head = dist_utils.divide(
359
370
  projection_size, num_heads
360
371
  )
361
372
  self.num_attention_heads_per_partition = dist_utils.divide(
362
- num_heads, world_size
373
+ num_dummy_heads + num_heads, world_size
363
374
  )
364
375
  self.num_attention_kv_heads_per_partition = dist_utils.divide(
365
- num_heads, world_size
376
+ num_dummy_heads + num_heads, world_size
366
377
  )
367
378
 
368
379
  self.q_size = self.num_attention_heads_per_partition * self.head_size
369
380
  self.kv_size = self.num_attention_kv_heads_per_partition * self.head_size
370
381
 
382
+ self.qk_normalization = qk_normalization
383
+
384
+ # Additional dummy heads are used to enable TP for common GPU counts.
385
+ self.dummy_dim = (num_dummy_heads + num_heads) * self.head_size
386
+
387
+ if self.qk_normalization:
388
+ self.q_norm = RMSNorm(
389
+ self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
390
+ )
391
+ self.k_norm = RMSNorm(
392
+ self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
393
+ )
394
+
371
395
  if global_server_args_dict["mm_attention_backend"] is None:
372
396
  if qkv_backend is None:
373
397
  qkv_backend = "sdpa"
@@ -391,26 +415,46 @@ class VisionAttention(nn.Module):
391
415
  self.qkv_proj = QKVParallelLinear(
392
416
  hidden_size=embed_dim,
393
417
  head_size=self.head_size,
394
- total_num_heads=num_heads,
395
- total_num_kv_heads=num_heads,
418
+ total_num_heads=num_dummy_heads + num_heads,
419
+ total_num_kv_heads=num_dummy_heads + num_heads,
420
+ bias=qkv_bias,
396
421
  quant_config=quant_config,
397
422
  prefix=add_prefix("qkv_proj", prefix),
398
423
  )
399
424
  else:
400
425
  self.qkv_proj = ColumnParallelLinear(
401
426
  input_size=embed_dim,
402
- output_size=3 * projection_size,
427
+ output_size=3 * self.dummy_dim,
428
+ bias=qkv_bias,
403
429
  quant_config=quant_config,
404
430
  prefix=add_prefix("qkv_proj", prefix),
405
431
  )
406
432
  self.proj = RowParallelLinear(
407
- input_size=embed_dim,
433
+ input_size=self.dummy_dim,
408
434
  output_size=embed_dim,
409
435
  bias=proj_bias,
410
436
  quant_config=quant_config,
411
437
  prefix=add_prefix("proj", prefix),
412
438
  )
413
439
 
440
+ def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
441
+ """apply qk norm for internvl vit attn"""
442
+ q = q.flatten(1, 2)
443
+ k = k.flatten(1, 2)
444
+
445
+ if self.tp_size > 1:
446
+ q = tensor_model_parallel_all_gather(q.contiguous())
447
+ k = tensor_model_parallel_all_gather(k.contiguous())
448
+ q = self.q_norm(q)
449
+ k = self.k_norm(k)
450
+ if self.tp_size > 1:
451
+ splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
452
+ q = splitter(q)[self.tp_rank]
453
+ k = splitter(k)[self.tp_rank]
454
+ q = q.unflatten(-1, (-1, self.head_size))
455
+ k = k.unflatten(-1, (-1, self.head_size))
456
+ return q, k
457
+
414
458
  def forward(
415
459
  self,
416
460
  x: torch.Tensor,
@@ -489,6 +533,10 @@ class VisionAttention(nn.Module):
489
533
  assert k.dim() == 3, k.dim()
490
534
  assert v.dim() == 3, v.dim()
491
535
 
536
+ # internvl
537
+ if self.qk_normalization:
538
+ q, k = self._apply_qk_norm(q, k)
539
+
492
540
  output = self.qkv_backend.forward(
493
541
  q=q,
494
542
  k=k,
@@ -24,8 +24,8 @@ from sglang.srt.distributed import (
24
24
  tensor_model_parallel_all_reduce,
25
25
  )
26
26
  from sglang.srt.layers.dp_attention import (
27
- attn_tp_all_gather,
28
- attn_tp_reduce_scatter,
27
+ attn_tp_all_gather_into_tensor,
28
+ attn_tp_reduce_scatter_tensor,
29
29
  dp_gather_partial,
30
30
  dp_scatter,
31
31
  get_attention_dp_size,
@@ -309,8 +309,8 @@ class CommunicateSimpleFn:
309
309
  forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
310
310
  hidden_states,
311
311
  )
312
- attn_tp_all_gather(
313
- list(hidden_states.tensor_split(context.attn_tp_size)),
312
+ attn_tp_all_gather_into_tensor(
313
+ hidden_states,
314
314
  local_hidden_states,
315
315
  )
316
316
  return hidden_states
@@ -400,9 +400,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
400
400
  ].clone(),
401
401
  residual,
402
402
  )
403
- attn_tp_all_gather(
404
- list(residual.tensor_split(context.attn_tp_size)), local_residual
405
- )
403
+ attn_tp_all_gather_into_tensor(residual, local_residual)
406
404
  if context.attn_dp_size != 1:
407
405
  if context.attn_tp_rank == 0:
408
406
  hidden_states += residual
@@ -442,9 +440,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
442
440
  *,
443
441
  residual_input_mode,
444
442
  ):
445
- tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
446
- hidden_states = tensor_list[context.attn_tp_rank]
447
- attn_tp_reduce_scatter(hidden_states, tensor_list)
443
+ input_hidden_states = hidden_states
444
+ hidden_states = hidden_states.tensor_split(context.attn_tp_size)[
445
+ context.attn_tp_rank
446
+ ]
447
+ attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states)
448
448
  if residual_input_mode == ScatterMode.TP_ATTN_FULL:
449
449
  residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
450
450
  if hidden_states.shape[0] != 0:
@@ -547,8 +547,8 @@ class CommunicateSummableTensorPairFn:
547
547
  forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
548
548
  hidden_states,
549
549
  )
550
- attn_tp_all_gather(
551
- list(hidden_states.tensor_split(context.attn_tp_size)),
550
+ attn_tp_all_gather_into_tensor(
551
+ hidden_states,
552
552
  local_hidden_states,
553
553
  )
554
554
  return hidden_states, residual
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  import functools
4
4
  import logging
5
5
  from contextlib import contextmanager
6
- from typing import TYPE_CHECKING, List
6
+ from enum import IntEnum, auto
7
+ from typing import TYPE_CHECKING, List, Tuple
7
8
 
8
9
  import torch
9
10
  import triton
@@ -30,6 +31,34 @@ _LOCAL_ATTN_DP_SIZE = None
30
31
  _LOCAL_ATTN_DP_RANK = None
31
32
 
32
33
 
34
+ class DPPaddingMode(IntEnum):
35
+
36
+ # Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
37
+ MAX_LEN = auto()
38
+ # Padding tokens to sum length and then gather tokens using `all_reduce`
39
+ SUM_LEN = auto()
40
+
41
+ def is_max_len(self):
42
+ return self == DPPaddingMode.MAX_LEN
43
+
44
+ def is_sum_len(self):
45
+ return self == DPPaddingMode.SUM_LEN
46
+
47
+ @classmethod
48
+ def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
49
+ # we choose the mode that minimizes the communication cost
50
+ max_len = max(global_num_tokens)
51
+ sum_len = sum(global_num_tokens)
52
+ if sum_len * 2 > max_len * get_attention_dp_size():
53
+ return cls.MAX_LEN
54
+ else:
55
+ return cls.SUM_LEN
56
+
57
+ @classmethod
58
+ def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
59
+ return cls.MAX_LEN
60
+
61
+
33
62
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
34
63
  if not enable_dp_attention:
35
64
  return tp_rank, tp_size, 0
@@ -162,7 +191,7 @@ def disable_dp_size():
162
191
  _ATTN_DP_SIZE = old_dp_size
163
192
 
164
193
 
165
- def get_dp_local_info(forward_batch: ForwardBatch):
194
+ def get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch.Tensor]:
166
195
  # `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
167
196
  dp_rank = get_attention_dp_rank()
168
197
 
@@ -221,7 +250,7 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
221
250
  memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
222
251
 
223
252
 
224
- def _dp_gather(
253
+ def _dp_gather_via_all_reduce(
225
254
  global_tokens: torch.Tensor,
226
255
  local_tokens: torch.Tensor,
227
256
  forward_batch: ForwardBatch,
@@ -238,13 +267,6 @@ def _dp_gather(
238
267
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
239
268
  ), "aliasing between global_tokens and local_tokens not allowed"
240
269
 
241
- # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
242
- # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
243
- # actual size of the accepted tokens.
244
- if forward_batch.forward_mode.is_draft_extend():
245
- shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
246
- local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
247
-
248
270
  memcpy_triton(
249
271
  global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
250
272
  )
@@ -263,6 +285,38 @@ def _dp_gather(
263
285
  global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
264
286
 
265
287
 
288
+ def _dp_gather_via_all_gather(
289
+ global_tokens: torch.Tensor,
290
+ local_tokens: torch.Tensor,
291
+ forward_batch: ForwardBatch,
292
+ is_partial: bool,
293
+ ):
294
+ if not is_partial:
295
+ if get_attention_tp_rank() != 0:
296
+ local_tokens.fill_(0)
297
+ scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[
298
+ get_attention_tp_rank()
299
+ ]
300
+ get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens)
301
+ get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens)
302
+
303
+
304
+ def _dp_gather(
305
+ global_tokens: torch.Tensor,
306
+ local_tokens: torch.Tensor,
307
+ forward_batch: ForwardBatch,
308
+ is_partial: bool,
309
+ ):
310
+ if forward_batch.dp_padding_mode.is_max_len():
311
+ _dp_gather_via_all_gather(
312
+ global_tokens, local_tokens, forward_batch, is_partial
313
+ )
314
+ else:
315
+ _dp_gather_via_all_reduce(
316
+ global_tokens, local_tokens, forward_batch, is_partial
317
+ )
318
+
319
+
266
320
  def dp_gather_partial(
267
321
  global_tokens: torch.Tensor,
268
322
  local_tokens: torch.Tensor,
@@ -296,24 +350,18 @@ def dp_scatter(
296
350
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
297
351
  ), "aliasing between local_tokens and global_tokens not allowed"
298
352
 
299
- # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
300
- # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
301
- # actual size of the accepted tokens.
302
- if forward_batch.forward_mode.is_draft_extend():
303
- shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
304
- local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
305
-
306
353
  memcpy_triton(
307
354
  local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
308
355
  )
309
356
 
310
357
 
311
- def attn_tp_reduce_scatter(
312
- output: torch.Tensor,
313
- input_list: List[torch.Tensor],
314
- ):
315
- return get_attention_tp_group().reduce_scatter(output, input_list)
358
+ def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
359
+ return get_attention_tp_group().reduce_scatter_tensor(output, input)
360
+
361
+
362
+ def attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor):
363
+ return get_attention_tp_group().all_gather_into_tensor(output, input)
316
364
 
317
365
 
318
- def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
319
- return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
366
+ def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor):
367
+ return get_attention_tp_group().all_gather(input, output_tensor_list=output_list)
@@ -61,10 +61,15 @@ class RMSNorm(CustomOp):
61
61
  self,
62
62
  hidden_size: int,
63
63
  eps: float = 1e-6,
64
+ var_hidden_size: Optional[int] = None,
64
65
  ) -> None:
65
66
  super().__init__()
66
67
  self.weight = nn.Parameter(torch.ones(hidden_size))
67
68
  self.variance_epsilon = eps
69
+ self.hidden_size = hidden_size
70
+ self.variance_size_override = (
71
+ None if var_hidden_size == hidden_size else var_hidden_size
72
+ )
68
73
  if _use_aiter:
69
74
  self._forward_method = self.forward_aiter
70
75
 
@@ -73,6 +78,8 @@ class RMSNorm(CustomOp):
73
78
  x: torch.Tensor,
74
79
  residual: Optional[torch.Tensor] = None,
75
80
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
81
+ if self.variance_size_override is not None:
82
+ return self.forward_native(x, residual)
76
83
  if residual is not None:
77
84
  fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
78
85
  return x, residual
@@ -138,7 +145,25 @@ class RMSNorm(CustomOp):
138
145
  x = x + residual.to(torch.float32)
139
146
  residual = x.to(orig_dtype)
140
147
 
141
- variance = x.pow(2).mean(dim=-1, keepdim=True)
148
+ hidden_size = x.shape[-1]
149
+ if hidden_size != self.hidden_size:
150
+ raise ValueError(
151
+ "Expected hidden_size to be "
152
+ f"{self.hidden_size}, but found: {hidden_size}"
153
+ )
154
+
155
+ if self.variance_size_override is None:
156
+ x_var = x
157
+ else:
158
+ if hidden_size < self.variance_size_override:
159
+ raise ValueError(
160
+ "Expected hidden_size to be at least "
161
+ f"{self.variance_size_override}, but found: {hidden_size}"
162
+ )
163
+
164
+ x_var = x[..., : self.variance_size_override]
165
+
166
+ variance = x_var.pow(2).mean(dim=-1, keepdim=True)
142
167
  x = x * torch.rsqrt(variance + self.variance_epsilon)
143
168
  x = (x * self.weight).to(orig_dtype)
144
169
  if residual is None: