sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -60,12 +60,9 @@ from sglang.srt.layers.linear import (
|
|
60
60
|
RowParallelLinear,
|
61
61
|
)
|
62
62
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
63
|
-
from sglang.srt.layers.moe.ep_moe.layer import
|
64
|
-
DeepEPMoE,
|
65
|
-
get_moe_impl_class,
|
66
|
-
should_use_flashinfer_trtllm_moe,
|
67
|
-
)
|
63
|
+
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
68
64
|
from sglang.srt.layers.moe.topk import TopK
|
65
|
+
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
69
66
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
70
67
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
71
68
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
@@ -211,13 +208,21 @@ class DeepseekV2MLP(nn.Module):
|
|
211
208
|
)
|
212
209
|
self.act_fn = SiluAndMul()
|
213
210
|
|
214
|
-
def forward(
|
211
|
+
def forward(
|
212
|
+
self,
|
213
|
+
x,
|
214
|
+
forward_batch=None,
|
215
|
+
should_allreduce_fusion: bool = False,
|
216
|
+
use_reduce_scatter: bool = False,
|
217
|
+
):
|
215
218
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
216
219
|
return x
|
217
220
|
|
218
221
|
gate_up, _ = self.gate_up_proj(x)
|
219
222
|
x = self.act_fn(gate_up)
|
220
|
-
x, _ = self.down_proj(
|
223
|
+
x, _ = self.down_proj(
|
224
|
+
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
|
225
|
+
)
|
221
226
|
return x
|
222
227
|
|
223
228
|
|
@@ -307,19 +312,15 @@ class DeepseekV2MoE(nn.Module):
|
|
307
312
|
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
308
313
|
)
|
309
314
|
|
310
|
-
self.topk = (
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
320
|
-
)
|
321
|
-
if not should_use_flashinfer_trtllm_moe()
|
322
|
-
else None
|
315
|
+
self.topk = TopK(
|
316
|
+
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
317
|
+
renormalize=config.norm_topk_prob,
|
318
|
+
use_grouped_topk=True,
|
319
|
+
num_expert_group=config.n_group,
|
320
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
321
|
+
topk_group=config.topk_group,
|
322
|
+
correction_bias=self.gate.e_score_correction_bias,
|
323
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
323
324
|
)
|
324
325
|
|
325
326
|
self.experts = get_moe_impl_class()(
|
@@ -447,7 +448,8 @@ class DeepseekV2MoE(nn.Module):
|
|
447
448
|
self,
|
448
449
|
hidden_states: torch.Tensor,
|
449
450
|
forward_batch: Optional[ForwardBatch] = None,
|
450
|
-
|
451
|
+
should_allreduce_fusion: bool = False,
|
452
|
+
use_reduce_scatter: bool = False,
|
451
453
|
) -> torch.Tensor:
|
452
454
|
if not self._enable_deepep_moe:
|
453
455
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
@@ -457,15 +459,20 @@ class DeepseekV2MoE(nn.Module):
|
|
457
459
|
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
458
460
|
):
|
459
461
|
return self.forward_normal_dual_stream(
|
460
|
-
hidden_states,
|
462
|
+
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
461
463
|
)
|
462
464
|
else:
|
463
|
-
return self.forward_normal(
|
465
|
+
return self.forward_normal(
|
466
|
+
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
467
|
+
)
|
464
468
|
else:
|
465
469
|
return self.forward_deepep(hidden_states, forward_batch)
|
466
470
|
|
467
471
|
def forward_normal_dual_stream(
|
468
|
-
self,
|
472
|
+
self,
|
473
|
+
hidden_states: torch.Tensor,
|
474
|
+
should_allreduce_fusion: bool = False,
|
475
|
+
use_reduce_scatter: bool = False,
|
469
476
|
) -> torch.Tensor:
|
470
477
|
|
471
478
|
current_stream = torch.cuda.current_stream()
|
@@ -476,10 +483,14 @@ class DeepseekV2MoE(nn.Module):
|
|
476
483
|
# router_logits: (num_tokens, n_experts)
|
477
484
|
router_logits = self.gate(hidden_states)
|
478
485
|
kwargs = {"hidden_states": hidden_states}
|
479
|
-
|
480
|
-
|
486
|
+
|
487
|
+
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
488
|
+
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
489
|
+
if should_use_flashinfer_trtllm_moe():
|
490
|
+
kwargs["topk_output"] = (self.topk, router_logits)
|
481
491
|
else:
|
482
|
-
kwargs["
|
492
|
+
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
493
|
+
|
483
494
|
final_hidden_states = self.experts(**kwargs)
|
484
495
|
if not _is_cuda:
|
485
496
|
final_hidden_states *= self.routed_scaling_factor
|
@@ -489,26 +500,33 @@ class DeepseekV2MoE(nn.Module):
|
|
489
500
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
490
501
|
final_hidden_states = final_hidden_states_out
|
491
502
|
sm.tag(final_hidden_states)
|
492
|
-
if self.tp_size > 1 and not
|
503
|
+
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
|
493
504
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
494
505
|
return final_hidden_states
|
495
506
|
|
496
507
|
def forward_normal(
|
497
|
-
self,
|
508
|
+
self,
|
509
|
+
hidden_states: torch.Tensor,
|
510
|
+
should_allreduce_fusion: bool = False,
|
511
|
+
use_reduce_scatter: bool = False,
|
498
512
|
) -> torch.Tensor:
|
499
513
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
500
514
|
self.shared_experts.gate_up_proj
|
501
515
|
):
|
502
|
-
return self.forward_cpu(hidden_states,
|
516
|
+
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
503
517
|
|
504
518
|
shared_output = self._forward_shared_experts(hidden_states)
|
505
519
|
# router_logits: (num_tokens, n_experts)
|
506
520
|
router_logits = self.gate(hidden_states)
|
507
521
|
kwargs = {"hidden_states": hidden_states}
|
508
|
-
|
509
|
-
|
522
|
+
|
523
|
+
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
524
|
+
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
525
|
+
if should_use_flashinfer_trtllm_moe():
|
526
|
+
kwargs["topk_output"] = (self.topk, router_logits)
|
510
527
|
else:
|
511
|
-
kwargs["
|
528
|
+
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
529
|
+
|
512
530
|
final_hidden_states = self.experts(**kwargs)
|
513
531
|
if not _is_cuda and not _use_aiter:
|
514
532
|
# fused in biased_grouped_topk so we can skip here
|
@@ -519,12 +537,14 @@ class DeepseekV2MoE(nn.Module):
|
|
519
537
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
520
538
|
final_hidden_states = final_hidden_states_out
|
521
539
|
sm.tag(final_hidden_states)
|
522
|
-
if self.tp_size > 1 and not
|
540
|
+
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
|
523
541
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
524
542
|
return final_hidden_states
|
525
543
|
|
526
544
|
def forward_cpu(
|
527
|
-
self,
|
545
|
+
self,
|
546
|
+
hidden_states: torch.Tensor,
|
547
|
+
should_allreduce_fusion: bool = False,
|
528
548
|
) -> torch.Tensor:
|
529
549
|
# router_logits: (num_tokens, n_experts)
|
530
550
|
router_logits = self.gate(hidden_states)
|
@@ -575,7 +595,7 @@ class DeepseekV2MoE(nn.Module):
|
|
575
595
|
None, # a2_scale
|
576
596
|
True, # is_vnni
|
577
597
|
)
|
578
|
-
if self.tp_size > 1 and not
|
598
|
+
if self.tp_size > 1 and not should_allreduce_fusion:
|
579
599
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
580
600
|
return final_hidden_states
|
581
601
|
|
@@ -1176,6 +1196,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1176
1196
|
output, _ = self.o_proj(attn_output)
|
1177
1197
|
return output
|
1178
1198
|
|
1199
|
+
def _fuse_rope_for_trtllm_mla(self, forward_batch: ForwardBatch) -> bool:
|
1200
|
+
"""
|
1201
|
+
Check if we should skip rope and do fused rope+quantize for TRTLLM MLA decode in fp8_e4m3 path.
|
1202
|
+
"""
|
1203
|
+
return (
|
1204
|
+
self.current_attention_backend == "trtllm_mla"
|
1205
|
+
and forward_batch.forward_mode.is_decode_or_idle()
|
1206
|
+
and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
|
1207
|
+
)
|
1208
|
+
|
1179
1209
|
def forward_absorb_prepare(
|
1180
1210
|
self,
|
1181
1211
|
positions: torch.Tensor,
|
@@ -1255,7 +1285,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1255
1285
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
1256
1286
|
|
1257
1287
|
q_nope_out = q_nope_out.transpose(0, 1)
|
1258
|
-
|
1288
|
+
|
1289
|
+
if not self._fuse_rope_for_trtllm_mla(forward_batch):
|
1290
|
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1259
1291
|
|
1260
1292
|
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1261
1293
|
|
@@ -1268,8 +1300,20 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1268
1300
|
or self.current_attention_backend == "cutlass_mla"
|
1269
1301
|
or self.current_attention_backend == "trtllm_mla"
|
1270
1302
|
):
|
1303
|
+
extra_args = {}
|
1304
|
+
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
1305
|
+
extra_args = {
|
1306
|
+
"cos_sin_cache": self.rotary_emb.cos_sin_cache,
|
1307
|
+
"is_neox": self.rotary_emb.is_neox_style,
|
1308
|
+
}
|
1271
1309
|
attn_output = self.attn_mqa(
|
1272
|
-
q_nope_out,
|
1310
|
+
q_nope_out,
|
1311
|
+
k_nope,
|
1312
|
+
k_nope,
|
1313
|
+
forward_batch,
|
1314
|
+
q_rope=q_pe,
|
1315
|
+
k_rope=k_pe,
|
1316
|
+
**extra_args,
|
1273
1317
|
)
|
1274
1318
|
else:
|
1275
1319
|
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
@@ -1821,8 +1865,11 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1821
1865
|
layer_scatter_modes=self.layer_scatter_modes,
|
1822
1866
|
input_layernorm=self.input_layernorm,
|
1823
1867
|
post_attention_layernorm=self.post_attention_layernorm,
|
1868
|
+
allow_reduce_scatter=True,
|
1824
1869
|
)
|
1825
1870
|
|
1871
|
+
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
|
1872
|
+
|
1826
1873
|
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
1827
1874
|
return is_nextn or (
|
1828
1875
|
self.config.n_routed_experts is not None
|
@@ -1831,27 +1878,18 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1831
1878
|
)
|
1832
1879
|
|
1833
1880
|
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
1834
|
-
"""Check if MLP allreduce can be fused with next layer's
|
1881
|
+
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
|
1835
1882
|
|
1836
|
-
|
1837
|
-
|
1838
|
-
|
1839
|
-
|
1840
|
-
|
1841
|
-
|
1842
|
-
if not global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False):
|
1843
|
-
return False
|
1844
|
-
|
1845
|
-
if not _is_sm100_supported or not _is_flashinfer_available:
|
1846
|
-
return False
|
1883
|
+
batch_size = (
|
1884
|
+
forward_batch.input_ids.shape[0]
|
1885
|
+
if hasattr(forward_batch, "input_ids")
|
1886
|
+
else 0
|
1887
|
+
)
|
1847
1888
|
|
1848
|
-
if
|
1849
|
-
forward_batch.input_ids.shape[0] == 0
|
1850
|
-
or forward_batch.input_ids.shape[0] > 128
|
1851
|
-
):
|
1889
|
+
if batch_size > 128:
|
1852
1890
|
return False
|
1853
1891
|
|
1854
|
-
return
|
1892
|
+
return self._fuse_allreduce_lookup_table.get(batch_size, False)
|
1855
1893
|
|
1856
1894
|
def forward(
|
1857
1895
|
self,
|
@@ -1877,18 +1915,24 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1877
1915
|
hidden_states, residual, forward_batch
|
1878
1916
|
)
|
1879
1917
|
|
1880
|
-
|
1918
|
+
should_allreduce_fusion = (
|
1881
1919
|
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
|
1882
1920
|
and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
|
1883
1921
|
and not self.is_nextn
|
1884
1922
|
)
|
1885
1923
|
|
1886
|
-
|
1924
|
+
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
1925
|
+
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
1926
|
+
forward_batch
|
1927
|
+
)
|
1928
|
+
hidden_states = self.mlp(
|
1929
|
+
hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
|
1930
|
+
)
|
1887
1931
|
|
1888
|
-
if
|
1932
|
+
if should_allreduce_fusion:
|
1889
1933
|
hidden_states._sglang_needs_allreduce_fusion = True
|
1890
1934
|
|
1891
|
-
if not
|
1935
|
+
if not should_allreduce_fusion:
|
1892
1936
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
1893
1937
|
hidden_states, residual, forward_batch
|
1894
1938
|
)
|
@@ -1965,6 +2009,26 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1965
2009
|
)
|
1966
2010
|
return output
|
1967
2011
|
|
2012
|
+
def _build_fuse_allreduce_lookup_table(self):
|
2013
|
+
static_conditions_met = (
|
2014
|
+
self.layer_id != self.config.num_hidden_layers - 1
|
2015
|
+
and get_tensor_model_parallel_world_size() > 1
|
2016
|
+
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
2017
|
+
and _is_sm100_supported
|
2018
|
+
and _is_flashinfer_available
|
2019
|
+
)
|
2020
|
+
|
2021
|
+
if not static_conditions_met:
|
2022
|
+
return {}
|
2023
|
+
|
2024
|
+
lookup_table = {}
|
2025
|
+
for batch_size in range(129): # 0 to 128
|
2026
|
+
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
|
2027
|
+
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
|
2028
|
+
lookup_table[batch_size] = should_fuse
|
2029
|
+
|
2030
|
+
return lookup_table
|
2031
|
+
|
1968
2032
|
|
1969
2033
|
class DeepseekV2Model(nn.Module):
|
1970
2034
|
fall_back_to_pt_during_load = False
|
@@ -2060,6 +2124,8 @@ class DeepseekV2Model(nn.Module):
|
|
2060
2124
|
|
2061
2125
|
|
2062
2126
|
class DeepseekV2ForCausalLM(nn.Module):
|
2127
|
+
# for quark model load
|
2128
|
+
packed_modules_mapping = {}
|
2063
2129
|
|
2064
2130
|
def __init__(
|
2065
2131
|
self,
|
@@ -2068,6 +2134,18 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2068
2134
|
prefix: str = "",
|
2069
2135
|
) -> None:
|
2070
2136
|
super().__init__()
|
2137
|
+
|
2138
|
+
# for quark model load
|
2139
|
+
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
2140
|
+
self.fuse_qkv_a_proj = (
|
2141
|
+
hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
|
2142
|
+
)
|
2143
|
+
if self.fuse_qkv_a_proj:
|
2144
|
+
self.packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
|
2145
|
+
"q_a_proj",
|
2146
|
+
"kv_a_proj_with_mqa",
|
2147
|
+
]
|
2148
|
+
|
2071
2149
|
self.config = config
|
2072
2150
|
self.tp_size = get_tensor_model_parallel_world_size()
|
2073
2151
|
self.quant_config = quant_config
|