sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -32,7 +32,11 @@ from sglang.srt.distributed import (
32
32
  parallel_state,
33
33
  tensor_model_parallel_all_reduce,
34
34
  )
35
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
36
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
37
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
35
38
  from sglang.srt.layers.activation import SiluAndMul
39
+ from sglang.srt.layers.amx_utils import PackWeightMethod
36
40
  from sglang.srt.layers.communicator import (
37
41
  LayerCommunicator,
38
42
  LayerScatterModes,
@@ -77,11 +81,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
77
81
  ParallelLMHead,
78
82
  VocabParallelEmbedding,
79
83
  )
80
- from sglang.srt.managers.expert_distribution import (
81
- get_global_expert_distribution_recorder,
82
- )
83
- from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
84
- from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
85
84
  from sglang.srt.managers.schedule_batch import global_server_args_dict
86
85
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
87
86
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -97,12 +96,14 @@ from sglang.srt.utils import (
97
96
  bind_or_assign,
98
97
  cpu_has_amx_support,
99
98
  get_bool_env_var,
99
+ get_device_sm,
100
100
  get_int_env_var,
101
101
  is_cpu,
102
102
  is_cuda,
103
103
  is_hip,
104
104
  is_non_idle_and_non_empty,
105
105
  log_info_on_rank0,
106
+ use_intel_amx_backend,
106
107
  )
107
108
 
108
109
  _is_hip = is_hip()
@@ -111,9 +112,16 @@ _is_fp8_fnuz = is_fp8_fnuz()
111
112
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
112
113
  _is_cpu_amx_available = cpu_has_amx_support()
113
114
  _is_cpu = is_cpu()
115
+ _device_sm = get_device_sm()
114
116
 
115
117
  if _is_cuda:
116
- from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
118
+ from sgl_kernel import (
119
+ awq_dequantize,
120
+ bmm_fp8,
121
+ dsv3_fused_a_gemm,
122
+ dsv3_router_gemm,
123
+ merge_state_v2,
124
+ )
117
125
  elif _is_cpu and _is_cpu_amx_available:
118
126
  pass
119
127
  else:
@@ -124,8 +132,6 @@ if _is_hip:
124
132
  decode_attention_fwd_grouped_rope,
125
133
  )
126
134
 
127
- if _use_aiter:
128
- from aiter.rotary_embedding import get_rope
129
135
 
130
136
  logger = logging.getLogger(__name__)
131
137
 
@@ -144,6 +150,9 @@ class AttnForwardMethod(IntEnum):
144
150
  # Use MLA but with fused RoPE
145
151
  MLA_FUSED_ROPE = auto()
146
152
 
153
+ # Use MLA with fused RoPE kernel for CPU
154
+ MLA_FUSED_ROPE_CPU = auto()
155
+
147
156
 
148
157
  class DeepseekV2MLP(nn.Module):
149
158
  def __init__(
@@ -212,9 +221,31 @@ class MoEGate(nn.Module):
212
221
  )
213
222
  else:
214
223
  self.e_score_correction_bias = None
224
+ if _is_cpu and _is_cpu_amx_available:
225
+ self.quant_method = PackWeightMethod(weight_names=["weight"])
215
226
 
216
227
  def forward(self, hidden_states):
217
- logits = F.linear(hidden_states, self.weight, None)
228
+ if use_intel_amx_backend(self):
229
+ return torch.ops.sgl_kernel.weight_packed_linear(
230
+ hidden_states,
231
+ self.weight,
232
+ None, # bias
233
+ True, # is_vnni
234
+ )
235
+
236
+ if (
237
+ _is_cuda
238
+ and hidden_states.shape[0] < 4
239
+ and hidden_states.shape[1] == 7168
240
+ and self.weight.shape[0] == 256
241
+ and _device_sm >= 90
242
+ ):
243
+ logits = dsv3_router_gemm(hidden_states, self.weight).to(
244
+ hidden_states.dtype
245
+ )
246
+ else:
247
+ logits = F.linear(hidden_states, self.weight, None)
248
+
218
249
  return logits
