sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +5 -1
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +17 -2
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +65 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +5 -9
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +148 -72
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +105 -66
- sglang/srt/function_call/function_call_parser.py +6 -4
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +46 -25
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +88 -34
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +33 -14
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +188 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +62 -13
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +27 -11
- sglang/srt/managers/scheduler.py +48 -26
- sglang/srt/managers/tokenizer_manager.py +62 -28
- sglang/srt/managers/tp_worker.py +5 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +35 -18
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +109 -37
- sglang/srt/models/deepseek_v2.py +63 -30
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +2 -6
- sglang/srt/models/qwen3_moe.py +6 -8
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +48 -5
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +132 -60
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +113 -69
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_activation.py +50 -1
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -68,6 +68,7 @@ from sglang.srt.layers.sampler import Sampler
|
|
68
68
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
69
69
|
from sglang.srt.layers.utils import is_sm100_supported
|
70
70
|
from sglang.srt.lora.lora_manager import LoRAManager
|
71
|
+
from sglang.srt.lora.lora_registry import LoRARef
|
71
72
|
from sglang.srt.managers.schedule_batch import (
|
72
73
|
GLOBAL_SERVER_ARGS_KEYS,
|
73
74
|
global_server_args_dict,
|
@@ -108,7 +109,6 @@ from sglang.srt.utils import (
|
|
108
109
|
get_bool_env_var,
|
109
110
|
get_cpu_ids_by_node,
|
110
111
|
init_custom_process_group,
|
111
|
-
is_cuda,
|
112
112
|
is_fa3_default_architecture,
|
113
113
|
is_flashinfer_available,
|
114
114
|
is_hip,
|
@@ -275,6 +275,7 @@ class ModelRunner:
|
|
275
275
|
self.sampler = Sampler()
|
276
276
|
self.load_model()
|
277
277
|
|
278
|
+
# Check if the model is using hybrid SWA
|
278
279
|
if (
|
279
280
|
not self.server_args.disable_hybrid_swa_memory
|
280
281
|
and self.sliding_window_size is not None
|
@@ -377,6 +378,7 @@ class ModelRunner:
|
|
377
378
|
is_hopper_with_cuda_12_3()
|
378
379
|
and is_no_spec_infer_or_topk_one(server_args)
|
379
380
|
and is_fa3_default_architecture(self.model_config.hf_config)
|
381
|
+
and (not server_args.enable_hierarchical_cache)
|
380
382
|
):
|
381
383
|
server_args.attention_backend = "fa3"
|
382
384
|
elif _is_hip:
|
@@ -389,7 +391,9 @@ class ModelRunner:
|
|
389
391
|
)
|
390
392
|
else:
|
391
393
|
# MLA architecture
|
392
|
-
if is_hopper_with_cuda_12_3()
|
394
|
+
if is_hopper_with_cuda_12_3() and (
|
395
|
+
not server_args.enable_hierarchical_cache
|
396
|
+
):
|
393
397
|
server_args.attention_backend = "fa3"
|
394
398
|
elif is_sm100_supported():
|
395
399
|
server_args.attention_backend = "flashinfer"
|
@@ -890,44 +894,38 @@ class ModelRunner:
|
|
890
894
|
tp_rank=self.tp_rank,
|
891
895
|
max_lora_rank=self.server_args.max_lora_rank,
|
892
896
|
target_modules=self.server_args.lora_target_modules,
|
897
|
+
lora_paths=self.server_args.lora_paths,
|
893
898
|
)
|
894
|
-
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths or {})
|
895
|
-
if result.success:
|
896
|
-
logger.info(
|
897
|
-
f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
|
898
|
-
)
|
899
|
-
else:
|
900
|
-
raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")
|
901
899
|
|
902
|
-
def load_lora_adapter(self,
|
900
|
+
def load_lora_adapter(self, lora_ref: LoRARef):
|
903
901
|
"""Load a new lora adapter from disk or huggingface."""
|
904
902
|
|
905
903
|
logger.info(
|
906
|
-
f"LoRA adapter loading starts:
|
904
|
+
f"LoRA adapter loading starts: {lora_ref}. "
|
907
905
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
908
906
|
)
|
909
907
|
|
910
|
-
result = self.lora_manager.load_lora_adapter(
|
908
|
+
result = self.lora_manager.load_lora_adapter(lora_ref)
|
911
909
|
|
912
910
|
logger.info(
|
913
|
-
f"LoRA adapter loading completes:
|
911
|
+
f"LoRA adapter loading completes: {lora_ref}. "
|
914
912
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
915
913
|
)
|
916
914
|
|
917
915
|
return result
|
918
916
|
|
919
|
-
def unload_lora_adapter(self,
|
917
|
+
def unload_lora_adapter(self, lora_ref: LoRARef):
|
920
918
|
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
|
921
919
|
|
922
920
|
logger.info(
|
923
|
-
f"LoRA adapter unloading starts:
|
921
|
+
f"LoRA adapter unloading starts: {lora_ref}. "
|
924
922
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
925
923
|
)
|
926
924
|
|
927
|
-
result = self.lora_manager.unload_lora_adapter(
|
925
|
+
result = self.lora_manager.unload_lora_adapter(lora_ref)
|
928
926
|
|
929
927
|
logger.info(
|
930
|
-
f"LoRA adapter unloading completes:
|
928
|
+
f"LoRA adapter unloading completes: {lora_ref}. "
|
931
929
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
932
930
|
)
|
933
931
|
|
@@ -1010,8 +1008,11 @@ class ModelRunner:
|
|
1010
1008
|
try:
|
1011
1009
|
layers = self.model.language_model.model.layers
|
1012
1010
|
except:
|
1013
|
-
|
1014
|
-
|
1011
|
+
try:
|
1012
|
+
layers = self.model.language_model.layers
|
1013
|
+
except:
|
1014
|
+
self.is_hybrid = False
|
1015
|
+
return
|
1015
1016
|
|
1016
1017
|
for layer in layers:
|
1017
1018
|
if (
|
@@ -1307,9 +1308,58 @@ class ModelRunner:
|
|
1307
1308
|
else:
|
1308
1309
|
self.attn_backend = self._get_attention_backend()
|
1309
1310
|
|
1310
|
-
# TODO unify with 6338
|
1311
1311
|
def _get_attention_backend(self):
|
1312
|
-
|
1312
|
+
"""Init attention kernel backend."""
|
1313
|
+
self.decode_attention_backend_str = (
|
1314
|
+
self.server_args.decode_attention_backend
|
1315
|
+
if self.server_args.decode_attention_backend
|
1316
|
+
else self.server_args.attention_backend
|
1317
|
+
)
|
1318
|
+
self.prefill_attention_backend_str = (
|
1319
|
+
self.server_args.prefill_attention_backend
|
1320
|
+
if self.server_args.prefill_attention_backend
|
1321
|
+
else self.server_args.attention_backend
|
1322
|
+
)
|
1323
|
+
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
1324
|
+
assert (
|
1325
|
+
self.server_args.speculative_algorithm is None
|
1326
|
+
), "Currently HybridAttentionBackend does not support speculative decoding."
|
1327
|
+
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
1328
|
+
HybridAttnBackend,
|
1329
|
+
)
|
1330
|
+
|
1331
|
+
attn_backend = HybridAttnBackend(
|
1332
|
+
decode_backend=self._get_attention_backend_from_str(
|
1333
|
+
self.decode_attention_backend_str
|
1334
|
+
),
|
1335
|
+
prefill_backend=self._get_attention_backend_from_str(
|
1336
|
+
self.prefill_attention_backend_str
|
1337
|
+
),
|
1338
|
+
)
|
1339
|
+
logger.info(
|
1340
|
+
f"Using hybrid attention backend for decode and prefill: "
|
1341
|
+
f"decode_backend={self.decode_attention_backend_str}, "
|
1342
|
+
f"prefill_backend={self.prefill_attention_backend_str}."
|
1343
|
+
)
|
1344
|
+
logger.warning(
|
1345
|
+
f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
|
1346
|
+
f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
|
1347
|
+
)
|
1348
|
+
else:
|
1349
|
+
attn_backend = self._get_attention_backend_from_str(
|
1350
|
+
self.server_args.attention_backend
|
1351
|
+
)
|
1352
|
+
|
1353
|
+
global_server_args_dict.update(
|
1354
|
+
{
|
1355
|
+
"decode_attention_backend": self.decode_attention_backend_str,
|
1356
|
+
"prefill_attention_backend": self.prefill_attention_backend_str,
|
1357
|
+
}
|
1358
|
+
)
|
1359
|
+
return attn_backend
|
1360
|
+
|
1361
|
+
def _get_attention_backend_from_str(self, backend_str: str):
|
1362
|
+
if backend_str == "flashinfer":
|
1313
1363
|
if not self.use_mla_backend:
|
1314
1364
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
1315
1365
|
FlashInferAttnBackend,
|
@@ -1317,7 +1367,11 @@ class ModelRunner:
|
|
1317
1367
|
|
1318
1368
|
# Init streams
|
1319
1369
|
if self.server_args.speculative_algorithm == "EAGLE":
|
1320
|
-
|
1370
|
+
if (
|
1371
|
+
not hasattr(self, "plan_stream_for_flashinfer")
|
1372
|
+
or not self.plan_stream_for_flashinfer
|
1373
|
+
):
|
1374
|
+
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
1321
1375
|
return FlashInferAttnBackend(self)
|
1322
1376
|
else:
|
1323
1377
|
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
@@ -1325,15 +1379,15 @@ class ModelRunner:
|
|
1325
1379
|
)
|
1326
1380
|
|
1327
1381
|
return FlashInferMLAAttnBackend(self)
|
1328
|
-
elif
|
1382
|
+
elif backend_str == "aiter":
|
1329
1383
|
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
1330
1384
|
|
1331
1385
|
return AiterAttnBackend(self)
|
1332
|
-
elif
|
1386
|
+
elif backend_str == "ascend":
|
1333
1387
|
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
1334
1388
|
|
1335
1389
|
return AscendAttnBackend(self)
|
1336
|
-
elif
|
1390
|
+
elif backend_str == "triton":
|
1337
1391
|
assert not self.model_config.is_encoder_decoder, (
|
1338
1392
|
"Cross attention is not supported in the triton attention backend. "
|
1339
1393
|
"Please use `--attention-backend flashinfer`."
|
@@ -1348,17 +1402,17 @@ class ModelRunner:
|
|
1348
1402
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
1349
1403
|
|
1350
1404
|
return TritonAttnBackend(self)
|
1351
|
-
elif
|
1405
|
+
elif backend_str == "torch_native":
|
1352
1406
|
from sglang.srt.layers.attention.torch_native_backend import (
|
1353
1407
|
TorchNativeAttnBackend,
|
1354
1408
|
)
|
1355
1409
|
|
1356
1410
|
return TorchNativeAttnBackend(self)
|
1357
|
-
elif
|
1411
|
+
elif backend_str == "flashmla":
|
1358
1412
|
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
1359
1413
|
|
1360
1414
|
return FlashMLABackend(self)
|
1361
|
-
elif
|
1415
|
+
elif backend_str == "fa3":
|
1362
1416
|
assert (
|
1363
1417
|
torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
|
1364
1418
|
) or torch.cuda.get_device_capability()[0] == 9, (
|
@@ -1370,7 +1424,7 @@ class ModelRunner:
|
|
1370
1424
|
)
|
1371
1425
|
|
1372
1426
|
return FlashAttentionBackend(self)
|
1373
|
-
elif
|
1427
|
+
elif backend_str == "cutlass_mla":
|
1374
1428
|
from sglang.srt.layers.attention.cutlass_mla_backend import (
|
1375
1429
|
CutlassMLABackend,
|
1376
1430
|
)
|
@@ -1384,9 +1438,7 @@ class ModelRunner:
|
|
1384
1438
|
logger.info(f"Intel AMX attention backend is enabled.")
|
1385
1439
|
return IntelAMXAttnBackend(self)
|
1386
1440
|
else:
|
1387
|
-
raise ValueError(
|
1388
|
-
f"Invalid attention backend: {self.server_args.attention_backend}"
|
1389
|
-
)
|
1441
|
+
raise ValueError(f"Invalid attention backend: {backend_str}")
|
1390
1442
|
|
1391
1443
|
def init_double_sparsity_channel_config(self, selected_channel):
|
1392
1444
|
selected_channel = "." + selected_channel + "_proj"
|
@@ -1462,15 +1514,22 @@ class ModelRunner:
|
|
1462
1514
|
tensor_parallel(self.model, device_mesh)
|
1463
1515
|
|
1464
1516
|
def forward_decode(
|
1465
|
-
self,
|
1517
|
+
self,
|
1518
|
+
forward_batch: ForwardBatch,
|
1519
|
+
skip_attn_backend_init: bool = False,
|
1520
|
+
pp_proxy_tensors=None,
|
1466
1521
|
) -> LogitsProcessorOutput:
|
1467
|
-
|
1522
|
+
if not skip_attn_backend_init:
|
1523
|
+
self.attn_backend.init_forward_metadata(forward_batch)
|
1468
1524
|
# FIXME: add pp_proxy_tensors arg to all models
|
1469
1525
|
kwargs = {}
|
1470
1526
|
if self.support_pp:
|
1471
1527
|
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
1472
1528
|
return self.model.forward(
|
1473
|
-
forward_batch.input_ids,
|
1529
|
+
forward_batch.input_ids,
|
1530
|
+
forward_batch.positions,
|
1531
|
+
forward_batch,
|
1532
|
+
**kwargs,
|
1474
1533
|
)
|
1475
1534
|
|
1476
1535
|
def forward_extend(
|
@@ -1576,8 +1635,18 @@ class ModelRunner:
|
|
1576
1635
|
skip_attn_backend_init=skip_attn_backend_init,
|
1577
1636
|
pp_proxy_tensors=pp_proxy_tensors,
|
1578
1637
|
)
|
1579
|
-
|
1580
|
-
|
1638
|
+
return ret, can_run_cuda_graph
|
1639
|
+
|
1640
|
+
# For MLP sync
|
1641
|
+
if forward_batch.global_num_tokens_cpu is not None:
|
1642
|
+
forward_batch.prepare_mlp_sync_batch(self)
|
1643
|
+
|
1644
|
+
if forward_batch.forward_mode.is_decode():
|
1645
|
+
ret = self.forward_decode(
|
1646
|
+
forward_batch,
|
1647
|
+
skip_attn_backend_init=skip_attn_backend_init,
|
1648
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
1649
|
+
)
|
1581
1650
|
elif forward_batch.forward_mode.is_extend():
|
1582
1651
|
ret = self.forward_extend(
|
1583
1652
|
forward_batch,
|
@@ -1595,6 +1664,9 @@ class ModelRunner:
|
|
1595
1664
|
else:
|
1596
1665
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
1597
1666
|
|
1667
|
+
if forward_batch.global_num_tokens_cpu is not None:
|
1668
|
+
forward_batch.post_forward_mlp_sync_batch(ret)
|
1669
|
+
|
1598
1670
|
return ret, can_run_cuda_graph
|
1599
1671
|
|
1600
1672
|
def _preprocess_logits(
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -56,7 +56,11 @@ from sglang.srt.layers.linear import (
|
|
56
56
|
RowParallelLinear,
|
57
57
|
)
|
58
58
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
59
|
-
from sglang.srt.layers.moe.ep_moe.layer import
|
59
|
+
from sglang.srt.layers.moe.ep_moe.layer import (
|
60
|
+
DeepEPMoE,
|
61
|
+
get_moe_impl_class,
|
62
|
+
use_flashinfer_trtllm_moe,
|
63
|
+
)
|
60
64
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
61
65
|
from sglang.srt.layers.moe.topk import TopK
|
62
66
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
@@ -302,15 +306,19 @@ class DeepseekV2MoE(nn.Module):
|
|
302
306
|
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
303
307
|
)
|
304
308
|
|
305
|
-
self.topk =
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
309
|
+
self.topk = (
|
310
|
+
TopK(
|
311
|
+
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
312
|
+
renormalize=config.norm_topk_prob,
|
313
|
+
use_grouped_topk=True,
|
314
|
+
num_expert_group=config.n_group,
|
315
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
316
|
+
topk_group=config.topk_group,
|
317
|
+
correction_bias=self.gate.e_score_correction_bias,
|
318
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
319
|
+
)
|
320
|
+
if not use_flashinfer_trtllm_moe
|
321
|
+
else None
|
314
322
|
)
|
315
323
|
|
316
324
|
self.experts = get_moe_impl_class()(
|
@@ -332,10 +340,22 @@ class DeepseekV2MoE(nn.Module):
|
|
332
340
|
# Additional args for FusedMoE
|
333
341
|
**(
|
334
342
|
dict(
|
335
|
-
|
343
|
+
enable_flashinfer_cutlass_moe=True,
|
336
344
|
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
337
345
|
)
|
338
|
-
if global_server_args_dict["
|
346
|
+
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
347
|
+
else {}
|
348
|
+
),
|
349
|
+
**(
|
350
|
+
dict(
|
351
|
+
renormalize=config.norm_topk_prob,
|
352
|
+
use_grouped_topk=True,
|
353
|
+
num_expert_group=config.n_group,
|
354
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
355
|
+
topk_group=config.topk_group,
|
356
|
+
correction_bias=self.gate.e_score_correction_bias,
|
357
|
+
)
|
358
|
+
if use_flashinfer_trtllm_moe
|
339
359
|
else {}
|
340
360
|
),
|
341
361
|
)
|
@@ -455,10 +475,12 @@ class DeepseekV2MoE(nn.Module):
|
|
455
475
|
with torch.cuda.stream(self.alt_stream):
|
456
476
|
# router_logits: (num_tokens, n_experts)
|
457
477
|
router_logits = self.gate(hidden_states)
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
478
|
+
kwargs = {"hidden_states": hidden_states}
|
479
|
+
if self.topk is not None:
|
480
|
+
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
481
|
+
else:
|
482
|
+
kwargs["router_logits"] = router_logits
|
483
|
+
final_hidden_states = self.experts(**kwargs)
|
462
484
|
if not _is_cuda:
|
463
485
|
final_hidden_states *= self.routed_scaling_factor
|
464
486
|
current_stream.wait_stream(self.alt_stream)
|
@@ -478,10 +500,12 @@ class DeepseekV2MoE(nn.Module):
|
|
478
500
|
shared_output = self._forward_shared_experts(hidden_states)
|
479
501
|
# router_logits: (num_tokens, n_experts)
|
480
502
|
router_logits = self.gate(hidden_states)
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
503
|
+
kwargs = {"hidden_states": hidden_states}
|
504
|
+
if self.topk is not None:
|
505
|
+
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
506
|
+
else:
|
507
|
+
kwargs["router_logits"] = router_logits
|
508
|
+
final_hidden_states = self.experts(**kwargs)
|
485
509
|
if not _is_cuda and not _use_aiter:
|
486
510
|
# fused in biased_grouped_topk so we can skip here
|
487
511
|
final_hidden_states *= self.routed_scaling_factor
|
@@ -550,9 +574,8 @@ class DeepseekV2MoE(nn.Module):
|
|
550
574
|
def forward_deepep(
|
551
575
|
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
552
576
|
) -> torch.Tensor:
|
553
|
-
forward_mode = forward_batch.forward_mode
|
554
577
|
shared_output = None
|
555
|
-
if
|
578
|
+
if hidden_states.shape[0] > 0:
|
556
579
|
# router_logits: (num_tokens, n_experts)
|
557
580
|
router_logits = self.gate(hidden_states)
|
558
581
|
shared_output = self._forward_shared_experts(hidden_states)
|
@@ -902,7 +925,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
902
925
|
self.disable_chunked_prefix_cache = global_server_args_dict[
|
903
926
|
"disable_chunked_prefix_cache"
|
904
927
|
]
|
905
|
-
|
928
|
+
|
929
|
+
self.current_attention_backend = (
|
930
|
+
None # Attention backend used by current forward batch
|
931
|
+
)
|
906
932
|
self.rocm_fused_decode_mla = get_bool_env_var(
|
907
933
|
"SGLANG_ROCM_FUSED_DECODE_MLA", "false"
|
908
934
|
)
|
@@ -986,9 +1012,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
986
1012
|
else:
|
987
1013
|
return AttnForwardMethod.MLA
|
988
1014
|
|
989
|
-
|
1015
|
+
# Determine attention backend used by current forward batch
|
1016
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
1017
|
+
attention_backend = global_server_args_dict["decode_attention_backend"]
|
1018
|
+
else:
|
1019
|
+
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
1020
|
+
self.current_attention_backend = attention_backend
|
1021
|
+
|
1022
|
+
if attention_backend == "ascend":
|
990
1023
|
return AttnForwardMethod.MLA
|
991
|
-
elif
|
1024
|
+
elif attention_backend == "flashinfer":
|
992
1025
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
993
1026
|
if (
|
994
1027
|
not self.flashinfer_mla_disable_ragged
|
@@ -1000,7 +1033,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1000
1033
|
return AttnForwardMethod.MHA
|
1001
1034
|
else:
|
1002
1035
|
return _dispatch_mla_subtype()
|
1003
|
-
elif
|
1036
|
+
elif attention_backend == "fa3":
|
1004
1037
|
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
1005
1038
|
if forward_batch.extend_prefix_lens_cpu is not None:
|
1006
1039
|
sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
|
@@ -1017,7 +1050,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1017
1050
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
1018
1051
|
else:
|
1019
1052
|
return _dispatch_mla_subtype()
|
1020
|
-
elif
|
1053
|
+
elif attention_backend == "aiter":
|
1021
1054
|
if (
|
1022
1055
|
forward_batch.forward_mode.is_extend()
|
1023
1056
|
and not forward_batch.forward_mode.is_target_verify()
|
@@ -1265,9 +1298,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1265
1298
|
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1266
1299
|
):
|
1267
1300
|
if (
|
1268
|
-
self.
|
1269
|
-
or self.
|
1270
|
-
or self.
|
1301
|
+
self.current_attention_backend == "fa3"
|
1302
|
+
or self.current_attention_backend == "flashinfer"
|
1303
|
+
or self.current_attention_backend == "cutlass_mla"
|
1271
1304
|
):
|
1272
1305
|
attn_output = self.attn_mqa(
|
1273
1306
|
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|