sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. sglang/bench_one_batch.py +1 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/lang/chat_template.py +44 -0
  4. sglang/srt/configs/deepseekvl2.py +3 -0
  5. sglang/srt/configs/device_config.py +1 -1
  6. sglang/srt/configs/internvl.py +696 -0
  7. sglang/srt/configs/janus_pro.py +3 -0
  8. sglang/srt/configs/model_config.py +17 -0
  9. sglang/srt/constrained/xgrammar_backend.py +11 -19
  10. sglang/srt/conversation.py +30 -3
  11. sglang/srt/disaggregation/decode.py +4 -1
  12. sglang/srt/disaggregation/mini_lb.py +74 -23
  13. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  14. sglang/srt/disaggregation/nixl/conn.py +241 -71
  15. sglang/srt/disaggregation/utils.py +44 -1
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  17. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  19. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  20. sglang/srt/distributed/parallel_state.py +22 -1
  21. sglang/srt/entrypoints/engine.py +14 -2
  22. sglang/srt/entrypoints/http_server.py +28 -1
  23. sglang/srt/entrypoints/verl_engine.py +3 -2
  24. sglang/srt/hf_transformers_utils.py +20 -1
  25. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  26. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  27. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  28. sglang/srt/layers/attention/merge_state.py +46 -0
  29. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  30. sglang/srt/layers/attention/vision.py +290 -163
  31. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  32. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  33. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
  37. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  38. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  39. sglang/srt/layers/quantization/deep_gemm.py +5 -0
  40. sglang/srt/layers/quantization/fp8.py +108 -95
  41. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  42. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  43. sglang/srt/layers/quantization/kv_cache.py +3 -10
  44. sglang/srt/layers/quantization/utils.py +0 -5
  45. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  46. sglang/srt/lora/lora_manager.py +10 -13
  47. sglang/srt/managers/cache_controller.py +115 -119
  48. sglang/srt/managers/io_struct.py +10 -0
  49. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  50. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  51. sglang/srt/managers/schedule_batch.py +19 -1
  52. sglang/srt/managers/schedule_policy.py +11 -5
  53. sglang/srt/managers/scheduler.py +28 -13
  54. sglang/srt/managers/tokenizer_manager.py +24 -13
  55. sglang/srt/managers/tp_worker.py +9 -12
  56. sglang/srt/mem_cache/chunk_cache.py +2 -0
  57. sglang/srt/mem_cache/memory_pool.py +2 -2
  58. sglang/srt/model_executor/model_runner.py +44 -33
  59. sglang/srt/model_loader/loader.py +18 -11
  60. sglang/srt/models/clip.py +4 -4
  61. sglang/srt/models/deepseek_janus_pro.py +1 -1
  62. sglang/srt/models/deepseek_nextn.py +1 -20
  63. sglang/srt/models/deepseek_v2.py +55 -20
  64. sglang/srt/models/gemma3_mm.py +1 -1
  65. sglang/srt/models/internlm2.py +3 -0
  66. sglang/srt/models/internvl.py +670 -0
  67. sglang/srt/models/llama.py +1 -1
  68. sglang/srt/models/llama4.py +53 -7
  69. sglang/srt/models/minicpmv.py +1 -1
  70. sglang/srt/models/mllama.py +1 -1
  71. sglang/srt/models/phi3_small.py +16 -2
  72. sglang/srt/models/qwen2_5_vl.py +8 -4
  73. sglang/srt/models/qwen2_vl.py +4 -4
  74. sglang/srt/models/xiaomi_mimo.py +171 -0
  75. sglang/srt/openai_api/adapter.py +24 -40
  76. sglang/srt/openai_api/protocol.py +28 -16
  77. sglang/srt/reasoning_parser.py +2 -2
  78. sglang/srt/sampling/sampling_batch_info.py +54 -2
  79. sglang/srt/sampling/sampling_params.py +2 -0
  80. sglang/srt/server_args.py +30 -6
  81. sglang/srt/utils.py +35 -1
  82. sglang/test/test_block_fp8.py +2 -2
  83. sglang/test/test_deepep_utils.py +219 -0
  84. sglang/test/test_utils.py +3 -1
  85. sglang/version.py +1 -1
  86. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
  87. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
  88. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  89. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  90. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -173,6 +173,7 @@ class ModelRunner:
173
173
  "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
174
174
  "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
175
175
  "use_mla_backend": self.use_mla_backend,
176
+ "mm_attention_backend": server_args.mm_attention_backend,
176
177
  }
177
178
  )
178
179
 
@@ -278,9 +279,10 @@ class ModelRunner:
278
279
  server_args.attention_backend = "fa3"
279
280
  else:
280
281
  server_args.attention_backend = "triton"
281
- logger.info(
282
- f"Attention backend not set. Use {server_args.attention_backend} backend by default."
283
- )
282
+ if self.should_log:
283
+ logger.info(
284
+ f"Attention backend not set. Use {server_args.attention_backend} backend by default."
285
+ )
284
286
  elif self.use_mla_backend:
285
287
  if server_args.device != "cpu":
286
288
  if server_args.attention_backend in [
@@ -290,9 +292,10 @@ class ModelRunner:
290
292
  "flashmla",
291
293
  "cutlass_mla",
292
294
  ]:
293
- logger.info(
294
- f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
295
- )
295
+ if self.should_log:
296
+ logger.info(
297
+ f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
298
+ )
296
299
  else:
297
300
  raise ValueError(
298
301
  f"Invalid attention backend for MLA: {server_args.attention_backend}"
@@ -311,9 +314,10 @@ class ModelRunner:
311
314
  server_args.attention_backend = "triton"
312
315
 
313
316
  if server_args.enable_double_sparsity:
314
- logger.info(
315
- "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
316
- )
317
+ if self.should_log:
318
+ logger.info(
319
+ "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
320
+ )
317
321
  server_args.attention_backend = "triton"
318
322
  server_args.disable_cuda_graph = True
319
323
  if server_args.ds_heavy_channel_type is None:
@@ -324,23 +328,26 @@ class ModelRunner:
324
328
 
325
329
  if self.is_multimodal:
326
330
  self.mem_fraction_static *= 0.90
327
- logger.info(
328
- f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
329
- f"because this is a multimodal model."
330
- )
331
- logger.info(
332
- "Automatically turn off --chunked-prefill-size for multimodal model."
333
- )
331
+ if self.should_log:
332
+ logger.info(
333
+ f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
334
+ f"because this is a multimodal model."
335
+ )
336
+ logger.info(
337
+ "Automatically turn off --chunked-prefill-size for multimodal model."
338
+ )
334
339
  server_args.chunked_prefill_size = -1
335
340
 
336
341
  if not self.use_mla_backend:
337
342
  server_args.disable_chunked_prefix_cache = True
338
343
  elif self.page_size > 1:
339
- logger.info("Disable chunked prefix cache when page size > 1.")
344
+ if self.should_log:
345
+ logger.info("Disable chunked prefix cache when page size > 1.")
340
346
  server_args.disable_chunked_prefix_cache = True
341
347
 
342
348
  if not server_args.disable_chunked_prefix_cache:
343
- logger.info("Chunked prefix cache is turned on.")
349
+ if self.should_log:
350
+ logger.info("Chunked prefix cache is turned on.")
344
351
 
345
352
  def init_torch_distributed(self):
346
353
  logger.info("Init torch distributed begin.")
@@ -361,6 +368,8 @@ class ModelRunner:
361
368
  backend = "hccl"
362
369
  elif self.device == "cpu":
363
370
  backend = "gloo"
371
+ elif self.device == "npu":
372
+ backend = "hccl"
364
373
 
365
374
  before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
366
375
  if not self.server_args.enable_p2p_check:
@@ -431,9 +440,10 @@ class ModelRunner:
431
440
  torch.set_num_threads(1)
432
441
  if self.device == "cuda":
433
442
  if torch.cuda.get_device_capability()[0] < 8:
434
- logger.info(
435
- "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
436
- )
443
+ if self.should_log:
444
+ logger.info(
445
+ "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
446
+ )
437
447
  self.server_args.dtype = "float16"
438
448
  self.model_config.dtype = torch.float16
439
449
  if torch.cuda.get_device_capability()[1] < 5:
@@ -469,10 +479,11 @@ class ModelRunner:
469
479
  self.model.load_kv_cache_scales(
470
480
  self.server_args.quantization_param_path
471
481
  )
472
- logger.info(
473
- "Loaded KV cache scaling factors from %s",
474
- self.server_args.quantization_param_path,
475
- )
482
+ if self.should_log:
483
+ logger.info(
484
+ "Loaded KV cache scaling factors from %s",
485
+ self.server_args.quantization_param_path,
486
+ )
476
487
  else:
477
488
  raise RuntimeError(
478
489
  "Using FP8 KV cache and scaling factors provided but "
@@ -547,12 +558,7 @@ class ModelRunner:
547
558
  return iter
548
559
 
549
560
  def model_load_weights(model, iter):
550
- model.load_weights(iter)
551
- for _, module in self.model.named_modules():
552
- quant_method = getattr(module, "quant_method", None)
553
- if quant_method is not None:
554
- with device_loading_context(module, target_device):
555
- quant_method.process_weights_after_loading(module)
561
+ DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
556
562
  return model
557
563
 
558
564
  with set_default_torch_dtype(self.model_config.dtype):
@@ -1019,7 +1025,8 @@ class ModelRunner:
1019
1025
  )
1020
1026
 
1021
1027
  def apply_torch_tp(self):
1022
- logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1028
+ if self.should_log:
1029
+ logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1023
1030
  from sglang.srt.model_parallel import tensor_parallel
1024
1031
 
1025
1032
  device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
@@ -1138,7 +1145,9 @@ class ModelRunner:
1138
1145
  [self.sample(values, forward_batch) for values in logits_output],
1139
1146
  axis=-1,
1140
1147
  )
