sglang 0.5.2rc0__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/lang/interpreter.py +1 -1
  2. sglang/srt/configs/internvl.py +6 -0
  3. sglang/srt/configs/model_config.py +2 -1
  4. sglang/srt/disaggregation/mini_lb.py +2 -2
  5. sglang/srt/distributed/parallel_state.py +46 -41
  6. sglang/srt/entrypoints/engine.py +1 -1
  7. sglang/srt/entrypoints/http_server.py +5 -1
  8. sglang/srt/entrypoints/openai/protocol.py +3 -3
  9. sglang/srt/entrypoints/openai/serving_chat.py +3 -3
  10. sglang/srt/entrypoints/openai/serving_completions.py +3 -1
  11. sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
  12. sglang/srt/entrypoints/openai/serving_responses.py +1 -1
  13. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  14. sglang/srt/layers/attention/aiter_backend.py +93 -68
  15. sglang/srt/layers/communicator.py +45 -7
  16. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  17. sglang/srt/layers/moe/ep_moe/layer.py +2 -7
  18. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  19. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  21. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  24. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  25. sglang/srt/layers/moe/utils.py +0 -1
  26. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +8 -0
  27. sglang/srt/layers/quantization/modelopt_quant.py +35 -2
  28. sglang/srt/layers/quantization/mxfp4.py +4 -1
  29. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  30. sglang/srt/layers/quantization/quark/utils.py +97 -0
  31. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  32. sglang/srt/layers/quantization/w4afp8.py +30 -25
  33. sglang/srt/layers/rocm_linear_utils.py +44 -0
  34. sglang/srt/layers/rotary_embedding.py +0 -18
  35. sglang/srt/managers/cache_controller.py +42 -39
  36. sglang/srt/managers/detokenizer_manager.py +0 -34
  37. sglang/srt/managers/multi_tokenizer_mixin.py +48 -6
  38. sglang/srt/managers/schedule_policy.py +3 -2
  39. sglang/srt/managers/scheduler.py +7 -100
  40. sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
  41. sglang/srt/managers/template_manager.py +3 -3
  42. sglang/srt/managers/tokenizer_manager.py +1 -0
  43. sglang/srt/mem_cache/allocator.py +1 -1
  44. sglang/srt/mem_cache/hicache_storage.py +15 -10
  45. sglang/srt/mem_cache/hiradix_cache.py +16 -0
  46. sglang/srt/mem_cache/memory_pool_host.py +18 -11
  47. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  48. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +35 -6
  49. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
  50. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  51. sglang/srt/metrics/collector.py +12 -4
  52. sglang/srt/metrics/utils.py +48 -0
  53. sglang/srt/model_executor/forward_batch_info.py +16 -17
  54. sglang/srt/model_executor/model_runner.py +1 -1
  55. sglang/srt/models/deepseek_v2.py +245 -36
  56. sglang/srt/models/glm4_moe.py +10 -1
  57. sglang/srt/models/gpt_oss.py +5 -4
  58. sglang/srt/models/internvl.py +28 -0
  59. sglang/srt/models/longcat_flash.py +26 -15
  60. sglang/srt/models/longcat_flash_nextn.py +23 -15
  61. sglang/srt/models/minicpmv.py +165 -3
  62. sglang/srt/models/qwen2_moe.py +4 -1
  63. sglang/srt/models/qwen3.py +8 -2
  64. sglang/srt/models/qwen3_moe.py +39 -8
  65. sglang/srt/models/torch_native_llama.py +1 -1
  66. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  67. sglang/srt/server_args.py +79 -2
  68. sglang/srt/speculative/eagle_worker.py +158 -112
  69. sglang/srt/utils.py +12 -10
  70. sglang/test/few_shot_gsm8k.py +1 -0
  71. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  72. sglang/utils.py +1 -0
  73. sglang/version.py +1 -1
  74. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +2 -2
  75. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +83 -76
  76. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  77. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  78. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  79. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  80. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  81. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  82. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
  83. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
