sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +48 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +34 -0
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/nixl/conn.py +6 -6
  10. sglang/srt/disaggregation/prefill.py +2 -2
  11. sglang/srt/disaggregation/utils.py +1 -1
  12. sglang/srt/distributed/parallel_state.py +44 -17
  13. sglang/srt/entrypoints/EngineBase.py +8 -0
  14. sglang/srt/entrypoints/engine.py +40 -6
  15. sglang/srt/entrypoints/http_server.py +111 -24
  16. sglang/srt/entrypoints/openai/protocol.py +4 -2
  17. sglang/srt/eplb/__init__.py +0 -0
  18. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  19. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  20. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  21. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  22. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  24. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  25. sglang/srt/hf_transformers_utils.py +2 -1
  26. sglang/srt/layers/activation.py +2 -2
  27. sglang/srt/layers/amx_utils.py +86 -0
  28. sglang/srt/layers/attention/ascend_backend.py +219 -0
  29. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  30. sglang/srt/layers/attention/tbo_backend.py +37 -9
  31. sglang/srt/layers/communicator.py +18 -2
  32. sglang/srt/layers/dp_attention.py +9 -3
  33. sglang/srt/layers/elementwise.py +76 -12
  34. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  35. sglang/srt/layers/layernorm.py +26 -0
  36. sglang/srt/layers/linear.py +84 -14
  37. sglang/srt/layers/logits_processor.py +4 -4
  38. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +36 -13
  40. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  41. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
  42. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
  43. sglang/srt/layers/moe/router.py +60 -22
  44. sglang/srt/layers/moe/topk.py +10 -28
  45. sglang/srt/layers/parameter.py +67 -7
  46. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  47. sglang/srt/layers/quantization/fp8.py +44 -0
  48. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  49. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  50. sglang/srt/layers/quantization/gptq.py +5 -1
  51. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  52. sglang/srt/layers/quantization/quant_utils.py +166 -0
  53. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  54. sglang/srt/layers/rotary_embedding.py +2 -2
  55. sglang/srt/layers/vocab_parallel_embedding.py +11 -7
  56. sglang/srt/lora/lora.py +4 -5
  57. sglang/srt/lora/lora_manager.py +73 -20
  58. sglang/srt/managers/configure_logging.py +1 -1
  59. sglang/srt/managers/io_struct.py +50 -13
  60. sglang/srt/managers/mm_utils.py +73 -59
  61. sglang/srt/managers/multimodal_processor.py +2 -6
  62. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  63. sglang/srt/managers/schedule_batch.py +77 -84
  64. sglang/srt/managers/scheduler.py +113 -59
  65. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  66. sglang/srt/managers/session_controller.py +12 -3
  67. sglang/srt/managers/tokenizer_manager.py +314 -103
  68. sglang/srt/managers/tp_worker.py +13 -1
  69. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  70. sglang/srt/mem_cache/allocator.py +290 -0
  71. sglang/srt/mem_cache/chunk_cache.py +34 -2
  72. sglang/srt/mem_cache/memory_pool.py +289 -3
  73. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  74. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  75. sglang/srt/model_executor/forward_batch_info.py +17 -4
  76. sglang/srt/model_executor/model_runner.py +297 -56
  77. sglang/srt/model_loader/loader.py +41 -0
  78. sglang/srt/model_loader/weight_utils.py +72 -4
  79. sglang/srt/models/deepseek_nextn.py +1 -3
  80. sglang/srt/models/deepseek_v2.py +181 -45
  81. sglang/srt/models/deepseek_vl2.py +3 -5
  82. sglang/srt/models/gemma3_causal.py +1 -2
  83. sglang/srt/models/gemma3n_causal.py +4 -3
  84. sglang/srt/models/gemma3n_mm.py +4 -20
  85. sglang/srt/models/hunyuan.py +1 -1
  86. sglang/srt/models/kimi_vl.py +1 -2
  87. sglang/srt/models/llama.py +10 -4
  88. sglang/srt/models/llama4.py +32 -45
  89. sglang/srt/models/llama_eagle3.py +61 -11
  90. sglang/srt/models/llava.py +5 -5
  91. sglang/srt/models/minicpmo.py +2 -2
  92. sglang/srt/models/mistral.py +1 -1
  93. sglang/srt/models/mllama4.py +43 -11
  94. sglang/srt/models/phi4mm.py +1 -3
  95. sglang/srt/models/pixtral.py +3 -7
  96. sglang/srt/models/qwen2.py +31 -3
  97. sglang/srt/models/qwen2_5_vl.py +1 -3
  98. sglang/srt/models/qwen2_audio.py +200 -0
  99. sglang/srt/models/qwen2_moe.py +32 -6
  100. sglang/srt/models/qwen2_vl.py +1 -4
  101. sglang/srt/models/qwen3.py +94 -25
  102. sglang/srt/models/qwen3_moe.py +68 -21
  103. sglang/srt/models/vila.py +3 -8
  104. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  105. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  106. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  107. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  108. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  109. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  110. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  111. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  112. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  117. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  120. sglang/srt/operations_strategy.py +6 -2
  121. sglang/srt/reasoning_parser.py +26 -0
  122. sglang/srt/sampling/sampling_batch_info.py +39 -1
  123. sglang/srt/server_args.py +69 -22
  124. sglang/srt/speculative/build_eagle_tree.py +57 -18
  125. sglang/srt/speculative/eagle_worker.py +6 -4
  126. sglang/srt/two_batch_overlap.py +200 -27
  127. sglang/srt/utils.py +306 -146
  128. sglang/srt/warmup.py +12 -3
  129. sglang/test/runners.py +10 -1
  130. sglang/test/test_utils.py +15 -3
  131. sglang/version.py +1 -1
  132. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  133. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
  134. sglang/math_utils.py +0 -8
  135. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  136. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  137. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  138. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  139. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  140. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  141. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ import torch
20
20
  import torch.distributed
21
21
  from torch.distributed import P2POp
22
22
 
23
- from sglang.srt.managers.expert_location import (
23
+ from sglang.srt.eplb.expert_location import (
24
24
  ExpertLocationMetadata,
25
25
  get_global_expert_location_metadata,
26
26
  )
@@ -30,6 +30,9 @@ from sglang.srt.utils import get_bool_env_var
30
30
  logger = logging.getLogger(__name__)
31
31
 
32
32
 
33
+ _LOG_INPUT = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_INPUT")
34
+
35
+
33
36
  class ExpertLocationUpdater:
34
37
  def __init__(self):
35
38
  self._first_execution = True
@@ -175,6 +178,19 @@ def update_expert_weights_single_layer(
175
178
  assert isinstance(old_physical_to_logical_map, list)
176
179
  assert isinstance(new_physical_to_logical_map, list)
177
180
 
181
+ if _LOG_INPUT:
182
+ logger.info(
183
+ "update_expert_weights_single_layer "
184
+ f"{[x.shape for x in routed_experts_weights]=} "
185
+ f"{[x.shape for x in temp_buffers]=} "
186
+ f"{old_physical_to_logical_map=} "
187
+ f"{new_physical_to_logical_map=} "
188
+ f"{num_local_physical_experts=} "
189
+ f"{num_gpu_per_node=} "
190
+ f"{rank=} "
191
+ f"{world_size=} "
192
+ )
193
+
178
194
  output_logs = [] if debug else None
179
195
 
180
196
  num_physical_experts = len(old_physical_to_logical_map)
@@ -42,7 +42,7 @@ from sglang.srt.configs import (
42
42
  )
43
43
  from sglang.srt.configs.internvl import InternVLChatConfig
44
44
  from sglang.srt.connector import create_remote_connector
45
- from sglang.srt.utils import is_remote_url
45
+ from sglang.srt.utils import is_remote_url, lru_cache_frozenset
46
46
 
47
47
  _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
48
48
  ChatGLMConfig.model_type: ChatGLMConfig,
@@ -103,6 +103,7 @@ def get_hf_text_config(config: PretrainedConfig):
103
103
  return config
104
104
 
105
105
 
106
+ @lru_cache_frozenset(maxsize=32)
106
107
  def get_config(
107
108
  model: str,
108
109
  trust_remote_code: bool,
@@ -46,11 +46,11 @@ _is_cpu = is_cpu()
46
46
  if _is_cuda:
47
47
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
48
48
 
49
- logger = logging.getLogger(__name__)
50
-
51
49
  if is_npu():
52
50
  import torch_npu
53
51
 
52
+ logger = logging.getLogger(__name__)
53
+
54
54
 
55
55
  class SiluAndMul(CustomOp):
56
56
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@@ -0,0 +1,86 @@
1
+ import logging
2
+
3
+ import torch
4
+
5
+ from sglang.srt.utils import cpu_has_amx_support
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def amx_process_weight_after_loading(weight):
11
+ if weight.device != torch.device("cpu"):
12
+ return weight
13
+ if not cpu_has_amx_support():
14
+ return weight
15
+
16
+ return torch.ops.sgl_kernel.convert_weight_packed(weight)
17
+
18
+
19
+ # TODO: currently gemm kernel has the below requirements:
20
+ # OC % TILE_N == 0, where TILE_N = 16
21
+ # IC % TILE_K == 0, where TILE_K = 32
22
+ def dim_is_supported(weight):
23
+ TILE_N = 16
24
+ TILE_K = 32
25
+ ndim = weight.ndim
26
+ OC = weight.size(1) if ndim == 3 else weight.size(0)
27
+ IC = weight.size(2) if ndim == 3 else weight.size(1)
28
+ return OC % TILE_N == 0 and IC % TILE_K == 0
29
+
30
+
31
+ def _amx_process_weight_after_loading(
32
+ module, weight_names, transpose_dims=None
33
+ ) -> None:
34
+ # Pack weight for get better performance on CPU
35
+ devices = {getattr(module, weight_name).device for weight_name in weight_names}
36
+ assert len(devices) == 1, f"Expects all weights to be on the same device"
37
+ device = devices.pop()
38
+
39
+ if transpose_dims:
40
+ assert len(weight_names) == len(
41
+ transpose_dims
42
+ ), "len(weight_names) should be equal to len(transpose_dims)"
43
+
44
+ for i, weight_name in enumerate(weight_names):
45
+ weight_tensor = getattr(module, weight_name)
46
+
47
+ if transpose_dims and transpose_dims[i]:
48
+ weight_tensor = weight_tensor.transpose(*transpose_dims[i])
49
+
50
+ # We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
51
+ if not dim_is_supported(weight_tensor):
52
+ logger.warning(
53
+ f"Unsupported dimension for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} in {module}. "
54
+ f"The derived (OC, IC) dimensions must be divisible by (16, 32). "
55
+ )
56
+ module.use_intel_amx_backend = False
57
+ return
58
+
59
+ packed_weight = torch.nn.Parameter(
60
+ amx_process_weight_after_loading(weight_tensor),
61
+ requires_grad=False,
62
+ )
63
+ packed_weight.__dict__ = weight_tensor.__dict__
64
+ setattr(module, weight_name, packed_weight)
65
+
66
+ module.use_intel_amx_backend = (
67
+ device == torch.device("cpu") and cpu_has_amx_support()
68
+ )
69
+
70
+ if (
71
+ module.use_intel_amx_backend
72
+ and hasattr(module, "bias")
73
+ and module.bias is not None
74
+ ):
75
+ module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)
76
+
77
+
78
+ class PackWeightMethod:
79
+ def __init__(self, weight_names, transpose_dims=None):
80
+ self.weight_names = weight_names
81
+ self.transpose_dims = transpose_dims
82
+
83
+ def process_weights_after_loading(self, module) -> None:
84
+ _amx_process_weight_after_loading(
85
+ module, self.weight_names, self.transpose_dims
86
+ )
@@ -0,0 +1,219 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Optional
5
+
6
+ import torch
7
+ import torch_npu
8
+ from torch.nn.functional import scaled_dot_product_attention
9
+
10
+ from sglang.srt.configs.model_config import AttentionArch
11
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
12
+ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
13
+ from sglang.srt.layers.radix_attention import AttentionType
14
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
15
+
16
+ if TYPE_CHECKING:
17
+ from sglang.srt.layers.radix_attention import RadixAttention
18
+ from sglang.srt.model_executor.model_runner import ModelRunner
19
+
20
+
21
+ @dataclass
22
+ class ForwardMetadata:
23
+
24
+ # calculated map for kv positions [bs * maxseqlen]
25
+ block_tables: Optional[torch.Tensor] = None
26
+
27
+ # seq len inputs
28
+ extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
29
+ seq_lens_cpu_int: Optional[torch.Tensor] = None
30
+
31
+
32
+ class AscendAttnBackend(AttentionBackend):
33
+
34
+ def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16):
35
+ mask_flag = torch.tril(
36
+ torch.ones((max_seq_len, max_seq_len), dtype=torch.bool)
37
+ ).view(max_seq_len, max_seq_len)
38
+ mask_flag = ~mask_flag
39
+ if dtype == torch.float16:
40
+ mask_value = torch.finfo(torch.float32).min
41
+ else:
42
+ mask_value = 1
43
+ self.mask = (
44
+ torch.masked_fill(
45
+ torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value
46
+ )
47
+ .to(dtype)
48
+ .to(self.device)
49
+ )
50
+ self.mask_len = max_seq_len
51
+
52
+ def __init__(self, model_runner: ModelRunner):
53
+ super().__init__()
54
+ self.forward_metadata = ForwardMetadata()
55
+ self.device = model_runner.device
56
+ self.gen_attention_mask(128, model_runner.dtype)
57
+ self.page_size = model_runner.page_size
58
+ self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
59
+ if self.use_mla:
60
+ self.kv_lora_rank = model_runner.model_config.kv_lora_rank
61
+ self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
62
+ self.native_attn = TorchNativeAttnBackend(model_runner)
63
+
64
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
65
+ """Init the metadata for a forward pass."""
66
+ self.forward_metadata.block_tables = (
67
+ forward_batch.req_to_token_pool.req_to_token[
68
+ forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
69
+ ][:, :: self.page_size]
70
+ // self.page_size
71
+ )
72
+ if forward_batch.extend_seq_lens is not None:
73
+ self.forward_metadata.extend_seq_lens_cpu_int = (
74
+ forward_batch.extend_seq_lens.cpu().int()
75
+ )
76
+ self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
77
+
78
+ def forward_extend(
79
+ self,
80
+ q,
81
+ k,
82
+ v,
83
+ layer: RadixAttention,
84
+ forward_batch: ForwardBatch,
85
+ save_kv_cache=True,
86
+ ):
87
+ if save_kv_cache:
88
+ forward_batch.token_to_kv_pool.set_kv_buffer(
89
+ layer, forward_batch.out_cache_loc, k, v
90
+ )
91
+
92
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
93
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
94
+
95
+ if not self.use_mla:
96
+ query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
97
+ output = torch.empty(
98
+ (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
99
+ dtype=query.dtype,
100
+ device=query.device,
101
+ )
102
+
103
+ torch_npu._npu_flash_attention_qlens(
104
+ query=query,
105
+ key_cache=k_cache,
106
+ value_cache=v_cache,
107
+ mask=self.mask,
108
+ block_table=self.forward_metadata.block_tables,
109
+ seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
110
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
111
+ scale_value=layer.scaling,
112
+ num_heads=layer.tp_q_head_num,
113
+ num_kv_heads=layer.tp_k_head_num,
114
+ out=output,
115
+ )
116
+ return output
117
+ else:
118
+ if layer.qk_head_dim != layer.v_head_dim:
119
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
120
+ else:
121
+ o = torch.empty_like(q)
122
+
123
+ use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
124
+
125
+ q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
126
+ o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
127
+
128
+ causal = True
129
+ if (
130
+ layer.is_cross_attention
131
+ or layer.attn_type == AttentionType.ENCODER_ONLY
132
+ ):
133
+ causal = False
134
+
135
+ self.native_attn._run_sdpa_forward_extend(
136
+ q_,
137
+ o_,
138
+ k_cache.view(
139
+ -1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim)
140
+ ),
141
+ v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank),
142
+ forward_batch.req_to_token_pool.req_to_token,
143
+ forward_batch.req_pool_indices,
144
+ forward_batch.seq_lens,
145
+ forward_batch.extend_prefix_lens,
146
+ forward_batch.extend_seq_lens,
147
+ scaling=layer.scaling,
148
+ enable_gqa=use_gqa,
149
+ causal=causal,
150
+ )
151
+ return o
152
+
153
+ def forward_decode(
154
+ self,
155
+ q: torch.Tensor,
156
+ k: torch.Tensor,
157
+ v: torch.Tensor,
158
+ layer: RadixAttention,
159
+ forward_batch: ForwardBatch,
160
+ save_kv_cache=True,
161
+ ):
162
+ if save_kv_cache:
163
+ forward_batch.token_to_kv_pool.set_kv_buffer(
164
+ layer, forward_batch.out_cache_loc, k, v
165
+ )
166
+ if not self.use_mla:
167
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
168
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
169
+
170
+ query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
171
+ num_tokens = query.shape[0]
172
+ output = torch.empty(
173
+ (num_tokens, layer.tp_q_head_num, layer.v_head_dim),
174
+ dtype=query.dtype,
175
+ device=query.device,
176
+ )
177
+
178
+ torch_npu._npu_paged_attention(
179
+ query=query,
180
+ key_cache=k_cache,
181
+ value_cache=v_cache,
182
+ num_heads=layer.tp_q_head_num,
183
+ num_kv_heads=layer.tp_k_head_num,
184
+ scale_value=layer.scaling,
185
+ block_table=self.forward_metadata.block_tables,
186
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
187
+ out=output,
188
+ )
189
+ return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
190
+ else:
191
+ query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
192
+ num_tokens = query.shape[0]
193
+ kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
194
+ layer.layer_id
195
+ )
196
+ kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
197
+ -1,
198
+ self.page_size,
199
+ layer.tp_k_head_num,
200
+ self.kv_lora_rank + self.qk_rope_head_dim,
201
+ )
202
+
203
+ attn_output = torch.empty(
204
+ [num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
205
+ dtype=q.dtype,
206
+ device=q.device,
207
+ )
208
+ torch_npu._npu_paged_attention_mla(
209
+ query=query,
210
+ key_cache=kv_c_and_k_pe_cache,
211
+ num_kv_heads=layer.tp_k_head_num,
212
+ num_heads=layer.tp_q_head_num,
213
+ scale_value=layer.scaling,
214
+ block_table=self.forward_metadata.block_tables,
215
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
216
+ mla_vheadsize=self.kv_lora_rank,
217
+ out=attn_output,
218
+ )
219
+ return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
@@ -9,6 +9,7 @@ import torch
9
9
  from sglang.srt.configs.model_config import AttentionArch
10
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
11
  from sglang.srt.managers.schedule_batch import global_server_args_dict
12
+ from sglang.srt.mem_cache.memory_pool import SWAKVPool
12
13
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
13
14
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
14
15
 
@@ -320,6 +321,11 @@ class FlashAttentionBackend(AttentionBackend):
320
321
  self.page_size = model_runner.page_size
321
322
  self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
322
323
  self.skip_prefill = skip_prefill
324
+ self.is_hybrid = model_runner.is_hybrid
325
+ if self.is_hybrid:
326
+ self.full_to_swa_index_mapping = (
327
+ model_runner.token_to_kv_pool.full_to_swa_index_mapping
328
+ )
323
329
  self.topk = model_runner.server_args.speculative_eagle_topk or 0
324
330
  self.speculative_num_steps = speculative_num_steps
325
331
  self.speculative_num_draft_tokens = (
@@ -428,7 +434,7 @@ class FlashAttentionBackend(AttentionBackend):
428
434
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
429
435
  ]
430
436
  # TODO: we need to test this part for llama 4 eagle case
431
- self._init_local_attn_metadata(metadata, device)
437
+ self._init_local_attn_metadata(forward_batch, metadata, device)
432
438
  elif forward_batch.forward_mode.is_target_verify():
433
439
  if self.topk <= 1:
434
440
  metadata.cache_seqlens_int32 = (
@@ -456,7 +462,7 @@ class FlashAttentionBackend(AttentionBackend):
456
462
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
457
463
  ]
458
464
 
459
- self._init_local_attn_metadata(metadata, device)
465
+ self._init_local_attn_metadata(forward_batch, metadata, device)
460
466
  else:
461
467
  metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
462
468
  metadata.max_seq_len_q = self.speculative_num_draft_tokens
@@ -575,7 +581,7 @@ class FlashAttentionBackend(AttentionBackend):
575
581
 
576
582
  # Setup local attention if enabled
577
583
  if forward_batch.forward_mode == ForwardMode.EXTEND:
578
- self._init_local_attn_metadata(metadata, device)
584
+ self._init_local_attn_metadata(forward_batch, metadata, device)
579
585
 
580
586
  # Encoder metadata for cross attention
581
587
  if forward_batch.encoder_lens is not None:
@@ -1588,7 +1594,7 @@ class FlashAttentionBackend(AttentionBackend):
1588
1594
  forward_mode: ForwardMode,
1589
1595
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1590
1596
  seq_lens_cpu: Optional[torch.Tensor],
1591
- out_cache_loc: torch.Tensor = None,
1597
+ out_cache_loc: Optional[torch.Tensor] = None,
1592
1598
  ):