1141
-
1148
+ sampling_info = forward_batch.sampling_info
1149
+ if sampling_info.thinking_budgets is not None:
1150
+ sampling_info.apply_thinking_budgets(logits_output.next_token_logits)
1142
1151
  self._preprocess_logits(logits_output, forward_batch.sampling_info)
1143
1152
 
1144
1153
  # Sample the next tokens
@@ -1149,6 +1158,8 @@ class ModelRunner:
1149
1158
  forward_batch.top_logprobs_nums,
1150
1159
  forward_batch.token_ids_logprobs,
1151
1160
  )
1161
+ if sampling_info.thinking_budgets is not None:
1162
+ sampling_info.update_thinking_budgets(next_token_ids)
1152
1163
  return next_token_ids
1153
1164
 
1154
1165
  @property
@@ -374,20 +374,27 @@ class DefaultModelLoader(BaseModelLoader):
374
374
  self.load_config,
375
375
  )
376
376
 
377
- model.load_weights(self._get_all_weights(model_config, model))
377
+ self.load_weights_and_postprocess(
378
+ model, self._get_all_weights(model_config, model), target_device
379
+ )
378
380
 
379
- for _, module in model.named_modules():
380
- quant_method = getattr(module, "quant_method", None)
381
- if quant_method is not None:
382
- # When quant methods need to process weights after loading
383
- # (for repacking, quantizing, etc), they expect parameters
384
- # to be on the global target device. This scope is for the
385
- # case where cpu offloading is used, where we will move the
386
- # parameters onto device for processing and back off after.
387
- with device_loading_context(module, target_device):
388
- quant_method.process_weights_after_loading(module)
389
381
  return model.eval()
390
382
 
383
+ @staticmethod
384
+ def load_weights_and_postprocess(model, weights, target_device):
385
+ model.load_weights(weights)
386
+
387
+ for _, module in model.named_modules():
388
+ quant_method = getattr(module, "quant_method", None)
389
+ if quant_method is not None:
390
+ # When quant methods need to process weights after loading
391
+ # (for repacking, quantizing, etc), they expect parameters
392
+ # to be on the global target device. This scope is for the
393
+ # case where cpu offloading is used, where we will move the
394
+ # parameters onto device for processing and back off after.
395
+ with device_loading_context(module, target_device):
396
+ quant_method.process_weights_after_loading(module)
397
+
391
398
 
392
399
  class LayeredModelLoader(DefaultModelLoader):
393
400
  """Model loader that loads weights layer by layer so that one can quantize a
sglang/srt/models/clip.py CHANGED
@@ -151,20 +151,20 @@ class CLIPEncoderLayer(nn.Module):
151
151
  self.layer_norm1 = norm_layer(config.hidden_size)
152
152
  self.layer_norm2 = norm_layer(config.hidden_size)
153
153
  if attn_implementation == "sdpa":
154
- use_context_forward = False
154
+ qkv_backend = "sdpa"
155
155
  softmax_in_single_precision = False
156
156
  elif attn_implementation == "flash_attention_2":
157
+ qkv_backend = "triton_attn"
157
158
  softmax_in_single_precision = False
158
- use_context_forward = True
159
159
  elif attn_implementation == "eager":
160
+ qkv_backend = "sdpa"
160
161
  softmax_in_single_precision = True
161
- use_context_forward = False
162
162
  self.self_attn = VisionAttention(
163
163
  embed_dim=config.hidden_size,
164
164
  num_heads=config.num_attention_heads,
165
165
  projection_size=config.hidden_size,
166
166
  use_qkv_parallel=True,
167
- use_context_forward=use_context_forward,
167
+ qkv_backend=qkv_backend,
168
168
  softmax_in_single_precision=softmax_in_single_precision,
169
169
  flatten_batch=True,
170
170
  quant_config=quant_config,
@@ -532,7 +532,7 @@ class VisionTransformerBlock(nn.Module):
532
532
  num_heads=num_heads,
533
533
  projection_size=dim,
534
534
  use_qkv_parallel=True,
535
- use_context_forward=False,
535
+ qkv_backend="sdpa",
536
536
  softmax_in_single_precision=False,
537
537
  dropout=attn_drop,
538
538
  )
@@ -24,34 +24,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
24
  from sglang.srt.layers.layernorm import RMSNorm
25
25
  from sglang.srt.layers.linear import ReplicatedLinear
26
26
  from sglang.srt.layers.logits_processor import LogitsProcessor
27
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
28
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
29
27
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
30
- from sglang.srt.layers.quantization.fp8_utils import (
31
- block_quant_to_tensor_quant,
32
- normalize_e4m3fn_to_e4m3fnuz,
33
- )
34
- from sglang.srt.layers.quantization.int8_utils import (
35
- block_dequant as int8_block_dequant,
36
- )
37
28
  from sglang.srt.layers.vocab_parallel_embedding import (
38
29
  ParallelLMHead,
39
30
  VocabParallelEmbedding,
40
31
  )
41
32
  from sglang.srt.managers.schedule_batch import global_server_args_dict
42
33
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
- from sglang.srt.model_loader.weight_utils import default_weight_loader
44
34
  from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
45
- from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip
46
-
47
- _is_hip = is_hip()
48
- _is_cuda = is_cuda()
49
-
50
- if _is_cuda:
51
- from sgl_kernel import awq_dequantize
52
- else:
53
- from vllm._custom_ops import awq_dequantize
54
-
35
+ from sglang.srt.utils import BumpAllocator, add_prefix
55
36
 
56
37
  logger = logging.getLogger(__name__)
57
38
 
@@ -59,10 +59,11 @@ from sglang.srt.layers.moe.topk import select_experts
59
59
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
60
60
  from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
61
61
  from sglang.srt.layers.quantization.fp8_kernel import (
62
- per_tensor_quant_mla_deep_gemm_masked_fp8,
63
62
  per_tensor_quant_mla_fp8,
63
+ per_token_group_quant_mla_deep_gemm_masked_fp8,
64
64
  )
65
65
  from sglang.srt.layers.quantization.fp8_utils import (
66
+ block_quant_dequant,
66
67
  block_quant_to_tensor_quant,
67
68
  channel_quant_to_tensor_quant,
68
69
  normalize_e4m3fn_to_e4m3fnuz,
@@ -88,6 +89,7 @@ from sglang.srt.utils import (
88
89
  get_int_env_var,
89
90
  is_cuda,
90
91
  is_hip,
92
+ log_info_on_rank0,
91
93
  )
92
94
 
93
95
  _is_hip = is_hip()
@@ -356,6 +358,7 @@ class DeepseekV2MoE(nn.Module):
356
358
  topk_idx,
357
359
  topk_weights,
358
360
  reorder_topk_ids,
361
+ num_recv_tokens_per_expert,
359
362
  seg_indptr,
360
363
  masked_m,
361
364
  expected_m,
@@ -367,10 +370,13 @@ class DeepseekV2MoE(nn.Module):
367
370
  )
368
371
  final_hidden_states = self.experts(
369
372
  hidden_states=hidden_states,
373
+ topk_idx=topk_idx,
374
+ topk_weights=topk_weights,
370
375
  reorder_topk_ids=reorder_topk_ids,
371
376
  seg_indptr=seg_indptr,
372
377
  masked_m=masked_m,
373
378
  expected_m=expected_m,
379
+ num_recv_tokens_per_expert=num_recv_tokens_per_expert,
374
380
  forward_mode=forward_mode,
375
381
  )
376
382
  if self.ep_size > 1:
@@ -421,6 +427,7 @@ class DeepseekV2AttentionMLA(nn.Module):
421
427
  reduce_results: bool = True,
422
428
  layer_id: int = None,
423
429
  prefix: str = "",
430
+ alt_stream: Optional[torch.cuda.Stream] = None,
424
431
  ) -> None:
425
432
  super().__init__()
426
433
  self.layer_id = layer_id
@@ -543,6 +550,8 @@ class DeepseekV2AttentionMLA(nn.Module):
543
550
  prefix=add_prefix("attn_mha", prefix),
544
551
  )
545
552
 
553
+ self.alt_stream = alt_stream
554
+
546
555
  self.w_kc = None
547
556
  self.w_vc = None
548
557
  self.w_scale = None
@@ -706,20 +715,36 @@ class DeepseekV2AttentionMLA(nn.Module):
706
715
  q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
707
716
  [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
708
717
  )
709
- q = self.q_a_layernorm(q)
718
+ k_nope = latent_cache[..., : self.kv_lora_rank]
719
+
720
+ # overlap qk norm
721
+ if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
722
+ current_stream = torch.cuda.current_stream()
723
+ self.alt_stream.wait_stream(current_stream)
724
+ q = self.q_a_layernorm(q)
725
+ with torch.cuda.stream(self.alt_stream):
726
+ k_nope = self.kv_a_layernorm(k_nope)
727
+ current_stream.wait_stream(self.alt_stream)
728
+ else:
729
+ q = self.q_a_layernorm(q)
730
+ k_nope = self.kv_a_layernorm(k_nope)
731
+
732
+ k_nope = k_nope.unsqueeze(1)
710
733
  q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
711
734
  else:
712
735
  q = self.q_proj(hidden_states)[0].view(
713
736
  -1, self.num_local_heads, self.qk_head_dim
714
737
  )
715
738
  latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
739
+ k_nope = latent_cache[..., : self.kv_lora_rank]
740
+ k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
741
+
716
742
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
743
+ k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
717
744
 
718
745
  if self.use_deep_gemm_bmm:
719
746
  q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
720
- per_tensor_quant_mla_deep_gemm_masked_fp8(
721
- q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
722
- )
747
+ per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
723
748
  )
724
749
  q_nope_out = q_nope.new_empty(
725
750
  (self.num_local_heads, aligned_m, self.kv_lora_rank)
@@ -750,14 +775,9 @@ class DeepseekV2AttentionMLA(nn.Module):
750
775
  q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
751
776
 
752
777
  q_nope_out = q_nope_out.transpose(0, 1)
753
-
754
- k_nope = latent_cache[..., : self.kv_lora_rank]
755
- k_nope = self.kv_a_layernorm(k_nope.contiguous()).unsqueeze(1)
756
- k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
757
-
758
778
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
759
779
 
760
- if self.attention_backend == "fa3":
780
+ if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
761
781
  attn_output = self.attn_mqa(
762
782
  q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
763
783
  )
@@ -769,8 +789,8 @@ class DeepseekV2AttentionMLA(nn.Module):
769
789
 
770
790
  if self.use_deep_gemm_bmm:
771
791
  attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
772
- per_tensor_quant_mla_deep_gemm_masked_fp8(
773
- attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
792
+ per_token_group_quant_mla_deep_gemm_masked_fp8(
793
+ attn_output.transpose(0, 1)
774
794
  )
775
795
  )
776
796
  attn_bmm_output = attn_output.new_empty(
@@ -1104,6 +1124,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1104
1124
  quant_config: Optional[QuantizationConfig] = None,
1105
1125
  is_nextn: bool = False,
1106
1126
  prefix: str = "",
1127
+ alt_stream: Optional[torch.cuda.Stream] = None,
1107
1128
  ) -> None:
1108
1129
  super().__init__()
1109
1130
  self.hidden_size = config.hidden_size
@@ -1133,6 +1154,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1133
1154
  layer_id=layer_id,
1134
1155
  reduce_results=False,
1135
1156
  prefix=add_prefix("self_attn", prefix),
1157
+ alt_stream=alt_stream,
1136
1158
  )
1137
1159
 
1138
1160
  self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
@@ -1376,6 +1398,7 @@ class DeepseekV2Model(nn.Module):
1376
1398
  config.hidden_size,
1377
1399
  enable_tp=not global_server_args_dict["enable_dp_attention"],
1378
1400
  )
1401
+ self.alt_stream = torch.cuda.Stream()
1379
1402
  self.layers = nn.ModuleList(
1380
1403
  [
1381
1404
  DeepseekV2DecoderLayer(
@@ -1383,6 +1406,7 @@ class DeepseekV2Model(nn.Module):
1383
1406
  layer_id,
1384
1407
  quant_config=quant_config,
1385
1408
  prefix=add_prefix(f"layers.{layer_id}", prefix),
1409
+ alt_stream=self.alt_stream,
1386
1410
  )
1387
1411
  for layer_id in range(config.num_hidden_layers)
1388
1412
  ]
@@ -1467,8 +1491,9 @@ class DeepseekV2ForCausalLM(nn.Module):
1467
1491
  ):
1468
1492
  self.n_share_experts_fusion = 0
1469
1493
  global_server_args_dict["n_share_experts_fusion"] = 0
1470
- logger.info(
1471
- "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
1494
+ log_info_on_rank0(
1495
+ logger,
1496
+ "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
1472
1497
  )
1473
1498
  else:
1474
1499
  assert (
@@ -1483,8 +1508,9 @@ class DeepseekV2ForCausalLM(nn.Module):
1483
1508
  ):
1484
1509
  self.n_share_experts_fusion = self.tp_size
1485
1510
  global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1486
- logger.info(
1487
- "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
1511
+ log_info_on_rank0(
1512
+ logger,
1513
+ "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
1488
1514
  )
1489
1515
 
1490
1516
  def get_input_embeddings(self) -> nn.Embedding:
@@ -1564,13 +1590,22 @@ class DeepseekV2ForCausalLM(nn.Module):
1564
1590
 
1565
1591
  if (
1566
1592
  _is_cuda
1567
- and _ENABLE_JIT_DEEPGEMM
1568
1593
  and weight_block_size[0] == 128
1569
1594
  and weight_block_size[1] == 128
1570
1595
  and model_dtype == torch.bfloat16
1571
1596
  ):
1572
- block_scale = weight_scale
1573
- use_deep_gemm_bmm = True
1597
+ if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
1598
+ "SGL_USE_DEEPGEMM_BMM", "false"
1599
+ ):
1600
+ block_scale = weight_scale
1601
+ use_deep_gemm_bmm = True
1602
+ else:
1603
+ w = block_quant_dequant(
1604
+ weight,
1605
+ weight_scale,
1606
+ weight_block_size,
1607
+ model_dtype,
1608
+ )
1574
1609
  else:
1575
1610
  w, scale = block_quant_to_tensor_quant(
1576
1611
  weight, weight_scale, weight_block_size
@@ -281,7 +281,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
281
281
  pixel_values = torch.stack(
282
282
  flatten_nested_list([item.pixel_values for item in items]), dim=0
283
283
  )
284
- pixel_values = pixel_values.to("cuda")
284
+ pixel_values = pixel_values.to(device=self.vision_tower.device)
285
285
  pixel_values = pixel_values.to(dtype=self.language_model.dtype())
286
286
 
287
287
  vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
@@ -290,6 +290,9 @@ class InternLM2ForCausalLM(nn.Module):
290
290
  )
291
291
  self.logits_processor = LogitsProcessor(config)
292
292
 
293
+ def get_input_embeddings(self) -> nn.Embedding:
294
+ return self.model.tok_embeddings
295
+
293
296
  @torch.no_grad()
294
297
  def forward(
295
298
  self,