@@ -112,6 +112,7 @@ from sglang.srt.utils import (
112
112
  is_cpu,
113
113
  is_cuda,
114
114
  is_flashinfer_available,
115
+ is_gfx95_supported,
115
116
  is_hip,
116
117
  is_non_idle_and_non_empty,
117
118
  is_npu,
@@ -129,6 +130,22 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
129
130
  _is_cpu_amx_available = cpu_has_amx_support()
130
131
  _is_cpu = is_cpu()
131
132
  _device_sm = get_device_sm()
133
+ _is_gfx95_supported = is_gfx95_supported()
134
+
135
+ _use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
136
+
137
+ if _use_aiter_gfx95:
138
+ from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
139
+ from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
140
+ batched_gemm_afp4wfp4_pre_quant,
141
+ fused_flatten_mxfp4_quant,
142
+ fused_rms_mxfp4_quant,
143
+ )
144
+ from sglang.srt.layers.rocm_linear_utils import (
145
+ aiter_dsv3_router_gemm,
146
+ fused_qk_rope_cat,
147
+ get_dsv3_gemm_output_zero_allocator_size,
148
+ )
132
149
 
133
150
  if _is_cuda:
134
151
  from sgl_kernel import (
@@ -224,10 +241,17 @@ class DeepseekV2MLP(nn.Module):
224
241
  forward_batch=None,
225
242
  should_allreduce_fusion: bool = False,
226
243
  use_reduce_scatter: bool = False,
244
+ gemm_output_zero_allocator: BumpAllocator = None,
227
245
  ):
228
246
  if (self.tp_size == 1) and x.shape[0] == 0:
229
247
  return x
230
248
 
249
+ if gemm_output_zero_allocator != None and x.shape[0] <= 256:
250
+ y = gemm_output_zero_allocator.allocate(
251
+ x.shape[0] * self.gate_up_proj.output_size_per_partition
252
+ ).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
253
+ x = (x, None, y)
254
+
231
255
  gate_up, _ = self.gate_up_proj(x)
232
256
  x = self.act_fn(gate_up)
233
257
  x, _ = self.down_proj(
@@ -257,7 +281,7 @@ class MoEGate(nn.Module):
257
281
  if _is_cpu and _is_cpu_amx_available:
258
282
  self.quant_method = PackWeightMethod(weight_names=["weight"])
259
283
 
260
- def forward(self, hidden_states):
284
+ def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None):
261
285
  if use_intel_amx_backend(self):
262
286
  return torch.ops.sgl_kernel.weight_packed_linear(
263
287
  hidden_states,
@@ -276,6 +300,10 @@ class MoEGate(nn.Module):
276
300
  ):
277
301
  # router gemm output float32
278
302
  logits = dsv3_router_gemm(hidden_states, self.weight)
303
+ elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
304
+ logits = aiter_dsv3_router_gemm(
305
+ hidden_states, self.weight, gemm_output_zero_allocator
306
+ )
279
307
  else:
280
308
  logits = F.linear(hidden_states, self.weight, None)
281
309
 
@@ -439,6 +467,7 @@ class DeepseekV2MoE(nn.Module):
439
467
  forward_batch: Optional[ForwardBatch] = None,
440
468
  should_allreduce_fusion: bool = False,
441
469
  use_reduce_scatter: bool = False,
470
+ gemm_output_zero_allocator: BumpAllocator = None,
442
471
  ) -> torch.Tensor:
443
472
  if not self._enable_deepep_moe:
444
473
  DUAL_STREAM_TOKEN_THRESHOLD = 1024
@@ -452,12 +481,14 @@ class DeepseekV2MoE(nn.Module):
452
481
  hidden_states,
453
482
  should_allreduce_fusion,
454
483
  use_reduce_scatter,
484
+ gemm_output_zero_allocator,
455
485
  )
456
486
  else:
457
487
  return self.forward_normal(
458
488
  hidden_states,
459
489
  should_allreduce_fusion,
460
490
  use_reduce_scatter,
491
+ gemm_output_zero_allocator,
461
492
  )
462
493
  else:
463
494
  return self.forward_deepep(hidden_states, forward_batch)
