sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -209,6 +209,17 @@ def get_quant_config(
209
209
  config["adapter_name_or_path"] = model_name_or_path
210
210
  elif model_config.quantization == "modelopt":
211
211
  if config["producer"]["name"] == "modelopt":
212
+ # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
213
+ if config["quantization"]["quant_algo"] is None:
214
+ if (
215
+ model_config.hf_config.architectures[0]
216
+ != "LlamaForCausalLMEagle3"
217
+ ):
218
+ raise ValueError(
219
+ f"Invalid quant_config, quantization method: {model_config.quantization},"
220
+ f"hf architectures: {model_config.hf_config.architectures[0]}. "
221
+ )
222
+ return None
212
223
  if "FP4" in config["quantization"]["quant_algo"]:
213
224
  return ModelOptFp4Config.from_config(config)
214
225
  else:
@@ -449,10 +460,12 @@ def safetensors_weights_iterator(
449
460
  if disable_mmap:
450
461
  with open(st_file, "rb") as f:
451
462
  result = safetensors.torch.load(f.read())
463
+ for name, param in result.items():
464
+ yield name, param
452
465
  else:
453
- result = safetensors.torch.load_file(st_file, device="cpu")
454
- for name, param in result.items():
455
- yield name, param
466
+ with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
467
+ for name in f.keys():
468
+ yield name, f.get_tensor(name)
456
469
 
457
470
 
458
471
  def multi_thread_safetensors_weights_iterator(
@@ -485,7 +498,8 @@ def multi_thread_safetensors_weights_iterator(
485
498
  with open(st_file, "rb") as f:
486
499
  result = safetensors.torch.load(f.read())
487
500
  else:
488
- result = safetensors.torch.load_file(st_file, device="cpu")
501
+ with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
502
+ result = {k: f.get_tensor(k) for k in f.keys()}
489
503
 
490
504
  return result
491
505
 
@@ -947,3 +961,57 @@ def kv_cache_scales_loader(
947
961
  tp_rank,
948
962
  )
949
963
  return []
964
+
965
+
966
+ def get_actual_shard_size(shard_size, weight_start, weight_end):
967
+ if weight_end < weight_start:
968
+ return 0
969
+
970
+ return min(shard_size, weight_end - weight_start)
971
+
972
+
973
+ def reset_param_data_if_needed(param_data, dim, start, length):
974
+ if length == 0:
975
+ return
976
+
977
+ assert length > 0, f"Length should be positive, but got {length}"
978
+
979
+ param_data.narrow(dim, start, length).zero_()
980
+ return
981
+
982
+
983
+ def narrow_padded_param_and_loaded_weight(
984
+ param_data,
985
+ loaded_weight,
986
+ param_data_start,
987
+ weight_start,
988
+ dim,
989
+ shard_size,
990
+ narrow_weight=True,
991
+ ):
992
+ actual_shard_size = get_actual_shard_size(
993
+ shard_size, weight_start, loaded_weight.size(dim)
994
+ )
995
+
996
+ if narrow_weight:
997
+ if actual_shard_size > 0:
998
+ loaded_weight = loaded_weight.narrow(dim, weight_start, actual_shard_size)
999
+ else:
1000
+ # No real data to load; create a dummy tensor filled with zeros
1001
+ loaded_weight = torch.zeros_like(
1002
+ param_data.narrow(dim, param_data_start, actual_shard_size)
1003
+ )
1004
+
1005
+ # [Note] Reset padded weights to zero.
1006
+ # If the actual shard size is less than the shard size, we need to reset
1007
+ # the padded param_data to zero and then copy the loaded_weight into it.
1008
+ reset_param_data_if_needed(
1009
+ param_data,
1010
+ dim,
1011
+ param_data_start + actual_shard_size,
1012
+ shard_size - actual_shard_size,
1013
+ )
1014
+
1015
+ param_data = param_data.narrow(dim, param_data_start, actual_shard_size)
1016
+
1017
+ return param_data, loaded_weight
@@ -21,6 +21,7 @@ from torch import nn
21
21
  from transformers import PretrainedConfig
22
22
 
23
23
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
24
25
  from sglang.srt.layers.layernorm import RMSNorm
25
26
  from sglang.srt.layers.logits_processor import LogitsProcessor
26
27
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -28,9 +29,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
28
29
  ParallelLMHead,
29
30
  VocabParallelEmbedding,
30
31
  )
31
- from sglang.srt.managers.expert_distribution import (
32
- get_global_expert_distribution_recorder,
33
- )
34
32
  from sglang.srt.managers.schedule_batch import global_server_args_dict
35
33
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
36
34
  from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
@@ -32,7 +32,11 @@ from sglang.srt.distributed import (
32
32
  parallel_state,
33
33
  tensor_model_parallel_all_reduce,
34
34
  )
35
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
36
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
37
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
35
38
  from sglang.srt.layers.activation import SiluAndMul
39
+ from sglang.srt.layers.amx_utils import PackWeightMethod
36
40
  from sglang.srt.layers.communicator import (
37
41
  LayerCommunicator,
38
42
  LayerScatterModes,
@@ -77,11 +81,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
77
81
  ParallelLMHead,
78
82
  VocabParallelEmbedding,
79
83
  )
80
- from sglang.srt.managers.expert_distribution import (
81
- get_global_expert_distribution_recorder,
82
- )
83
- from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
84
- from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
85
84
  from sglang.srt.managers.schedule_batch import global_server_args_dict
86
85
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
87
86
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -93,17 +92,18 @@ from sglang.srt.utils import (
93
92
  BumpAllocator,
94
93
  DeepEPMode,
95
94
  LazyValue,
96
- PackWeightMethod,
97
95
  add_prefix,
98
96
  bind_or_assign,
99
97
  cpu_has_amx_support,
100
98
  get_bool_env_var,
99
+ get_device_sm,
101
100
  get_int_env_var,
102
101
  is_cpu,
103
102
  is_cuda,
104
103
  is_hip,
105
104
  is_non_idle_and_non_empty,
106
105
  log_info_on_rank0,
106
+ use_intel_amx_backend,
107
107
  )
108
108
 
109
109
  _is_hip = is_hip()
@@ -112,9 +112,16 @@ _is_fp8_fnuz = is_fp8_fnuz()
112
112
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
113
113
  _is_cpu_amx_available = cpu_has_amx_support()
114
114
  _is_cpu = is_cpu()
115
+ _device_sm = get_device_sm()
115
116
 
116
117
  if _is_cuda:
117
- from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
118
+ from sgl_kernel import (
119
+ awq_dequantize,
120
+ bmm_fp8,
121
+ dsv3_fused_a_gemm,
122
+ dsv3_router_gemm,
123
+ merge_state_v2,
124
+ )
118
125
  elif _is_cpu and _is_cpu_amx_available:
119
126
  pass
120
127
  else:
@@ -203,8 +210,10 @@ class MoEGate(nn.Module):
203
210
  self,
204
211
  config,
205
212
  prefix: str = "",
213
+ is_nextn: bool = False,
206
214
  ):
207
215
  super().__init__()
216
+ self.is_nextn = is_nextn
208
217
  self.weight = nn.Parameter(
209
218
  torch.empty((config.n_routed_experts, config.hidden_size))
210
219
  )
@@ -218,7 +227,7 @@ class MoEGate(nn.Module):
218
227
  self.quant_method = PackWeightMethod(weight_names=["weight"])
219
228
 
220
229
  def forward(self, hidden_states):
221
- if getattr(self, "use_intel_amx_backend", False):
230
+ if use_intel_amx_backend(self):
222
231
  return torch.ops.sgl_kernel.weight_packed_linear(
223
232
  hidden_states,
224
233
  self.weight,
@@ -226,7 +235,21 @@ class MoEGate(nn.Module):
226
235
  True, # is_vnni
227
236
  )
228
237
 
229
- logits = F.linear(hidden_states, self.weight, None)
238
+ # NOTE: For some unknown reason, router_gemm seems degrade accept length.
239
+ if (
240
+ _is_cuda
241
+ and not self.is_nextn
242
+ and hidden_states.shape[0] < 4
243
+ and hidden_states.shape[1] == 7168
244
+ and self.weight.shape[0] == 256
245
+ and _device_sm >= 90
246
+ ):
247
+ logits = dsv3_router_gemm(hidden_states, self.weight).to(
248
+ hidden_states.dtype
249
+ )
250
+ else:
251
+ logits = F.linear(hidden_states, self.weight, None)
252
+
230
253
  return logits
231
254
 
232
255
 
@@ -239,6 +262,7 @@ class DeepseekV2MoE(nn.Module):
239
262
  quant_config: Optional[QuantizationConfig] = None,
240
263
  prefix: str = "",
241
264
  alt_stream: Optional[torch.cuda.Stream] = None,
265
+ is_nextn: bool = False,
242
266
  ):
243
267
  super().__init__()
244
268
  self.tp_size = get_tensor_model_parallel_world_size()
@@ -265,7 +289,9 @@ class DeepseekV2MoE(nn.Module):
265
289
  "Only silu is supported for now."
266
290
  )
