sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@
2
2
 
3
3
  import logging
4
4
  from functools import partial
5
- from typing import Any, Dict, Iterable, Optional, Tuple
5
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
6
6
 
7
7
  import torch
8
8
  from torch import nn
@@ -11,9 +11,9 @@ from sglang.srt.distributed import (
11
11
  get_pp_group,
12
12
  get_tensor_model_parallel_rank,
13
13
  get_tensor_model_parallel_world_size,
14
- split_tensor_along_last_dim,
15
- tensor_model_parallel_all_gather,
16
14
  )
15
+ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
16
+ from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
17
17
  from sglang.srt.layers.layernorm import RMSNorm
18
18
  from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
19
19
  from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -23,15 +23,17 @@ from sglang.srt.layers.radix_attention import RadixAttention
23
23
  from sglang.srt.layers.rotary_embedding import get_rope
24
24
  from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
25
25
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
26
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
26
27
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
27
28
  from sglang.srt.model_loader.weight_utils import default_weight_loader
28
29
  from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
29
30
  from sglang.srt.models.qwen2 import Qwen2Model
30
- from sglang.srt.utils import add_prefix
31
+ from sglang.srt.utils import add_prefix, is_cuda
31
32
 
32
33
  Qwen3Config = None
33
34
 
34
35
  logger = logging.getLogger(__name__)
36
+ _is_cuda = is_cuda()
35
37
 
36
38
 
37
39
  class Qwen3Attention(nn.Module):
@@ -49,23 +51,27 @@ class Qwen3Attention(nn.Module):
49
51
  rms_norm_eps: float = None,
50
52
  attention_bias: bool = False,
51
53
  prefix: str = "",
54
+ alt_stream: Optional[torch.cuda.Stream] = None,
52
55
  ) -> None:
53
56
  super().__init__()
54
57
  self.hidden_size = hidden_size
55
58
  self.tp_size = get_tensor_model_parallel_world_size()
56
59
  self.total_num_heads = num_heads
57
- assert self.total_num_heads % self.tp_size == 0
58
- self.num_heads = self.total_num_heads // self.tp_size
60
+ attn_tp_rank = get_attention_tp_rank()
61
+ attn_tp_size = get_attention_tp_size()
62
+
63
+ assert self.total_num_heads % attn_tp_size == 0
64
+ self.num_heads = self.total_num_heads // attn_tp_size
59
65
  self.total_num_kv_heads = num_kv_heads
60
- if self.total_num_kv_heads >= self.tp_size:
66
+ if self.total_num_kv_heads >= attn_tp_size:
61
67
  # Number of KV heads is greater than TP size, so we partition
62
68
  # the KV heads across multiple tensor parallel GPUs.
63
- assert self.total_num_kv_heads % self.tp_size == 0
69
+ assert self.total_num_kv_heads % attn_tp_size == 0
64
70
  else:
65
71
  # Number of KV heads is less than TP size, so we replicate
66
72
  # the KV heads across multiple tensor parallel GPUs.
67
- assert self.tp_size % self.total_num_kv_heads == 0
68
- self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
73
+ assert attn_tp_size % self.total_num_kv_heads == 0
74
+ self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
69
75
  self.head_dim = head_dim or hidden_size // self.total_num_heads
70
76
  self.q_size = self.num_heads * self.head_dim
71
77
  self.kv_size = self.num_kv_heads * self.head_dim
@@ -84,6 +90,8 @@ class Qwen3Attention(nn.Module):
84
90
  self.total_num_kv_heads,
85
91
  bias=attention_bias,
86
92
  quant_config=quant_config,
93
+ tp_rank=attn_tp_rank,
94
+ tp_size=attn_tp_size,
87
95
  prefix=add_prefix("qkv_proj", prefix),
88
96
  )
89
97
  self.o_proj = RowParallelLinear(
@@ -91,6 +99,9 @@ class Qwen3Attention(nn.Module):
91
99
  hidden_size,
92
100
  bias=attention_bias,
93
101
  quant_config=quant_config,
102
+ tp_rank=attn_tp_rank,
103
+ tp_size=attn_tp_size,
104
+ reduce_results=False,
94
105
  prefix=add_prefix("o_proj", prefix),
95
106
  )
96
107
 
@@ -109,15 +120,27 @@ class Qwen3Attention(nn.Module):
109
120
  layer_id=layer_id,
110
121
  prefix=add_prefix("attn", prefix),
111
122
  )
123
+ self.alt_stream = alt_stream
112
124
 
113
125
  def _apply_qk_norm(
114
126
  self, q: torch.Tensor, k: torch.Tensor
115
127
  ) -> Tuple[torch.Tensor, torch.Tensor]:
116
- q_by_head = q.reshape(-1, self.head_dim)
117
- q_by_head = self.q_norm(q_by_head)
128
+ # overlap qk norm
129
+ if self.alt_stream is not None and get_is_capture_mode():
130
+ current_stream = torch.cuda.current_stream()
131
+ self.alt_stream.wait_stream(current_stream)
132
+ q_by_head = q.reshape(-1, self.head_dim)
133
+ q_by_head = self.q_norm(q_by_head)
134
+ with torch.cuda.stream(self.alt_stream):
135
+ k_by_head = k.reshape(-1, self.head_dim)
136
+ k_by_head = self.k_norm(k_by_head)
137
+ current_stream.wait_stream(self.alt_stream)
138
+ else:
139
+ q_by_head = q.reshape(-1, self.head_dim)
140
+ q_by_head = self.q_norm(q_by_head)
141
+ k_by_head = k.reshape(-1, self.head_dim)
142
+ k_by_head = self.k_norm(k_by_head)
118
143
  q = q_by_head.view(q.shape)
119
- k_by_head = k.reshape(-1, self.head_dim)
120
- k_by_head = self.k_norm(k_by_head)
121
144
  k = k_by_head.view(k.shape)
122
145
  return q, k
123
146
 
@@ -143,6 +166,7 @@ class Qwen3DecoderLayer(nn.Module):
143
166
  layer_id: int = 0,
144
167
  quant_config: Optional[QuantizationConfig] = None,
145
168
  prefix: str = "",
169
+ alt_stream: Optional[torch.cuda.Stream] = None,
146
170
  ) -> None:
147
171
  super().__init__()
148
172
  self.hidden_size = config.hidden_size
@@ -163,6 +187,7 @@ class Qwen3DecoderLayer(nn.Module):
163
187
  rms_norm_eps=config.rms_norm_eps,
164
188
  attention_bias=config.attention_bias,
165
189
  prefix=add_prefix("self_attn", prefix),
190
+ alt_stream=alt_stream,
166
191
  )
167
192
  self.mlp = Qwen3MLP(
168
193
  hidden_size=self.hidden_size,
@@ -176,6 +201,18 @@ class Qwen3DecoderLayer(nn.Module):
176
201
  config.hidden_size, eps=config.rms_norm_eps
177
202
  )
178
203
 
204
+ self.layer_scatter_modes = LayerScatterModes.init_new(
205
+ layer_id=layer_id,
206
+ num_layers=config.num_hidden_layers,
207
+ is_layer_sparse=False,
208
+ is_previous_layer_sparse=False,
209
+ )
210
+ self.layer_communicator = LayerCommunicator(
211
+ layer_scatter_modes=self.layer_scatter_modes,
212
+ input_layernorm=self.input_layernorm,
213
+ post_attention_layernorm=self.post_attention_layernorm,
214
+ )
215
+
179
216
  def forward(
180
217
  self,
181
218
  positions: torch.Tensor,
@@ -184,20 +221,24 @@ class Qwen3DecoderLayer(nn.Module):
184
221
  residual: Optional[torch.Tensor],
185
222
  ) -> Tuple[torch.Tensor, torch.Tensor]:
186
223
  # Self Attention
187
- if residual is None:
188
- residual = hidden_states
189
- hidden_states = self.input_layernorm(hidden_states)
190
- else:
191
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
192
- hidden_states = self.self_attn(
193
- positions=positions,
194
- hidden_states=hidden_states,
195
- forward_batch=forward_batch,
224
+ hidden_states, residual = self.layer_communicator.prepare_attn(
225
+ hidden_states, residual, forward_batch
196
226
  )
227
+ if hidden_states.shape[0] != 0:
228
+ hidden_states = self.self_attn(
229
+ positions=positions,
230
+ hidden_states=hidden_states,
231
+ forward_batch=forward_batch,
232
+ )
197
233
 
198
234
  # Fully Connected
199
- hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
235
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
236
+ hidden_states, residual, forward_batch
237
+ )
200
238
  hidden_states = self.mlp(hidden_states)
239
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
240
+ hidden_states, residual, forward_batch
241
+ )
201
242
  return hidden_states, residual
202
243
 
203
244
 
@@ -208,11 +249,13 @@ class Qwen3Model(Qwen2Model):
208
249
  quant_config: Optional[QuantizationConfig] = None,
209
250
  prefix: str = "",
210
251
  ) -> None:
252
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
211
253
  super().__init__(
212
254
  config=config,
213
255
  quant_config=quant_config,
214
256
  prefix=prefix,
215
257
  decoder_layer_type=Qwen3DecoderLayer,
258
+ alt_stream=alt_stream,
216
259
  )
217
260
 
218
261
 
@@ -282,6 +325,9 @@ class Qwen3ForCausalLM(nn.Module):
282
325
  self.logits_processor = LogitsProcessor(config)
283
326
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
284
327
 
328
+ # For EAGLE3 support
329
+ self.capture_aux_hidden_states = False
330
+
285
331
  def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
286
332
  return self.model.get_input_embeddings(input_ids)
287
333
 
@@ -303,10 +349,18 @@ class Qwen3ForCausalLM(nn.Module):
303
349
  pp_proxy_tensors=pp_proxy_tensors,
304
350
  )
305
351
 
352
+ aux_hidden_states = None
353
+ if self.capture_aux_hidden_states:
354
+ hidden_states, aux_hidden_states = hidden_states
355
+
306
356
  if self.pp_group.is_last_rank:
307
357
  if not get_embedding:
308
358
  return self.logits_processor(
309
- input_ids, hidden_states, self.lm_head, forward_batch
359
+ input_ids,
360
+ hidden_states,
361
+ self.lm_head,
362
+ forward_batch,
363
+ aux_hidden_states,
310
364
  )
311
365
  else:
312
366
  return self.pooler(hidden_states, forward_batch)
@@ -404,5 +458,20 @@ class Qwen3ForCausalLM(nn.Module):
404
458
  def load_kv_cache_scales(self, quantization_param_path: str) -> None:
405
459
  self.model.load_kv_cache_scales(quantization_param_path)
406
460
 
461
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
462
+ if not self.pp_group.is_last_rank:
463
+ return
464
+
465
+ self.capture_aux_hidden_states = True
466
+ if layer_ids is None:
467
+ num_layers = self.config.num_hidden_layers
468
+ self.model.layers_to_capture = [
469
+ 2,
470
+ num_layers // 2,
471
+ num_layers - 3,
472
+ ] # Specific layers for EAGLE3 support
473
+ else:
474
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
475
+
407
476
 
408
477
  EntryClass = Qwen3ForCausalLM
@@ -18,7 +18,7 @@
18
18
  """Inference-only Qwen3MoE model compatible with HuggingFace weights."""
19
19
 
20
20
  import logging
21
- from typing import Any, Dict, Iterable, Optional, Tuple
21
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
22
22
 
23
23
  import torch
24
24
  from torch import nn
@@ -32,6 +32,9 @@ from sglang.srt.distributed import (
32
32
  tensor_model_parallel_all_gather,
33
33
  tensor_model_parallel_all_reduce,
34
34
  )
35
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
36
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
37
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
35
38
  from sglang.srt.layers.activation import SiluAndMul
36
39
  from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
37
40
  from sglang.srt.layers.dp_attention import (
@@ -63,12 +66,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
63
66
  ParallelLMHead,
64
67
  VocabParallelEmbedding,
65
68
  )
66
- from sglang.srt.managers.expert_distribution import (
67
- get_global_expert_distribution_recorder,
68
- )
69
- from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
70
- from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
71
69
  from sglang.srt.managers.schedule_batch import global_server_args_dict
70
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
72
71
  from sglang.srt.model_executor.forward_batch_info import (
73
72
  ForwardBatch,
74
73
  ForwardMode,
@@ -78,11 +77,12 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
78
77
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
79
78
  from sglang.srt.models.qwen2_moe import Qwen2MoeModel
80
79
  from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
81
- from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty
80
+ from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_non_idle_and_non_empty
82
81
 
83
82
  Qwen3MoeConfig = None
84
83
 
85
84
  logger = logging.getLogger(__name__)
85
+ _is_cuda = is_cuda()
86
86
 
87
87
 
88
88
  class Qwen3MoeSparseMoeBlock(nn.Module):
@@ -117,6 +117,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
117
117
  if global_server_args_dict["enable_deepep_moe"]
118
118
  else {}
119
119
  ),
120
+ # Additional args for FusedMoE
121
+ **(
122
+ dict(
123
+ enable_flashinfer_moe=True,
124
+ enable_ep_moe=global_server_args_dict["enable_ep_moe"],
125
+ )
126
+ if global_server_args_dict["enable_flashinfer_moe"]
127
+ else {}
128
+ ),
120
129
  )
121
130
 
122
131
  self.gate = ReplicatedLinear(
@@ -220,7 +229,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
220
229
  hidden_states=hidden_states,
221
230
  topk_idx=topk_idx,
222
231
  topk_weights=topk_weights,
223
- forward_mode=forward_mode,
232
+ forward_batch=forward_batch,
224
233
  )
225
234
  final_hidden_states = self.experts(
226
235
  hidden_states=hidden_states,
@@ -231,14 +240,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
231
240
  masked_m=masked_m,
232
241
  expected_m=expected_m,
233
242
  num_recv_tokens_per_expert=num_recv_tokens_per_expert,
234
- forward_mode=forward_mode,
243
+ forward_batch=forward_batch,
235
244
  )
236
245
  if self.ep_size > 1:
237
246
  final_hidden_states = self.deepep_dispatcher.combine(
238
247
  hidden_states=final_hidden_states,
239
248
  topk_idx=topk_idx,
240
249
  topk_weights=topk_weights,
241
- forward_mode=forward_mode,
250
+ forward_batch=forward_batch,
242
251
  )
243
252
  return final_hidden_states
244
253
 
@@ -284,7 +293,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
284
293
  hidden_states=state.pop("hidden_states_mlp_input"),
285
294
  topk_idx=state.pop("topk_idx_local"),
286
295
  topk_weights=state.pop("topk_weights_local"),
287
- forward_mode=state.forward_batch.forward_mode,
296
+ forward_batch=state.forward_batch,
288
297
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
289
298
  )
290
299
 
@@ -316,7 +325,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
316
325
  masked_m=state.pop("masked_m"),
317
326
  expected_m=state.pop("expected_m"),
318
327
  num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
319
- forward_mode=state.forward_batch.forward_mode,
328
+ forward_batch=state.forward_batch,
320
329
  )
321
330
 
322
331
  def op_combine_a(self, state):
@@ -325,7 +334,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
325
334
  hidden_states=state.pop("hidden_states_experts_output"),
326
335
  topk_idx=state.pop("topk_idx_dispatched"),
327
336
  topk_weights=state.pop("topk_weights_dispatched"),
328
- forward_mode=state.forward_batch.forward_mode,
337
+ forward_batch=state.forward_batch,
329
338
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
330
339
  )
331
340
 
@@ -354,6 +363,7 @@ class Qwen3MoeAttention(nn.Module):
354
363
  attention_bias: bool = False,
355
364
  quant_config: Optional[QuantizationConfig] = None,
356
365
  prefix: str = "",
366
+ alt_stream: Optional[torch.cuda.Stream] = None,
357
367
  ) -> None:
358
368
  super().__init__()
359
369
  self.hidden_size = hidden_size
@@ -423,15 +433,27 @@ class Qwen3MoeAttention(nn.Module):
423
433
 
424
434
  self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
425
435
  self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
436
+ self.alt_stream = alt_stream
426
437
 
427
438
  def _apply_qk_norm(
428
439
  self, q: torch.Tensor, k: torch.Tensor
429
440
  ) -> Tuple[torch.Tensor, torch.Tensor]:
430
- q_by_head = q.reshape(-1, self.head_dim)
431
- q_by_head = self.q_norm(q_by_head)
441
+ # overlap qk norm
442
+ if self.alt_stream is not None and get_is_capture_mode():
443
+ current_stream = torch.cuda.current_stream()
444
+ self.alt_stream.wait_stream(current_stream)
445
+ q_by_head = q.reshape(-1, self.head_dim)
446
+ q_by_head = self.q_norm(q_by_head)
447
+ with torch.cuda.stream(self.alt_stream):
448
+ k_by_head = k.reshape(-1, self.head_dim)
449
+ k_by_head = self.k_norm(k_by_head)
450
+ current_stream.wait_stream(self.alt_stream)
451
+ else:
452
+ q_by_head = q.reshape(-1, self.head_dim)
453
+ q_by_head = self.q_norm(q_by_head)
454
+ k_by_head = k.reshape(-1, self.head_dim)
455
+ k_by_head = self.k_norm(k_by_head)
432
456
  q = q_by_head.view(q.shape)
433
- k_by_head = k.reshape(-1, self.head_dim)
434
- k_by_head = self.k_norm(k_by_head)
435
457
  k = k_by_head.view(k.shape)
436
458
  return q, k
437
459
 
@@ -491,6 +513,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
491
513
  layer_id: int,
492
514
  quant_config: Optional[QuantizationConfig] = None,
493
515
  prefix: str = "",
516
+ alt_stream: Optional[torch.cuda.Stream] = None,
494
517
  ) -> None:
495
518
  super().__init__()
496
519
  self.config = config
@@ -516,6 +539,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
516
539
  attention_bias=attention_bias,
517
540
  quant_config=quant_config,
518
541
  prefix=add_prefix("self_attn", prefix),
542
+ alt_stream=alt_stream,
519
543
  )
