sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,6 @@ import torch.nn as nn
31
31
  import torch.nn.functional as F
32
32
  from einops import rearrange
33
33
  from transformers.activations import ACT2FN
34
- from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
35
34
  from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
36
35
  Qwen2_5_VLConfig,
37
36
  Qwen2_5_VLVisionConfig,
@@ -43,7 +42,12 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
43
42
 
44
43
  from sglang.srt.hf_transformers_utils import get_processor
45
44
  from sglang.srt.layers.attention.vision import VisionAttention
46
- from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
45
+ from sglang.srt.layers.layernorm import RMSNorm
46
+ from sglang.srt.layers.linear import (
47
+ ColumnParallelLinear,
48
+ MergedColumnParallelLinear,
49
+ RowParallelLinear,
50
+ )
47
51
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
52
  from sglang.srt.layers.pooler import Pooler, PoolingType
49
53
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -62,7 +66,6 @@ logger = logging.getLogger(__name__)
62
66
 
63
67
 
64
68
  class Qwen2_5_VLMLP(nn.Module):
65
-
66
69
  def __init__(
67
70
  self,
68
71
  in_features: int,
@@ -73,19 +76,12 @@ class Qwen2_5_VLMLP(nn.Module):
73
76
  prefix: str = "",
74
77
  ):
75
78
  super().__init__()
76
- self.gate_proj = ColumnParallelLinear(
77
- in_features,
78
- hidden_features,
79
+ self.gate_up_proj = MergedColumnParallelLinear(
80
+ input_size=in_features,
81
+ output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
79
82
  bias=bias,
80
83
  quant_config=quant_config,
81
- prefix=add_prefix("gate_proj", prefix),
82
- )
83
- self.up_proj = ColumnParallelLinear(
84
- in_features,
85
- hidden_features,
86
- bias=bias,
87
- quant_config=quant_config,
88
- prefix=add_prefix("up_proj", prefix),
84
+ prefix=add_prefix("gate_up_proj", prefix),
89
85
  )
90
86
  self.down_proj = RowParallelLinear(
91
87
  hidden_features,
@@ -97,12 +93,11 @@ class Qwen2_5_VLMLP(nn.Module):
97
93
  self.act = ACT2FN[hidden_act]
98
94
 
99
95
  def forward(self, x: torch.Tensor) -> torch.Tensor:
100
- x_parallel_gate, _ = self.gate_proj(x)
101
- x_parallel_gate = self.act(x_parallel_gate)
102
- x_parallel_up, _ = self.up_proj(x)
103
- x_parallel = x_parallel_gate * x_parallel_up
104
- x, _ = self.down_proj(x_parallel)
105
- return x
96
+ gate_up, _ = self.gate_up_proj(x)
97
+ gate, up = gate_up.chunk(2, dim=-1)
98
+ x = self.act(gate) * up
99
+ x_down, _ = self.down_proj(x)
100
+ return x_down
106
101
 
107
102
 
108
103
  class Qwen2_5_VisionBlock(nn.Module):
@@ -118,12 +113,13 @@ class Qwen2_5_VisionBlock(nn.Module):
118
113
  quant_config: Optional[QuantizationConfig] = None,
119
114
  prefix: str = "",
120
115
  num_dummy_heads: int = 0,
116
+ rms_norm_eps: float = 1e-6,
121
117
  ) -> None:
122
118
  super().__init__()
123
119
  if norm_layer is None:
124
120
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
125
- self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
126
- self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
121
+ self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
122
+ self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
127
123
 
128
124
  if attn_implementation is None:
129
125
  softmax_in_single_precision = False
@@ -174,18 +170,29 @@ class Qwen2_5_VisionBlock(nn.Module):
174
170
  cu_seqlens: torch.Tensor,
175
171
  position_embeddings: torch.Tensor,
176
172
  ) -> torch.Tensor:
177
- hidden_states = self.norm1(x)
178
- hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
173
+ S, B, H = x.shape
174
+ # norm1: flatten to 2D -> [S*B, H], then reshape back
175
+ x2d = x.reshape(-1, H)
176
+ hidden_states = self.norm1(x2d).reshape(S, B, H)
177
+
178
+ # Attention expects [B, S, H]
179
+ hidden_states = rearrange(hidden_states, "s b h -> b s h")
179
180
  attn = self.attn(
180
181
  hidden_states,
181
182
  cu_seqlens=cu_seqlens,
182
183
  position_embeddings=position_embeddings,
183
184
  )
184
- attn = rearrange(attn, "b s ... -> s b ...")
185
- x = x + attn
186
- norm2 = self.norm2(x)
187
- mlp = self.mlp(norm2)
188
- x = x + mlp
185
+ attn = rearrange(attn, "b s h -> s b h")
186
+
187
+ # norm2 with fused residual-add: also 2D
188
+ attn2d = attn.reshape(-1, H)
189
+ x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d)
190
+ x_norm = x_norm_2d.reshape(S, B, H)
191
+ x_after_add = x_after_add_2d.reshape(S, B, H)
192
+
193
+ # MLP and final residual
194
+ mlp_out = self.mlp(x_norm)
195
+ x = x_after_add + mlp_out
189
196
  return x
190
197
 
191
198
 
@@ -201,7 +208,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
201
208
  ) -> None:
202
209
  super().__init__()
203
210
  self.hidden_size = context_dim * (spatial_merge_size**2)
204
- self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
211
+ self.ln_q = RMSNorm(context_dim, eps=1e-6)
205
212
  self.mlp = nn.ModuleList(
206
213
  [
207
214
  ColumnParallelLinear(
@@ -223,11 +230,13 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
223
230
  )
224
231
 
225
232
  def forward(self, x: torch.Tensor) -> torch.Tensor:
226
- x = self.ln_q(x)
227
- x = x.view(-1, self.hidden_size)
228
-
233
+ # x expected shape: [S, B, context_dim]
234
+ S, B, D = x.shape
235
+ x2d = x.reshape(-1, D)
236
+ x2d = self.ln_q(x2d) # RMSNorm expects 2D
237
+ x2d = x2d.view(-1, self.hidden_size) # group into spatial_merge_unit
229
238
  mlp_fc1, mlp_act, mlp_fc2 = self.mlp
230
- x_parallel, _ = mlp_fc1(x)
239
+ x_parallel, _ = mlp_fc1(x2d)
231
240
  x_parallel = mlp_act(x_parallel)
232
241
  out, _ = mlp_fc2(x_parallel)
233
242
  return out
@@ -340,7 +349,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
340
349
 
341
350
  @property
342
351
  def device(self) -> torch.device:
343
- return self.blocks[0].mlp.gate_proj.weight.device
352
+ return self.patch_embed.proj.weight.device
344
353
 
345
354
  def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
346
355
  pos_ids = []
@@ -394,6 +403,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
394
403
  )
395
404
  cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
396
405
 
406
+ # Move window_index to the same device as x before using it to index x
407
+ window_index = window_index.to(device=x.device)
408
+
409
+ # Ensure rotary_pos_emb is on the same device/dtype as x
410
+ rotary_pos_emb = rotary_pos_emb.to(device=x.device, dtype=x.dtype)
411
+
397
412
  seq_len, _ = x.size()
398
413
 
399
414
  x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
@@ -406,12 +421,19 @@ class Qwen2_5_VisionTransformer(nn.Module):
406
421
  rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
407
422
  emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
408
423
  position_embeddings = (emb.cos(), emb.sin())
424
+ # After building position_embeddings, make sure both cos and sin are on the same device/dtype as the attention input
425
+ position_embeddings = (
426
+ position_embeddings[0].to(x.device, x.dtype),
427
+ position_embeddings[1].to(x.device, x.dtype),
428
+ )
409
429
 
410
- # compute cu_seqlens
430
+ # compute cu_seqlens - move cu_seqlens to GPU and make it int32
411
431
  cu_seqlens = torch.cat(
412
432
  [
413
- torch.tensor([0], device=grid_thw.device),
414
- (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
433
+ torch.tensor([0], device=x.device, dtype=torch.int32),
434
+ (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2])
435
+ .cumsum(dim=0)
436
+ .to(device=x.device, dtype=torch.int32),
415
437
  ]
416
438
  )
417
439
  cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
@@ -442,9 +464,8 @@ cached_get_processor = lru_cache(get_processor)
442
464
  class Qwen2_5_VLForConditionalGeneration(nn.Module):
443
465
  # BitandBytes specific attributes
444
466
  default_bitsandbytes_target_modules = [
445
- ".gate_proj.",
467
+ ".gate_up_proj.",
446
468
  ".down_proj.",
447
- ".up_proj.",
448
469
  ".q_proj.",
449
470
  ".k_proj.",
450
471
  ".v_proj.",
@@ -497,6 +518,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
497
518
  self.logits_processor = LogitsProcessor(config)
498
519
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
499
520
 
521
+ # For EAGLE3 support
522
+ self.capture_aux_hidden_states = False
523
+
500
524
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
501
525
  pattern = MultiModalityDataPaddingPatternMultimodalTokens()
502
526
  return pattern.pad_input_tokens(input_ids, mm_inputs)
@@ -567,9 +591,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
567
591
  positions=positions,
568
592
  )
569
593
 
594
+ aux_hidden_states = None
595
+ if self.capture_aux_hidden_states:
596
+ hidden_states, aux_hidden_states = hidden_states
597
+
570
598
  if not get_embedding:
571
599
  return self.logits_processor(
572
- input_ids, hidden_states, self.lm_head, forward_batch
600
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
573
601
  )
574
602
  else:
575
603
  return self.pooler(hidden_states, forward_batch)
@@ -591,7 +619,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
591
619
  for param_name, weight_name, shard_id in stacked_params_mapping:
592
620
  if weight_name not in name:
593
621
  continue
594
- if "visual" in name:
622
+ if (
623
+ "visual" in name
624
+ and "up_proj" not in name
625
+ and "gate_proj" not in name
626
+ ):
595
627
  continue
596
628
  name = name.replace(weight_name, param_name)
597
629
 
@@ -619,5 +651,21 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
619
651
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
620
652
  weight_loader(param, loaded_weight)
621
653
 
654
+ def get_embed_and_head(self):
655
+ return self.model.embed_tokens.weight, self.lm_head.weight
656
+
657
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
658
+ self.capture_aux_hidden_states = True
659
+ self.model.capture_aux_hidden_states = True
660
+ if layer_ids is None:
661
+ num_layers = self.config.num_hidden_layers
662
+ self.model.layers_to_capture = [
663
+ 2,
664
+ num_layers // 2,
665
+ num_layers - 3,
666
+ ] # Specific layers for EAGLE3 support
667
+ else:
668
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
669
+
622
670
 
623
671
  EntryClass = [Qwen2_5_VLForConditionalGeneration]
@@ -17,7 +17,7 @@
17
17
  """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
18
18
 
19
19
  import logging
20
- from typing import Any, Dict, Iterable, Optional, Tuple, Union
20
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
21
21
 
22
22
  import torch
23
23
  import torch.nn.functional as F
@@ -65,10 +65,12 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
65
65
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
66
66
  from sglang.srt.model_loader.weight_utils import default_weight_loader
67
67
  from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
68
- from sglang.srt.utils import add_prefix, make_layers
68
+ from sglang.srt.utils import add_prefix, is_cuda, make_layers
69
69
 
70
70
  logger = logging.getLogger(__name__)
71
71
 
72
+ _is_cuda = is_cuda()
73
+
72
74
 
73
75
  class Qwen2MoeMLP(nn.Module):
74
76
  def __init__(
@@ -105,11 +107,14 @@ class Qwen2MoeMLP(nn.Module):
105
107
  def forward(
106
108
  self,
107
109
  x,
110
+ should_allreduce_fusion: bool = False,
108
111
  use_reduce_scatter: bool = False,
109
112
  ):
110
113
  gate_up, _ = self.gate_up_proj(x)
111
114
  x = self.act_fn(gate_up)
112
- x, _ = self.down_proj(x, skip_all_reduce=use_reduce_scatter)
115
+ x, _ = self.down_proj(
116
+ x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
117
+ )
113
118
  return x
114
119
 
115
120
 
@@ -119,11 +124,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
119
124
  layer_id: int,
120
125
  config: PretrainedConfig,
121
126
  quant_config: Optional[QuantizationConfig] = None,
127
+ alt_stream: Optional[torch.cuda.Stream] = None,
122
128
  prefix: str = "",
123
129
  ):
124
130
  super().__init__()
125
131
  self.tp_size = get_tensor_model_parallel_world_size()
126
132
  self.layer_id = layer_id
133
+ self.alt_stream = alt_stream
127
134
  if self.tp_size > config.num_experts:
128
135
  raise ValueError(
129
136
  f"Tensor parallel size {self.tp_size} is greater than "
@@ -165,14 +172,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
165
172
  self.shared_expert = None
166
173
  self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
167
174
 
168
- def forward(
169
- self,
170
- hidden_states: torch.Tensor,
171
- forward_batch: Optional[ForwardBatch] = None,
172
- use_reduce_scatter: bool = False,
173
- ) -> torch.Tensor:
174
- num_tokens, hidden_dim = hidden_states.shape
175
- hidden_states = hidden_states.view(-1, hidden_dim)
175
+ def _forward_shared_experts(self, hidden_states: torch.Tensor):
176
176
  shared_output = None
177
177
  if self.shared_expert is not None:
178
178
  shared_output = self.shared_expert(hidden_states)
@@ -180,11 +180,51 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
180
180
  shared_output = (
181
181
  F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
182
182
  )
183
+ return shared_output
183
184
 
185
+ def _forward_router_experts(self, hidden_states: torch.Tensor):
184
186
  # router_logits: (num_tokens, n_experts)
185
187
  router_logits, _ = self.gate(hidden_states)
186
188
  topk_output = self.topk(hidden_states, router_logits)
187
- final_hidden_states = self.experts(hidden_states, topk_output)
189
+ return self.experts(hidden_states, topk_output)
190
+
191
+ def forward_normal_dual_stream(
192
+ self,
193
+ hidden_states: torch.Tensor,
194
+ ) -> torch.Tensor:
195
+ current_stream = torch.cuda.current_stream()
196
+ self.alt_stream.wait_stream(current_stream)
197
+ shared_output = self._forward_shared_experts(hidden_states)
198
+
199
+ with torch.cuda.stream(self.alt_stream):
200
+ router_output = self._forward_router_experts(hidden_states)
201
+
202
+ current_stream.wait_stream(self.alt_stream)
203
+
204
+ return router_output, shared_output
205
+
206
+ def forward(
207
+ self,
208
+ hidden_states: torch.Tensor,
209
+ forward_batch: Optional[ForwardBatch] = None,
210
+ use_reduce_scatter: bool = False,
211
+ ) -> torch.Tensor:
212
+ num_tokens, hidden_dim = hidden_states.shape
213
+ hidden_states = hidden_states.view(-1, hidden_dim)
214
+
215
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024
216
+ if (
217
+ self.alt_stream is not None
218
+ and hidden_states.shape[0] > 0
219
+ and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
220
+ ):
221
+ final_hidden_states, shared_output = self.forward_normal_dual_stream(
222
+ hidden_states
223
+ )
224
+ else:
225
+ shared_output = self._forward_shared_experts(hidden_states)
226
+ final_hidden_states = self._forward_router_experts(hidden_states)
227
+
188
228
  if shared_output is not None:
189
229
  final_hidden_states = final_hidden_states + shared_output
190
230
  if self.tp_size > 1 and not use_reduce_scatter:
@@ -343,6 +383,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
343
383
  layer_id=layer_id,
344
384
  config=config,
345
385
  quant_config=quant_config,
386
+ alt_stream=alt_stream,
346
387
  prefix=add_prefix("mlp", prefix),
347
388
  )
348
389
  else:
@@ -525,8 +566,12 @@ class Qwen2MoeForCausalLM(nn.Module):
525
566
  self.pp_group = get_pp_group()
526
567
  self.config = config
527
568
  self.quant_config = quant_config
569
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
528
570
  self.model = Qwen2MoeModel(
529
- config, quant_config, prefix=add_prefix("model", prefix)
571
+ config,
572
+ quant_config,
573
+ prefix=add_prefix("model", prefix),
574
+ alt_stream=alt_stream,
530
575
  )
531
576
  self.lm_head = ParallelLMHead(
532
577
  config.vocab_size,
@@ -536,6 +581,8 @@ class Qwen2MoeForCausalLM(nn.Module):
536
581
  use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
537
582
  )
538
583
  self.logits_processor = LogitsProcessor(config)
584
+ # For EAGLE3 support
585
+ self.capture_aux_hidden_states = False
539
586
 
540
587
  @torch.no_grad()
541
588
  def forward(
@@ -553,9 +600,12 @@ class Qwen2MoeForCausalLM(nn.Module):
553
600
  input_embeds,
554
601
  pp_proxy_tensors=pp_proxy_tensors,
555
602
  )
603
+ aux_hidden_states = None
604
+ if self.capture_aux_hidden_states:
605
+ hidden_states, aux_hidden_states = hidden_states
556
606
  if self.pp_group.is_last_rank:
557
607
  return self.logits_processor(
558
- input_ids, hidden_states, self.lm_head, forward_batch
608
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
559
609
  )
560
610
  else:
561
611
  return hidden_states
@@ -705,5 +755,20 @@ class Qwen2MoeForCausalLM(nn.Module):
705
755
  num_groups=None,
706
756
  )
707
757
 
758
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
759
+ if not self.pp_group.is_last_rank:
760
+ return
761
+
762
+ self.capture_aux_hidden_states = True
763
+ if layer_ids is None:
764
+ num_layers = self.config.num_hidden_layers
765
+ self.model.layers_to_capture = [
766
+ 2,
767
+ num_layers // 2,
768
+ num_layers - 3,
769
+ ] # Specific layers for EAGLE3 support
770
+ else:
771
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
772
+
708
773
 
709
774
  EntryClass = Qwen2MoeForCausalLM
@@ -24,7 +24,10 @@ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
24
24
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
25
25
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
26
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
27
- from sglang.srt.model_loader.weight_utils import default_weight_loader
27
+ from sglang.srt.model_loader.weight_utils import (
28
+ default_weight_loader,
29
+ maybe_remap_kv_scale_name,
30
+ )
28
31
  from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
29
32
  from sglang.srt.models.qwen2 import Qwen2Model
30
33
  from sglang.srt.utils import add_prefix, is_cuda
@@ -458,7 +461,10 @@ class Qwen3ForCausalLM(nn.Module):
458
461
  continue
459
462
  if name.startswith("model.vision_tower") and name not in params_dict:
460
463
  continue
461
-
464
+ if "scale" in name:
465
+ name = maybe_remap_kv_scale_name(name, params_dict)
466
+ if name is None:
467
+ continue
462
468
  for param_name, weight_name, shard_id in stacked_params_mapping:
463
469
  if weight_name not in name:
464
470
  continue
@@ -42,7 +42,10 @@ from sglang.srt.layers.linear import (
42
42
  RowParallelLinear,
43
43
  )
44
44
  from sglang.srt.layers.logits_processor import LogitsProcessor
45
- from sglang.srt.layers.moe import get_moe_a2a_backend
45
+ from sglang.srt.layers.moe import (
46
+ get_moe_a2a_backend,
47
+ should_use_flashinfer_cutlass_moe_fp4_allgather,
48
+ )
46
49
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
47
50
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
48
51
  from sglang.srt.layers.moe.topk import TopK
@@ -57,10 +60,17 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
57
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
58
61
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
59
62
  from sglang.srt.models.qwen2_moe import Qwen2MoeModel
60
- from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
63
+ from sglang.srt.utils import (
64
+ add_prefix,
65
+ is_cuda,
66
+ is_flashinfer_available,
67
+ is_non_idle_and_non_empty,
68
+ )
61
69
 
62
70
  Qwen3MoeConfig = None
63
71
 
72
+ _is_flashinfer_available = is_flashinfer_available()
73
+
64
74
  logger = logging.getLogger(__name__)
65
75
  _is_cuda = is_cuda()
66
76
 
@@ -119,11 +129,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
119
129
  self,
120
130
  hidden_states: torch.Tensor,
121
131
  forward_batch: Optional[ForwardBatch] = None,
132
+ should_allreduce_fusion: bool = False,
122
133
  use_reduce_scatter: bool = False,
123
134
  ) -> torch.Tensor:
124
135
 
125
136
  if not get_moe_a2a_backend().is_deepep():
126
- return self.forward_normal(hidden_states, use_reduce_scatter)
137
+ return self.forward_normal(
138
+ hidden_states, should_allreduce_fusion, use_reduce_scatter
139
+ )
127
140
  else:
128
141
  return self.forward_deepep(hidden_states, forward_batch)
129
142
 
@@ -137,6 +150,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
137
150
  def forward_normal(
138
151
  self,
139
152
  hidden_states: torch.Tensor,
153
+ should_allreduce_fusion: bool = False,
140
154
  use_reduce_scatter: bool = False,
141
155
  ) -> torch.Tensor:
142
156
  num_tokens, hidden_dim = hidden_states.shape
@@ -146,7 +160,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
146
160
  router_logits, _ = self.gate(hidden_states)
147
161
  topk_output = self.topk(hidden_states, router_logits)
148
162
  final_hidden_states = self.experts(hidden_states, topk_output)
149
- if self.tp_size > 1 and not use_reduce_scatter:
163
+ if (
164
+ self.tp_size > 1
165
+ and not should_allreduce_fusion
166
+ and not use_reduce_scatter
167
+ and not should_use_flashinfer_cutlass_moe_fp4_allgather()
168
+ ):
150
169
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
151
170
 
152
171
  return final_hidden_states.view(num_tokens, hidden_dim)
@@ -500,6 +519,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
500
519
  input_layernorm=self.input_layernorm,
501
520
  post_attention_layernorm=self.post_attention_layernorm,
502
521
  allow_reduce_scatter=True,
522
+ is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1),
503
523
  )
504
524
 
505
525
  def forward(
@@ -525,17 +545,28 @@ class Qwen3MoeDecoderLayer(nn.Module):
525
545
  hidden_states, residual, forward_batch
526
546
  )
527
547
 
548
+ should_allreduce_fusion = (
549
+ self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
550
+ forward_batch
551
+ )
552
+ )
553
+
528
554
  # For DP with padding, reduce scatter can be used instead of all-reduce.
529
555
  use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
530
556
  forward_batch
531
557
  )
532
558
 
533
- hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
534
-
535
- hidden_states, residual = self.layer_communicator.postprocess_layer(
536
- hidden_states, residual, forward_batch
559
+ hidden_states = self.mlp(
560
+ hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
537
561
  )
538
562
 
563
+ if should_allreduce_fusion:
564
+ hidden_states._sglang_needs_allreduce_fusion = True
565
+ else:
566
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
567
+ hidden_states, residual, forward_batch
568
+ )
569
+
539
570
  return hidden_states, residual
540
571
 
541
572
  def op_comm_prepare_attn(