1593
1599
  """Initialize forward metadata for replaying CUDA graph."""
1594
1600
  seq_lens = seq_lens[:bs]
@@ -1673,7 +1679,10 @@ class FlashAttentionBackend(AttentionBackend):
1673
1679
  self.page_size,
1674
1680
  )
1675
1681
 
1676
- self._update_local_attn_metadata_for_replay(metadata, bs)
1682
+ self._update_local_attn_metadata_for_replay(
1683
+ metadata,
1684
+ bs,
1685
+ )
1677
1686
  elif forward_mode.is_target_verify():
1678
1687
  if self.topk <= 1:
1679
1688
  metadata = self.target_verify_metadata[bs]
@@ -1829,7 +1838,9 @@ class FlashAttentionBackend(AttentionBackend):
1829
1838
  """Get the fill value for sequence length in CUDA graph."""
1830
1839
  return 1
1831
1840
 
1832
- def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
1841
+ def _init_local_attn_metadata(
1842
+ self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device
1843
+ ):
1833
1844
  """Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
1834
1845
  if self.attention_chunk_size is None:
1835
1846
  metadata.local_attn_metadata = None
@@ -1837,7 +1848,12 @@ class FlashAttentionBackend(AttentionBackend):
1837
1848
 
1838
1849
  cu_seqlens_q = metadata.cu_seqlens_q
1839
1850
  cache_seqlens_int32 = metadata.cache_seqlens_int32
1840
- page_table = metadata.page_table
1851
+ if self.is_hybrid:
1852
+ page_table = self.full_to_swa_index_mapping[metadata.page_table].to(
1853
+ torch.int32
1854
+ )
1855
+ else:
1856
+ page_table = metadata.page_table
1841
1857
  if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
1842
1858
  metadata.local_attn_metadata = None
1843
1859
  return
@@ -1923,7 +1939,9 @@ class FlashAttentionBackend(AttentionBackend):
1923
1939
  )
1924
1940
 
1925
1941
  def _update_local_attn_metadata_for_replay(
1926
- self, metadata: FlashAttentionMetadata, bs: int
1942
+ self,
1943
+ metadata: FlashAttentionMetadata,
1944
+ bs: int,
1927
1945
  ):
1928
1946
  """Update preallocated local attention metadata in-place before CUDA graph replay."""
1929
1947
  if self.attention_chunk_size is None:
@@ -1954,7 +1972,12 @@ class FlashAttentionBackend(AttentionBackend):
1954
1972
  # Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
1955
1973
  # beyond the actual sequence length, leading to incorrect attention calculations
1956
1974
  max_seq_len = int(seqlens.max().item())
1957
- sliced_page_table = metadata.page_table[:bs, :max_seq_len]
1975
+ if self.is_hybrid:
1976
+ sliced_page_table = self.full_to_swa_index_mapping[
1977
+ metadata.page_table[:bs, :max_seq_len]
1978
+ ].to(torch.int32)
1979
+ else:
1980
+ sliced_page_table = metadata.page_table[:bs, :max_seq_len]
1958
1981
 
1959
1982
  cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
1960
1983
  seqlens_np = seqlens.cpu().numpy()
@@ -119,21 +119,27 @@ class TboAttnBackend(AttentionBackend):
119
119
  replay_seq_lens_sum: int = None,
120
120
  replay_seq_lens_cpu: Optional[torch.Tensor] = None,
121
121
  ):
122
+ token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
123
+ forward_mode=forward_mode, spec_info=spec_info
124
+ )
122
125
  if fn_name == "init_forward_metadata_capture_cuda_graph":
123
- assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
124
- num_tokens = bs
126
+ assert (
127
+ capture_num_tokens == bs * token_num_per_seq
128
+ ), "For target-verify or decode mode, num_tokens should be equal to token_num_per_seq * bs"
129
+ num_tokens = bs * token_num_per_seq
125
130
 
126
131
  tbo_split_seq_index, tbo_split_token_index = (
127
132
  two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
128
133
  forward_mode=forward_mode,
129
134
  cuda_graph_num_tokens=num_tokens,
135
+ spec_info=spec_info,
130
136
  )
131
137
  )
132
138
 
133
139
  num_tokens_child_left = tbo_split_token_index
134
140
  num_tokens_child_right = num_tokens - tbo_split_token_index
135
- bs_child_left = num_tokens_child_left
136
- bs_child_right = num_tokens_child_right
141
+ bs_child_left = tbo_split_seq_index
142
+ bs_child_right = bs - bs_child_left
137
143
 
138
144
  assert (
139
145
  num_tokens_child_left > 0 and num_tokens_child_right > 0
@@ -190,16 +196,36 @@ def _init_forward_metadata_cuda_graph_split(
190
196
  seq_lens: torch.Tensor,
191
197
  encoder_lens: Optional[torch.Tensor],
192
198
  forward_mode: "ForwardMode",
193
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
199
+ spec_info: Optional[EagleVerifyInput],
194
200
  # capture args
195
201
  capture_num_tokens: int = None,
196
202
  # replay args
197
203
  replay_seq_lens_sum: int = None,
198
204
  replay_seq_lens_cpu: Optional[torch.Tensor] = None,
199
205
  ):
206
+ token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
207
+ forward_mode=forward_mode, spec_info=spec_info
208
+ )
200
209
  assert encoder_lens is None, "encoder_lens is not supported yet"
201
- assert spec_info is None, "spec_info is not supported yet"
210
+ if spec_info is not None:
211
+ output_spec_info = two_batch_overlap.split_spec_info(
212
+ spec_info=spec_info,
213
+ start_seq_index=seq_slice.start if seq_slice.start is not None else 0,
214
+ end_seq_index=seq_slice.stop if seq_slice.stop is not None else bs,
215
+ start_token_index=(
216
+ seq_slice.start * token_num_per_seq
217
+ if seq_slice.start is not None
218
+ else 0
219
+ ),
220
+ end_token_index=(
221
+ seq_slice.stop * token_num_per_seq
222
+ if seq_slice.stop is not None
223
+ else bs * token_num_per_seq
224
+ ),
225
+ )
202
226
 
227
+ else:
228
+ output_spec_info = None
203
229
  ans = dict(
204
230
  bs=output_bs,
205
231
  req_pool_indices=req_pool_indices[seq_slice],
@@ -208,14 +234,16 @@ def _init_forward_metadata_cuda_graph_split(
208
234
  forward_mode=forward_mode,
209
235
  # ignore
210
236
  encoder_lens=None,
211
- spec_info=None,
237
+ spec_info=output_spec_info,
212
238
  )
213
239
 
214
240
  if fn_name == "init_forward_metadata_capture_cuda_graph":
215
- assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
241
+ assert (
242
+ capture_num_tokens == bs * token_num_per_seq
243
+ ), "Only support num_tokens==bs * token_num_per_seq for target-verify or decode mode"
216
244
  ans.update(
217
245
  dict(
218
- num_tokens=output_bs,
246
+ num_tokens=output_bs * token_num_per_seq,
219
247
  )
220
248
  )
221
249
  elif fn_name == "init_forward_metadata_replay_cuda_graph":
@@ -32,8 +32,13 @@ from sglang.srt.layers.dp_attention import (
32
32
  get_attention_tp_rank,
33
33
  get_attention_tp_size,
34
34
  )
35
+ from sglang.srt.layers.utils import is_sm100_supported
35
36
  from sglang.srt.managers.schedule_batch import global_server_args_dict
36
37
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
38
+ from sglang.srt.utils import is_cuda, is_flashinfer_available
39
+
40
+ _is_flashinfer_available = is_flashinfer_available()
41
+ _is_sm100_supported = is_cuda() and is_sm100_supported()
37
42
 
38
43
 
39
44
  class ScatterMode(Enum):
@@ -397,8 +402,19 @@ class CommunicateWithAllReduceAndLayerNormFn:
397
402
  if hidden_states.shape[0] != 0:
398
403
  hidden_states = layernorm(hidden_states)
399
404
  else:
400
- hidden_states = tensor_model_parallel_all_reduce(hidden_states)
401
- hidden_states, residual = layernorm(hidden_states, residual)
405
+ if (
406
+ _is_sm100_supported
407
+ and _is_flashinfer_available
408
+ and hasattr(layernorm, "forward_with_allreduce_fusion")
409
+ and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
410
+ and hidden_states.shape[0] <= 1024
411
+ ):
412
+ hidden_states, residual = layernorm.forward_with_allreduce_fusion(
413
+ hidden_states, residual
414
+ )
415
+ else:
416
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
417
+ hidden_states, residual = layernorm(hidden_states, residual)
402
418
  return hidden_states, residual
403
419
 
404
420
  @staticmethod
@@ -79,14 +79,12 @@ def initialize_dp_attention(
79
79
  )
80
80
 
81
81
  if enable_dp_attention:
82
- local_rank = tp_rank % (tp_size // dp_size)
83
82
  _ATTN_DP_SIZE = dp_size
84
83
  if moe_dense_tp_size is None:
85
84
  _LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
86
85
  else:
87
86
  _LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
88
87
  else:
89
- local_rank = tp_rank
90
88
  _ATTN_DP_SIZE = 1
91
89
  _LOCAL_ATTN_DP_SIZE = 1
92
90
 
@@ -96,7 +94,7 @@ def initialize_dp_attention(
96
94
  list(range(head, head + _ATTN_TP_SIZE))
97
95
  for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
98
96
  ],
99
- local_rank,
97
+ tp_group.local_rank,
100
98
  torch.distributed.get_backend(tp_group.device_group),
101
99
  use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
102
100
  use_pymscclpp=False,
@@ -239,6 +237,10 @@ def _dp_gather(
239
237
  assert (
240
238
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
241
239
  ), "aliasing between global_tokens and local_tokens not allowed"
240
+
241
+ # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
242
+ # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
243
+ # actual size of the accepted tokens.
242
244
  if forward_batch.forward_mode.is_draft_extend():
243
245
  shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
244
246
  local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
@@ -293,6 +295,10 @@ def dp_scatter(
293
295
  assert (
294
296
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
295
297
  ), "aliasing between local_tokens and global_tokens not allowed"
298
+
299
+ # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
300
+ # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
301
+ # actual size of the accepted tokens.
296
302
  if forward_batch.forward_mode.is_draft_extend():
297
303
  shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
298
304
  local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)