267
291
 
268
- self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
292
+ self.gate = MoEGate(
293
+ config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
294
+ )
269
295
 
270
296
  self.experts = get_moe_impl_class()(
271
297
  num_experts=config.n_routed_experts
@@ -300,6 +326,9 @@ class DeepseekV2MoE(nn.Module):
300
326
  ),
301
327
  )
302
328
 
329
+ self.shared_experts_is_int8 = False
330
+ self.shared_experts_is_fp8 = False
331
+ self.shared_experts_weight_block_size = None
303
332
  if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
304
333
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
305
334
  # disable tp for shared experts when enable deepep moe
@@ -316,6 +345,28 @@ class DeepseekV2MoE(nn.Module):
316
345
  else {}
317
346
  ),
318
347
  )
348
+ is_packed_weight = hasattr(
349
+ self.shared_experts.gate_up_proj.quant_method, "quant_config"
350
+ ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
351
+ "awq",
352
+ "moe_wna16",
353
+ }
354
+ self.shared_experts_is_int8 = (
355
+ not is_packed_weight
356
+ and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
357
+ )
358
+ self.shared_experts_is_fp8 = (
359
+ not is_packed_weight
360
+ and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
361
+ )
362
+ if self.shared_experts_is_fp8:
363
+ assert (
364
+ self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
365
+ == self.shared_experts.down_proj.quant_method.quant_config.weight_block_size
366
+ )
367
+ self.shared_experts_weight_block_size = (
368
+ self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
369
+ )
319
370
 