@@ -467,15 +498,18 @@ class DeepseekV2MoE(nn.Module):
467
498
  hidden_states: torch.Tensor,
468
499
  should_allreduce_fusion: bool = False,
469
500
  use_reduce_scatter: bool = False,
501
+ gemm_output_zero_allocator: BumpAllocator = None,
470
502
  ) -> torch.Tensor:
471
503
 
472
504
  current_stream = torch.cuda.current_stream()
473
505
  self.alt_stream.wait_stream(current_stream)
474
- shared_output = self._forward_shared_experts(hidden_states)
506
+ shared_output = self._forward_shared_experts(
507
+ hidden_states, gemm_output_zero_allocator
508
+ )
475
509
 
476
510
  with torch.cuda.stream(self.alt_stream):
477
511
  # router_logits: (num_tokens, n_experts)
478
- router_logits = self.gate(hidden_states)
512
+ router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
479
513
  topk_output = self.topk(hidden_states, router_logits)
480
514
  final_hidden_states = self.experts(hidden_states, topk_output)
481
515
  if not _is_cuda:
@@ -502,6 +536,7 @@ class DeepseekV2MoE(nn.Module):
502
536
  hidden_states: torch.Tensor,
503
537
  should_allreduce_fusion: bool = False,
504
538
  use_reduce_scatter: bool = False,
539
+ gemm_output_zero_allocator: BumpAllocator = None,
505
540
  ) -> torch.Tensor:
506
541
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
507
542
  self.shared_experts.gate_up_proj
@@ -509,9 +544,11 @@ class DeepseekV2MoE(nn.Module):
509
544
  return self.forward_cpu(hidden_states, should_allreduce_fusion)
510
545
 
511
546
  if hidden_states.shape[0] > 0:
512
- shared_output = self._forward_shared_experts(hidden_states)
547
+ shared_output = self._forward_shared_experts(
548
+ hidden_states, gemm_output_zero_allocator
549
+ )
513
550
  # router_logits: (num_tokens, n_experts)
514
- router_logits = self.gate(hidden_states)
551
+ router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
515
552
  topk_output = self.topk(hidden_states, router_logits)
516
553
  else:
517
554
  shared_output = None
@@ -631,9 +668,13 @@ class DeepseekV2MoE(nn.Module):
631
668
 
632
669
  return final_hidden_states
633
670
 
634
- def _forward_shared_experts(self, hidden_states):
671
+ def _forward_shared_experts(
672
+ self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
673
+ ):
635
674
  if self.num_fused_shared_experts == 0:
636
- return self.shared_experts(hidden_states)
675
+ return self.shared_experts(
676
+ hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
677
+ )
637
678
  else:
638
679
  return None
639
680
 
@@ -1044,7 +1085,13 @@ class DeepseekV2AttentionMLA(nn.Module):
1044
1085
  and not forward_batch.forward_mode.is_target_verify()
1045
1086
  and not forward_batch.forward_mode.is_draft_extend()
1046
1087
  ):
1047
- return AttnForwardMethod.MHA
1088
+ if is_dp_attention_enabled():
1089
+ if sum(forward_batch.extend_prefix_lens_cpu) == 0:
1090
+ return AttnForwardMethod.MHA
1091
+ else:
1092
+ return AttnForwardMethod.MLA
1093
+ else:
1094
+ return AttnForwardMethod.MHA
1048
1095
  else:
1049
1096
  return AttnForwardMethod.MLA
1050
1097
  else:
@@ -1097,11 +1144,19 @@ class DeepseekV2AttentionMLA(nn.Module):
1097
1144
  if self.attn_mha.kv_b_proj is None:
1098
1145
  self.attn_mha.kv_b_proj = self.kv_b_proj
1099
1146
 
1100
- if hidden_states.shape[0] == 0:
1101
- assert (
1102
- not self.o_proj.reduce_results
1103
- ), "short-circuiting allreduce will lead to hangs"
1104
- return hidden_states, None, forward_batch, None
1147
+ # when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
1148
+ if isinstance(hidden_states, tuple):
1149
+ if hidden_states[0].shape[0] == 0:
1150
+ assert (
1151
+ not self.o_proj.reduce_results
1152
+ ), "short-circuiting allreduce will lead to hangs"
1153
+ return hidden_states[0]
1154
+ else:
1155
+ if hidden_states.shape[0] == 0:
1156
+ assert (
1157
+ not self.o_proj.reduce_results
1158
+ ), "short-circuiting allreduce will lead to hangs"
1159
+ return hidden_states, None, forward_batch, None
1105
1160
 
1106
1161
  attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
1107
1162
 
@@ -1225,7 +1280,11 @@ class DeepseekV2AttentionMLA(nn.Module):
1225
1280
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1226
1281
 
1227
1282
  if self.q_lora_rank is not None:
1228
- if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
1283
+ if (
1284
+ (not isinstance(hidden_states, tuple))
1285
+ and hidden_states.shape[0] <= 16
1286
+ and self.use_min_latency_fused_a_gemm
1287
+ ):
1229
1288
  fused_qkv_a_proj_out = dsv3_fused_a_gemm(
1230
1289
  hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
1231
1290
  )
@@ -1245,8 +1304,18 @@ class DeepseekV2AttentionMLA(nn.Module):
1245
1304
  k_nope = self.kv_a_layernorm(k_nope)
1246
1305
  current_stream.wait_stream(self.alt_stream)
1247
1306
  else:
1248
- q = self.q_a_layernorm(q)
1249
- k_nope = self.kv_a_layernorm(k_nope)
1307
+ if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
1308
+ q, k_nope = fused_rms_mxfp4_quant(
1309
+ q,
1310
+ self.q_a_layernorm.weight,
1311
+ self.q_a_layernorm.variance_epsilon,
1312
+ k_nope,
1313
+ self.kv_a_layernorm.weight,
1314
+ self.kv_a_layernorm.variance_epsilon,
1315
+ )
1316
+ else:
1317
+ q = self.q_a_layernorm(q)
1318
+ k_nope = self.kv_a_layernorm(k_nope)
1250
1319
 
1251
1320
  k_nope = k_nope.unsqueeze(1)
1252
1321
  q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
@@ -1278,10 +1347,27 @@ class DeepseekV2AttentionMLA(nn.Module):
1278
1347
  q_nope_out = q_nope_out[:, :expected_m, :]
1279
1348
  elif _is_hip:
1280
1349
  # TODO(haishaw): add bmm_fp8 to ROCm
1281
- q_nope_out = torch.bmm(
1282
- q_nope.to(torch.bfloat16).transpose(0, 1),
1283
- self.w_kc.to(torch.bfloat16) * self.w_scale,
1284
- )
1350
+ if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
1351
+ x = q_nope.transpose(0, 1)
1352
+ q_nope_out = torch.empty(
1353
+ x.shape[0],
1354
+ x.shape[1],
1355
+ self.w_kc.shape[2],
1356
+ device=x.device,
1357
+ dtype=torch.bfloat16,
1358
+ )
1359
+ batched_gemm_afp4wfp4_pre_quant(
1360
+ x,
1361
+ self.w_kc.transpose(-2, -1),
1362
+ self.w_scale_k.transpose(-2, -1),
1363
+ torch.bfloat16,
1364
+ q_nope_out,
1365
+ )
1366
+ else:
1367
+ q_nope_out = torch.bmm(
1368
+ q_nope.to(torch.bfloat16).transpose(0, 1),
1369
+ self.w_kc.to(torch.bfloat16) * self.w_scale,
1370
+ )
1285
1371
  elif self.w_kc.dtype == torch.float8_e4m3fn:
1286
1372
  q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1287
1373
  q_nope.transpose(0, 1),
@@ -1295,13 +1381,15 @@ class DeepseekV2AttentionMLA(nn.Module):
1295
1381
 
1296
1382
  q_nope_out = q_nope_out.transpose(0, 1)
1297
1383
 
1298
- if not self._fuse_rope_for_trtllm_mla(forward_batch):
1384
+ if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
1385
+ not _use_aiter or not _is_gfx95_supported
1386
+ ):
1299
1387
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1300
1388
 
1301
- return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1389
+ return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1302
1390
 
1303
1391
  def forward_absorb_core(
1304
- self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1392
+ self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1305
1393
  ):
1306
1394
  if (
1307
1395
  self.current_attention_backend == "fa3"
@@ -1326,8 +1414,23 @@ class DeepseekV2AttentionMLA(nn.Module):
1326
1414
  **extra_args,
1327
1415
  )
1328
1416
  else:
1329
- q = torch.cat([q_nope_out, q_pe], dim=-1)
1330
- k = torch.cat([k_nope, k_pe], dim=-1)
1417
+ if _use_aiter_gfx95:
1418
+ cos = self.rotary_emb.cos_cache
1419
+ sin = self.rotary_emb.sin_cache
1420
+ q, k = fused_qk_rope_cat(
1421
+ q_nope_out,
1422
+ q_pe,
1423
+ k_nope,
1424
+ k_pe,
1425
+ positions,
1426
+ cos,
1427
+ sin,
1428
+ self.rotary_emb.is_neox_style,
1429
+ )
1430
+ else:
1431
+ q = torch.cat([q_nope_out, q_pe], dim=-1)
1432
+ k = torch.cat([k_nope, k_pe], dim=-1)
1433
+
1331
1434
  attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
1332
1435
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1333
1436
 
@@ -1352,11 +1455,34 @@ class DeepseekV2AttentionMLA(nn.Module):
1352
1455
  )
1353
1456
  elif _is_hip:
1354
1457
  # TODO(haishaw): add bmm_fp8 to ROCm
1355
- attn_bmm_output = torch.bmm(
1356
- attn_output.to(torch.bfloat16).transpose(0, 1),
1357
- self.w_vc.to(torch.bfloat16) * self.w_scale,
1358
- )
1359
- attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1458
+ if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
1459
+ x = attn_output.transpose(0, 1)
1460
+ attn_bmm_output = torch.empty(
1461
+ x.shape[0],
1462
+ x.shape[1],
1463
+ self.w_vc.shape[2],
1464
+ device=x.device,
1465
+ dtype=torch.bfloat16,
1466
+ )
1467
+ batched_gemm_afp4wfp4_pre_quant(
1468
+ x,
1469
+ self.w_vc.transpose(-2, -1),
1470
+ self.w_scale_v.transpose(-2, -1),
1471
+ torch.bfloat16,
1472
+ attn_bmm_output,
1473
+ )
1474
+ else:
1475
+ attn_bmm_output = torch.bmm(
1476
+ attn_output.to(torch.bfloat16).transpose(0, 1),
1477
+ self.w_vc.to(torch.bfloat16) * self.w_scale,
1478
+ )
1479
+
1480
+ if self.o_proj.weight.dtype == torch.uint8:
1481
+ attn_bmm_output = attn_bmm_output.transpose(0, 1)
1482
+ attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
1483
+ else:
1484
+ attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1485
+
1360
1486
  elif self.w_vc.dtype == torch.float8_e4m3fn:
1361
1487
  attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1362
1488
  attn_output.transpose(0, 1),
@@ -1678,9 +1804,11 @@ class DeepseekV2AttentionMLA(nn.Module):
1678
1804
  latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
1679
1805
  self.attn_mha.layer_id
1680
1806
  )
1681
- latent_cache = latent_cache_buf[
1682
- forward_batch.prefix_chunk_kv_indices[i]
1683
- ].contiguous()
1807
+ latent_cache = (
1808
+ latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
1809
+ .contiguous()
1810
+ .to(q.dtype)
1811
+ )
1684
1812
 
1685
1813
  kv_a_normed, k_pe = latent_cache.split(
1686
1814
  [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
@@ -1864,10 +1992,21 @@ class DeepseekV2DecoderLayer(nn.Module):
1864
1992
  forward_batch: ForwardBatch,
1865
1993
  residual: Optional[torch.Tensor],
1866
1994
  zero_allocator: BumpAllocator,
1995
+ gemm_output_zero_allocator: BumpAllocator = None,
1867
1996
  ) -> torch.Tensor:
1868
1997
 
1998
+ quant_format = (
1999
+ "mxfp4"
2000
+ if _is_gfx95_supported
2001
+ and self.self_attn.fused_qkv_a_proj_with_mqa.weight == torch.uint8
2002
+ else ""
2003
+ )
2004
+
1869
2005
  hidden_states, residual = self.layer_communicator.prepare_attn(
1870
- hidden_states, residual, forward_batch
2006
+ hidden_states,
2007
+ residual,
2008
+ forward_batch,
2009
+ quant_format,
1871
2010
  )
1872
2011
 
1873
2012
  hidden_states = self.self_attn(
@@ -1891,8 +2030,16 @@ class DeepseekV2DecoderLayer(nn.Module):
1891
2030
  use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
1892
2031
  forward_batch
1893
2032
  )
2033
+
2034
+ if isinstance(self.mlp, DeepseekV2MLP):
2035
+ gemm_output_zero_allocator = None
2036
+
1894
2037
  hidden_states = self.mlp(
1895
- hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
2038
+ hidden_states,
2039
+ forward_batch,
2040
+ should_allreduce_fusion,
2041
+ use_reduce_scatter,
2042
+ gemm_output_zero_allocator,
1896
2043
  )
1897
2044
 
1898
2045
  if should_allreduce_fusion:
@@ -2036,6 +2183,37 @@ class DeepseekV2Model(nn.Module):
2036
2183
  else:
2037
2184
  self.norm = PPMissingLayer(return_tuple=True)
2038
2185
 
2186
+ self.gemm_output_zero_allocator_size = 0
2187
+ if (
2188
+ _use_aiter_gfx95
2189
+ and config.n_routed_experts == 256
2190
+ and self.embed_tokens.embedding_dim == 7168
2191
+ ):
2192
+ num_moe_layers = sum(
2193
+ [
2194
+ 1
2195
+ for i in range(len(self.layers))
2196
+ if isinstance(self.layers[i].mlp, DeepseekV2MoE)
2197
+ ]
2198
+ )
2199
+
2200
+ allocate_size = 0
2201
+ for i in range(len(self.layers)):
2202
+ if isinstance(self.layers[i].mlp, DeepseekV2MoE):
2203
+ allocate_size = self.layers[
2204
+ i
2205
+ ].mlp.shared_experts.gate_up_proj.output_size_per_partition
2206
+ break
2207
+
2208
+ self.gemm_output_zero_allocator_size = (
2209
+ get_dsv3_gemm_output_zero_allocator_size(
2210
+ config.n_routed_experts,
2211
+ num_moe_layers,
2212
+ allocate_size,
2213
+ self.embed_tokens.embedding_dim,
2214
+ )
2215
+ )
2216
+
2039
2217
  def get_input_embeddings(self) -> torch.Tensor:
2040
2218
  return self.embed_tokens
2041
2219
 
@@ -2055,6 +2233,21 @@ class DeepseekV2Model(nn.Module):
2055
2233
  device=device,
2056
2234
  )
2057
2235
 
2236
+ has_gemm_output_zero_allocator = hasattr(
2237
+ self, "gemm_output_zero_allocator_size"
2238
+ )
2239
+
2240
+ gemm_output_zero_allocator = (
2241
+ BumpAllocator(
2242
+ buffer_size=self.gemm_output_zero_allocator_size,
2243
+ dtype=torch.float32,
2244
+ device=device,
2245
+ )
2246
+ if has_gemm_output_zero_allocator
2247
+ and self.gemm_output_zero_allocator_size > 0
2248
+ else None
2249
+ )
2250
+
2058
2251
  if self.pp_group.is_first_rank:
2059
2252
  if input_embeds is None:
2060
2253
  hidden_states = self.embed_tokens(input_ids)
@@ -2081,7 +2274,12 @@ class DeepseekV2Model(nn.Module):
2081
2274
  with get_global_expert_distribution_recorder().with_current_layer(i):
2082
2275
  layer = self.layers[i]
2083
2276
  hidden_states, residual = layer(
2084
- positions, hidden_states, forward_batch, residual, zero_allocator
2277
+ positions,
2278
+ hidden_states,
2279
+ forward_batch,
2280
+ residual,
2281
+ zero_allocator,
2282
+ gemm_output_zero_allocator,
2085
2283
  )
2086
2284
 
2087
2285
  if normal_end_layer != self.end_layer:
@@ -2185,6 +2383,8 @@ class DeepseekV2ForCausalLM(nn.Module):
2185
2383
  disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
2186
2384
  elif get_moe_expert_parallel_world_size() > 1:
2187
2385
  disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
2386
+ elif self.quant_config.get_name() == "w4afp8":
2387
+ disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
2188
2388
 
2189
2389
  if disable_reason is not None:
2190
2390
  global_server_args_dict["disable_shared_experts_fusion"] = True
@@ -2352,6 +2552,12 @@ class DeepseekV2ForCausalLM(nn.Module):
2352
2552
  w_kc, w_vc = w.unflatten(
2353
2553
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
2354
2554
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
2555
+
2556
+ if _use_aiter_gfx95 and self.quant_config.get_name() == "quark":
2557
+ w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
2558
+ quark_post_load_weights(self_attn, w, "mxfp4")
2559
+ )
2560
+
2355
2561
  if not use_deep_gemm_bmm:
2356
2562
  self_attn.w_kc = bind_or_assign(
2357
2563
  self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
@@ -2496,6 +2702,9 @@ class DeepseekV2ForCausalLM(nn.Module):
2496
2702
  ckpt_up_proj_name="up_proj",
2497
2703
  num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
2498
2704
  )
2705
+ # Params for special naming rules in mixed-precision models, for example:
2706
+ # model.layers.xx.mlp.experts.xx.w1.input_scale. For details,
2707
+ # see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main.
2499
2708
  if self.quant_config and self.quant_config.get_name() == "w4afp8":
2500
2709
  expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
2501
2710
  num_experts=self.config.n_routed_experts
@@ -153,7 +153,13 @@ class Glm4MoeMLP(nn.Module):
153
153
  )
154
154
  self.act_fn = SiluAndMul()
155
155
 
156
- def forward(self, x, forward_batch=None, should_allreduce_fusion=False):
156
+ def forward(
157
+ self,
158
+ x,
159
+ forward_batch=None,
160
+ should_allreduce_fusion=False,
161
+ gemm_output_zero_allocator: BumpAllocator = None,
162
+ ):
157
163
  if (self.tp_size == 1) and x.shape[0] == 0:
158
164
  return x
159
165
 
@@ -501,6 +507,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
501
507
  hidden_states: torch.Tensor,
502
508
  should_allreduce_fusion: bool = False,
503
509
  use_reduce_scatter: bool = False,
510
+ gemm_output_zero_allocator: BumpAllocator = None,
504
511
  ) -> torch.Tensor:
505
512
 
506
513
  current_stream = torch.cuda.current_stream()
@@ -543,6 +550,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
543
550
  hidden_states: torch.Tensor,
544
551
  should_allreduce_fusion: bool = False,
545
552
  use_reduce_scatter: bool = False,
553
+ gemm_output_zero_allocator: BumpAllocator = None,
546
554
  ) -> torch.Tensor:
547
555
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
548
556
  self.shared_experts.gate_up_proj
@@ -666,6 +674,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
666
674
  forward_batch: ForwardBatch,
667
675
  residual: Optional[torch.Tensor],
668
676
  zero_allocator: BumpAllocator,
677
+ gemm_output_zero_allocator: BumpAllocator = None,
669
678
  ) -> torch.Tensor:
670
679
  hidden_states, residual = self.layer_communicator.prepare_attn(
671
680
  hidden_states, residual, forward_batch
@@ -193,8 +193,9 @@ class GptOssSparseMoeBlock(nn.Module):
193
193
  return ans
194
194
 
195
195
 
196
- def _enable_fused_set_kv_buffer():
197
- return _is_cuda
196
+ def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
197
+ """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
198
+ return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
198
199
 
199
200
 
200
201
  # TODO maybe move to a model-common utils
@@ -341,7 +342,7 @@ class GptOssAttention(nn.Module):
341
342
  layer=self.attn,
342
343
  forward_batch=forward_batch,
343
344
  )
344
- if _enable_fused_set_kv_buffer()
345
+ if _enable_fused_set_kv_buffer(forward_batch)
345
346
  else None
346
347
  ),
347
348
  )
@@ -355,7 +356,7 @@ class GptOssAttention(nn.Module):
355
356
  attn_output = self.attn(
356
357
  *inner_state,
357
358
  sinks=self.sinks,
358
- save_kv_cache=not _enable_fused_set_kv_buffer(),
359
+ save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
359
360
  )
360
361
  output, _ = self.o_proj(attn_output)
361
362
  return output
@@ -26,8 +26,10 @@ from sglang.srt.managers.schedule_batch import (
26
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
27
  from sglang.srt.model_loader.weight_utils import default_weight_loader
28
28
  from sglang.srt.models.deepseek_janus_pro import DropPath
29
+ from sglang.srt.models.gpt_oss import GptOssForCausalLM
29
30
  from sglang.srt.models.internlm2 import InternLM2ForCausalLM
30
31
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
32
+ from sglang.srt.models.qwen3 import Qwen3ForCausalLM
31
33
  from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
32
34
  from sglang.utils import logger
33
35
 
@@ -445,6 +447,14 @@ class InternVLChatModel(nn.Module):
445
447
  self.language_model = Qwen3MoeForCausalLM(
446
448
  config=config.llm_config, quant_config=quant_config
447
449
  )
450
+ elif config.llm_config.architectures[0] == "GptOssForCausalLM":
451
+ self.language_model = GptOssForCausalLM(
452
+ config=config.llm_config, quant_config=quant_config
453
+ )
454
+ elif config.llm_config.architectures[0] == "Qwen3ForCausalLM":
455
+ self.language_model = Qwen3ForCausalLM(
456
+ config=config.llm_config, quant_config=quant_config
457
+ )
448
458
  else:
449
459
  raise NotImplementedError(
450
460
  f"{config.llm_config.architectures[0]} is not implemented."
@@ -577,6 +587,15 @@ class InternVLChatModel(nn.Module):
577
587
  ckpt_up_proj_name="up_proj",
578
588
  num_experts=self.config.num_experts,
579
589
  )
590
+ elif "Qwen3ForCausalLM" in self.config.llm_config.architectures:
591
+ stacked_params_mapping = [
592
+ # (param_name, shard_name, shard_id)
593
+ ("qkv_proj", "q_proj", "q"),
594
+ ("qkv_proj", "k_proj", "k"),
595
+ ("qkv_proj", "v_proj", "v"),
596
+ ("gate_up_proj", "gate_proj", 0),
597
+ ("gate_up_proj", "up_proj", 1),
598
+ ]
580
599
 
581
600
  params_dict = dict(self.named_parameters())
582
601
  loaded_params: Set[str] = set()
@@ -661,6 +680,15 @@ class InternVLChatModel(nn.Module):
661
680
 
662
681
  loaded_params.add(name)
663
682
  unloaded_params = params_dict.keys() - loaded_params
683
+ # Skip params that are created by quantization wrappers and are not expected in the ckpt
684
+ _quant_only_fragments = (
685
+ "weight_scale", # per-matrix FP8 scales (e.g., w2_weight_scale, w13_weight_scale)
686
+ )
687
+ unloaded_params = {
688
+ n
689
+ for n in unloaded_params
690
+ if not any(frag in n for frag in _quant_only_fragments)
691
+ }
664
692
  if unloaded_params:
665
693
  raise RuntimeError(
666
694
  f"Some weights are not initialized from checkpoints: {unloaded_params}"