sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -36,13 +36,13 @@ from sglang.srt.distributed import (
|
|
36
36
|
)
|
37
37
|
from sglang.srt.layers.activation import SiluAndMul
|
38
38
|
from sglang.srt.layers.dp_attention import (
|
39
|
+
attn_tp_all_gather,
|
40
|
+
attn_tp_reduce_scatter,
|
39
41
|
dp_gather_partial,
|
40
42
|
dp_scatter,
|
41
|
-
get_attention_dp_size,
|
42
43
|
get_attention_tp_rank,
|
43
44
|
get_attention_tp_size,
|
44
|
-
|
45
|
-
tp_reduce_scatter,
|
45
|
+
get_local_attention_dp_size,
|
46
46
|
)
|
47
47
|
from sglang.srt.layers.layernorm import RMSNorm
|
48
48
|
from sglang.srt.layers.linear import (
|
@@ -59,10 +59,11 @@ from sglang.srt.layers.moe.topk import select_experts
|
|
59
59
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
60
60
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
61
61
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
62
|
-
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
63
62
|
per_tensor_quant_mla_fp8,
|
63
|
+
per_token_group_quant_mla_deep_gemm_masked_fp8,
|
64
64
|
)
|
65
65
|
from sglang.srt.layers.quantization.fp8_utils import (
|
66
|
+
block_quant_dequant,
|
66
67
|
block_quant_to_tensor_quant,
|
67
68
|
channel_quant_to_tensor_quant,
|
68
69
|
normalize_e4m3fn_to_e4m3fnuz,
|
@@ -88,6 +89,7 @@ from sglang.srt.utils import (
|
|
88
89
|
get_int_env_var,
|
89
90
|
is_cuda,
|
90
91
|
is_hip,
|
92
|
+
log_info_on_rank0,
|
91
93
|
)
|
92
94
|
|
93
95
|
_is_hip = is_hip()
|
@@ -356,6 +358,7 @@ class DeepseekV2MoE(nn.Module):
|
|
356
358
|
topk_idx,
|
357
359
|
topk_weights,
|
358
360
|
reorder_topk_ids,
|
361
|
+
num_recv_tokens_per_expert,
|
359
362
|
seg_indptr,
|
360
363
|
masked_m,
|
361
364
|
expected_m,
|
@@ -367,10 +370,13 @@ class DeepseekV2MoE(nn.Module):
|
|
367
370
|
)
|
368
371
|
final_hidden_states = self.experts(
|
369
372
|
hidden_states=hidden_states,
|
373
|
+
topk_idx=topk_idx,
|
374
|
+
topk_weights=topk_weights,
|
370
375
|
reorder_topk_ids=reorder_topk_ids,
|
371
376
|
seg_indptr=seg_indptr,
|
372
377
|
masked_m=masked_m,
|
373
378
|
expected_m=expected_m,
|
379
|
+
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
374
380
|
forward_mode=forward_mode,
|
375
381
|
)
|
376
382
|
if self.ep_size > 1:
|
@@ -421,6 +427,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
421
427
|
reduce_results: bool = True,
|
422
428
|
layer_id: int = None,
|
423
429
|
prefix: str = "",
|
430
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
424
431
|
) -> None:
|
425
432
|
super().__init__()
|
426
433
|
self.layer_id = layer_id
|
@@ -431,7 +438,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
431
438
|
self.v_head_dim = v_head_dim
|
432
439
|
self.q_lora_rank = q_lora_rank
|
433
440
|
self.kv_lora_rank = kv_lora_rank
|
434
|
-
self.dp_size = get_attention_dp_size()
|
435
441
|
attn_tp_rank = get_attention_tp_rank()
|
436
442
|
attn_tp_size = get_attention_tp_size()
|
437
443
|
|
@@ -543,6 +549,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
543
549
|
prefix=add_prefix("attn_mha", prefix),
|
544
550
|
)
|
545
551
|
|
552
|
+
self.alt_stream = alt_stream
|
553
|
+
|
546
554
|
self.w_kc = None
|
547
555
|
self.w_vc = None
|
548
556
|
self.w_scale = None
|
@@ -706,20 +714,36 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
706
714
|
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
707
715
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
708
716
|
)
|
709
|
-
|
717
|
+
k_nope = latent_cache[..., : self.kv_lora_rank]
|
718
|
+
|
719
|
+
# overlap qk norm
|
720
|
+
if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
|
721
|
+
current_stream = torch.cuda.current_stream()
|
722
|
+
self.alt_stream.wait_stream(current_stream)
|
723
|
+
q = self.q_a_layernorm(q)
|
724
|
+
with torch.cuda.stream(self.alt_stream):
|
725
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
726
|
+
current_stream.wait_stream(self.alt_stream)
|
727
|
+
else:
|
728
|
+
q = self.q_a_layernorm(q)
|
729
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
730
|
+
|
731
|
+
k_nope = k_nope.unsqueeze(1)
|
710
732
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
711
733
|
else:
|
712
734
|
q = self.q_proj(hidden_states)[0].view(
|
713
735
|
-1, self.num_local_heads, self.qk_head_dim
|
714
736
|
)
|
715
737
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
738
|
+
k_nope = latent_cache[..., : self.kv_lora_rank]
|
739
|
+
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
|
740
|
+
|
716
741
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
742
|
+
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
717
743
|
|
718
744
|
if self.use_deep_gemm_bmm:
|
719
745
|
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
720
|
-
|
721
|
-
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
722
|
-
)
|
746
|
+
per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
|
723
747
|
)
|
724
748
|
q_nope_out = q_nope.new_empty(
|
725
749
|
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
@@ -750,14 +774,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
750
774
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
751
775
|
|
752
776
|
q_nope_out = q_nope_out.transpose(0, 1)
|
753
|
-
|
754
|
-
k_nope = latent_cache[..., : self.kv_lora_rank]
|
755
|
-
k_nope = self.kv_a_layernorm(k_nope.contiguous()).unsqueeze(1)
|
756
|
-
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
757
|
-
|
758
777
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
759
778
|
|
760
|
-
if self.attention_backend == "fa3":
|
779
|
+
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
|
761
780
|
attn_output = self.attn_mqa(
|
762
781
|
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
763
782
|
)
|
@@ -769,8 +788,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
769
788
|
|
770
789
|
if self.use_deep_gemm_bmm:
|
771
790
|
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
|
772
|
-
|
773
|
-
attn_output.transpose(0, 1)
|
791
|
+
per_token_group_quant_mla_deep_gemm_masked_fp8(
|
792
|
+
attn_output.transpose(0, 1)
|
774
793
|
)
|
775
794
|
)
|
776
795
|
attn_bmm_output = attn_output.new_empty(
|
@@ -1104,6 +1123,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1104
1123
|
quant_config: Optional[QuantizationConfig] = None,
|
1105
1124
|
is_nextn: bool = False,
|
1106
1125
|
prefix: str = "",
|
1126
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
1107
1127
|
) -> None:
|
1108
1128
|
super().__init__()
|
1109
1129
|
self.hidden_size = config.hidden_size
|
@@ -1112,7 +1132,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1112
1132
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
1113
1133
|
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
1114
1134
|
self.layer_id = layer_id
|
1115
|
-
self.
|
1135
|
+
self.local_dp_size = get_local_attention_dp_size()
|
1116
1136
|
self.attn_tp_size = get_attention_tp_size()
|
1117
1137
|
self.attn_tp_rank = get_attention_tp_rank()
|
1118
1138
|
self.self_attn = DeepseekV2AttentionMLA(
|
@@ -1133,6 +1153,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1133
1153
|
layer_id=layer_id,
|
1134
1154
|
reduce_results=False,
|
1135
1155
|
prefix=add_prefix("self_attn", prefix),
|
1156
|
+
alt_stream=alt_stream,
|
1136
1157
|
)
|
1137
1158
|
|
1138
1159
|
self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
|
@@ -1162,7 +1183,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1162
1183
|
)
|
1163
1184
|
|
1164
1185
|
self.input_is_scattered = (
|
1165
|
-
|
1186
|
+
layer_id > 0
|
1187
|
+
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
1166
1188
|
)
|
1167
1189
|
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
1168
1190
|
|
@@ -1242,7 +1264,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1242
1264
|
# Gather
|
1243
1265
|
if get_tensor_model_parallel_world_size() > 1:
|
1244
1266
|
# all gather and all reduce
|
1245
|
-
if self.
|
1267
|
+
if self.local_dp_size != 1:
|
1246
1268
|
if self.attn_tp_rank == 0:
|
1247
1269
|
hidden_states += residual
|
1248
1270
|
hidden_states, local_hidden_states = (
|
@@ -1265,9 +1287,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1265
1287
|
# Fully Connected
|
1266
1288
|
hidden_states = self.mlp(hidden_states)
|
1267
1289
|
|
1268
|
-
# TODO(ch-wan):
|
1290
|
+
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
1269
1291
|
# Scatter
|
1270
|
-
if self.
|
1292
|
+
if self.local_dp_size != 1:
|
1271
1293
|
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
1272
1294
|
# be careful about this!
|
1273
1295
|
hidden_states, global_hidden_states = (
|
@@ -1301,7 +1323,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1301
1323
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1302
1324
|
hidden_states,
|
1303
1325
|
)
|
1304
|
-
|
1326
|
+
attn_tp_all_gather(
|
1305
1327
|
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1306
1328
|
)
|
1307
1329
|
|
@@ -1317,7 +1339,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1317
1339
|
if self.input_is_scattered:
|
1318
1340
|
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
1319
1341
|
hidden_states = tensor_list[self.attn_tp_rank]
|
1320
|
-
|
1342
|
+
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
1321
1343
|
if hidden_states.shape[0] != 0:
|
1322
1344
|
hidden_states, residual = self.post_attention_layernorm(
|
1323
1345
|
hidden_states, residual
|
@@ -1327,7 +1349,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1327
1349
|
hidden_states += residual
|
1328
1350
|
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
1329
1351
|
hidden_states = tensor_list[self.attn_tp_rank]
|
1330
|
-
|
1352
|
+
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
1331
1353
|
residual = hidden_states
|
1332
1354
|
if hidden_states.shape[0] != 0:
|
1333
1355
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
@@ -1351,7 +1373,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1351
1373
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1352
1374
|
hidden_states,
|
1353
1375
|
)
|
1354
|
-
|
1376
|
+
attn_tp_all_gather(
|
1355
1377
|
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1356
1378
|
)
|
1357
1379
|
|
@@ -1376,6 +1398,7 @@ class DeepseekV2Model(nn.Module):
|
|
1376
1398
|
config.hidden_size,
|
1377
1399
|
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
1378
1400
|
)
|
1401
|
+
self.alt_stream = torch.cuda.Stream()
|
1379
1402
|
self.layers = nn.ModuleList(
|
1380
1403
|
[
|
1381
1404
|
DeepseekV2DecoderLayer(
|
@@ -1383,13 +1406,14 @@ class DeepseekV2Model(nn.Module):
|
|
1383
1406
|
layer_id,
|
1384
1407
|
quant_config=quant_config,
|
1385
1408
|
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
1409
|
+
alt_stream=self.alt_stream,
|
1386
1410
|
)
|
1387
1411
|
for layer_id in range(config.num_hidden_layers)
|
1388
1412
|
]
|
1389
1413
|
)
|
1390
1414
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1391
1415
|
|
1392
|
-
self.dp_size =
|
1416
|
+
self.dp_size = get_local_attention_dp_size()
|
1393
1417
|
|
1394
1418
|
def get_input_embeddings(self) -> torch.Tensor:
|
1395
1419
|
return self.embed_tokens
|
@@ -1451,9 +1475,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1451
1475
|
config.hidden_size,
|
1452
1476
|
quant_config=quant_config,
|
1453
1477
|
prefix=add_prefix("lm_head", prefix),
|
1478
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
1454
1479
|
)
|
1455
1480
|
self.logits_processor = LogitsProcessor(config)
|
1456
|
-
self.dp_size =
|
1481
|
+
self.dp_size = get_local_attention_dp_size()
|
1457
1482
|
|
1458
1483
|
def determine_n_share_experts_fusion(
|
1459
1484
|
self, architecture: str = "DeepseekV3ForCausalLM"
|
@@ -1462,29 +1487,33 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1462
1487
|
if self.n_share_experts_fusion > 0:
|
1463
1488
|
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
1464
1489
|
if (
|
1465
|
-
|
1490
|
+
not _is_cuda
|
1491
|
+
or self.config.architectures[0] != architecture
|
1466
1492
|
or self.config.n_routed_experts != 256
|
1467
1493
|
):
|
1468
1494
|
self.n_share_experts_fusion = 0
|
1469
1495
|
global_server_args_dict["n_share_experts_fusion"] = 0
|
1470
|
-
|
1471
|
-
|
1496
|
+
log_info_on_rank0(
|
1497
|
+
logger,
|
1498
|
+
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
|
1472
1499
|
)
|
1473
1500
|
else:
|
1474
1501
|
assert (
|
1475
1502
|
self.n_share_experts_fusion == self.tp_size
|
1476
|
-
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized
|
1503
|
+
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
|
1477
1504
|
elif self.n_share_experts_fusion == 0:
|
1478
1505
|
if (
|
1479
|
-
|
1506
|
+
_is_cuda
|
1507
|
+
and torch.cuda.get_device_capability("cuda") >= (9, 0)
|
1480
1508
|
and self.config.architectures[0] == architecture
|
1481
1509
|
and self.config.n_routed_experts == 256
|
1482
1510
|
and (not global_server_args_dict["enable_deepep_moe"])
|
1483
1511
|
):
|
1484
1512
|
self.n_share_experts_fusion = self.tp_size
|
1485
1513
|
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
|
1486
|
-
|
1487
|
-
|
1514
|
+
log_info_on_rank0(
|
1515
|
+
logger,
|
1516
|
+
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
1488
1517
|
)
|
1489
1518
|
|
1490
1519
|
def get_input_embeddings(self) -> nn.Embedding:
|
@@ -1564,13 +1593,22 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1564
1593
|
|
1565
1594
|
if (
|
1566
1595
|
_is_cuda
|
1567
|
-
and _ENABLE_JIT_DEEPGEMM
|
1568
1596
|
and weight_block_size[0] == 128
|
1569
1597
|
and weight_block_size[1] == 128
|
1570
1598
|
and model_dtype == torch.bfloat16
|
1571
1599
|
):
|
1572
|
-
|
1573
|
-
|
1600
|
+
if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
|
1601
|
+
"SGL_USE_DEEPGEMM_BMM", "false"
|
1602
|
+
):
|
1603
|
+
block_scale = weight_scale
|
1604
|
+
use_deep_gemm_bmm = True
|
1605
|
+
else:
|
1606
|
+
w = block_quant_dequant(
|
1607
|
+
weight,
|
1608
|
+
weight_scale,
|
1609
|
+
weight_block_size,
|
1610
|
+
model_dtype,
|
1611
|
+
)
|
1574
1612
|
else:
|
1575
1613
|
w, scale = block_quant_to_tensor_quant(
|
1576
1614
|
weight, weight_scale, weight_block_size
|
@@ -1628,7 +1666,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1628
1666
|
if is_nextn:
|
1629
1667
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
1630
1668
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
1631
|
-
assert num_nextn_layers == 1, "Only 1 nextn layer is
|
1669
|
+
assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
|
1632
1670
|
# compatible with old design
|
1633
1671
|
nextn_layer_id = (
|
1634
1672
|
0
|
sglang/srt/models/gemma3_mm.py
CHANGED
@@ -281,7 +281,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
281
281
|
pixel_values = torch.stack(
|
282
282
|
flatten_nested_list([item.pixel_values for item in items]), dim=0
|
283
283
|
)
|
284
|
-
pixel_values = pixel_values.to(
|
284
|
+
pixel_values = pixel_values.to(device=self.vision_tower.device)
|
285
285
|
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
|
286
286
|
|
287
287
|
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
|
sglang/srt/models/internlm2.py
CHANGED