219
250
 
220
251
 
@@ -288,6 +319,9 @@ class DeepseekV2MoE(nn.Module):
288
319
  ),
289
320
  )
290
321
 
322
+ self.shared_experts_is_int8 = False
323
+ self.shared_experts_is_fp8 = False
324
+ self.shared_experts_weight_block_size = None
291
325
  if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
292
326
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
293
327
  # disable tp for shared experts when enable deepep moe
@@ -304,6 +338,28 @@ class DeepseekV2MoE(nn.Module):
304
338
  else {}
305
339
  ),
306
340
  )
341
+ is_packed_weight = hasattr(
342
+ self.shared_experts.gate_up_proj.quant_method, "quant_config"
343
+ ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
344
+ "awq",
345
+ "moe_wna16",
346
+ }
347
+ self.shared_experts_is_int8 = (
348
+ not is_packed_weight
349
+ and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
350
+ )
351
+ self.shared_experts_is_fp8 = (
352
+ not is_packed_weight
353
+ and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
354
+ )
355
+ if self.shared_experts_is_fp8:
356
+ assert (
357
+ self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
358
+ == self.shared_experts.down_proj.quant_method.quant_config.weight_block_size
359
+ )
360
+ self.shared_experts_weight_block_size = (
361
+ self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
362
+ )
307
363
 
308
364
  self.top_k = config.num_experts_per_tok
309
365
 
@@ -382,13 +438,19 @@ class DeepseekV2MoE(nn.Module):
382
438
  return final_hidden_states
383
439
 
384
440
  def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
441
+ if hasattr(self, "shared_experts") and use_intel_amx_backend(
442
+ self.shared_experts.gate_up_proj
443
+ ):
444
+ return self.forward_cpu(hidden_states)
445
+
385
446
  shared_output = self._forward_shared_experts(hidden_states)
386
447
  # router_logits: (num_tokens, n_experts)
387
448
  router_logits = self.gate(hidden_states)
388
449
  final_hidden_states = self.experts(
389
450
  hidden_states=hidden_states, router_logits=router_logits
390
451
  )
391
- if not _is_cuda:
452
+ if not _is_cuda and not _use_aiter:
453
+ # fused in biased_grouped_topk so we can skip here
392
454
  final_hidden_states *= self.routed_scaling_factor
393
455
  if shared_output is not None:
394
456
  final_hidden_states = final_hidden_states + shared_output
@@ -396,6 +458,59 @@ class DeepseekV2MoE(nn.Module):
396
458
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
397
459
  return final_hidden_states
398
460
 
461
+ def forward_cpu(self, hidden_states: torch.Tensor) -> torch.Tensor:
462
+ # router_logits: (num_tokens, n_experts)
463
+ router_logits = self.gate(hidden_states)
464
+ fused_experts_out = self.experts(
465
+ hidden_states=hidden_states, router_logits=router_logits
466
+ )
467
+
468
+ assert use_intel_amx_backend(
469
+ self.shared_experts.gate_up_proj
470
+ ) == use_intel_amx_backend(self.shared_experts.down_proj)
471
+ # [Note] inplace should be False in fused_experts.
472
+ # If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
473
+ # While hidden_states is still needed in shared_expert.
474
+ final_hidden_states = torch.ops.sgl_kernel.shared_expert_cpu(
475
+ hidden_states,
476
+ self.shared_experts.gate_up_proj.weight,
477
+ self.shared_experts.down_proj.weight,
478
+ fused_experts_out,
479
+ self.routed_scaling_factor,
480
+ True, # inplace
481
+ self.shared_experts_is_int8, # use_int8_w8a8
482
+ self.shared_experts_is_fp8, # use_fp8_w8a16
483
+ (
484
+ self.shared_experts.gate_up_proj.weight_scale
485
+ if self.shared_experts_is_int8
486
+ else (
487
+ self.shared_experts.gate_up_proj.weight_scale_inv
488
+ if self.shared_experts_is_fp8
489
+ else None
490
+ )
491
+ ), # w1_scale
492
+ (
493
+ self.shared_experts.down_proj.weight_scale
494
+ if self.shared_experts_is_int8
495
+ else (
496
+ self.shared_experts.down_proj.weight_scale_inv
497
+ if self.shared_experts_is_fp8
498
+ else None
499
+ )
500
+ ), # w2_scale
501
+ (
502
+ self.shared_experts_weight_block_size
503
+ if self.shared_experts_is_fp8
504
+ else None
505
+ ), # block_size
506
+ None, # a1_scale
507
+ None, # a2_scale
508
+ True, # is_vnni
509
+ )
510
+ if self.tp_size > 1:
511
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
512
+ return final_hidden_states
513
+
399
514
  def forward_deepep(
400
515
  self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
401
516
  ) -> torch.Tensor:
