sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/_custom_ops.py +29 -1
  3. sglang/srt/configs/internvl.py +3 -0
  4. sglang/srt/configs/model_config.py +5 -1
  5. sglang/srt/constrained/base_grammar_backend.py +10 -2
  6. sglang/srt/constrained/xgrammar_backend.py +7 -5
  7. sglang/srt/conversation.py +17 -2
  8. sglang/srt/debug_utils/__init__.py +0 -0
  9. sglang/srt/debug_utils/dump_comparator.py +131 -0
  10. sglang/srt/debug_utils/dumper.py +108 -0
  11. sglang/srt/debug_utils/text_comparator.py +172 -0
  12. sglang/srt/disaggregation/common/conn.py +34 -6
  13. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  14. sglang/srt/disaggregation/mini_lb.py +3 -2
  15. sglang/srt/disaggregation/mooncake/conn.py +65 -20
  16. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  17. sglang/srt/disaggregation/nixl/conn.py +17 -13
  18. sglang/srt/disaggregation/prefill.py +13 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  21. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  23. sglang/srt/distributed/parallel_state.py +70 -15
  24. sglang/srt/entrypoints/engine.py +5 -9
  25. sglang/srt/entrypoints/http_server.py +20 -32
  26. sglang/srt/entrypoints/openai/protocol.py +3 -3
  27. sglang/srt/entrypoints/openai/serving_chat.py +148 -72
  28. sglang/srt/function_call/base_format_detector.py +74 -12
  29. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  30. sglang/srt/function_call/ebnf_composer.py +105 -66
  31. sglang/srt/function_call/function_call_parser.py +6 -4
  32. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  33. sglang/srt/function_call/kimik2_detector.py +41 -16
  34. sglang/srt/function_call/llama32_detector.py +6 -3
  35. sglang/srt/function_call/mistral_detector.py +11 -3
  36. sglang/srt/function_call/pythonic_detector.py +16 -14
  37. sglang/srt/function_call/qwen25_detector.py +12 -3
  38. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
  39. sglang/srt/layers/activation.py +11 -3
  40. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  41. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  42. sglang/srt/layers/attention/vision.py +56 -8
  43. sglang/srt/layers/communicator.py +12 -12
  44. sglang/srt/layers/dp_attention.py +72 -24
  45. sglang/srt/layers/layernorm.py +26 -1
  46. sglang/srt/layers/logits_processor.py +46 -25
  47. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  48. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  51. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  52. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  53. sglang/srt/layers/moe/topk.py +88 -34
  54. sglang/srt/layers/multimodal.py +11 -8
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  56. sglang/srt/layers/quantization/fp8.py +25 -247
  57. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  58. sglang/srt/layers/quantization/modelopt_quant.py +33 -14
  59. sglang/srt/layers/quantization/unquant.py +24 -76
  60. sglang/srt/layers/quantization/utils.py +0 -9
  61. sglang/srt/layers/quantization/w4afp8.py +68 -17
  62. sglang/srt/layers/radix_attention.py +5 -3
  63. sglang/srt/lora/lora_manager.py +133 -169
  64. sglang/srt/lora/lora_registry.py +188 -0
  65. sglang/srt/lora/mem_pool.py +2 -2
  66. sglang/srt/managers/cache_controller.py +62 -13
  67. sglang/srt/managers/io_struct.py +19 -1
  68. sglang/srt/managers/mm_utils.py +154 -35
  69. sglang/srt/managers/multimodal_processor.py +3 -14
  70. sglang/srt/managers/schedule_batch.py +27 -11
  71. sglang/srt/managers/scheduler.py +48 -26
  72. sglang/srt/managers/tokenizer_manager.py +62 -28
  73. sglang/srt/managers/tp_worker.py +5 -4
  74. sglang/srt/mem_cache/allocator.py +67 -7
  75. sglang/srt/mem_cache/hicache_storage.py +17 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +35 -18
  77. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  78. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  79. sglang/srt/model_executor/forward_batch_info.py +201 -29
  80. sglang/srt/model_executor/model_runner.py +109 -37
  81. sglang/srt/models/deepseek_v2.py +63 -30
  82. sglang/srt/models/glm4_moe.py +1035 -0
  83. sglang/srt/models/glm4_moe_nextn.py +167 -0
  84. sglang/srt/models/interns1.py +328 -0
  85. sglang/srt/models/internvl.py +143 -47
  86. sglang/srt/models/llava.py +9 -5
  87. sglang/srt/models/minicpmo.py +4 -1
  88. sglang/srt/models/mllama4.py +10 -3
  89. sglang/srt/models/qwen2_moe.py +2 -6
  90. sglang/srt/models/qwen3_moe.py +6 -8
  91. sglang/srt/multimodal/processors/base_processor.py +20 -6
  92. sglang/srt/multimodal/processors/clip.py +2 -2
  93. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  94. sglang/srt/multimodal/processors/gemma3.py +2 -2
  95. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  96. sglang/srt/multimodal/processors/internvl.py +21 -8
  97. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  98. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  99. sglang/srt/multimodal/processors/llava.py +4 -4
  100. sglang/srt/multimodal/processors/minicpm.py +2 -3
  101. sglang/srt/multimodal/processors/mlama.py +2 -2
  102. sglang/srt/multimodal/processors/mllama4.py +18 -111
  103. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  104. sglang/srt/multimodal/processors/pixtral.py +2 -2
  105. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  106. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  107. sglang/srt/multimodal/processors/vila.py +3 -1
  108. sglang/srt/reasoning_parser.py +48 -5
  109. sglang/srt/sampling/sampling_batch_info.py +6 -5
  110. sglang/srt/server_args.py +132 -60
  111. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  112. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  113. sglang/srt/speculative/eagle_utils.py +51 -23
  114. sglang/srt/speculative/eagle_worker.py +59 -44
  115. sglang/srt/two_batch_overlap.py +9 -5
  116. sglang/srt/utils.py +113 -69
  117. sglang/srt/weight_sync/utils.py +119 -0
  118. sglang/test/runners.py +4 -0
  119. sglang/test/test_activation.py +50 -1
  120. sglang/test/test_utils.py +65 -5
  121. sglang/utils.py +19 -0
  122. sglang/version.py +1 -1
  123. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
  124. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
  125. sglang/srt/debug_utils.py +0 -74
  126. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  127. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  128. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -68,6 +68,7 @@ from sglang.srt.layers.sampler import Sampler
