sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -36,13 +36,13 @@ from sglang.srt.distributed import (
36
36
  )
37
37
  from sglang.srt.layers.activation import SiluAndMul
38
38
  from sglang.srt.layers.dp_attention import (
39
+ attn_tp_all_gather,
40
+ attn_tp_reduce_scatter,
39
41
  dp_gather_partial,
40
42
  dp_scatter,
41
- get_attention_dp_size,
42
43
  get_attention_tp_rank,
43
44
  get_attention_tp_size,
44
- tp_all_gather,
45
- tp_reduce_scatter,
45
+ get_local_attention_dp_size,
46
46
  )
47
47
  from sglang.srt.layers.layernorm import RMSNorm
48
48
  from sglang.srt.layers.linear import (
@@ -59,10 +59,11 @@ from sglang.srt.layers.moe.topk import select_experts
59
59
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
60
60
  from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
61
61
  from sglang.srt.layers.quantization.fp8_kernel import (
62
- per_tensor_quant_mla_deep_gemm_masked_fp8,
63
62
  per_tensor_quant_mla_fp8,
63
+ per_token_group_quant_mla_deep_gemm_masked_fp8,
64
64
  )
65
65
  from sglang.srt.layers.quantization.fp8_utils import (
66
+ block_quant_dequant,
66
67
  block_quant_to_tensor_quant,
67
68
  channel_quant_to_tensor_quant,
68
69
  normalize_e4m3fn_to_e4m3fnuz,
@@ -88,6 +89,7 @@ from sglang.srt.utils import (
88
89
  get_int_env_var,
89
90
  is_cuda,
90
91
  is_hip,
92
+ log_info_on_rank0,
91
93
  )
92
94
 
93
95
  _is_hip = is_hip()
@@ -356,6 +358,7 @@ class DeepseekV2MoE(nn.Module):
356
358
  topk_idx,
357
359
  topk_weights,
358
360
  reorder_topk_ids,
361
+ num_recv_tokens_per_expert,
359
362
  seg_indptr,
360
363
  masked_m,
361
364
  expected_m,
@@ -367,10 +370,13 @@ class DeepseekV2MoE(nn.Module):
367
370
  )
368
371
  final_hidden_states = self.experts(
369
372
  hidden_states=hidden_states,
373
+ topk_idx=topk_idx,
374
+ topk_weights=topk_weights,
370
375
  reorder_topk_ids=reorder_topk_ids,
371
376
  seg_indptr=seg_indptr,
372
377
  masked_m=masked_m,
373
378
  expected_m=expected_m,
379
+ num_recv_tokens_per_expert=num_recv_tokens_per_expert,
374
380
  forward_mode=forward_mode,
375
381
  )
376
382
  if self.ep_size > 1:
@@ -421,6 +427,7 @@ class DeepseekV2AttentionMLA(nn.Module):
421
427
  reduce_results: bool = True,
422
428
  layer_id: int = None,
423
429
  prefix: str = "",
430
+ alt_stream: Optional[torch.cuda.Stream] = None,
424
431
  ) -> None:
425
432
  super().__init__()
426
433
  self.layer_id = layer_id
@@ -431,7 +438,6 @@ class DeepseekV2AttentionMLA(nn.Module):
431
438
  self.v_head_dim = v_head_dim
432
439
  self.q_lora_rank = q_lora_rank
433
440
  self.kv_lora_rank = kv_lora_rank
434
- self.dp_size = get_attention_dp_size()
435
441
  attn_tp_rank = get_attention_tp_rank()
436
442
  attn_tp_size = get_attention_tp_size()
437
443
 
@@ -543,6 +549,8 @@ class DeepseekV2AttentionMLA(nn.Module):
543
549
  prefix=add_prefix("attn_mha", prefix),
544
550
  )
545
551
 
552
+ self.alt_stream = alt_stream
553
+
546
554
  self.w_kc = None
547
555
  self.w_vc = None
548
556
  self.w_scale = None
@@ -706,20 +714,36 @@ class DeepseekV2AttentionMLA(nn.Module):
706
714
  q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
707
715
  [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
708
716
  )
709
- q = self.q_a_layernorm(q)
717
+ k_nope = latent_cache[..., : self.kv_lora_rank]
718
+
719
+ # overlap qk norm
720
+ if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
721
+ current_stream = torch.cuda.current_stream()
722
+ self.alt_stream.wait_stream(current_stream)
723
+ q = self.q_a_layernorm(q)
724
+ with torch.cuda.stream(self.alt_stream):
725
+ k_nope = self.kv_a_layernorm(k_nope)
726
+ current_stream.wait_stream(self.alt_stream)
727
+ else:
728
+ q = self.q_a_layernorm(q)
729
+ k_nope = self.kv_a_layernorm(k_nope)
730
+
731
+ k_nope = k_nope.unsqueeze(1)
710
732
  q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
711
733
  else:
712
734
  q = self.q_proj(hidden_states)[0].view(
713
735
  -1, self.num_local_heads, self.qk_head_dim
714
736
  )
715
737
  latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
738
+ k_nope = latent_cache[..., : self.kv_lora_rank]
739
+ k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
740
+
716
741
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
742
+ k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
717
743
 
718
744
  if self.use_deep_gemm_bmm:
719
745
  q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
720
- per_tensor_quant_mla_deep_gemm_masked_fp8(
721
- q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
722
- )
746
+ per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
723
747
  )
724
748
  q_nope_out = q_nope.new_empty(
725
749
  (self.num_local_heads, aligned_m, self.kv_lora_rank)
@@ -750,14 +774,9 @@ class DeepseekV2AttentionMLA(nn.Module):
750
774
  q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
751
775
 
752
776
  q_nope_out = q_nope_out.transpose(0, 1)
753
-
754
- k_nope = latent_cache[..., : self.kv_lora_rank]
755
- k_nope = self.kv_a_layernorm(k_nope.contiguous()).unsqueeze(1)
756
- k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
757
-
758
777
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
759
778
 
760
- if self.attention_backend == "fa3":
779
+ if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
761
780
  attn_output = self.attn_mqa(
762
781
  q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
763
782
  )
@@ -769,8 +788,8 @@ class DeepseekV2AttentionMLA(nn.Module):
769
788
 
770
789
  if self.use_deep_gemm_bmm:
771
790
  attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
772
- per_tensor_quant_mla_deep_gemm_masked_fp8(
773
- attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
791
+ per_token_group_quant_mla_deep_gemm_masked_fp8(
792
+ attn_output.transpose(0, 1)
774
793
  )
775
794
  )
776
795
  attn_bmm_output = attn_output.new_empty(
@@ -1104,6 +1123,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1104
1123
  quant_config: Optional[QuantizationConfig] = None,
1105
1124
  is_nextn: bool = False,
1106
1125
  prefix: str = "",
1126
+ alt_stream: Optional[torch.cuda.Stream] = None,
1107
1127
  ) -> None:
1108
1128
  super().__init__()
1109
1129
  self.hidden_size = config.hidden_size
@@ -1112,7 +1132,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1112
1132
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1113
1133
  self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
1114
1134
  self.layer_id = layer_id
1115
- self.dp_size = get_attention_dp_size()
1135
+ self.local_dp_size = get_local_attention_dp_size()
1116
1136
  self.attn_tp_size = get_attention_tp_size()
1117
1137
  self.attn_tp_rank = get_attention_tp_rank()
1118
1138
  self.self_attn = DeepseekV2AttentionMLA(
@@ -1133,6 +1153,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1133
1153
  layer_id=layer_id,
1134
1154
  reduce_results=False,
1135
1155
  prefix=add_prefix("self_attn", prefix),
1156
+ alt_stream=alt_stream,
1136
1157
  )
1137
1158
 
1138
1159
  self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
@@ -1162,7 +1183,8 @@ class DeepseekV2DecoderLayer(nn.Module):
1162
1183
  )
1163
1184
 
1164
1185
  self.input_is_scattered = (
1165
- previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
1186
+ layer_id > 0
1187
+ and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
1166
1188
  )
1167
1189
  self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
1168
1190
 
@@ -1242,7 +1264,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1242
1264
  # Gather
1243
1265
  if get_tensor_model_parallel_world_size() > 1:
1244
1266
  # all gather and all reduce
1245
- if self.dp_size != 1:
1267
+ if self.local_dp_size != 1:
1246
1268
  if self.attn_tp_rank == 0:
1247
1269
  hidden_states += residual
1248
1270
  hidden_states, local_hidden_states = (
@@ -1265,9 +1287,9 @@ class DeepseekV2DecoderLayer(nn.Module):
1265
1287
  # Fully Connected
1266
1288
  hidden_states = self.mlp(hidden_states)
1267
1289
 
1268
- # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
1290
+ # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
1269
1291
  # Scatter
1270
- if self.dp_size != 1:
1292
+ if self.local_dp_size != 1:
1271
1293
  # important: forward batch.gathered_buffer is used both after scatter and after gather.
1272
1294
  # be careful about this!
1273
1295
  hidden_states, global_hidden_states = (
@@ -1301,7 +1323,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1301
1323
  forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1302
1324
  hidden_states,
1303
1325
  )
1304
- tp_all_gather(
1326
+ attn_tp_all_gather(
1305
1327
  list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1306
1328
  )
1307
1329
 
@@ -1317,7 +1339,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1317
1339
  if self.input_is_scattered:
1318
1340
  tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
1319
1341
  hidden_states = tensor_list[self.attn_tp_rank]
1320
- tp_reduce_scatter(hidden_states, tensor_list)
1342
+ attn_tp_reduce_scatter(hidden_states, tensor_list)
1321
1343
  if hidden_states.shape[0] != 0:
1322
1344
  hidden_states, residual = self.post_attention_layernorm(
1323
1345
  hidden_states, residual
@@ -1327,7 +1349,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1327
1349
  hidden_states += residual
1328
1350
  tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
1329
1351
  hidden_states = tensor_list[self.attn_tp_rank]
1330
- tp_reduce_scatter(hidden_states, tensor_list)
1352
+ attn_tp_reduce_scatter(hidden_states, tensor_list)
1331
1353
  residual = hidden_states
1332
1354
  if hidden_states.shape[0] != 0:
1333
1355
  hidden_states = self.post_attention_layernorm(hidden_states)
@@ -1351,7 +1373,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1351
1373
  forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1352
1374
  hidden_states,
1353
1375
  )
1354
- tp_all_gather(
1376
+ attn_tp_all_gather(
1355
1377
  list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1356
1378
  )
1357
1379
 
@@ -1376,6 +1398,7 @@ class DeepseekV2Model(nn.Module):
1376
1398
  config.hidden_size,
1377
1399
  enable_tp=not global_server_args_dict["enable_dp_attention"],
1378
1400
  )
1401
+ self.alt_stream = torch.cuda.Stream()
1379
1402
  self.layers = nn.ModuleList(
1380
1403
  [
1381
1404
  DeepseekV2DecoderLayer(
@@ -1383,13 +1406,14 @@ class DeepseekV2Model(nn.Module):
1383
1406
  layer_id,
1384
1407
  quant_config=quant_config,
1385
1408
  prefix=add_prefix(f"layers.{layer_id}", prefix),
1409
+ alt_stream=self.alt_stream,
1386
1410
  )
1387
1411
  for layer_id in range(config.num_hidden_layers)
1388
1412
  ]
1389
1413
  )
1390
1414
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1391
1415
 
1392
- self.dp_size = get_attention_dp_size()
1416
+ self.dp_size = get_local_attention_dp_size()
1393
1417
 
1394
1418
  def get_input_embeddings(self) -> torch.Tensor:
1395
1419
  return self.embed_tokens
@@ -1451,9 +1475,10 @@ class DeepseekV2ForCausalLM(nn.Module):
1451
1475
  config.hidden_size,
1452
1476
  quant_config=quant_config,
1453
1477
  prefix=add_prefix("lm_head", prefix),
1478
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
1454
1479
  )
1455
1480
  self.logits_processor = LogitsProcessor(config)
1456
- self.dp_size = get_attention_dp_size()
1481
+ self.dp_size = get_local_attention_dp_size()
1457
1482
 
1458
1483
  def determine_n_share_experts_fusion(
1459
1484
  self, architecture: str = "DeepseekV3ForCausalLM"
@@ -1462,29 +1487,33 @@ class DeepseekV2ForCausalLM(nn.Module):
1462
1487
  if self.n_share_experts_fusion > 0:
1463
1488
  # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1464
1489
  if (
1465
- self.config.architectures[0] != architecture
1490
+ not _is_cuda
1491
+ or self.config.architectures[0] != architecture
1466
1492
  or self.config.n_routed_experts != 256
1467
1493
  ):
1468
1494
  self.n_share_experts_fusion = 0
1469
1495
  global_server_args_dict["n_share_experts_fusion"] = 0
1470
- logger.info(
1471
- "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
1496
+ log_info_on_rank0(
1497
+ logger,
1498
+ "Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
1472
1499
  )
1473
1500
  else:
1474
1501
  assert (
1475
1502
  self.n_share_experts_fusion == self.tp_size
1476
- ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
1503
+ ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
1477
1504
  elif self.n_share_experts_fusion == 0:
1478
1505
  if (
1479
- torch.cuda.get_device_capability("cuda") >= (9, 0)
1506
+ _is_cuda
1507
+ and torch.cuda.get_device_capability("cuda") >= (9, 0)
1480
1508
  and self.config.architectures[0] == architecture
1481
1509
  and self.config.n_routed_experts == 256
1482
1510
  and (not global_server_args_dict["enable_deepep_moe"])
1483
1511
  ):
1484
1512
  self.n_share_experts_fusion = self.tp_size
1485
1513
  global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1486
- logger.info(
1487
- "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
1514
+ log_info_on_rank0(
1515
+ logger,
1516
+ "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
1488
1517
  )
1489
1518
 
1490
1519
  def get_input_embeddings(self) -> nn.Embedding:
@@ -1564,13 +1593,22 @@ class DeepseekV2ForCausalLM(nn.Module):
1564
1593
 
1565
1594
  if (
1566
1595
  _is_cuda
1567
- and _ENABLE_JIT_DEEPGEMM
1568
1596
  and weight_block_size[0] == 128
1569
1597
  and weight_block_size[1] == 128
1570
1598
  and model_dtype == torch.bfloat16
1571
1599
  ):
1572
- block_scale = weight_scale
1573
- use_deep_gemm_bmm = True
1600
+ if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
1601
+ "SGL_USE_DEEPGEMM_BMM", "false"
1602
+ ):
1603
+ block_scale = weight_scale
1604
+ use_deep_gemm_bmm = True
1605
+ else:
1606
+ w = block_quant_dequant(
1607
+ weight,
1608
+ weight_scale,
1609
+ weight_block_size,
1610
+ model_dtype,
1611
+ )
1574
1612
  else:
1575
1613
  w, scale = block_quant_to_tensor_quant(
1576
1614
  weight, weight_scale, weight_block_size
@@ -1628,7 +1666,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1628
1666
  if is_nextn:
1629
1667
  if hasattr(self.config, "num_nextn_predict_layers"):
1630
1668
  num_nextn_layers = self.config.num_nextn_predict_layers
1631
- assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
1669
+ assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
1632
1670
  # compatible with old design
1633
1671
  nextn_layer_id = (
1634
1672
  0
@@ -281,7 +281,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
281
281
  pixel_values = torch.stack(
282
282
  flatten_nested_list([item.pixel_values for item in items]), dim=0
283
283
  )
284
- pixel_values = pixel_values.to("cuda")
284
+ pixel_values = pixel_values.to(device=self.vision_tower.device)
285
285
  pixel_values = pixel_values.to(dtype=self.language_model.dtype())
286
286
 
287
287
  vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
@@ -290,6 +290,9 @@ class InternLM2ForCausalLM(nn.Module):
290
290
  )
291
291
  self.logits_processor = LogitsProcessor(config)
292
292
 
293
+ def get_input_embeddings(self) -> nn.Embedding:
294
+ return self.model.tok_embeddings
295
+
293
296
  @torch.no_grad()
294
297
  def forward(
295
298
  self,