520
544
 
521
545
  self.layer_id = layer_id
@@ -623,9 +647,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
623
647
 
624
648
  def op_mlp(self, state):
625
649
  hidden_states = state.pop("hidden_states_mlp_input")
626
- state.hidden_states_mlp_output = self.mlp(
627
- hidden_states, state.forward_batch.forward_mode
628
- )
650
+ state.hidden_states_mlp_output = self.mlp(hidden_states, state.forward_batch)
629
651
 
630
652
  def op_comm_postprocess_layer(self, state):
631
653
  hidden_states, residual = self.layer_communicator.postprocess_layer(
@@ -659,11 +681,13 @@ class Qwen3MoeModel(Qwen2MoeModel):
659
681
  quant_config: Optional[QuantizationConfig] = None,
660
682
  prefix: str = "",
661
683
  ) -> None:
684
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
662
685
  super().__init__(
663
686
  config=config,
664
687
  quant_config=quant_config,
665
688
  prefix=prefix,
666
689
  decoder_layer_type=Qwen3MoeDecoderLayer,
690
+ alt_stream=alt_stream,
667
691
  )
668
692
 
669
693
 
@@ -691,6 +715,7 @@ class Qwen3MoeForCausalLM(nn.Module):
691
715
  use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
692
716
  )
693
717
  self.logits_processor = LogitsProcessor(config)
718
+ self.capture_aux_hidden_states = False
694
719
 
695
720
  @torch.no_grad()
696
721
  def forward(
@@ -709,9 +734,13 @@ class Qwen3MoeForCausalLM(nn.Module):
709
734
  pp_proxy_tensors=pp_proxy_tensors,
710
735
  )
711
736
 
737
+ aux_hidden_states = None
738
+ if self.capture_aux_hidden_states:
739
+ hidden_states, aux_hidden_states = hidden_states
740
+
712
741
  if self.pp_group.is_last_rank:
713
742
  return self.logits_processor(
714
- input_ids, hidden_states, self.lm_head, forward_batch
743
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
715
744
  )
716
745
  else:
717
746
  return hidden_states
@@ -724,6 +753,24 @@ class Qwen3MoeForCausalLM(nn.Module):
724
753
  def end_layer(self):
725
754
  return self.model.end_layer
726
755
 
756
+ def get_embed_and_head(self):
757
+ return self.model.embed_tokens.weight, self.lm_head.weight
758
+
759
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
760
+ if not self.pp_group.is_last_rank:
761
+ return
762
+
763
+ self.capture_aux_hidden_states = True
764
+ if layer_ids is None:
765
+ num_layers = self.config.num_hidden_layers
766
+ self.model.layers_to_capture = [
767
+ 2,
768
+ num_layers // 2,
769
+ num_layers - 3,
770
+ ] # Specific layers for EAGLE3 support
771
+ else:
772
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
773
+
727
774
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
728
775
  stacked_params_mapping = [
729
776
  # (param_name, shard_name, shard_id)
sglang/srt/models/vila.py CHANGED
@@ -270,15 +270,10 @@ class VILAForConditionalGeneration(nn.Module):
270
270
  weight_loader(param, loaded_weight)
271
271
 
272
272
  def pad_input_ids(
273
- self,
274
- input_ids: List[int],
275
- image_inputs: MultimodalInputs,
273
+ self, input_ids: List[int], mm_inputs: MultimodalInputs
276
274
  ) -> List[int]:
277
- pattern = MultiModalityDataPaddingPatternMultimodalTokens(
278
- token_ids=[self.config.image_token_id],
279
- )
280
-
281
- return pattern.pad_input_tokens(input_ids, image_inputs)
275
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
276
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
282
277
 
283
278
  ##### BEGIN COPY modeling_vila.py #####
284
279
 
@@ -28,12 +28,12 @@ LLaVA-Onevision : https://arxiv.org/pdf/2408.03326
28
28
 
29
29
  """
30
30
  import ast
31
- import base64
32
31
  import math
33
32
  import re
34
33
  from io import BytesIO
35
34
 
36
35
  import numpy as np
36
+ import pybase64
37
37
  from PIL import Image
38
38
 
39
39
  from sglang.srt.utils import flatten_nested_list
@@ -252,7 +252,7 @@ def process_anyres_image(image, processor, grid_pinpoints):
252
252
 
253
253
 
254
254
  def load_image_from_base64(image):
255
- return Image.open(BytesIO(base64.b64decode(image)))
255
+ return Image.open(BytesIO(pybase64.b64decode(image, validate=True)))
256
256
 
257
257
 
258
258
  def expand2square(pil_img, background_color):