@@ -443,7 +558,7 @@ class DeepseekV2MoE(nn.Module):
443
558
  hidden_states=hidden_states,
444
559
  topk_idx=topk_idx,
445
560
  topk_weights=topk_weights,
446
- forward_mode=forward_mode,
561
+ forward_batch=forward_batch,
447
562
  )
448
563
  final_hidden_states = self.experts(
449
564
  hidden_states=hidden_states,
@@ -454,14 +569,14 @@ class DeepseekV2MoE(nn.Module):
454
569
  masked_m=masked_m,
455
570
  expected_m=expected_m,
456
571
  num_recv_tokens_per_expert=num_recv_tokens_per_expert,
457
- forward_mode=forward_mode,
572
+ forward_batch=forward_batch,
458
573
  )
459
574
  if self.ep_size > 1:
460
575
  final_hidden_states = self.deepep_dispatcher.combine(
461
576
  hidden_states=final_hidden_states,
462
577
  topk_idx=topk_idx,
463
578
  topk_weights=topk_weights,
464
- forward_mode=forward_mode,
579
+ forward_batch=forward_batch,
465
580
  )
466
581
 
467
582
  if shared_output is not None:
@@ -536,7 +651,7 @@ class DeepseekV2MoE(nn.Module):
536
651
  hidden_states=state.hidden_states_mlp_input,
537
652
  topk_idx=state.pop("topk_idx_local"),
538
653
  topk_weights=state.pop("topk_weights_local"),
539
- forward_mode=state.forward_batch.forward_mode,
654
+ forward_batch=state.forward_batch,
540
655
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
541
656
  )
542
657
 
@@ -568,7 +683,7 @@ class DeepseekV2MoE(nn.Module):
568
683
  masked_m=state.pop("masked_m"),
569
684
  expected_m=state.pop("expected_m"),
570
685
  num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
571
- forward_mode=state.forward_batch.forward_mode,
686
+ forward_batch=state.forward_batch,
572
687
  )
573
688
 
574
689
  def op_combine_a(self, state):
@@ -577,7 +692,7 @@ class DeepseekV2MoE(nn.Module):
577
692
  hidden_states=state.pop("hidden_states_experts_output"),
578
693
  topk_idx=state.pop("topk_idx_dispatched"),
579
694
  topk_weights=state.pop("topk_weights_dispatched"),
580
- forward_mode=state.forward_batch.forward_mode,
695
+ forward_batch=state.forward_batch,
581
696
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
582
697
  )
583
698
 
@@ -777,6 +892,60 @@ class DeepseekV2AttentionMLA(nn.Module):
777
892
  "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
778
893
  )
779
894
 
895
+ # If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
896
+ # which requires self.w_kc and self.w_vc to be packed.
897
+ # If not, we will use torch.bmm and weight shouldn't be packed in this case
898
+ has_fused_proj = hasattr(self, "fused_qkv_a_proj_with_mqa")
899
+ if has_fused_proj and _is_cpu and _is_cpu_amx_available:
900
+ self.quant_method = PackWeightMethod(
901
+ weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
902
+ )
903
+
904
+ is_packed_weight = (
905
+ has_fused_proj
906
+ and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
907
+ and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
908
+ in {"awq", "moe_wna16"}
909
+ )
910
+ self.use_min_latency_fused_a_gemm = (
911
+ has_fused_proj
912
+ and not is_packed_weight
913
+ and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
914
+ and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
915
+ and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
916
+ and _is_cuda
917
+ and _device_sm >= 90
918
+ )
919
+
920
+ self.qkv_proj_with_rope_is_int8 = (
921
+ has_fused_proj
922
+ and not is_packed_weight
923
+ and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
924
+ )
925
+ self.qkv_proj_with_rope_is_fp8 = (
926
+ has_fused_proj
927
+ and not is_packed_weight
928
+ and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
929
+ )
930
+
931
+ self.weight_block_size = None
932
+ if self.qkv_proj_with_rope_is_fp8 and _is_cpu and _is_cpu_amx_available:
933
+ assert getattr(
934
+ self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
935
+ ) == getattr(self.q_b_proj.quant_method, "block_quant", False)
936
+ use_block_quant = getattr(
937
+ self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
938
+ )
939
+
940
+ if use_block_quant:
941
+ assert (
942
+ self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
943
+ == self.q_b_proj.quant_method.quant_config.weight_block_size
944
+ )
945
+ self.weight_block_size = (
946
+ self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
947
+ )
948
+
780
949
  def dispatch_attn_forward_method(
781
950
  self, forward_batch: ForwardBatch
782
951
  ) -> AttnForwardMethod:
@@ -790,9 +959,16 @@ class DeepseekV2AttentionMLA(nn.Module):
790
959
  else:
791
960
  return AttnForwardMethod.MLA
792
961
  else:
793
- return AttnForwardMethod.MLA
962
+ if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
963
+ self
964
+ ):
965
+ return AttnForwardMethod.MLA_FUSED_ROPE_CPU
966
+ else:
967
+ return AttnForwardMethod.MLA
794
968
 
795
- if self.attention_backend == "flashinfer":
969
+ if self.attention_backend == "ascend":
970
+ return AttnForwardMethod.MLA
971
+ elif self.attention_backend == "flashinfer":
796
972
  # Flashinfer MLA: Do not absorb when enabling ragged prefill
797
973
  if (
798
974
  not self.flashinfer_mla_disable_ragged
@@ -904,6 +1080,10 @@ class DeepseekV2AttentionMLA(nn.Module):
904
1080
  inner_state = self.forward_absorb_fused_mla_rope_prepare(
905
1081
  positions, hidden_states, forward_batch, zero_allocator
906
1082
  )
1083
+ elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
1084
+ inner_state = self.forward_absorb_fused_mla_rope_cpu_prepare(
1085
+ positions, hidden_states, forward_batch, zero_allocator
1086
+ )
907
1087
  else:
908
1088
  raise NotImplementedError
909
1089
  return None, attn_forward_method, forward_batch, inner_state
@@ -923,6 +1103,8 @@ class DeepseekV2AttentionMLA(nn.Module):
923
1103
  return self.forward_absorb_core(*inner_state)
924
1104
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
925
1105
  return self.forward_absorb_fused_mla_rope_core(*inner_state)
1106
+ elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
1107
+ return self.forward_absorb_fused_mla_rope_cpu_core(*inner_state)
926
1108
  else:
927
1109
  raise NotImplementedError
928
1110
 
@@ -986,7 +1168,13 @@ class DeepseekV2AttentionMLA(nn.Module):
986
1168
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
987
1169
 
988
1170
  if self.q_lora_rank is not None:
989
- q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
1171
+ if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
1172
+ fused_qkv_a_proj_out = dsv3_fused_a_gemm(
1173
+ hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
1174
+ )
1175
+ else:
1176
+ fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
1177
+ q, latent_cache = fused_qkv_a_proj_out.split(
990
1178
  [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
991
1179
  )
992
1180
  k_nope = latent_cache[..., : self.kv_lora_rank]
@@ -1240,6 +1428,57 @@ class DeepseekV2AttentionMLA(nn.Module):
1240
1428
  zero_allocator,
1241
1429
  )
1242
1430
 
1431
+ def forward_absorb_fused_mla_rope_cpu_prepare(
1432
+ self,
1433
+ positions: torch.Tensor,
1434
+ hidden_states: torch.Tensor,
1435
+ forward_batch: ForwardBatch,
1436
+ zero_allocator: BumpAllocator,
1437
+ ):
1438
+ assert self.q_lora_rank is not None and use_intel_amx_backend(
1439
+ self
1440
+ ), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
1441
+
1442
+ q_input, k_input, v_input = (
1443
+ torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight(
1444
+ hidden_states,
1445
+ self.fused_qkv_a_proj_with_mqa.weight,
1446
+ self.q_b_proj.weight,
1447
+ self.w_kc,
1448
+ self.q_a_layernorm.weight,
1449
+ self.kv_a_layernorm.weight,
1450
+ positions,
1451
+ self.rotary_emb.cos_sin_cache,
1452
+ self.kv_a_layernorm.variance_epsilon,
1453
+ self.qkv_proj_with_rope_is_int8,
1454
+ self.qkv_proj_with_rope_is_fp8,
1455
+ (
1456
+ self.fused_qkv_a_proj_with_mqa.weight_scale
1457
+ if self.qkv_proj_with_rope_is_int8
1458
+ else (
1459
+ self.fused_qkv_a_proj_with_mqa.weight_scale_inv
1460
+ if self.qkv_proj_with_rope_is_fp8
1461
+ else None
1462
+ )
1463
+ ),
1464
+ (
1465
+ self.q_b_proj.weight_scale
1466
+ if self.qkv_proj_with_rope_is_int8
1467
+ else (
1468
+ self.q_b_proj.weight_scale_inv
1469
+ if self.qkv_proj_with_rope_is_fp8
1470
+ else None
1471
+ )
1472
+ ),
1473
+ True, # is_vnni
1474
+ self.weight_block_size,
1475
+ self.q_lora_rank,
1476
+ self.kv_lora_rank,
1477
+ self.qk_rope_head_dim,
1478
+ )
1479
+ )
1480
+ return (q_input, k_input, v_input, forward_batch, zero_allocator)
1481
+
1243
1482
  def forward_absorb_fused_mla_rope_core(
1244
1483
  self,
1245
1484
  q_input,
@@ -1313,6 +1552,43 @@ class DeepseekV2AttentionMLA(nn.Module):
1313
1552
 
1314
1553
  return output
1315
1554
 
1555
+ def forward_absorb_fused_mla_rope_cpu_core(
1556
+ self, q_input, k_input, v_input, forward_batch, zero_allocator
1557
+ ):
1558
+ assert self.q_lora_rank is not None and use_intel_amx_backend(
1559
+ self
1560
+ ), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
1561
+
1562
+ attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
1563
+ attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1564
+
1565
+ # [Note] Align shapes of bmm inputs.
1566
+ # Shapes of inputs:
1567
+ # q_nope: [M, B, K]
1568
+ # original self.w_kc: [B, K, N]
1569
+ # current self.w_kc (which has been converted in PackWeightMethod): [B, N, K]
1570
+
1571
+ # Shapes of inputs to sgl_kernel.cpu.bmm:
1572
+ # out: [B, M, N]
1573
+ # mat1: [B, M, K]
1574
+ # mat2: [B, N, K]
1575
+ B = self.w_vc.size(0)
1576
+ N = self.w_vc.size(1)
1577
+ M = attn_output.size(0)
1578
+ output = torch.empty([M, int(B * N)], dtype=attn_output.dtype)
1579
+ attn_bmm_output = output.view([M, B, N]).transpose_(0, 1)
1580
+ torch.ops.sgl_kernel.bmm_cpu(
1581
+ attn_bmm_output,
1582
+ attn_output.transpose(0, 1),
1583
+ self.w_vc,
1584
+ True, # is_vnni
1585
+ None, # scale
1586
+ )
1587
+ attn_output = output
1588
+ output, _ = self.o_proj(attn_output)
1589
+
1590
+ return output
1591
+
1316
1592
  def _chunked_prefix_attn_mha(
1317
1593
  self,
1318
1594
  q: torch.Tensor,
@@ -1564,11 +1840,6 @@ class DeepseekV2DecoderLayer(nn.Module):
1564
1840
  hidden_states, residual, forward_batch
1565
1841
  )
1566
1842
 
1567
- if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
1568
- # NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
1569
- # See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
1570
- hidden_states = hidden_states.clone()
1571
-
1572
1843
  return hidden_states, residual
1573
1844
 
1574
1845
  def op_comm_prepare_attn(
@@ -1610,7 +1881,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1610
1881
  and hidden_states.shape[0] == 0
1611
1882
  ):
1612
1883
  state.hidden_states_mlp_output = self.mlp(
1613
- hidden_states, state.forward_batch.forward_mode
1884
+ hidden_states, state.forward_batch
1614
1885
  )
1615
1886
  else:
1616
1887
  state.hidden_states_mlp_output = hidden_states
@@ -1659,7 +1930,7 @@ class DeepseekV2Model(nn.Module):
1659
1930
  self.embed_tokens = VocabParallelEmbedding(
1660
1931
  config.vocab_size,
1661
1932
  config.hidden_size,
1662
- enable_tp=not global_server_args_dict["enable_dp_attention"],
1933
+ use_attn_tp_group=True,
1663
1934
  )
1664
1935
  self.alt_stream = torch.cuda.Stream() if _is_cuda else None
1665
1936
  self.layers = nn.ModuleList(
@@ -1964,6 +2235,14 @@ class DeepseekV2ForCausalLM(nn.Module):
1964
2235
  )
1965
2236
  if _is_hip:
1966
2237
  self_attn.w_scale *= 2.0
2238
+ # TODO: remove this after adding FP8 support in bmm cpu kernel
2239
+ if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn:
2240
+ self_attn.w_kc = (
2241
+ self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale
2242
+ )
2243
+ self_attn.w_vc = (
2244
+ self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale
2245
+ )
1967
2246
  else:
1968
2247
  num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
1969
2248
  num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
@@ -253,11 +253,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
253
253
  weights_loader = getattr(param, "weight_loader", default_weight_loader)
254
254
  weights_loader(param, loaded_weight)
255
255
 
256
- def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
257
- helper = MultiModalityDataPaddingPatternMultimodalTokens(
258
- [image_inputs.im_token_id]
259
- )
260
- return helper.pad_input_tokens(input_ids, image_inputs)
256
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
257
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
258
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
261
259
 
262
260
  def get_image_feature(self, items: List[MultimodalDataItem]):
263
261
 
@@ -166,8 +166,7 @@ class Gemma3Attention(nn.Module):
166
166
  prefix=add_prefix("o_proj", prefix),
167
167
  )
168
168
 
169
- # Determine if layer uses sliding window based on pattern
170
- self.is_sliding = bool((layer_id + 1) % config.sliding_window_pattern)
169
+ self.is_sliding = config.layer_types[layer_id] == "sliding_attention"
171
170
 
172
171
  # Initialize the rotary embedding.
173
172
  if self.is_sliding: