sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +48 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +34 -0
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/nixl/conn.py +6 -6
  10. sglang/srt/disaggregation/prefill.py +2 -2
  11. sglang/srt/disaggregation/utils.py +1 -1
  12. sglang/srt/distributed/parallel_state.py +44 -17
  13. sglang/srt/entrypoints/EngineBase.py +8 -0
  14. sglang/srt/entrypoints/engine.py +40 -6
  15. sglang/srt/entrypoints/http_server.py +111 -24
  16. sglang/srt/entrypoints/openai/protocol.py +4 -2
  17. sglang/srt/eplb/__init__.py +0 -0
  18. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  19. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  20. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  21. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  22. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  24. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  25. sglang/srt/hf_transformers_utils.py +2 -1
  26. sglang/srt/layers/activation.py +2 -2
  27. sglang/srt/layers/amx_utils.py +86 -0
  28. sglang/srt/layers/attention/ascend_backend.py +219 -0
  29. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  30. sglang/srt/layers/attention/tbo_backend.py +37 -9
  31. sglang/srt/layers/communicator.py +18 -2
  32. sglang/srt/layers/dp_attention.py +9 -3
  33. sglang/srt/layers/elementwise.py +76 -12
  34. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  35. sglang/srt/layers/layernorm.py +26 -0
  36. sglang/srt/layers/linear.py +84 -14
  37. sglang/srt/layers/logits_processor.py +4 -4
  38. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +36 -13
  40. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  41. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
  42. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
  43. sglang/srt/layers/moe/router.py +60 -22
  44. sglang/srt/layers/moe/topk.py +10 -28
  45. sglang/srt/layers/parameter.py +67 -7
  46. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  47. sglang/srt/layers/quantization/fp8.py +44 -0
  48. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  49. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  50. sglang/srt/layers/quantization/gptq.py +5 -1
  51. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  52. sglang/srt/layers/quantization/quant_utils.py +166 -0
  53. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  54. sglang/srt/layers/rotary_embedding.py +2 -2
  55. sglang/srt/layers/vocab_parallel_embedding.py +11 -7
  56. sglang/srt/lora/lora.py +4 -5
  57. sglang/srt/lora/lora_manager.py +73 -20
  58. sglang/srt/managers/configure_logging.py +1 -1
  59. sglang/srt/managers/io_struct.py +50 -13
  60. sglang/srt/managers/mm_utils.py +73 -59
  61. sglang/srt/managers/multimodal_processor.py +2 -6
  62. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  63. sglang/srt/managers/schedule_batch.py +77 -84
  64. sglang/srt/managers/scheduler.py +113 -59
  65. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  66. sglang/srt/managers/session_controller.py +12 -3
  67. sglang/srt/managers/tokenizer_manager.py +314 -103
  68. sglang/srt/managers/tp_worker.py +13 -1
  69. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  70. sglang/srt/mem_cache/allocator.py +290 -0
  71. sglang/srt/mem_cache/chunk_cache.py +34 -2
  72. sglang/srt/mem_cache/memory_pool.py +289 -3
  73. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  74. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  75. sglang/srt/model_executor/forward_batch_info.py +17 -4
  76. sglang/srt/model_executor/model_runner.py +297 -56
  77. sglang/srt/model_loader/loader.py +41 -0
  78. sglang/srt/model_loader/weight_utils.py +72 -4
  79. sglang/srt/models/deepseek_nextn.py +1 -3
  80. sglang/srt/models/deepseek_v2.py +181 -45
  81. sglang/srt/models/deepseek_vl2.py +3 -5
  82. sglang/srt/models/gemma3_causal.py +1 -2
  83. sglang/srt/models/gemma3n_causal.py +4 -3
  84. sglang/srt/models/gemma3n_mm.py +4 -20
  85. sglang/srt/models/hunyuan.py +1 -1
  86. sglang/srt/models/kimi_vl.py +1 -2
  87. sglang/srt/models/llama.py +10 -4
  88. sglang/srt/models/llama4.py +32 -45
  89. sglang/srt/models/llama_eagle3.py +61 -11
  90. sglang/srt/models/llava.py +5 -5
  91. sglang/srt/models/minicpmo.py +2 -2
  92. sglang/srt/models/mistral.py +1 -1
  93. sglang/srt/models/mllama4.py +43 -11
  94. sglang/srt/models/phi4mm.py +1 -3
  95. sglang/srt/models/pixtral.py +3 -7
  96. sglang/srt/models/qwen2.py +31 -3
  97. sglang/srt/models/qwen2_5_vl.py +1 -3
  98. sglang/srt/models/qwen2_audio.py +200 -0
  99. sglang/srt/models/qwen2_moe.py +32 -6
  100. sglang/srt/models/qwen2_vl.py +1 -4
  101. sglang/srt/models/qwen3.py +94 -25
  102. sglang/srt/models/qwen3_moe.py +68 -21
  103. sglang/srt/models/vila.py +3 -8
  104. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  105. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  106. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  107. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  108. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  109. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  110. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  111. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  112. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  117. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  120. sglang/srt/operations_strategy.py +6 -2
  121. sglang/srt/reasoning_parser.py +26 -0
  122. sglang/srt/sampling/sampling_batch_info.py +39 -1
  123. sglang/srt/server_args.py +69 -22
  124. sglang/srt/speculative/build_eagle_tree.py +57 -18
  125. sglang/srt/speculative/eagle_worker.py +6 -4
  126. sglang/srt/two_batch_overlap.py +200 -27
  127. sglang/srt/utils.py +306 -146
  128. sglang/srt/warmup.py +12 -3
  129. sglang/test/runners.py +10 -1
  130. sglang/test/test_utils.py +15 -3
  131. sglang/version.py +1 -1
  132. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  133. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
  134. sglang/math_utils.py +0 -8
  135. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  136. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  137. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  138. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  139. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  140. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  141. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -209,6 +209,17 @@ def get_quant_config(
209
209
  config["adapter_name_or_path"] = model_name_or_path
210
210
  elif model_config.quantization == "modelopt":
211
211
  if config["producer"]["name"] == "modelopt":
212
+ # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
213
+ if config["quantization"]["quant_algo"] is None:
214
+ if (
215
+ model_config.hf_config.architectures[0]
216
+ != "LlamaForCausalLMEagle3"
217
+ ):
218
+ raise ValueError(
219
+ f"Invalid quant_config, quantization method: {model_config.quantization},"
220
+ f"hf architectures: {model_config.hf_config.architectures[0]}. "
221
+ )
222
+ return None
212
223
  if "FP4" in config["quantization"]["quant_algo"]:
213
224
  return ModelOptFp4Config.from_config(config)
214
225
  else:
@@ -449,10 +460,12 @@ def safetensors_weights_iterator(
449
460
  if disable_mmap:
450
461
  with open(st_file, "rb") as f:
451
462
  result = safetensors.torch.load(f.read())
463
+ for name, param in result.items():
464
+ yield name, param
452
465
  else:
453
- result = safetensors.torch.load_file(st_file, device="cpu")
454
- for name, param in result.items():
455
- yield name, param
466
+ with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
467
+ for name in f.keys():
468
+ yield name, f.get_tensor(name)
456
469
 
457
470
 
458
471
  def multi_thread_safetensors_weights_iterator(
@@ -485,7 +498,8 @@ def multi_thread_safetensors_weights_iterator(
485
498
  with open(st_file, "rb") as f:
486
499
  result = safetensors.torch.load(f.read())
487
500
  else:
488
- result = safetensors.torch.load_file(st_file, device="cpu")
501
+ with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
502
+ result = {k: f.get_tensor(k) for k in f.keys()}
489
503
 
490
504
  return result
491
505
 
@@ -947,3 +961,57 @@ def kv_cache_scales_loader(
947
961
  tp_rank,
948
962
  )
949
963
  return []
964
+
965
+
966
+ def get_actual_shard_size(shard_size, weight_start, weight_end):
967
+ if weight_end < weight_start:
968
+ return 0
969
+
970
+ return min(shard_size, weight_end - weight_start)
971
+
972
+
973
+ def reset_param_data_if_needed(param_data, dim, start, length):
974
+ if length == 0:
975
+ return
976
+
977
+ assert length > 0, f"Length should be positive, but got {length}"
978
+
979
+ param_data.narrow(dim, start, length).zero_()
980
+ return
981
+
982
+
983
+ def narrow_padded_param_and_loaded_weight(
984
+ param_data,
985
+ loaded_weight,
986
+ param_data_start,
987
+ weight_start,
988
+ dim,
989
+ shard_size,
990
+ narrow_weight=True,
991
+ ):
992
+ actual_shard_size = get_actual_shard_size(
993
+ shard_size, weight_start, loaded_weight.size(dim)
994
+ )
995
+
996
+ if narrow_weight:
997
+ if actual_shard_size > 0:
998
+ loaded_weight = loaded_weight.narrow(dim, weight_start, actual_shard_size)
999
+ else:
1000
+ # No real data to load; create a dummy tensor filled with zeros
1001
+ loaded_weight = torch.zeros_like(
1002
+ param_data.narrow(dim, param_data_start, actual_shard_size)
1003
+ )
1004
+
1005
+ # [Note] Reset padded weights to zero.
1006
+ # If the actual shard size is less than the shard size, we need to reset
1007
+ # the padded param_data to zero and then copy the loaded_weight into it.
1008
+ reset_param_data_if_needed(
1009
+ param_data,
1010
+ dim,
1011
+ param_data_start + actual_shard_size,
1012
+ shard_size - actual_shard_size,
1013
+ )
1014
+
1015
+ param_data = param_data.narrow(dim, param_data_start, actual_shard_size)
1016
+
1017
+ return param_data, loaded_weight
@@ -21,6 +21,7 @@ from torch import nn
21
21
  from transformers import PretrainedConfig
22
22
 
23
23
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
24
25
  from sglang.srt.layers.layernorm import RMSNorm
25
26
  from sglang.srt.layers.logits_processor import LogitsProcessor
26
27
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -28,9 +29,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
28
29
  ParallelLMHead,
29
30
  VocabParallelEmbedding,
30
31
  )
31
- from sglang.srt.managers.expert_distribution import (
32
- get_global_expert_distribution_recorder,
33
- )
34
32
  from sglang.srt.managers.schedule_batch import global_server_args_dict
35
33
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
36
34
  from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
@@ -32,7 +32,11 @@ from sglang.srt.distributed import (
32
32
  parallel_state,
33
33
  tensor_model_parallel_all_reduce,
34
34
  )
35
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
36
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
37
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
35
38
  from sglang.srt.layers.activation import SiluAndMul
39
+ from sglang.srt.layers.amx_utils import PackWeightMethod
36
40
  from sglang.srt.layers.communicator import (
37
41
  LayerCommunicator,
38
42
  LayerScatterModes,
@@ -77,11 +81,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
77
81
  ParallelLMHead,
78
82
  VocabParallelEmbedding,
79
83
  )
80
- from sglang.srt.managers.expert_distribution import (
81
- get_global_expert_distribution_recorder,
82
- )
83
- from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
84
- from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
85
84
  from sglang.srt.managers.schedule_batch import global_server_args_dict
86
85
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
87
86
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -93,17 +92,18 @@ from sglang.srt.utils import (
93
92
  BumpAllocator,
94
93
  DeepEPMode,
95
94
  LazyValue,
96
- PackWeightMethod,
97
95
  add_prefix,
98
96
  bind_or_assign,
99
97
  cpu_has_amx_support,
100
98
  get_bool_env_var,
99
+ get_device_sm,
101
100
  get_int_env_var,
102
101
  is_cpu,
103
102
  is_cuda,
104
103
  is_hip,
105
104
  is_non_idle_and_non_empty,
106
105
  log_info_on_rank0,
106
+ use_intel_amx_backend,
107
107
  )
108
108
 
109
109
  _is_hip = is_hip()
@@ -112,9 +112,16 @@ _is_fp8_fnuz = is_fp8_fnuz()
112
112
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
113
113
  _is_cpu_amx_available = cpu_has_amx_support()
114
114
  _is_cpu = is_cpu()
115
+ _device_sm = get_device_sm()
115
116
 
116
117
  if _is_cuda:
117
- from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
118
+ from sgl_kernel import (
119
+ awq_dequantize,
120
+ bmm_fp8,
121
+ dsv3_fused_a_gemm,
122
+ dsv3_router_gemm,
123
+ merge_state_v2,
124
+ )
118
125
  elif _is_cpu and _is_cpu_amx_available:
119
126
  pass
120
127
  else:
@@ -218,7 +225,7 @@ class MoEGate(nn.Module):
218
225
  self.quant_method = PackWeightMethod(weight_names=["weight"])
219
226
 
220
227
  def forward(self, hidden_states):
221
- if getattr(self, "use_intel_amx_backend", False):
228
+ if use_intel_amx_backend(self):
222
229
  return torch.ops.sgl_kernel.weight_packed_linear(
223
230
  hidden_states,
224
231
  self.weight,
@@ -226,7 +233,19 @@ class MoEGate(nn.Module):
226
233
  True, # is_vnni
227
234
  )
228
235
 
229
- logits = F.linear(hidden_states, self.weight, None)
236
+ if (
237
+ _is_cuda
238
+ and hidden_states.shape[0] < 4
239
+ and hidden_states.shape[1] == 7168
240
+ and self.weight.shape[0] == 256
241
+ and _device_sm >= 90
242
+ ):
243
+ logits = dsv3_router_gemm(hidden_states, self.weight).to(
244
+ hidden_states.dtype
245
+ )
246
+ else:
247
+ logits = F.linear(hidden_states, self.weight, None)
248
+
230
249
  return logits
231
250
 
232
251
 
@@ -300,6 +319,9 @@ class DeepseekV2MoE(nn.Module):
300
319
  ),
301
320
  )
302
321
 
322
+ self.shared_experts_is_int8 = False
323
+ self.shared_experts_is_fp8 = False
324
+ self.shared_experts_weight_block_size = None
303
325
  if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
304
326
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
305
327
  # disable tp for shared experts when enable deepep moe
@@ -316,6 +338,28 @@ class DeepseekV2MoE(nn.Module):
316
338
  else {}
317
339
  ),
318
340
  )
341
+ is_packed_weight = hasattr(
342
+ self.shared_experts.gate_up_proj.quant_method, "quant_config"
343
+ ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
344
+ "awq",
345
+ "moe_wna16",
346
+ }
347
+ self.shared_experts_is_int8 = (
348
+ not is_packed_weight
349
+ and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
350
+ )
351
+ self.shared_experts_is_fp8 = (
352
+ not is_packed_weight
353
+ and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
354
+ )
355
+ if self.shared_experts_is_fp8:
356
+ assert (
357
+ self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
358
+ == self.shared_experts.down_proj.quant_method.quant_config.weight_block_size
359
+ )
360
+ self.shared_experts_weight_block_size = (
361
+ self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
362
+ )
319
363
 
320
364
  self.top_k = config.num_experts_per_tok
321
365
 
@@ -394,6 +438,11 @@ class DeepseekV2MoE(nn.Module):
394
438
  return final_hidden_states
395
439
 
396
440
  def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
441
+ if hasattr(self, "shared_experts") and use_intel_amx_backend(
442
+ self.shared_experts.gate_up_proj
443
+ ):
444
+ return self.forward_cpu(hidden_states)
445
+
397
446
  shared_output = self._forward_shared_experts(hidden_states)
398
447
  # router_logits: (num_tokens, n_experts)
399
448
  router_logits = self.gate(hidden_states)
@@ -409,6 +458,59 @@ class DeepseekV2MoE(nn.Module):
409
458
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
410
459
  return final_hidden_states
411
460
 
461
+ def forward_cpu(self, hidden_states: torch.Tensor) -> torch.Tensor:
462
+ # router_logits: (num_tokens, n_experts)
463
+ router_logits = self.gate(hidden_states)
464
+ fused_experts_out = self.experts(
465
+ hidden_states=hidden_states, router_logits=router_logits
466
+ )
467
+
468
+ assert use_intel_amx_backend(
469
+ self.shared_experts.gate_up_proj
470
+ ) == use_intel_amx_backend(self.shared_experts.down_proj)
471
+ # [Note] inplace should be False in fused_experts.
472
+ # If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
473
+ # While hidden_states is still needed in shared_expert.
474
+ final_hidden_states = torch.ops.sgl_kernel.shared_expert_cpu(
475
+ hidden_states,
476
+ self.shared_experts.gate_up_proj.weight,
477
+ self.shared_experts.down_proj.weight,
478
+ fused_experts_out,
479
+ self.routed_scaling_factor,
480
+ True, # inplace
481
+ self.shared_experts_is_int8, # use_int8_w8a8
482
+ self.shared_experts_is_fp8, # use_fp8_w8a16
483
+ (
484
+ self.shared_experts.gate_up_proj.weight_scale
485
+ if self.shared_experts_is_int8
486
+ else (
487
+ self.shared_experts.gate_up_proj.weight_scale_inv
488
+ if self.shared_experts_is_fp8
489
+ else None
490
+ )
491
+ ), # w1_scale
492
+ (
493
+ self.shared_experts.down_proj.weight_scale
494
+ if self.shared_experts_is_int8
495
+ else (
496
+ self.shared_experts.down_proj.weight_scale_inv
497
+ if self.shared_experts_is_fp8
498
+ else None
499
+ )
500
+ ), # w2_scale
501
+ (
502
+ self.shared_experts_weight_block_size
503
+ if self.shared_experts_is_fp8
504
+ else None
505
+ ), # block_size
506
+ None, # a1_scale
507
+ None, # a2_scale
508
+ True, # is_vnni
509
+ )
510
+ if self.tp_size > 1:
511
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
512
+ return final_hidden_states
513
+
412
514
  def forward_deepep(
413
515
  self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
414
516
  ) -> torch.Tensor:
@@ -456,7 +558,7 @@ class DeepseekV2MoE(nn.Module):
456
558
  hidden_states=hidden_states,
457
559
  topk_idx=topk_idx,
458
560
  topk_weights=topk_weights,
459
- forward_mode=forward_mode,
561
+ forward_batch=forward_batch,
460
562
  )
461
563
  final_hidden_states = self.experts(
462
564
  hidden_states=hidden_states,
@@ -467,14 +569,14 @@ class DeepseekV2MoE(nn.Module):
467
569
  masked_m=masked_m,
468
570
  expected_m=expected_m,
469
571
  num_recv_tokens_per_expert=num_recv_tokens_per_expert,
470
- forward_mode=forward_mode,
572
+ forward_batch=forward_batch,
471
573
  )
472
574
  if self.ep_size > 1:
473
575
  final_hidden_states = self.deepep_dispatcher.combine(
474
576
  hidden_states=final_hidden_states,
475
577
  topk_idx=topk_idx,
476
578
  topk_weights=topk_weights,
477
- forward_mode=forward_mode,
579
+ forward_batch=forward_batch,
478
580
  )
479
581
 
480
582
  if shared_output is not None:
@@ -549,7 +651,7 @@ class DeepseekV2MoE(nn.Module):
549
651
  hidden_states=state.hidden_states_mlp_input,
550
652
  topk_idx=state.pop("topk_idx_local"),
551
653
  topk_weights=state.pop("topk_weights_local"),
552
- forward_mode=state.forward_batch.forward_mode,
654
+ forward_batch=state.forward_batch,
553
655
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
554
656
  )
555
657
 
@@ -581,7 +683,7 @@ class DeepseekV2MoE(nn.Module):
581
683
  masked_m=state.pop("masked_m"),
582
684
  expected_m=state.pop("expected_m"),
583
685
  num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
584
- forward_mode=state.forward_batch.forward_mode,
686
+ forward_batch=state.forward_batch,
585
687
  )
586
688
 
587
689
  def op_combine_a(self, state):
@@ -590,7 +692,7 @@ class DeepseekV2MoE(nn.Module):
590
692
  hidden_states=state.pop("hidden_states_experts_output"),
591
693
  topk_idx=state.pop("topk_idx_dispatched"),
592
694
  topk_weights=state.pop("topk_weights_dispatched"),
593
- forward_mode=state.forward_batch.forward_mode,
695
+ forward_batch=state.forward_batch,
594
696
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
595
697
  )
596
698
 
@@ -793,33 +895,56 @@ class DeepseekV2AttentionMLA(nn.Module):
793
895
  # If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
794
896
  # which requires self.w_kc and self.w_vc to be packed.
795
897
  # If not, we will use torch.bmm and weight shouldn't be packed in this case
796
- if (
797
- hasattr(self, "fused_qkv_a_proj_with_mqa")
798
- and _is_cpu
799
- and _is_cpu_amx_available
800
- ):
898
+ has_fused_proj = hasattr(self, "fused_qkv_a_proj_with_mqa")
899
+ if has_fused_proj and _is_cpu and _is_cpu_amx_available:
801
900
  self.quant_method = PackWeightMethod(
802
901
  weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
803
902
  )
804
903
 
904
+ is_packed_weight = (
905
+ has_fused_proj
906
+ and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
907
+ and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
908
+ in {"awq", "moe_wna16"}
909
+ )
910
+ self.use_min_latency_fused_a_gemm = (
911
+ has_fused_proj
912
+ and not is_packed_weight
913
+ and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
914
+ and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
915
+ and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
916
+ and _is_cuda
917
+ and _device_sm >= 90
918
+ )
919
+
805
920
  self.qkv_proj_with_rope_is_int8 = (
806
- hasattr(self, "fused_qkv_a_proj_with_mqa")
921
+ has_fused_proj
922
+ and not is_packed_weight
807
923
  and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
808
924
  )
809
925
  self.qkv_proj_with_rope_is_fp8 = (
810
- hasattr(self, "fused_qkv_a_proj_with_mqa")
926
+ has_fused_proj
927
+ and not is_packed_weight
811
928
  and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
812
929
  )
813
930
 
814
931
  self.weight_block_size = None
815
- if self.qkv_proj_with_rope_is_fp8:
816
- assert (
817
- self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
818
- == self.q_b_proj.quant_method.quant_config.weight_block_size
819
- )
820
- self.weight_block_size = (
821
- self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
822
- )
932
+ if self.qkv_proj_with_rope_is_fp8 and _is_cpu and _is_cpu_amx_available:
933
+ assert getattr(
934
+ self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
935
+ ) == getattr(self.q_b_proj.quant_method, "block_quant", False)
936
+ use_block_quant = getattr(
937
+ self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
938
+ )
939
+
940
+ if use_block_quant:
941
+ assert (
942
+ self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
943
+ == self.q_b_proj.quant_method.quant_config.weight_block_size
944
+ )
945
+ self.weight_block_size = (
946
+ self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
947
+ )
823
948
 
824
949
  def dispatch_attn_forward_method(
825
950
  self, forward_batch: ForwardBatch
@@ -834,14 +959,16 @@ class DeepseekV2AttentionMLA(nn.Module):
834
959
  else:
835
960
  return AttnForwardMethod.MLA
836
961
  else:
837
- if hasattr(self, "fused_qkv_a_proj_with_mqa") and getattr(
838
- self, "use_intel_amx_backend", False
962
+ if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
963
+ self
839
964
  ):
840
965
  return AttnForwardMethod.MLA_FUSED_ROPE_CPU
841
966
  else:
842
967
  return AttnForwardMethod.MLA
843
968
 
844
- if self.attention_backend == "flashinfer":
969
+ if self.attention_backend == "ascend":
970
+ return AttnForwardMethod.MLA
971
+ elif self.attention_backend == "flashinfer":
845
972
  # Flashinfer MLA: Do not absorb when enabling ragged prefill
846
973
  if (
847
974
  not self.flashinfer_mla_disable_ragged
@@ -1041,7 +1168,13 @@ class DeepseekV2AttentionMLA(nn.Module):
1041
1168
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1042
1169
 
1043
1170
  if self.q_lora_rank is not None:
1044
- q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
1171
+ if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
1172
+ fused_qkv_a_proj_out = dsv3_fused_a_gemm(
1173
+ hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
1174
+ )
1175
+ else:
1176
+ fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
1177
+ q, latent_cache = fused_qkv_a_proj_out.split(
1045
1178
  [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
1046
1179
  )
1047
1180
  k_nope = latent_cache[..., : self.kv_lora_rank]
@@ -1302,8 +1435,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1302
1435
  forward_batch: ForwardBatch,
1303
1436
  zero_allocator: BumpAllocator,
1304
1437
  ):
1305
- assert self.q_lora_rank is not None and getattr(
1306
- self, "use_intel_amx_backend", False
1438
+ assert self.q_lora_rank is not None and use_intel_amx_backend(
1439
+ self
1307
1440
  ), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
1308
1441
 
1309
1442
  q_input, k_input, v_input = (
@@ -1422,8 +1555,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1422
1555
  def forward_absorb_fused_mla_rope_cpu_core(
1423
1556
  self, q_input, k_input, v_input, forward_batch, zero_allocator
1424
1557
  ):
1425
- assert self.q_lora_rank is not None and getattr(
1426
- self, "use_intel_amx_backend", False
1558
+ assert self.q_lora_rank is not None and use_intel_amx_backend(
1559
+ self
1427
1560
  ), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
1428
1561
 
1429
1562
  attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
@@ -1707,11 +1840,6 @@ class DeepseekV2DecoderLayer(nn.Module):
1707
1840
  hidden_states, residual, forward_batch
1708
1841
  )
1709
1842
 
1710
- if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
1711
- # NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
1712
- # See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
1713
- hidden_states = hidden_states.clone()
1714
-
1715
1843
  return hidden_states, residual
1716
1844
 
1717
1845
  def op_comm_prepare_attn(
@@ -1753,7 +1881,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1753
1881
  and hidden_states.shape[0] == 0
1754
1882
  ):
1755
1883
  state.hidden_states_mlp_output = self.mlp(
1756
- hidden_states, state.forward_batch.forward_mode
1884
+ hidden_states, state.forward_batch
1757
1885
  )
1758
1886
  else:
1759
1887
  state.hidden_states_mlp_output = hidden_states
@@ -1802,7 +1930,7 @@ class DeepseekV2Model(nn.Module):
1802
1930
  self.embed_tokens = VocabParallelEmbedding(
1803
1931
  config.vocab_size,
1804
1932
  config.hidden_size,
1805
- enable_tp=not global_server_args_dict["enable_dp_attention"],
1933
+ use_attn_tp_group=True,
1806
1934
  )
1807
1935
  self.alt_stream = torch.cuda.Stream() if _is_cuda else None
1808
1936
  self.layers = nn.ModuleList(
@@ -2107,6 +2235,14 @@ class DeepseekV2ForCausalLM(nn.Module):
2107
2235
  )
2108
2236
  if _is_hip:
2109
2237
  self_attn.w_scale *= 2.0
2238
+ # TODO: remove this after adding FP8 support in bmm cpu kernel
2239
+ if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn:
2240
+ self_attn.w_kc = (
2241
+ self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale
2242
+ )
2243
+ self_attn.w_vc = (
2244
+ self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale
2245
+ )
2110
2246
  else:
2111
2247
  num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
2112
2248
  num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
@@ -253,11 +253,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
253
253
  weights_loader = getattr(param, "weight_loader", default_weight_loader)
254
254
  weights_loader(param, loaded_weight)
255
255
 
256
- def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
257
- helper = MultiModalityDataPaddingPatternMultimodalTokens(
258
- [image_inputs.im_token_id]
259
- )
260
- return helper.pad_input_tokens(input_ids, image_inputs)
256
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
257
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
258
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
261
259
 
262
260
  def get_image_feature(self, items: List[MultimodalDataItem]):
263
261
 
@@ -166,8 +166,7 @@ class Gemma3Attention(nn.Module):
166
166
  prefix=add_prefix("o_proj", prefix),
167
167
  )
168
168
 
169
- # Determine if layer uses sliding window based on pattern
170
- self.is_sliding = bool((layer_id + 1) % config.sliding_window_pattern)
169
+ self.is_sliding = config.layer_types[layer_id] == "sliding_attention"
171
170
 
172
171
  # Initialize the rotary embedding.
173
172
  if self.is_sliding:
@@ -62,7 +62,7 @@ class Gemma3nTextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
62
62
  pass
63
63
 
64
64
 
65
- class Gemma3nMLP(nn.Module):
65
+ class Gemma3nTextMLP(nn.Module):
66
66
  def __init__(
67
67
  self,
68
68
  hidden_size: int,
@@ -514,10 +514,11 @@ class Gemma3nDecoderLayer(nn.Module):
514
514
  prefix=add_prefix("self_attn", prefix),
515
515
  )
516
516
 
517
+ intermediate_size = config.intermediate_size[layer_id]
517
518
  activation_sparsity = config.activation_sparsity_pattern[layer_id]
518
- self.mlp = Gemma3nMLP(
519
+ self.mlp = Gemma3nTextMLP(
519
520
  hidden_size=self.hidden_size,
520
- intermediate_size=config.intermediate_size,
521
+ intermediate_size=intermediate_size,
521
522
  hidden_activation=config.hidden_activation,
522
523
  activation_sparsity=activation_sparsity,
523
524
  quant_config=quant_config,
@@ -21,7 +21,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
21
21
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
22
22
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
23
23
  from sglang.srt.managers.mm_utils import (
24
- MultiModalityDataPaddingPatternTokenPairs,
24
+ MultiModalityDataPaddingPatternMultimodalTokens,
25
25
  general_mm_embed_routine,
26
26
  )
27
27
  from sglang.srt.managers.schedule_batch import (
@@ -244,26 +244,11 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
244
244
  def pad_input_ids(
245
245
  self,
246
246
  input_ids: List[int],
247
- mm_inputs: Optional[MultimodalInputs] = None,
247
+ mm_inputs: MultimodalInputs,
248
248
  ) -> List[int]:
249
249
  """Pad input IDs with image and audio tokens."""
250
- if mm_inputs is None:
251
- return input_ids
252
-
253
- # Collect available media token pairs
254
- media_token_pairs = []
255
- for attr_name in ["im_start_id", "audio_start_id"]:
256
- if hasattr(mm_inputs, attr_name):
257
- start_id = getattr(mm_inputs, attr_name)
258
- end_id = getattr(mm_inputs, attr_name.replace("start", "end"))
259
- media_token_pairs.append((start_id, end_id))
260
-
261
- # Apply padding pattern if we have media tokens
262
- if media_token_pairs:
263
- pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
264
- return pattern.pad_input_tokens(input_ids, mm_inputs)
265
-
266
- return input_ids
250
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
251
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
267
252
 
268
253
  def get_input_embeddings(self) -> nn.Embedding:
269
254
  return self.language_model.get_input_embeddings()
@@ -431,7 +416,6 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
431
416
  )
432
417
 
433
418
  positions += 1
434
-
435
419
  if input_ids is not None:
436
420
  # Prepare per-layer inputs from inputs_ids
437
421
  per_layer_inputs_mask = torch.logical_and(
@@ -28,6 +28,7 @@ from sglang.srt.distributed import (
28
28
  get_tensor_model_parallel_world_size,
29
29
  tensor_model_parallel_all_reduce,
30
30
  )
31
+ from sglang.srt.eplb.expert_distribution import ExpertDistributionRecorder
31
32
  from sglang.srt.layers.activation import SiluAndMul
32
33
  from sglang.srt.layers.layernorm import RMSNorm
33
34
  from sglang.srt.layers.linear import (
@@ -48,7 +49,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
48
49
  ParallelLMHead,
49
50
  VocabParallelEmbedding,
50
51
  )
51
- from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
52
52
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
53
53
  from sglang.srt.model_loader.weight_utils import (
54
54
  default_weight_loader,
@@ -154,8 +154,7 @@ class KimiVLForConditionalGeneration(nn.Module):
154
154
  return res
155
155
 
156
156
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
157
- # Get all special token IDs
158
- pattern = MultiModalityDataPaddingPatternMultimodalTokens(mm_inputs.im_token_id)
157
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
159
158
  return pattern.pad_input_tokens(input_ids, mm_inputs)
160
159
 
161
160
  def forward(