68
68
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
69
69
  from sglang.srt.layers.utils import is_sm100_supported
70
70
  from sglang.srt.lora.lora_manager import LoRAManager
71
+ from sglang.srt.lora.lora_registry import LoRARef
71
72
  from sglang.srt.managers.schedule_batch import (
72
73
  GLOBAL_SERVER_ARGS_KEYS,
73
74
  global_server_args_dict,
@@ -108,7 +109,6 @@ from sglang.srt.utils import (
108
109
  get_bool_env_var,
109
110
  get_cpu_ids_by_node,
110
111
  init_custom_process_group,
111
- is_cuda,
112
112
  is_fa3_default_architecture,
113
113
  is_flashinfer_available,
114
114
  is_hip,
@@ -275,6 +275,7 @@ class ModelRunner:
275
275
  self.sampler = Sampler()
276
276
  self.load_model()
277
277
 
278
+ # Check if the model is using hybrid SWA
278
279
  if (
279
280
  not self.server_args.disable_hybrid_swa_memory
280
281
  and self.sliding_window_size is not None
@@ -377,6 +378,7 @@ class ModelRunner:
377
378
  is_hopper_with_cuda_12_3()
378
379
  and is_no_spec_infer_or_topk_one(server_args)
379
380
  and is_fa3_default_architecture(self.model_config.hf_config)
381
+ and (not server_args.enable_hierarchical_cache)
380
382
  ):
381
383
  server_args.attention_backend = "fa3"
382
384
  elif _is_hip:
@@ -389,7 +391,9 @@ class ModelRunner:
389
391
  )
390
392
  else:
391
393
  # MLA architecture
392
- if is_hopper_with_cuda_12_3():
394
+ if is_hopper_with_cuda_12_3() and (
395
+ not server_args.enable_hierarchical_cache
396
+ ):
393
397
  server_args.attention_backend = "fa3"
394
398
  elif is_sm100_supported():
395
399
  server_args.attention_backend = "flashinfer"
@@ -890,44 +894,38 @@ class ModelRunner:
890
894
  tp_rank=self.tp_rank,
891
895
  max_lora_rank=self.server_args.max_lora_rank,
892
896
  target_modules=self.server_args.lora_target_modules,
897
+ lora_paths=self.server_args.lora_paths,
893
898
  )
894
- result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths or {})
895
- if result.success:
896
- logger.info(
897
- f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
898
- )
899
- else:
900
- raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")
901
899
 
902
- def load_lora_adapter(self, lora_name: str, lora_path: str):
900
+ def load_lora_adapter(self, lora_ref: LoRARef):
903
901
  """Load a new lora adapter from disk or huggingface."""
904
902
 
905
903
  logger.info(
906
- f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. "
904
+ f"LoRA adapter loading starts: {lora_ref}. "
907
905
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
908
906
  )
909
907
 
910
- result = self.lora_manager.load_lora_adapter(lora_name, lora_path)
908
+ result = self.lora_manager.load_lora_adapter(lora_ref)
911
909
 
912
910
  logger.info(
913
- f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. "
911
+ f"LoRA adapter loading completes: {lora_ref}. "
914
912
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
915
913
  )
916
914
 
917
915
  return result
918
916
 
919
- def unload_lora_adapter(self, lora_name: str):
917
+ def unload_lora_adapter(self, lora_ref: LoRARef):
920
918
  """Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
921
919
 
922
920
  logger.info(
923
- f"LoRA adapter unloading starts: name={lora_name}. "
921
+ f"LoRA adapter unloading starts: {lora_ref}. "
924
922
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
925
923
  )
926
924
 
927
- result = self.lora_manager.unload_lora_adapter(lora_name)
925
+ result = self.lora_manager.unload_lora_adapter(lora_ref)
928
926
 
929
927
  logger.info(
930
- f"LoRA adapter unloading completes: name={lora_name}. "
928
+ f"LoRA adapter unloading completes: {lora_ref}. "
931
929
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
932
930
  )
933
931
 
@@ -1010,8 +1008,11 @@ class ModelRunner:
1010
1008
  try:
1011
1009
  layers = self.model.language_model.model.layers
1012
1010
  except:
1013
- self.is_hybrid = False
1014
- return
1011
+ try:
1012
+ layers = self.model.language_model.layers
1013
+ except:
1014
+ self.is_hybrid = False
1015
+ return
1015
1016
 
1016
1017
  for layer in layers:
1017
1018
  if (
@@ -1307,9 +1308,58 @@ class ModelRunner:
1307
1308
  else:
1308
1309
  self.attn_backend = self._get_attention_backend()
1309
1310
 
1310
- # TODO unify with 6338
1311
1311
  def _get_attention_backend(self):
1312
- if self.server_args.attention_backend == "flashinfer":
1312
+ """Init attention kernel backend."""
1313
+ self.decode_attention_backend_str = (
1314
+ self.server_args.decode_attention_backend
1315
+ if self.server_args.decode_attention_backend
1316
+ else self.server_args.attention_backend
1317
+ )
1318
+ self.prefill_attention_backend_str = (
1319
+ self.server_args.prefill_attention_backend
1320
+ if self.server_args.prefill_attention_backend
1321
+ else self.server_args.attention_backend
1322
+ )
1323
+ if self.decode_attention_backend_str != self.prefill_attention_backend_str:
1324
+ assert (
1325
+ self.server_args.speculative_algorithm is None
1326
+ ), "Currently HybridAttentionBackend does not support speculative decoding."
1327
+ from sglang.srt.layers.attention.hybrid_attn_backend import (
1328
+ HybridAttnBackend,
1329
+ )
1330
+
1331
+ attn_backend = HybridAttnBackend(
1332
+ decode_backend=self._get_attention_backend_from_str(
1333
+ self.decode_attention_backend_str
1334
+ ),
1335
+ prefill_backend=self._get_attention_backend_from_str(
1336
+ self.prefill_attention_backend_str
1337
+ ),
1338
+ )
1339
+ logger.info(
1340
+ f"Using hybrid attention backend for decode and prefill: "
1341
+ f"decode_backend={self.decode_attention_backend_str}, "
1342
+ f"prefill_backend={self.prefill_attention_backend_str}."
1343
+ )
1344
+ logger.warning(
1345
+ f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
1346
+ f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
1347
+ )
1348
+ else:
1349
+ attn_backend = self._get_attention_backend_from_str(
1350
+ self.server_args.attention_backend
1351
+ )
1352
+
1353
+ global_server_args_dict.update(
1354
+ {
1355
+ "decode_attention_backend": self.decode_attention_backend_str,
1356
+ "prefill_attention_backend": self.prefill_attention_backend_str,
1357
+ }
1358
+ )
1359
+ return attn_backend
1360
+
1361
+ def _get_attention_backend_from_str(self, backend_str: str):
1362
+ if backend_str == "flashinfer":
1313
1363
  if not self.use_mla_backend:
1314
1364
  from sglang.srt.layers.attention.flashinfer_backend import (
1315
1365
  FlashInferAttnBackend,
@@ -1317,7 +1367,11 @@ class ModelRunner:
1317
1367
 
1318
1368
  # Init streams
1319
1369
  if self.server_args.speculative_algorithm == "EAGLE":
1320
- self.plan_stream_for_flashinfer = torch.cuda.Stream()
1370
+ if (
1371
+ not hasattr(self, "plan_stream_for_flashinfer")
1372
+ or not self.plan_stream_for_flashinfer
1373
+ ):
1374
+ self.plan_stream_for_flashinfer = torch.cuda.Stream()
1321
1375
  return FlashInferAttnBackend(self)
1322
1376
  else:
1323
1377
  from sglang.srt.layers.attention.flashinfer_mla_backend import (
@@ -1325,15 +1379,15 @@ class ModelRunner:
1325
1379
  )
1326
1380
 
1327
1381
  return FlashInferMLAAttnBackend(self)
1328
- elif self.server_args.attention_backend == "aiter":
1382
+ elif backend_str == "aiter":
1329
1383
  from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
1330
1384
 
1331
1385
  return AiterAttnBackend(self)
1332
- elif self.server_args.attention_backend == "ascend":
1386
+ elif backend_str == "ascend":
1333
1387
  from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
1334
1388
 
1335
1389
  return AscendAttnBackend(self)
1336
- elif self.server_args.attention_backend == "triton":
1390
+ elif backend_str == "triton":
1337
1391
  assert not self.model_config.is_encoder_decoder, (
1338
1392
  "Cross attention is not supported in the triton attention backend. "
1339
1393
  "Please use `--attention-backend flashinfer`."
@@ -1348,17 +1402,17 @@ class ModelRunner:
1348
1402
  from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
1349
1403
 
1350
1404
  return TritonAttnBackend(self)
1351
- elif self.server_args.attention_backend == "torch_native":
1405
+ elif backend_str == "torch_native":
1352
1406
  from sglang.srt.layers.attention.torch_native_backend import (
1353
1407
  TorchNativeAttnBackend,
1354
1408
  )
1355
1409
 
1356
1410
  return TorchNativeAttnBackend(self)
1357
- elif self.server_args.attention_backend == "flashmla":
1411
+ elif backend_str == "flashmla":
1358
1412
  from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
1359
1413
 
1360
1414
  return FlashMLABackend(self)
1361
- elif self.server_args.attention_backend == "fa3":
1415
+ elif backend_str == "fa3":
1362
1416
  assert (
1363
1417
  torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
1364
1418
  ) or torch.cuda.get_device_capability()[0] == 9, (
@@ -1370,7 +1424,7 @@ class ModelRunner:
1370
1424
  )
1371
1425
 
1372
1426
  return FlashAttentionBackend(self)
1373
- elif self.server_args.attention_backend == "cutlass_mla":
1427
+ elif backend_str == "cutlass_mla":
1374
1428
  from sglang.srt.layers.attention.cutlass_mla_backend import (
1375
1429
  CutlassMLABackend,
1376
1430
  )
@@ -1384,9 +1438,7 @@ class ModelRunner:
1384
1438
  logger.info(f"Intel AMX attention backend is enabled.")
1385
1439
  return IntelAMXAttnBackend(self)
1386
1440
  else:
1387
- raise ValueError(
1388
- f"Invalid attention backend: {self.server_args.attention_backend}"
1389
- )
1441
+ raise ValueError(f"Invalid attention backend: {backend_str}")
1390
1442
 
1391
1443
  def init_double_sparsity_channel_config(self, selected_channel):
1392
1444
  selected_channel = "." + selected_channel + "_proj"
@@ -1462,15 +1514,22 @@ class ModelRunner:
1462
1514
  tensor_parallel(self.model, device_mesh)
1463
1515
 
1464
1516
  def forward_decode(
1465
- self, forward_batch: ForwardBatch, pp_proxy_tensors=None
1517
+ self,
1518
+ forward_batch: ForwardBatch,
1519
+ skip_attn_backend_init: bool = False,
1520
+ pp_proxy_tensors=None,
1466
1521
  ) -> LogitsProcessorOutput:
1467
- self.attn_backend.init_forward_metadata(forward_batch)
1522
+ if not skip_attn_backend_init:
1523
+ self.attn_backend.init_forward_metadata(forward_batch)
1468
1524
  # FIXME: add pp_proxy_tensors arg to all models
1469
1525
  kwargs = {}
1470
1526
  if self.support_pp:
1471
1527
  kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1472
1528
  return self.model.forward(
1473
- forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
1529
+ forward_batch.input_ids,
1530
+ forward_batch.positions,
1531
+ forward_batch,
1532
+ **kwargs,
1474
1533
  )
1475
1534
 
1476
1535
  def forward_extend(
@@ -1576,8 +1635,18 @@ class ModelRunner:
1576
1635
  skip_attn_backend_init=skip_attn_backend_init,
1577
1636
  pp_proxy_tensors=pp_proxy_tensors,
1578
1637
  )
1579
- elif forward_batch.forward_mode.is_decode():
1580
- ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1638
+ return ret, can_run_cuda_graph
1639
+
1640
+ # For MLP sync
1641
+ if forward_batch.global_num_tokens_cpu is not None:
1642
+ forward_batch.prepare_mlp_sync_batch(self)
1643
+
1644
+ if forward_batch.forward_mode.is_decode():
1645
+ ret = self.forward_decode(
1646
+ forward_batch,
1647
+ skip_attn_backend_init=skip_attn_backend_init,
1648
+ pp_proxy_tensors=pp_proxy_tensors,
1649
+ )
1581
1650
  elif forward_batch.forward_mode.is_extend():
1582
1651
  ret = self.forward_extend(
1583
1652
  forward_batch,
@@ -1595,6 +1664,9 @@ class ModelRunner:
1595
1664
  else:
1596
1665
  raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1597
1666
 
1667
+ if forward_batch.global_num_tokens_cpu is not None:
1668
+ forward_batch.post_forward_mlp_sync_batch(ret)
1669
+
1598
1670
  return ret, can_run_cuda_graph
1599
1671
 
1600
1672
  def _preprocess_logits(
@@ -56,7 +56,11 @@ from sglang.srt.layers.linear import (
56
56
  RowParallelLinear,
57
57
  )
58
58
  from sglang.srt.layers.logits_processor import LogitsProcessor
59
- from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
59
+ from sglang.srt.layers.moe.ep_moe.layer import (
60
+ DeepEPMoE,
61
+ get_moe_impl_class,
62
+ use_flashinfer_trtllm_moe,
63
+ )
60
64
  from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
61
65
  from sglang.srt.layers.moe.topk import TopK
62
66
  from sglang.srt.layers.quantization import deep_gemm_wrapper
@@ -302,15 +306,19 @@ class DeepseekV2MoE(nn.Module):
302
306
  config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
303
307
  )
304
308
 
305
- self.topk = TopK(
306
- top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
307
- renormalize=config.norm_topk_prob,
308
- use_grouped_topk=True,
309
- num_expert_group=config.n_group,
310
- num_fused_shared_experts=self.num_fused_shared_experts,
311
- topk_group=config.topk_group,
312
- correction_bias=self.gate.e_score_correction_bias,
313
- routed_scaling_factor=self.routed_scaling_factor,
309
+ self.topk = (
310
+ TopK(
311
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
312
+ renormalize=config.norm_topk_prob,
313
+ use_grouped_topk=True,
314
+ num_expert_group=config.n_group,
315
+ num_fused_shared_experts=self.num_fused_shared_experts,
316
+ topk_group=config.topk_group,
317
+ correction_bias=self.gate.e_score_correction_bias,
318
+ routed_scaling_factor=self.routed_scaling_factor,
319
+ )
320
+ if not use_flashinfer_trtllm_moe
321
+ else None
314
322
  )
315
323
 
316
324
  self.experts = get_moe_impl_class()(
@@ -332,10 +340,22 @@ class DeepseekV2MoE(nn.Module):
332
340
  # Additional args for FusedMoE
333
341
  **(
334
342
  dict(
335
- enable_flashinfer_moe=True,
343
+ enable_flashinfer_cutlass_moe=True,
336
344
  enable_ep_moe=global_server_args_dict["enable_ep_moe"],
337
345
  )
338
- if global_server_args_dict["enable_flashinfer_moe"]
346
+ if global_server_args_dict["enable_flashinfer_cutlass_moe"]
347
+ else {}
348
+ ),
349
+ **(
350
+ dict(
351
+ renormalize=config.norm_topk_prob,
352
+ use_grouped_topk=True,
353
+ num_expert_group=config.n_group,
354
+ num_fused_shared_experts=self.num_fused_shared_experts,
355
+ topk_group=config.topk_group,
356
+ correction_bias=self.gate.e_score_correction_bias,
357
+ )
358
+ if use_flashinfer_trtllm_moe
339
359
  else {}
340
360
  ),
341
361
  )
@@ -455,10 +475,12 @@ class DeepseekV2MoE(nn.Module):
455
475
  with torch.cuda.stream(self.alt_stream):
456
476
  # router_logits: (num_tokens, n_experts)
457
477
  router_logits = self.gate(hidden_states)
458
- topk_output = self.topk(hidden_states, router_logits)
459
- final_hidden_states = self.experts(
460
- hidden_states=hidden_states, topk_output=topk_output
461
- )
478
+ kwargs = {"hidden_states": hidden_states}
479
+ if self.topk is not None:
480
+ kwargs["topk_output"] = self.topk(hidden_states, router_logits)
481
+ else:
482
+ kwargs["router_logits"] = router_logits
483
+ final_hidden_states = self.experts(**kwargs)
462
484
  if not _is_cuda:
463
485
  final_hidden_states *= self.routed_scaling_factor
464
486
  current_stream.wait_stream(self.alt_stream)
@@ -478,10 +500,12 @@ class DeepseekV2MoE(nn.Module):
478
500
  shared_output = self._forward_shared_experts(hidden_states)
479
501
  # router_logits: (num_tokens, n_experts)
480
502
  router_logits = self.gate(hidden_states)
481
- topk_output = self.topk(hidden_states, router_logits)
482
- final_hidden_states = self.experts(
483
- hidden_states=hidden_states, topk_output=topk_output
484
- )
503
+ kwargs = {"hidden_states": hidden_states}
504
+ if self.topk is not None:
505
+ kwargs["topk_output"] = self.topk(hidden_states, router_logits)
506
+ else:
507
+ kwargs["router_logits"] = router_logits
508
+ final_hidden_states = self.experts(**kwargs)
485
509
  if not _is_cuda and not _use_aiter:
486
510
  # fused in biased_grouped_topk so we can skip here
487
511
  final_hidden_states *= self.routed_scaling_factor
@@ -550,9 +574,8 @@ class DeepseekV2MoE(nn.Module):
550
574
  def forward_deepep(
551
575
  self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
552
576
  ) -> torch.Tensor:
553
- forward_mode = forward_batch.forward_mode
554
577
  shared_output = None
555
- if is_non_idle_and_non_empty(forward_mode, hidden_states):
578
+ if hidden_states.shape[0] > 0:
556
579
  # router_logits: (num_tokens, n_experts)
557
580
  router_logits = self.gate(hidden_states)
558
581
  shared_output = self._forward_shared_experts(hidden_states)
@@ -902,7 +925,10 @@ class DeepseekV2AttentionMLA(nn.Module):
902
925
  self.disable_chunked_prefix_cache = global_server_args_dict[
903
926
  "disable_chunked_prefix_cache"
904
927
  ]
905
- self.attention_backend = global_server_args_dict["attention_backend"]
928
+
929
+ self.current_attention_backend = (
930
+ None # Attention backend used by current forward batch
931
+ )
906
932
  self.rocm_fused_decode_mla = get_bool_env_var(
907
933
  "SGLANG_ROCM_FUSED_DECODE_MLA", "false"
908
934
  )
@@ -986,9 +1012,16 @@ class DeepseekV2AttentionMLA(nn.Module):
986
1012
  else:
987
1013
  return AttnForwardMethod.MLA
988
1014
 
989
- if self.attention_backend == "ascend":
1015
+ # Determine attention backend used by current forward batch
1016
+ if forward_batch.forward_mode.is_decode_or_idle():
1017
+ attention_backend = global_server_args_dict["decode_attention_backend"]
1018
+ else:
1019
+ attention_backend = global_server_args_dict["prefill_attention_backend"]
1020
+ self.current_attention_backend = attention_backend
1021
+
1022
+ if attention_backend == "ascend":
990
1023
  return AttnForwardMethod.MLA
991
- elif self.attention_backend == "flashinfer":
1024
+ elif attention_backend == "flashinfer":
992
1025
  # Flashinfer MLA: Do not absorb when enabling ragged prefill
993
1026
  if (
994
1027
  not self.flashinfer_mla_disable_ragged
@@ -1000,7 +1033,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1000
1033
  return AttnForwardMethod.MHA
1001
1034
  else:
1002
1035
  return _dispatch_mla_subtype()
1003
- elif self.attention_backend == "fa3":
1036
+ elif attention_backend == "fa3":
1004
1037
  # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
1005
1038
  if forward_batch.extend_prefix_lens_cpu is not None:
1006
1039
  sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
@@ -1017,7 +1050,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1017
1050
  return AttnForwardMethod.MHA_CHUNKED_KV
1018
1051
  else:
1019
1052
  return _dispatch_mla_subtype()
1020
- elif self.attention_backend == "aiter":
1053
+ elif attention_backend == "aiter":
1021
1054
  if (
1022
1055
  forward_batch.forward_mode.is_extend()
1023
1056
  and not forward_batch.forward_mode.is_target_verify()
@@ -1265,9 +1298,9 @@ class DeepseekV2AttentionMLA(nn.Module):
1265
1298
  self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1266
1299
  ):
1267
1300
  if (
1268
- self.attention_backend == "fa3"
1269
- or self.attention_backend == "flashinfer"
1270
- or self.attention_backend == "cutlass_mla"
1301
+ self.current_attention_backend == "fa3"
1302
+ or self.current_attention_backend == "flashinfer"
1303
+ or self.current_attention_backend == "cutlass_mla"
1271
1304
  ):
1272
1305
  attn_output = self.attn_mqa(
1273
1306
  q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe