sglang 0.5.2rc0__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/model_config.py +2 -1
- sglang/srt/disaggregation/mini_lb.py +2 -2
- sglang/srt/distributed/parallel_state.py +46 -41
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +5 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +3 -3
- sglang/srt/entrypoints/openai/serving_completions.py +3 -1
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
- sglang/srt/entrypoints/openai/serving_responses.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
- sglang/srt/layers/moe/ep_moe/layer.py +2 -7
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/utils.py +0 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +8 -0
- sglang/srt/layers/quantization/modelopt_quant.py +35 -2
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +30 -25
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/managers/cache_controller.py +42 -39
- sglang/srt/managers/detokenizer_manager.py +0 -34
- sglang/srt/managers/multi_tokenizer_mixin.py +48 -6
- sglang/srt/managers/schedule_policy.py +3 -2
- sglang/srt/managers/scheduler.py +7 -100
- sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +15 -10
- sglang/srt/mem_cache/hiradix_cache.py +16 -0
- sglang/srt/mem_cache/memory_pool_host.py +18 -11
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +35 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/metrics/collector.py +12 -4
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/forward_batch_info.py +16 -17
- sglang/srt/model_executor/model_runner.py +1 -1
- sglang/srt/models/deepseek_v2.py +245 -36
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/gpt_oss.py +5 -4
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/longcat_flash.py +26 -15
- sglang/srt/models/longcat_flash_nextn.py +23 -15
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/qwen2_moe.py +4 -1
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/server_args.py +79 -2
- sglang/srt/speculative/eagle_worker.py +158 -112
- sglang/srt/utils.py +12 -10
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +2 -2
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +83 -76
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -112,6 +112,7 @@ from sglang.srt.utils import (
|
|
112
112
|
is_cpu,
|
113
113
|
is_cuda,
|
114
114
|
is_flashinfer_available,
|
115
|
+
is_gfx95_supported,
|
115
116
|
is_hip,
|
116
117
|
is_non_idle_and_non_empty,
|
117
118
|
is_npu,
|
@@ -129,6 +130,22 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
129
130
|
_is_cpu_amx_available = cpu_has_amx_support()
|
130
131
|
_is_cpu = is_cpu()
|
131
132
|
_device_sm = get_device_sm()
|
133
|
+
_is_gfx95_supported = is_gfx95_supported()
|
134
|
+
|
135
|
+
_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
|
136
|
+
|
137
|
+
if _use_aiter_gfx95:
|
138
|
+
from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
|
139
|
+
from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
|
140
|
+
batched_gemm_afp4wfp4_pre_quant,
|
141
|
+
fused_flatten_mxfp4_quant,
|
142
|
+
fused_rms_mxfp4_quant,
|
143
|
+
)
|
144
|
+
from sglang.srt.layers.rocm_linear_utils import (
|
145
|
+
aiter_dsv3_router_gemm,
|
146
|
+
fused_qk_rope_cat,
|
147
|
+
get_dsv3_gemm_output_zero_allocator_size,
|
148
|
+
)
|
132
149
|
|
133
150
|
if _is_cuda:
|
134
151
|
from sgl_kernel import (
|
@@ -224,10 +241,17 @@ class DeepseekV2MLP(nn.Module):
|
|
224
241
|
forward_batch=None,
|
225
242
|
should_allreduce_fusion: bool = False,
|
226
243
|
use_reduce_scatter: bool = False,
|
244
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
227
245
|
):
|
228
246
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
229
247
|
return x
|
230
248
|
|
249
|
+
if gemm_output_zero_allocator != None and x.shape[0] <= 256:
|
250
|
+
y = gemm_output_zero_allocator.allocate(
|
251
|
+
x.shape[0] * self.gate_up_proj.output_size_per_partition
|
252
|
+
).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
|
253
|
+
x = (x, None, y)
|
254
|
+
|
231
255
|
gate_up, _ = self.gate_up_proj(x)
|
232
256
|
x = self.act_fn(gate_up)
|
233
257
|
x, _ = self.down_proj(
|
@@ -257,7 +281,7 @@ class MoEGate(nn.Module):
|
|
257
281
|
if _is_cpu and _is_cpu_amx_available:
|
258
282
|
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
259
283
|
|
260
|
-
def forward(self, hidden_states):
|
284
|
+
def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None):
|
261
285
|
if use_intel_amx_backend(self):
|
262
286
|
return torch.ops.sgl_kernel.weight_packed_linear(
|
263
287
|
hidden_states,
|
@@ -276,6 +300,10 @@ class MoEGate(nn.Module):
|
|
276
300
|
):
|
277
301
|
# router gemm output float32
|
278
302
|
logits = dsv3_router_gemm(hidden_states, self.weight)
|
303
|
+
elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
|
304
|
+
logits = aiter_dsv3_router_gemm(
|
305
|
+
hidden_states, self.weight, gemm_output_zero_allocator
|
306
|
+
)
|
279
307
|
else:
|
280
308
|
logits = F.linear(hidden_states, self.weight, None)
|
281
309
|
|
@@ -439,6 +467,7 @@ class DeepseekV2MoE(nn.Module):
|
|
439
467
|
forward_batch: Optional[ForwardBatch] = None,
|
440
468
|
should_allreduce_fusion: bool = False,
|
441
469
|
use_reduce_scatter: bool = False,
|
470
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
442
471
|
) -> torch.Tensor:
|
443
472
|
if not self._enable_deepep_moe:
|
444
473
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
@@ -452,12 +481,14 @@ class DeepseekV2MoE(nn.Module):
|
|
452
481
|
hidden_states,
|
453
482
|
should_allreduce_fusion,
|
454
483
|
use_reduce_scatter,
|
484
|
+
gemm_output_zero_allocator,
|
455
485
|
)
|
456
486
|
else:
|
457
487
|
return self.forward_normal(
|
458
488
|
hidden_states,
|
459
489
|
should_allreduce_fusion,
|
460
490
|
use_reduce_scatter,
|
491
|
+
gemm_output_zero_allocator,
|
461
492
|
)
|
462
493
|
else:
|
463
494
|
return self.forward_deepep(hidden_states, forward_batch)
|
@@ -467,15 +498,18 @@ class DeepseekV2MoE(nn.Module):
|
|
467
498
|
hidden_states: torch.Tensor,
|
468
499
|
should_allreduce_fusion: bool = False,
|
469
500
|
use_reduce_scatter: bool = False,
|
501
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
470
502
|
) -> torch.Tensor:
|
471
503
|
|
472
504
|
current_stream = torch.cuda.current_stream()
|
473
505
|
self.alt_stream.wait_stream(current_stream)
|
474
|
-
shared_output = self._forward_shared_experts(
|
506
|
+
shared_output = self._forward_shared_experts(
|
507
|
+
hidden_states, gemm_output_zero_allocator
|
508
|
+
)
|
475
509
|
|
476
510
|
with torch.cuda.stream(self.alt_stream):
|
477
511
|
# router_logits: (num_tokens, n_experts)
|
478
|
-
router_logits = self.gate(hidden_states)
|
512
|
+
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
479
513
|
topk_output = self.topk(hidden_states, router_logits)
|
480
514
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
481
515
|
if not _is_cuda:
|
@@ -502,6 +536,7 @@ class DeepseekV2MoE(nn.Module):
|
|
502
536
|
hidden_states: torch.Tensor,
|
503
537
|
should_allreduce_fusion: bool = False,
|
504
538
|
use_reduce_scatter: bool = False,
|
539
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
505
540
|
) -> torch.Tensor:
|
506
541
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
507
542
|
self.shared_experts.gate_up_proj
|
@@ -509,9 +544,11 @@ class DeepseekV2MoE(nn.Module):
|
|
509
544
|
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
510
545
|
|
511
546
|
if hidden_states.shape[0] > 0:
|
512
|
-
shared_output = self._forward_shared_experts(
|
547
|
+
shared_output = self._forward_shared_experts(
|
548
|
+
hidden_states, gemm_output_zero_allocator
|
549
|
+
)
|
513
550
|
# router_logits: (num_tokens, n_experts)
|
514
|
-
router_logits = self.gate(hidden_states)
|
551
|
+
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
515
552
|
topk_output = self.topk(hidden_states, router_logits)
|
516
553
|
else:
|
517
554
|
shared_output = None
|
@@ -631,9 +668,13 @@ class DeepseekV2MoE(nn.Module):
|
|
631
668
|
|
632
669
|
return final_hidden_states
|
633
670
|
|
634
|
-
def _forward_shared_experts(
|
671
|
+
def _forward_shared_experts(
|
672
|
+
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
|
673
|
+
):
|
635
674
|
if self.num_fused_shared_experts == 0:
|
636
|
-
return self.shared_experts(
|
675
|
+
return self.shared_experts(
|
676
|
+
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
|
677
|
+
)
|
637
678
|
else:
|
638
679
|
return None
|
639
680
|
|
@@ -1044,7 +1085,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1044
1085
|
and not forward_batch.forward_mode.is_target_verify()
|
1045
1086
|
and not forward_batch.forward_mode.is_draft_extend()
|
1046
1087
|
):
|
1047
|
-
|
1088
|
+
if is_dp_attention_enabled():
|
1089
|
+
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
1090
|
+
return AttnForwardMethod.MHA
|
1091
|
+
else:
|
1092
|
+
return AttnForwardMethod.MLA
|
1093
|
+
else:
|
1094
|
+
return AttnForwardMethod.MHA
|
1048
1095
|
else:
|
1049
1096
|
return AttnForwardMethod.MLA
|
1050
1097
|
else:
|
@@ -1097,11 +1144,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1097
1144
|
if self.attn_mha.kv_b_proj is None:
|
1098
1145
|
self.attn_mha.kv_b_proj = self.kv_b_proj
|
1099
1146
|
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1147
|
+
# when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
|
1148
|
+
if isinstance(hidden_states, tuple):
|
1149
|
+
if hidden_states[0].shape[0] == 0:
|
1150
|
+
assert (
|
1151
|
+
not self.o_proj.reduce_results
|
1152
|
+
), "short-circuiting allreduce will lead to hangs"
|
1153
|
+
return hidden_states[0]
|
1154
|
+
else:
|
1155
|
+
if hidden_states.shape[0] == 0:
|
1156
|
+
assert (
|
1157
|
+
not self.o_proj.reduce_results
|
1158
|
+
), "short-circuiting allreduce will lead to hangs"
|
1159
|
+
return hidden_states, None, forward_batch, None
|
1105
1160
|
|
1106
1161
|
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
1107
1162
|
|
@@ -1225,7 +1280,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1225
1280
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
1226
1281
|
|
1227
1282
|
if self.q_lora_rank is not None:
|
1228
|
-
if
|
1283
|
+
if (
|
1284
|
+
(not isinstance(hidden_states, tuple))
|
1285
|
+
and hidden_states.shape[0] <= 16
|
1286
|
+
and self.use_min_latency_fused_a_gemm
|
1287
|
+
):
|
1229
1288
|
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
1230
1289
|
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
1231
1290
|
)
|
@@ -1245,8 +1304,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1245
1304
|
k_nope = self.kv_a_layernorm(k_nope)
|
1246
1305
|
current_stream.wait_stream(self.alt_stream)
|
1247
1306
|
else:
|
1248
|
-
|
1249
|
-
|
1307
|
+
if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
|
1308
|
+
q, k_nope = fused_rms_mxfp4_quant(
|
1309
|
+
q,
|
1310
|
+
self.q_a_layernorm.weight,
|
1311
|
+
self.q_a_layernorm.variance_epsilon,
|
1312
|
+
k_nope,
|
1313
|
+
self.kv_a_layernorm.weight,
|
1314
|
+
self.kv_a_layernorm.variance_epsilon,
|
1315
|
+
)
|
1316
|
+
else:
|
1317
|
+
q = self.q_a_layernorm(q)
|
1318
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
1250
1319
|
|
1251
1320
|
k_nope = k_nope.unsqueeze(1)
|
1252
1321
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
@@ -1278,10 +1347,27 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1278
1347
|
q_nope_out = q_nope_out[:, :expected_m, :]
|
1279
1348
|
elif _is_hip:
|
1280
1349
|
# TODO(haishaw): add bmm_fp8 to ROCm
|
1281
|
-
|
1282
|
-
q_nope.
|
1283
|
-
|
1284
|
-
|
1350
|
+
if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
|
1351
|
+
x = q_nope.transpose(0, 1)
|
1352
|
+
q_nope_out = torch.empty(
|
1353
|
+
x.shape[0],
|
1354
|
+
x.shape[1],
|
1355
|
+
self.w_kc.shape[2],
|
1356
|
+
device=x.device,
|
1357
|
+
dtype=torch.bfloat16,
|
1358
|
+
)
|
1359
|
+
batched_gemm_afp4wfp4_pre_quant(
|
1360
|
+
x,
|
1361
|
+
self.w_kc.transpose(-2, -1),
|
1362
|
+
self.w_scale_k.transpose(-2, -1),
|
1363
|
+
torch.bfloat16,
|
1364
|
+
q_nope_out,
|
1365
|
+
)
|
1366
|
+
else:
|
1367
|
+
q_nope_out = torch.bmm(
|
1368
|
+
q_nope.to(torch.bfloat16).transpose(0, 1),
|
1369
|
+
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
1370
|
+
)
|
1285
1371
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
1286
1372
|
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
1287
1373
|
q_nope.transpose(0, 1),
|
@@ -1295,13 +1381,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1295
1381
|
|
1296
1382
|
q_nope_out = q_nope_out.transpose(0, 1)
|
1297
1383
|
|
1298
|
-
if not self._fuse_rope_for_trtllm_mla(forward_batch)
|
1384
|
+
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
1385
|
+
not _use_aiter or not _is_gfx95_supported
|
1386
|
+
):
|
1299
1387
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1300
1388
|
|
1301
|
-
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1389
|
+
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
1302
1390
|
|
1303
1391
|
def forward_absorb_core(
|
1304
|
-
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1392
|
+
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
1305
1393
|
):
|
1306
1394
|
if (
|
1307
1395
|
self.current_attention_backend == "fa3"
|
@@ -1326,8 +1414,23 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1326
1414
|
**extra_args,
|
1327
1415
|
)
|
1328
1416
|
else:
|
1329
|
-
|
1330
|
-
|
1417
|
+
if _use_aiter_gfx95:
|
1418
|
+
cos = self.rotary_emb.cos_cache
|
1419
|
+
sin = self.rotary_emb.sin_cache
|
1420
|
+
q, k = fused_qk_rope_cat(
|
1421
|
+
q_nope_out,
|
1422
|
+
q_pe,
|
1423
|
+
k_nope,
|
1424
|
+
k_pe,
|
1425
|
+
positions,
|
1426
|
+
cos,
|
1427
|
+
sin,
|
1428
|
+
self.rotary_emb.is_neox_style,
|
1429
|
+
)
|
1430
|
+
else:
|
1431
|
+
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
1432
|
+
k = torch.cat([k_nope, k_pe], dim=-1)
|
1433
|
+
|
1331
1434
|
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
1332
1435
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
1333
1436
|
|
@@ -1352,11 +1455,34 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1352
1455
|
)
|
1353
1456
|
elif _is_hip:
|
1354
1457
|
# TODO(haishaw): add bmm_fp8 to ROCm
|
1355
|
-
|
1356
|
-
attn_output.
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1458
|
+
if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
|
1459
|
+
x = attn_output.transpose(0, 1)
|
1460
|
+
attn_bmm_output = torch.empty(
|
1461
|
+
x.shape[0],
|
1462
|
+
x.shape[1],
|
1463
|
+
self.w_vc.shape[2],
|
1464
|
+
device=x.device,
|
1465
|
+
dtype=torch.bfloat16,
|
1466
|
+
)
|
1467
|
+
batched_gemm_afp4wfp4_pre_quant(
|
1468
|
+
x,
|
1469
|
+
self.w_vc.transpose(-2, -1),
|
1470
|
+
self.w_scale_v.transpose(-2, -1),
|
1471
|
+
torch.bfloat16,
|
1472
|
+
attn_bmm_output,
|
1473
|
+
)
|
1474
|
+
else:
|
1475
|
+
attn_bmm_output = torch.bmm(
|
1476
|
+
attn_output.to(torch.bfloat16).transpose(0, 1),
|
1477
|
+
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
1478
|
+
)
|
1479
|
+
|
1480
|
+
if self.o_proj.weight.dtype == torch.uint8:
|
1481
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1)
|
1482
|
+
attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
|
1483
|
+
else:
|
1484
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
1485
|
+
|
1360
1486
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
1361
1487
|
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
1362
1488
|
attn_output.transpose(0, 1),
|
@@ -1678,9 +1804,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1678
1804
|
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
|
1679
1805
|
self.attn_mha.layer_id
|
1680
1806
|
)
|
1681
|
-
latent_cache =
|
1682
|
-
forward_batch.prefix_chunk_kv_indices[i]
|
1683
|
-
|
1807
|
+
latent_cache = (
|
1808
|
+
latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
|
1809
|
+
.contiguous()
|
1810
|
+
.to(q.dtype)
|
1811
|
+
)
|
1684
1812
|
|
1685
1813
|
kv_a_normed, k_pe = latent_cache.split(
|
1686
1814
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
@@ -1864,10 +1992,21 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1864
1992
|
forward_batch: ForwardBatch,
|
1865
1993
|
residual: Optional[torch.Tensor],
|
1866
1994
|
zero_allocator: BumpAllocator,
|
1995
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
1867
1996
|
) -> torch.Tensor:
|
1868
1997
|
|
1998
|
+
quant_format = (
|
1999
|
+
"mxfp4"
|
2000
|
+
if _is_gfx95_supported
|
2001
|
+
and self.self_attn.fused_qkv_a_proj_with_mqa.weight == torch.uint8
|
2002
|
+
else ""
|
2003
|
+
)
|
2004
|
+
|
1869
2005
|
hidden_states, residual = self.layer_communicator.prepare_attn(
|
1870
|
-
hidden_states,
|
2006
|
+
hidden_states,
|
2007
|
+
residual,
|
2008
|
+
forward_batch,
|
2009
|
+
quant_format,
|
1871
2010
|
)
|
1872
2011
|
|
1873
2012
|
hidden_states = self.self_attn(
|
@@ -1891,8 +2030,16 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1891
2030
|
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
1892
2031
|
forward_batch
|
1893
2032
|
)
|
2033
|
+
|
2034
|
+
if isinstance(self.mlp, DeepseekV2MLP):
|
2035
|
+
gemm_output_zero_allocator = None
|
2036
|
+
|
1894
2037
|
hidden_states = self.mlp(
|
1895
|
-
hidden_states,
|
2038
|
+
hidden_states,
|
2039
|
+
forward_batch,
|
2040
|
+
should_allreduce_fusion,
|
2041
|
+
use_reduce_scatter,
|
2042
|
+
gemm_output_zero_allocator,
|
1896
2043
|
)
|
1897
2044
|
|
1898
2045
|
if should_allreduce_fusion:
|
@@ -2036,6 +2183,37 @@ class DeepseekV2Model(nn.Module):
|
|
2036
2183
|
else:
|
2037
2184
|
self.norm = PPMissingLayer(return_tuple=True)
|
2038
2185
|
|
2186
|
+
self.gemm_output_zero_allocator_size = 0
|
2187
|
+
if (
|
2188
|
+
_use_aiter_gfx95
|
2189
|
+
and config.n_routed_experts == 256
|
2190
|
+
and self.embed_tokens.embedding_dim == 7168
|
2191
|
+
):
|
2192
|
+
num_moe_layers = sum(
|
2193
|
+
[
|
2194
|
+
1
|
2195
|
+
for i in range(len(self.layers))
|
2196
|
+
if isinstance(self.layers[i].mlp, DeepseekV2MoE)
|
2197
|
+
]
|
2198
|
+
)
|
2199
|
+
|
2200
|
+
allocate_size = 0
|
2201
|
+
for i in range(len(self.layers)):
|
2202
|
+
if isinstance(self.layers[i].mlp, DeepseekV2MoE):
|
2203
|
+
allocate_size = self.layers[
|
2204
|
+
i
|
2205
|
+
].mlp.shared_experts.gate_up_proj.output_size_per_partition
|
2206
|
+
break
|
2207
|
+
|
2208
|
+
self.gemm_output_zero_allocator_size = (
|
2209
|
+
get_dsv3_gemm_output_zero_allocator_size(
|
2210
|
+
config.n_routed_experts,
|
2211
|
+
num_moe_layers,
|
2212
|
+
allocate_size,
|
2213
|
+
self.embed_tokens.embedding_dim,
|
2214
|
+
)
|
2215
|
+
)
|
2216
|
+
|
2039
2217
|
def get_input_embeddings(self) -> torch.Tensor:
|
2040
2218
|
return self.embed_tokens
|
2041
2219
|
|
@@ -2055,6 +2233,21 @@ class DeepseekV2Model(nn.Module):
|
|
2055
2233
|
device=device,
|
2056
2234
|
)
|
2057
2235
|
|
2236
|
+
has_gemm_output_zero_allocator = hasattr(
|
2237
|
+
self, "gemm_output_zero_allocator_size"
|
2238
|
+
)
|
2239
|
+
|
2240
|
+
gemm_output_zero_allocator = (
|
2241
|
+
BumpAllocator(
|
2242
|
+
buffer_size=self.gemm_output_zero_allocator_size,
|
2243
|
+
dtype=torch.float32,
|
2244
|
+
device=device,
|
2245
|
+
)
|
2246
|
+
if has_gemm_output_zero_allocator
|
2247
|
+
and self.gemm_output_zero_allocator_size > 0
|
2248
|
+
else None
|
2249
|
+
)
|
2250
|
+
|
2058
2251
|
if self.pp_group.is_first_rank:
|
2059
2252
|
if input_embeds is None:
|
2060
2253
|
hidden_states = self.embed_tokens(input_ids)
|
@@ -2081,7 +2274,12 @@ class DeepseekV2Model(nn.Module):
|
|
2081
2274
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
2082
2275
|
layer = self.layers[i]
|
2083
2276
|
hidden_states, residual = layer(
|
2084
|
-
positions,
|
2277
|
+
positions,
|
2278
|
+
hidden_states,
|
2279
|
+
forward_batch,
|
2280
|
+
residual,
|
2281
|
+
zero_allocator,
|
2282
|
+
gemm_output_zero_allocator,
|
2085
2283
|
)
|
2086
2284
|
|
2087
2285
|
if normal_end_layer != self.end_layer:
|
@@ -2185,6 +2383,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2185
2383
|
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
2186
2384
|
elif get_moe_expert_parallel_world_size() > 1:
|
2187
2385
|
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
|
2386
|
+
elif self.quant_config.get_name() == "w4afp8":
|
2387
|
+
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
|
2188
2388
|
|
2189
2389
|
if disable_reason is not None:
|
2190
2390
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|
@@ -2352,6 +2552,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2352
2552
|
w_kc, w_vc = w.unflatten(
|
2353
2553
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
2354
2554
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
2555
|
+
|
2556
|
+
if _use_aiter_gfx95 and self.quant_config.get_name() == "quark":
|
2557
|
+
w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
|
2558
|
+
quark_post_load_weights(self_attn, w, "mxfp4")
|
2559
|
+
)
|
2560
|
+
|
2355
2561
|
if not use_deep_gemm_bmm:
|
2356
2562
|
self_attn.w_kc = bind_or_assign(
|
2357
2563
|
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
@@ -2496,6 +2702,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2496
2702
|
ckpt_up_proj_name="up_proj",
|
2497
2703
|
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
2498
2704
|
)
|
2705
|
+
# Params for special naming rules in mixed-precision models, for example:
|
2706
|
+
# model.layers.xx.mlp.experts.xx.w1.input_scale. For details,
|
2707
|
+
# see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main.
|
2499
2708
|
if self.quant_config and self.quant_config.get_name() == "w4afp8":
|
2500
2709
|
expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
|
2501
2710
|
num_experts=self.config.n_routed_experts
|
sglang/srt/models/glm4_moe.py
CHANGED
@@ -153,7 +153,13 @@ class Glm4MoeMLP(nn.Module):
|
|
153
153
|
)
|
154
154
|
self.act_fn = SiluAndMul()
|
155
155
|
|
156
|
-
def forward(
|
156
|
+
def forward(
|
157
|
+
self,
|
158
|
+
x,
|
159
|
+
forward_batch=None,
|
160
|
+
should_allreduce_fusion=False,
|
161
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
162
|
+
):
|
157
163
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
158
164
|
return x
|
159
165
|
|
@@ -501,6 +507,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
501
507
|
hidden_states: torch.Tensor,
|
502
508
|
should_allreduce_fusion: bool = False,
|
503
509
|
use_reduce_scatter: bool = False,
|
510
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
504
511
|
) -> torch.Tensor:
|
505
512
|
|
506
513
|
current_stream = torch.cuda.current_stream()
|
@@ -543,6 +550,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
543
550
|
hidden_states: torch.Tensor,
|
544
551
|
should_allreduce_fusion: bool = False,
|
545
552
|
use_reduce_scatter: bool = False,
|
553
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
546
554
|
) -> torch.Tensor:
|
547
555
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
548
556
|
self.shared_experts.gate_up_proj
|
@@ -666,6 +674,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
|
666
674
|
forward_batch: ForwardBatch,
|
667
675
|
residual: Optional[torch.Tensor],
|
668
676
|
zero_allocator: BumpAllocator,
|
677
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
669
678
|
) -> torch.Tensor:
|
670
679
|
hidden_states, residual = self.layer_communicator.prepare_attn(
|
671
680
|
hidden_states, residual, forward_batch
|
sglang/srt/models/gpt_oss.py
CHANGED
@@ -193,8 +193,9 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
193
193
|
return ans
|
194
194
|
|
195
195
|
|
196
|
-
def _enable_fused_set_kv_buffer():
|
197
|
-
|
196
|
+
def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
|
197
|
+
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
|
198
|
+
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
|
198
199
|
|
199
200
|
|
200
201
|
# TODO maybe move to a model-common utils
|
@@ -341,7 +342,7 @@ class GptOssAttention(nn.Module):
|
|
341
342
|
layer=self.attn,
|
342
343
|
forward_batch=forward_batch,
|
343
344
|
)
|
344
|
-
if _enable_fused_set_kv_buffer()
|
345
|
+
if _enable_fused_set_kv_buffer(forward_batch)
|
345
346
|
else None
|
346
347
|
),
|
347
348
|
)
|
@@ -355,7 +356,7 @@ class GptOssAttention(nn.Module):
|
|
355
356
|
attn_output = self.attn(
|
356
357
|
*inner_state,
|
357
358
|
sinks=self.sinks,
|
358
|
-
save_kv_cache=not _enable_fused_set_kv_buffer(),
|
359
|
+
save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
|
359
360
|
)
|
360
361
|
output, _ = self.o_proj(attn_output)
|
361
362
|
return output
|
sglang/srt/models/internvl.py
CHANGED
@@ -26,8 +26,10 @@ from sglang.srt.managers.schedule_batch import (
|
|
26
26
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
27
27
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
28
28
|
from sglang.srt.models.deepseek_janus_pro import DropPath
|
29
|
+
from sglang.srt.models.gpt_oss import GptOssForCausalLM
|
29
30
|
from sglang.srt.models.internlm2 import InternLM2ForCausalLM
|
30
31
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
32
|
+
from sglang.srt.models.qwen3 import Qwen3ForCausalLM
|
31
33
|
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
|
32
34
|
from sglang.utils import logger
|
33
35
|
|
@@ -445,6 +447,14 @@ class InternVLChatModel(nn.Module):
|
|
445
447
|
self.language_model = Qwen3MoeForCausalLM(
|
446
448
|
config=config.llm_config, quant_config=quant_config
|
447
449
|
)
|
450
|
+
elif config.llm_config.architectures[0] == "GptOssForCausalLM":
|
451
|
+
self.language_model = GptOssForCausalLM(
|
452
|
+
config=config.llm_config, quant_config=quant_config
|
453
|
+
)
|
454
|
+
elif config.llm_config.architectures[0] == "Qwen3ForCausalLM":
|
455
|
+
self.language_model = Qwen3ForCausalLM(
|
456
|
+
config=config.llm_config, quant_config=quant_config
|
457
|
+
)
|
448
458
|
else:
|
449
459
|
raise NotImplementedError(
|
450
460
|
f"{config.llm_config.architectures[0]} is not implemented."
|
@@ -577,6 +587,15 @@ class InternVLChatModel(nn.Module):
|
|
577
587
|
ckpt_up_proj_name="up_proj",
|
578
588
|
num_experts=self.config.num_experts,
|
579
589
|
)
|
590
|
+
elif "Qwen3ForCausalLM" in self.config.llm_config.architectures:
|
591
|
+
stacked_params_mapping = [
|
592
|
+
# (param_name, shard_name, shard_id)
|
593
|
+
("qkv_proj", "q_proj", "q"),
|
594
|
+
("qkv_proj", "k_proj", "k"),
|
595
|
+
("qkv_proj", "v_proj", "v"),
|
596
|
+
("gate_up_proj", "gate_proj", 0),
|
597
|
+
("gate_up_proj", "up_proj", 1),
|
598
|
+
]
|
580
599
|
|
581
600
|
params_dict = dict(self.named_parameters())
|
582
601
|
loaded_params: Set[str] = set()
|
@@ -661,6 +680,15 @@ class InternVLChatModel(nn.Module):
|
|
661
680
|
|
662
681
|
loaded_params.add(name)
|
663
682
|
unloaded_params = params_dict.keys() - loaded_params
|
683
|
+
# Skip params that are created by quantization wrappers and are not expected in the ckpt
|
684
|
+
_quant_only_fragments = (
|
685
|
+
"weight_scale", # per-matrix FP8 scales (e.g., w2_weight_scale, w13_weight_scale)
|
686
|
+
)
|
687
|
+
unloaded_params = {
|
688
|
+
n
|
689
|
+
for n in unloaded_params
|
690
|
+
if not any(frag in n for frag in _quant_only_fragments)
|
691
|
+
}
|
664
692
|
if unloaded_params:
|
665
693
|
raise RuntimeError(
|
666
694
|
f"Some weights are not initialized from checkpoints: {unloaded_params}"
|