320
371
  self.top_k = config.num_experts_per_tok
321
372
 
@@ -394,6 +445,11 @@ class DeepseekV2MoE(nn.Module):
394
445
  return final_hidden_states
395
446
 
396
447
  def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
448
+ if hasattr(self, "shared_experts") and use_intel_amx_backend(
449
+ self.shared_experts.gate_up_proj
450
+ ):
451
+ return self.forward_cpu(hidden_states)
452
+
397
453
  shared_output = self._forward_shared_experts(hidden_states)
398
454
  # router_logits: (num_tokens, n_experts)
399
455
  router_logits = self.gate(hidden_states)
@@ -409,6 +465,59 @@ class DeepseekV2MoE(nn.Module):
409
465
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
410
466
  return final_hidden_states
411
467
 
468
+ def forward_cpu(self, hidden_states: torch.Tensor) -> torch.Tensor:
469
+ # router_logits: (num_tokens, n_experts)
470
+ router_logits = self.gate(hidden_states)
471
+ fused_experts_out = self.experts(
472
+ hidden_states=hidden_states, router_logits=router_logits
473
+ )
474
+
475
+ assert use_intel_amx_backend(
476
+ self.shared_experts.gate_up_proj
477
+ ) == use_intel_amx_backend(self.shared_experts.down_proj)
478
+ # [Note] inplace should be False in fused_experts.
479
+ # If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
480
+ # While hidden_states is still needed in shared_expert.
481
+ final_hidden_states = torch.ops.sgl_kernel.shared_expert_cpu(
482
+ hidden_states,
483
+ self.shared_experts.gate_up_proj.weight,
484
+ self.shared_experts.down_proj.weight,
485
+ fused_experts_out,
486
+ self.routed_scaling_factor,
487
+ True, # inplace
488
+ self.shared_experts_is_int8, # use_int8_w8a8
489
+ self.shared_experts_is_fp8, # use_fp8_w8a16
490
+ (
491
+ self.shared_experts.gate_up_proj.weight_scale
492
+ if self.shared_experts_is_int8
493
+ else (
494
+ self.shared_experts.gate_up_proj.weight_scale_inv
495
+ if self.shared_experts_is_fp8
496
+ else None
497
+ )
498
+ ), # w1_scale
499
+ (
500
+ self.shared_experts.down_proj.weight_scale
501
+ if self.shared_experts_is_int8
502
+ else (
503
+ self.shared_experts.down_proj.weight_scale_inv
504
+ if self.shared_experts_is_fp8
505
+ else None
506
+ )
507
+ ), # w2_scale
508
+ (
509
+ self.shared_experts_weight_block_size
510
+ if self.shared_experts_is_fp8
511
+ else None
512
+ ), # block_size
513
+ None, # a1_scale
514
+ None, # a2_scale
515
+ True, # is_vnni
516
+ )
517
+ if self.tp_size > 1:
518
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
519
+ return final_hidden_states
520
+
412
521
  def forward_deepep(
413
522
  self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
414
523
  ) -> torch.Tensor:
@@ -456,7 +565,7 @@ class DeepseekV2MoE(nn.Module):
456
565
  hidden_states=hidden_states,
457
566
  topk_idx=topk_idx,
458
567
  topk_weights=topk_weights,
459
- forward_mode=forward_mode,
568
+ forward_batch=forward_batch,
460
569
  )
461
570
  final_hidden_states = self.experts(
462
571
  hidden_states=hidden_states,
@@ -467,14 +576,14 @@ class DeepseekV2MoE(nn.Module):
467
576
  masked_m=masked_m,
468
577
  expected_m=expected_m,
469
578
  num_recv_tokens_per_expert=num_recv_tokens_per_expert,
470
- forward_mode=forward_mode,
579
+ forward_batch=forward_batch,
471
580
  )
472
581
  if self.ep_size > 1:
473
582
  final_hidden_states = self.deepep_dispatcher.combine(
474
583
  hidden_states=final_hidden_states,
475
584
  topk_idx=topk_idx,
476
585
  topk_weights=topk_weights,
477
- forward_mode=forward_mode,
586
+ forward_batch=forward_batch,
478
587
  )
479
588
 
480
589
  if shared_output is not None:
@@ -549,7 +658,7 @@ class DeepseekV2MoE(nn.Module):
549
658
  hidden_states=state.hidden_states_mlp_input,
550
659
  topk_idx=state.pop("topk_idx_local"),
551
660
  topk_weights=state.pop("topk_weights_local"),
552
- forward_mode=state.forward_batch.forward_mode,
661
+ forward_batch=state.forward_batch,
553
662
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
554
663
  )
555
664
 
@@ -581,7 +690,7 @@ class DeepseekV2MoE(nn.Module):
581
690
  masked_m=state.pop("masked_m"),
582
691
  expected_m=state.pop("expected_m"),
583
692
  num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
584
- forward_mode=state.forward_batch.forward_mode,
693
+ forward_batch=state.forward_batch,
585
694
  )
586
695
 
587
696
  def op_combine_a(self, state):
@@ -590,7 +699,7 @@ class DeepseekV2MoE(nn.Module):
590
699
  hidden_states=state.pop("hidden_states_experts_output"),
591
700
  topk_idx=state.pop("topk_idx_dispatched"),
592
701
  topk_weights=state.pop("topk_weights_dispatched"),
593
- forward_mode=state.forward_batch.forward_mode,
702
+ forward_batch=state.forward_batch,
594
703
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
595
704
  )
596
705
 
@@ -793,33 +902,56 @@ class DeepseekV2AttentionMLA(nn.Module):
793
902
  # If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
794
903
  # which requires self.w_kc and self.w_vc to be packed.
795
904
  # If not, we will use torch.bmm and weight shouldn't be packed in this case
796
- if (
797
- hasattr(self, "fused_qkv_a_proj_with_mqa")
798
- and _is_cpu
799
- and _is_cpu_amx_available
800
- ):
905
+ has_fused_proj = hasattr(self, "fused_qkv_a_proj_with_mqa")
906
+ if has_fused_proj and _is_cpu and _is_cpu_amx_available:
801
907
  self.quant_method = PackWeightMethod(
802
908
  weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
803
909
  )
804
910
 
911
+ is_packed_weight = (
912
+ has_fused_proj
913
+ and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
914
+ and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
915
+ in {"awq", "moe_wna16"}
916
+ )
917
+ self.use_min_latency_fused_a_gemm = (
918
+ has_fused_proj
919
+ and not is_packed_weight
920
+ and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
921
+ and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
922
+ and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
923
+ and _is_cuda
924
+ and _device_sm >= 90
925
+ )
926
+
805
927
  self.qkv_proj_with_rope_is_int8 = (
806
- hasattr(self, "fused_qkv_a_proj_with_mqa")
928
+ has_fused_proj
929
+ and not is_packed_weight
807
930
  and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
808
931
  )
809
932
  self.qkv_proj_with_rope_is_fp8 = (
810
- hasattr(self, "fused_qkv_a_proj_with_mqa")
933
+ has_fused_proj
934
+ and not is_packed_weight
811
935
  and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
812
936
  )
813
937
 
814
938
  self.weight_block_size = None
815
- if self.qkv_proj_with_rope_is_fp8:
816
- assert (
817
- self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
818
- == self.q_b_proj.quant_method.quant_config.weight_block_size
819
- )
820
- self.weight_block_size = (
821
- self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
822
- )
939
+ if self.qkv_proj_with_rope_is_fp8 and _is_cpu and _is_cpu_amx_available:
940
+ assert getattr(
941
+ self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
942
+ ) == getattr(self.q_b_proj.quant_method, "block_quant", False)
943
+ use_block_quant = getattr(
944
+ self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
945
+ )
946
+
947
+ if use_block_quant:
948
+ assert (
949
+ self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
950
+ == self.q_b_proj.quant_method.quant_config.weight_block_size
951
+ )
952
+ self.weight_block_size = (
953
+ self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
954
+ )
823
955
 
824
956
  def dispatch_attn_forward_method(
825
957
  self, forward_batch: ForwardBatch
@@ -834,14 +966,16 @@ class DeepseekV2AttentionMLA(nn.Module):
834
966
  else:
835
967
  return AttnForwardMethod.MLA
836
968
  else:
837
- if hasattr(self, "fused_qkv_a_proj_with_mqa") and getattr(
838
- self, "use_intel_amx_backend", False
969
+ if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
970
+ self
839
971
  ):
840
972
  return AttnForwardMethod.MLA_FUSED_ROPE_CPU
841
973
  else:
842
974
  return AttnForwardMethod.MLA
843
975
 
844
- if self.attention_backend == "flashinfer":
976
+ if self.attention_backend == "ascend":
977
+ return AttnForwardMethod.MLA
978
+ elif self.attention_backend == "flashinfer":
845
979
  # Flashinfer MLA: Do not absorb when enabling ragged prefill
846
980
  if (
847
981
  not self.flashinfer_mla_disable_ragged
@@ -1041,7 +1175,13 @@ class DeepseekV2AttentionMLA(nn.Module):
1041
1175
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1042
1176
 
1043
1177
  if self.q_lora_rank is not None:
1044
- q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
1178
+ if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
1179
+ fused_qkv_a_proj_out = dsv3_fused_a_gemm(
1180
+ hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
1181
+ )
1182
+ else:
1183
+ fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
1184
+ q, latent_cache = fused_qkv_a_proj_out.split(
1045
1185
  [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
1046
1186
  )
1047
1187
  k_nope = latent_cache[..., : self.kv_lora_rank]
@@ -1302,8 +1442,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1302
1442
  forward_batch: ForwardBatch,
1303
1443
  zero_allocator: BumpAllocator,
1304
1444
  ):
1305
- assert self.q_lora_rank is not None and getattr(
1306
- self, "use_intel_amx_backend", False
1445
+ assert self.q_lora_rank is not None and use_intel_amx_backend(
1446
+ self
1307
1447
  ), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
1308
1448
 
1309
1449
  q_input, k_input, v_input = (
@@ -1422,8 +1562,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1422
1562
  def forward_absorb_fused_mla_rope_cpu_core(
1423
1563
  self, q_input, k_input, v_input, forward_batch, zero_allocator
1424
1564
  ):
1425
- assert self.q_lora_rank is not None and getattr(
1426
- self, "use_intel_amx_backend", False
1565
+ assert self.q_lora_rank is not None and use_intel_amx_backend(
1566
+ self
1427
1567
  ), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
1428
1568
 
1429
1569
  attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
@@ -1643,6 +1783,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1643
1783
  prefix=add_prefix("mlp", prefix),
1644
1784
  layer_id=self.layer_id,
1645
1785
  alt_stream=alt_stream,
1786
+ is_nextn=is_nextn,
1646
1787
  )
1647
1788
  else:
1648
1789
  if enable_moe_dense_fully_dp():
@@ -1707,11 +1848,6 @@ class DeepseekV2DecoderLayer(nn.Module):
1707
1848
  hidden_states, residual, forward_batch
1708
1849
  )
1709
1850
 
1710
- if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
1711
- # NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
1712
- # See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
1713
- hidden_states = hidden_states.clone()
1714
-
1715
1851
  return hidden_states, residual
1716
1852
 
1717
1853
  def op_comm_prepare_attn(
@@ -1753,7 +1889,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1753
1889
  and hidden_states.shape[0] == 0
1754
1890
  ):
1755
1891
  state.hidden_states_mlp_output = self.mlp(
1756
- hidden_states, state.forward_batch.forward_mode
1892
+ hidden_states, state.forward_batch
1757
1893
  )
1758
1894
  else:
1759
1895
  state.hidden_states_mlp_output = hidden_states
@@ -2107,6 +2243,14 @@ class DeepseekV2ForCausalLM(nn.Module):
2107
2243
  )
2108
2244
  if _is_hip:
2109
2245
  self_attn.w_scale *= 2.0
2246
+ # TODO: remove this after adding FP8 support in bmm cpu kernel
2247
+ if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn:
2248
+ self_attn.w_kc = (
2249
+ self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale
2250
+ )
2251
+ self_attn.w_vc = (
2252
+ self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale
2253
+ )
2110
2254
  else:
2111
2255
  num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
2112
2256
  num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
@@ -2219,6 +2363,12 @@ class DeepseekV2ForCausalLM(nn.Module):
2219
2363
  ckpt_up_proj_name="up_proj",
2220
2364
  num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
2221
2365
  )
2366
+ if self.quant_config and self.quant_config.get_name() == "w4afp8":
2367
+ expert_params_mapping += (
2368
+ get_moe_impl_class().make_expert_input_scale_params_mapping(
2369
+ num_experts=self.config.n_routed_experts
2370
+ )
2371
+ )
2222
2372
 
2223
2373
  # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
2224
2374
  fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
@@ -253,11 +253,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
253
253
  weights_loader = getattr(param, "weight_loader", default_weight_loader)
254
254
  weights_loader(param, loaded_weight)
255
255
 
256
- def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
257
- helper = MultiModalityDataPaddingPatternMultimodalTokens(
258
- [image_inputs.im_token_id]
259
- )
260
- return helper.pad_input_tokens(input_ids, image_inputs)
256
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
257
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
258
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
261
259
 
262
260
  def get_image_feature(self, items: List[MultimodalDataItem]):
263
261
 
@@ -166,8 +166,7 @@ class Gemma3Attention(nn.Module):
166
166
  prefix=add_prefix("o_proj", prefix),
167
167
  )
168
168
 
169
- # Determine if layer uses sliding window based on pattern
170
- self.is_sliding = bool((layer_id + 1) % config.sliding_window_pattern)
169
+ self.is_sliding = config.layer_types[layer_id] == "sliding_attention"
171
170
 
172
171
  # Initialize the rotary embedding.
173
172
  if self.is_sliding:
@@ -62,7 +62,7 @@ class Gemma3nTextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
62
62
  pass
63
63
 
64
64
 
65
- class Gemma3nMLP(nn.Module):
65
+ class Gemma3nTextMLP(nn.Module):
66
66
  def __init__(
67
67
  self,
68
68
  hidden_size: int,
@@ -514,10 +514,11 @@ class Gemma3nDecoderLayer(nn.Module):
514
514
  prefix=add_prefix("self_attn", prefix),
515
515
  )
516
516
 
517
+ intermediate_size = config.intermediate_size[layer_id]
517
518
  activation_sparsity = config.activation_sparsity_pattern[layer_id]
518
- self.mlp = Gemma3nMLP(
519
+ self.mlp = Gemma3nTextMLP(
519
520
  hidden_size=self.hidden_size,
520
- intermediate_size=config.intermediate_size,
521
+ intermediate_size=intermediate_size,
521
522
  hidden_activation=config.hidden_activation,
522
523
  activation_sparsity=activation_sparsity,
523
524
  quant_config=quant_config,
@@ -21,7 +21,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
21
21
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
22
22
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
23
23
  from sglang.srt.managers.mm_utils import (
24
- MultiModalityDataPaddingPatternTokenPairs,
24
+ MultiModalityDataPaddingPatternMultimodalTokens,
25
25
  general_mm_embed_routine,
26
26
  )
27
27
  from sglang.srt.managers.schedule_batch import (
@@ -244,26 +244,11 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
244
244
  def pad_input_ids(
245
245
  self,
246
246
  input_ids: List[int],
247
- mm_inputs: Optional[MultimodalInputs] = None,
247
+ mm_inputs: MultimodalInputs,
248
248
  ) -> List[int]:
249
249
  """Pad input IDs with image and audio tokens."""
250
- if mm_inputs is None:
251
- return input_ids
252
-
253
- # Collect available media token pairs
254
- media_token_pairs = []
255
- for attr_name in ["im_start_id", "audio_start_id"]:
256
- if hasattr(mm_inputs, attr_name):
257
- start_id = getattr(mm_inputs, attr_name)
258
- end_id = getattr(mm_inputs, attr_name.replace("start", "end"))
259
- media_token_pairs.append((start_id, end_id))
260
-
261
- # Apply padding pattern if we have media tokens
262
- if media_token_pairs:
263
- pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
264
- return pattern.pad_input_tokens(input_ids, mm_inputs)
265
-
266
- return input_ids
250
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
251
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
267
252
 
268
253
  def get_input_embeddings(self) -> nn.Embedding:
269
254
  return self.language_model.get_input_embeddings()
@@ -431,7 +416,6 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
431
416
  )
432
417
 
433
418
  positions += 1
434
-
435
419
  if input_ids is not None:
436
420
  # Prepare per-layer inputs from inputs_ids
437
421
  per_layer_inputs_mask = torch.logical_and(