sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -0
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +7 -7
- sglang/srt/disaggregation/decode.py +8 -3
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +4 -5
- sglang/srt/entrypoints/openai/protocol.py +0 -9
- sglang/srt/entrypoints/openai/serving_chat.py +59 -265
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +8 -10
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/quantization/__init__.py +5 -3
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/modelopt_quant.py +6 -11
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +21 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +6 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +35 -20
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +15 -7
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +25 -26
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +22 -3
- sglang/srt/model_executor/forward_batch_info.py +26 -5
- sglang/srt/model_executor/model_runner.py +129 -35
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_v2.py +74 -35
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +9 -9
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +136 -19
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/server_args.py +115 -139
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +12 -4
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -121,6 +121,10 @@ from sglang.srt.utils import (
|
|
121
121
|
set_cpu_offload_max_bytes,
|
122
122
|
set_cuda_arch,
|
123
123
|
)
|
124
|
+
from sglang.srt.weight_sync.tensor_bucket import (
|
125
|
+
FlattenedTensorBucket,
|
126
|
+
FlattenedTensorMetadata,
|
127
|
+
)
|
124
128
|
|
125
129
|
_is_hip = is_hip()
|
126
130
|
_is_npu = is_npu()
|
@@ -378,6 +382,25 @@ class ModelRunner:
|
|
378
382
|
)
|
379
383
|
server_args.attention_backend = "torch_native"
|
380
384
|
|
385
|
+
if server_args.prefill_attention_backend is not None and (
|
386
|
+
server_args.prefill_attention_backend
|
387
|
+
== server_args.decode_attention_backend
|
388
|
+
): # override the default attention backend
|
389
|
+
server_args.attention_backend = server_args.prefill_attention_backend
|
390
|
+
|
391
|
+
if (
|
392
|
+
getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
|
393
|
+
is not None
|
394
|
+
):
|
395
|
+
if server_args.attention_backend is None:
|
396
|
+
server_args.attention_backend = "dual_chunk_flash_attn"
|
397
|
+
logger.info("Dual chunk attention is turned on by default.")
|
398
|
+
elif server_args.attention_backend != "dual_chunk_flash_attn":
|
399
|
+
raise ValueError(
|
400
|
+
"Dual chunk attention is enabled, but attention backend is set to "
|
401
|
+
f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
|
402
|
+
)
|
403
|
+
|
381
404
|
if server_args.attention_backend is None:
|
382
405
|
"""
|
383
406
|
Auto select the fastest attention backend.
|
@@ -397,7 +420,6 @@ class ModelRunner:
|
|
397
420
|
is_hopper_with_cuda_12_3()
|
398
421
|
and is_no_spec_infer_or_topk_one(server_args)
|
399
422
|
and is_fa3_default_architecture(self.model_config.hf_config)
|
400
|
-
and (not server_args.enable_hierarchical_cache)
|
401
423
|
):
|
402
424
|
server_args.attention_backend = "fa3"
|
403
425
|
elif _is_hip:
|
@@ -410,9 +432,7 @@ class ModelRunner:
|
|
410
432
|
)
|
411
433
|
else:
|
412
434
|
# MLA architecture
|
413
|
-
if is_hopper_with_cuda_12_3()
|
414
|
-
not server_args.enable_hierarchical_cache
|
415
|
-
):
|
435
|
+
if is_hopper_with_cuda_12_3():
|
416
436
|
server_args.attention_backend = "fa3"
|
417
437
|
elif is_sm100_supported():
|
418
438
|
server_args.attention_backend = "flashinfer"
|
@@ -500,6 +520,27 @@ class ModelRunner:
|
|
500
520
|
if self.model_config.context_len > 8192:
|
501
521
|
self.mem_fraction_static *= 0.85
|
502
522
|
|
523
|
+
if (
|
524
|
+
server_args.enable_hierarchical_cache
|
525
|
+
and server_args.hicache_io_backend == "kernel"
|
526
|
+
):
|
527
|
+
# fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
|
528
|
+
if server_args.decode_attention_backend is None:
|
529
|
+
if not self.use_mla_backend:
|
530
|
+
server_args.decode_attention_backend = (
|
531
|
+
"flashinfer" if is_flashinfer_available() else "triton"
|
532
|
+
)
|
533
|
+
else:
|
534
|
+
server_args.decode_attention_backend = (
|
535
|
+
"flashinfer" if is_sm100_supported() else "triton"
|
536
|
+
)
|
537
|
+
elif server_args.decode_attention_backend == "fa3":
|
538
|
+
server_args.hicache_io_backend = "direct"
|
539
|
+
logger.warning(
|
540
|
+
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
|
541
|
+
f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
542
|
+
)
|
543
|
+
|
503
544
|
def init_torch_distributed(self):
|
504
545
|
logger.info("Init torch distributed begin.")
|
505
546
|
|
@@ -871,8 +912,18 @@ class ModelRunner:
|
|
871
912
|
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
|
872
913
|
load_format: Optional[str] = None,
|
873
914
|
):
|
915
|
+
monkey_patch_torch_reductions()
|
916
|
+
if load_format == "flattened_bucket":
|
917
|
+
# Handle flattened bucket format
|
918
|
+
return self._update_weights_from_flattened_bucket(
|
919
|
+
flattened_tensor_bucket_dict=named_tensors
|
920
|
+
)
|
921
|
+
|
922
|
+
# We need to get device after patch otherwise the device would be wrong
|
923
|
+
infered_device = torch.cuda.current_device()
|
924
|
+
|
874
925
|
named_tensors = [
|
875
|
-
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
|
926
|
+
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
|
876
927
|
for name, tensor in named_tensors
|
877
928
|
]
|
878
929
|
if load_format == "direct":
|
@@ -886,6 +937,38 @@ class ModelRunner:
|
|
886
937
|
raise NotImplementedError(f"Unknown load_format={load_format}")
|
887
938
|
return True, "Success"
|
888
939
|
|
940
|
+
def _update_weights_from_flattened_bucket(
|
941
|
+
self,
|
942
|
+
flattened_tensor_bucket_dict,
|
943
|
+
):
|
944
|
+
"""Handle flattened bucket format for weight updates"""
|
945
|
+
flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
|
946
|
+
metadata = flattened_tensor_bucket_dict["metadata"]
|
947
|
+
|
948
|
+
# Convert metadata dict to our format
|
949
|
+
converted_metadata = []
|
950
|
+
for meta in metadata:
|
951
|
+
converted_meta = FlattenedTensorMetadata(
|
952
|
+
name=meta.name,
|
953
|
+
shape=meta.shape,
|
954
|
+
dtype=meta.dtype,
|
955
|
+
start_idx=meta.start_idx,
|
956
|
+
end_idx=meta.end_idx,
|
957
|
+
numel=meta.numel,
|
958
|
+
)
|
959
|
+
converted_metadata.append(converted_meta)
|
960
|
+
|
961
|
+
# Create bucket and reconstruct tensors
|
962
|
+
bucket = FlattenedTensorBucket(
|
963
|
+
flattened_tensor=flattened_tensor, metadata=converted_metadata
|
964
|
+
)
|
965
|
+
reconstructed_tensors = bucket.reconstruct_tensors()
|
966
|
+
|
967
|
+
# Load the reconstructed tensors using the standard method
|
968
|
+
self.model.load_weights(reconstructed_tensors)
|
969
|
+
|
970
|
+
return True, "Success"
|
971
|
+
|
889
972
|
def get_weights_by_name(
|
890
973
|
self, name: str, truncate_size: int = 100
|
891
974
|
) -> Optional[torch.Tensor]:
|
@@ -1181,30 +1264,33 @@ class ModelRunner:
|
|
1181
1264
|
# Draft worker shares req_to_token_pool with the target worker.
|
1182
1265
|
assert self.is_draft_worker
|
1183
1266
|
|
1184
|
-
if self.server_args.attention_backend == "ascend"
|
1185
|
-
self.
|
1186
|
-
self.
|
1187
|
-
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1204
|
-
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1267
|
+
if self.server_args.attention_backend == "ascend":
|
1268
|
+
if self.use_mla_backend:
|
1269
|
+
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
|
1270
|
+
self.max_total_num_tokens,
|
1271
|
+
page_size=self.page_size,
|
1272
|
+
dtype=self.kv_cache_dtype,
|
1273
|
+
kv_lora_rank=self.model_config.kv_lora_rank,
|
1274
|
+
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
1275
|
+
layer_num=self.num_effective_layers,
|
1276
|
+
device=self.device,
|
1277
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
1278
|
+
start_layer=self.start_layer,
|
1279
|
+
end_layer=self.end_layer,
|
1280
|
+
)
|
1281
|
+
else:
|
1282
|
+
self.token_to_kv_pool = AscendTokenToKVPool(
|
1283
|
+
self.max_total_num_tokens,
|
1284
|
+
page_size=self.page_size,
|
1285
|
+
dtype=self.kv_cache_dtype,
|
1286
|
+
head_num=self.model_config.get_num_kv_heads(
|
1287
|
+
get_attention_tp_size()
|
1288
|
+
),
|
1289
|
+
head_dim=self.model_config.head_dim,
|
1290
|
+
layer_num=self.model_config.num_hidden_layers,
|
1291
|
+
device=self.device,
|
1292
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
1293
|
+
)
|
1208
1294
|
elif self.use_mla_backend:
|
1209
1295
|
self.token_to_kv_pool = MLATokenToKVPool(
|
1210
1296
|
self.max_total_num_tokens,
|
@@ -1263,6 +1349,7 @@ class ModelRunner:
|
|
1263
1349
|
end_layer=self.end_layer,
|
1264
1350
|
)
|
1265
1351
|
|
1352
|
+
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
1266
1353
|
if self.token_to_kv_pool_allocator is None:
|
1267
1354
|
if self.page_size == 1:
|
1268
1355
|
if self.is_hybrid:
|
@@ -1272,6 +1359,7 @@ class ModelRunner:
|
|
1272
1359
|
dtype=self.kv_cache_dtype,
|
1273
1360
|
device=self.device,
|
1274
1361
|
kvcache=self.token_to_kv_pool,
|
1362
|
+
need_sort=need_sort,
|
1275
1363
|
)
|
1276
1364
|
else:
|
1277
1365
|
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
@@ -1279,23 +1367,26 @@ class ModelRunner:
|
|
1279
1367
|
dtype=self.kv_cache_dtype,
|
1280
1368
|
device=self.device,
|
1281
1369
|
kvcache=self.token_to_kv_pool,
|
1370
|
+
need_sort=need_sort,
|
1282
1371
|
)
|
1283
1372
|
else:
|
1284
|
-
if _is_npu:
|
1285
|
-
self.token_to_kv_pool_allocator =
|
1373
|
+
if not _is_npu:
|
1374
|
+
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
1286
1375
|
self.max_total_num_tokens,
|
1287
1376
|
page_size=self.page_size,
|
1288
1377
|
dtype=self.kv_cache_dtype,
|
1289
1378
|
device=self.device,
|
1290
1379
|
kvcache=self.token_to_kv_pool,
|
1380
|
+
need_sort=need_sort,
|
1291
1381
|
)
|
1292
1382
|
else:
|
1293
|
-
self.token_to_kv_pool_allocator =
|
1383
|
+
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
1294
1384
|
self.max_total_num_tokens,
|
1295
1385
|
page_size=self.page_size,
|
1296
1386
|
dtype=self.kv_cache_dtype,
|
1297
1387
|
device=self.device,
|
1298
1388
|
kvcache=self.token_to_kv_pool,
|
1389
|
+
need_sort=need_sort,
|
1299
1390
|
)
|
1300
1391
|
else:
|
1301
1392
|
assert self.is_draft_worker
|
@@ -1396,6 +1487,10 @@ class ModelRunner:
|
|
1396
1487
|
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
1397
1488
|
|
1398
1489
|
return AiterAttnBackend(self)
|
1490
|
+
elif self.server_args.attention_backend == "wave":
|
1491
|
+
from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
|
1492
|
+
|
1493
|
+
return WaveAttnBackend(self)
|
1399
1494
|
elif backend_str == "ascend":
|
1400
1495
|
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
1401
1496
|
|
@@ -1785,11 +1880,10 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
|
|
1785
1880
|
default_weight_loader(params_dict[name], tensor)
|
1786
1881
|
|
1787
1882
|
|
1788
|
-
def _unwrap_tensor(tensor, tp_rank):
|
1883
|
+
def _unwrap_tensor(tensor, tp_rank, device):
|
1789
1884
|
if isinstance(tensor, LocalSerializedTensor):
|
1790
|
-
monkey_patch_torch_reductions()
|
1791
1885
|
tensor = tensor.get(tp_rank)
|
1792
|
-
return tensor.to(
|
1886
|
+
return tensor.to(device)
|
1793
1887
|
|
1794
1888
|
|
1795
1889
|
@dataclass
|
@@ -162,12 +162,24 @@ def _initialize_model(
|
|
162
162
|
model_class, _ = get_model_architecture(model_config)
|
163
163
|
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
|
164
164
|
if _is_npu:
|
165
|
-
packed_modules_mapping
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
165
|
+
packed_modules_mapping.update(
|
166
|
+
{
|
167
|
+
"visual": {"qkv_proj": ["qkv"]},
|
168
|
+
"vision_model": {
|
169
|
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
170
|
+
"proj": ["out_proj"],
|
171
|
+
},
|
172
|
+
"model": {
|
173
|
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
174
|
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
175
|
+
"fused_qkv_a_proj_with_mqa": [
|
176
|
+
"q_a_proj",
|
177
|
+
"kv_a_proj_with_mqa",
|
178
|
+
],
|
179
|
+
},
|
180
|
+
}
|
181
|
+
)
|
182
|
+
|
171
183
|
quant_config = _get_quantization_config(
|
172
184
|
model_config, load_config, packed_modules_mapping
|
173
185
|
)
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -212,7 +212,7 @@ class DeepseekV2MLP(nn.Module):
|
|
212
212
|
self,
|
213
213
|
x,
|
214
214
|
forward_batch=None,
|
215
|
-
|
215
|
+
should_allreduce_fusion: bool = False,
|
216
216
|
use_reduce_scatter: bool = False,
|
217
217
|
):
|
218
218
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
@@ -221,7 +221,7 @@ class DeepseekV2MLP(nn.Module):
|
|
221
221
|
gate_up, _ = self.gate_up_proj(x)
|
222
222
|
x = self.act_fn(gate_up)
|
223
223
|
x, _ = self.down_proj(
|
224
|
-
x, skip_all_reduce=
|
224
|
+
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
|
225
225
|
)
|
226
226
|
return x
|
227
227
|
|
@@ -448,7 +448,7 @@ class DeepseekV2MoE(nn.Module):
|
|
448
448
|
self,
|
449
449
|
hidden_states: torch.Tensor,
|
450
450
|
forward_batch: Optional[ForwardBatch] = None,
|
451
|
-
|
451
|
+
should_allreduce_fusion: bool = False,
|
452
452
|
use_reduce_scatter: bool = False,
|
453
453
|
) -> torch.Tensor:
|
454
454
|
if not self._enable_deepep_moe:
|
@@ -459,11 +459,11 @@ class DeepseekV2MoE(nn.Module):
|
|
459
459
|
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
460
460
|
):
|
461
461
|
return self.forward_normal_dual_stream(
|
462
|
-
hidden_states,
|
462
|
+
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
463
463
|
)
|
464
464
|
else:
|
465
465
|
return self.forward_normal(
|
466
|
-
hidden_states,
|
466
|
+
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
467
467
|
)
|
468
468
|
else:
|
469
469
|
return self.forward_deepep(hidden_states, forward_batch)
|
@@ -471,7 +471,7 @@ class DeepseekV2MoE(nn.Module):
|
|
471
471
|
def forward_normal_dual_stream(
|
472
472
|
self,
|
473
473
|
hidden_states: torch.Tensor,
|
474
|
-
|
474
|
+
should_allreduce_fusion: bool = False,
|
475
475
|
use_reduce_scatter: bool = False,
|
476
476
|
) -> torch.Tensor:
|
477
477
|
|
@@ -500,20 +500,20 @@ class DeepseekV2MoE(nn.Module):
|
|
500
500
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
501
501
|
final_hidden_states = final_hidden_states_out
|
502
502
|
sm.tag(final_hidden_states)
|
503
|
-
if self.tp_size > 1 and not
|
503
|
+
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
|
504
504
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
505
505
|
return final_hidden_states
|
506
506
|
|
507
507
|
def forward_normal(
|
508
508
|
self,
|
509
509
|
hidden_states: torch.Tensor,
|
510
|
-
|
510
|
+
should_allreduce_fusion: bool = False,
|
511
511
|
use_reduce_scatter: bool = False,
|
512
512
|
) -> torch.Tensor:
|
513
513
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
514
514
|
self.shared_experts.gate_up_proj
|
515
515
|
):
|
516
|
-
return self.forward_cpu(hidden_states,
|
516
|
+
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
517
517
|
|
518
518
|
shared_output = self._forward_shared_experts(hidden_states)
|
519
519
|
# router_logits: (num_tokens, n_experts)
|
@@ -537,12 +537,14 @@ class DeepseekV2MoE(nn.Module):
|
|
537
537
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
538
538
|
final_hidden_states = final_hidden_states_out
|
539
539
|
sm.tag(final_hidden_states)
|
540
|
-
if self.tp_size > 1 and not
|
540
|
+
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
|
541
541
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
542
542
|
return final_hidden_states
|
543
543
|
|
544
544
|
def forward_cpu(
|
545
|
-
self,
|
545
|
+
self,
|
546
|
+
hidden_states: torch.Tensor,
|
547
|
+
should_allreduce_fusion: bool = False,
|
546
548
|
) -> torch.Tensor:
|
547
549
|
# router_logits: (num_tokens, n_experts)
|
548
550
|
router_logits = self.gate(hidden_states)
|
@@ -593,7 +595,7 @@ class DeepseekV2MoE(nn.Module):
|
|
593
595
|
None, # a2_scale
|
594
596
|
True, # is_vnni
|
595
597
|
)
|
596
|
-
if self.tp_size > 1 and not
|
598
|
+
if self.tp_size > 1 and not should_allreduce_fusion:
|
597
599
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
598
600
|
return final_hidden_states
|
599
601
|
|
@@ -1194,6 +1196,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1194
1196
|
output, _ = self.o_proj(attn_output)
|
1195
1197
|
return output
|
1196
1198
|
|
1199
|
+
def _fuse_rope_for_trtllm_mla(self, forward_batch: ForwardBatch) -> bool:
|
1200
|
+
"""
|
1201
|
+
Check if we should skip rope and do fused rope+quantize for TRTLLM MLA decode in fp8_e4m3 path.
|
1202
|
+
"""
|
1203
|
+
return (
|
1204
|
+
self.current_attention_backend == "trtllm_mla"
|
1205
|
+
and forward_batch.forward_mode.is_decode_or_idle()
|
1206
|
+
and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
|
1207
|
+
)
|
1208
|
+
|
1197
1209
|
def forward_absorb_prepare(
|
1198
1210
|
self,
|
1199
1211
|
positions: torch.Tensor,
|
@@ -1273,7 +1285,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1273
1285
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
1274
1286
|
|
1275
1287
|
q_nope_out = q_nope_out.transpose(0, 1)
|
1276
|
-
|
1288
|
+
|
1289
|
+
if not self._fuse_rope_for_trtllm_mla(forward_batch):
|
1290
|
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1277
1291
|
|
1278
1292
|
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1279
1293
|
|
@@ -1286,8 +1300,20 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1286
1300
|
or self.current_attention_backend == "cutlass_mla"
|
1287
1301
|
or self.current_attention_backend == "trtllm_mla"
|
1288
1302
|
):
|
1303
|
+
extra_args = {}
|
1304
|
+
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
1305
|
+
extra_args = {
|
1306
|
+
"cos_sin_cache": self.rotary_emb.cos_sin_cache,
|
1307
|
+
"is_neox": self.rotary_emb.is_neox_style,
|
1308
|
+
}
|
1289
1309
|
attn_output = self.attn_mqa(
|
1290
|
-
q_nope_out,
|
1310
|
+
q_nope_out,
|
1311
|
+
k_nope,
|
1312
|
+
k_nope,
|
1313
|
+
forward_batch,
|
1314
|
+
q_rope=q_pe,
|
1315
|
+
k_rope=k_pe,
|
1316
|
+
**extra_args,
|
1291
1317
|
)
|
1292
1318
|
else:
|
1293
1319
|
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
@@ -1842,6 +1868,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1842
1868
|
allow_reduce_scatter=True,
|
1843
1869
|
)
|
1844
1870
|
|
1871
|
+
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
|
1872
|
+
|
1845
1873
|
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
1846
1874
|
return is_nextn or (
|
1847
1875
|
self.config.n_routed_experts is not None
|
@@ -1850,27 +1878,18 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1850
1878
|
)
|
1851
1879
|
|
1852
1880
|
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
1853
|
-
"""Check if MLP allreduce can be fused with next layer's
|
1854
|
-
|
1855
|
-
if (
|
1856
|
-
self.layer_id == self.config.num_hidden_layers - 1
|
1857
|
-
or get_tensor_model_parallel_world_size() <= 1
|
1858
|
-
):
|
1859
|
-
return False
|
1860
|
-
|
1861
|
-
if not global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False):
|
1862
|
-
return False
|
1881
|
+
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
|
1863
1882
|
|
1864
|
-
|
1865
|
-
|
1883
|
+
batch_size = (
|
1884
|
+
forward_batch.input_ids.shape[0]
|
1885
|
+
if hasattr(forward_batch, "input_ids")
|
1886
|
+
else 0
|
1887
|
+
)
|
1866
1888
|
|
1867
|
-
if
|
1868
|
-
forward_batch.input_ids.shape[0] == 0
|
1869
|
-
or forward_batch.input_ids.shape[0] > 128
|
1870
|
-
):
|
1889
|
+
if batch_size > 128:
|
1871
1890
|
return False
|
1872
1891
|
|
1873
|
-
return
|
1892
|
+
return self._fuse_allreduce_lookup_table.get(batch_size, False)
|
1874
1893
|
|
1875
1894
|
def forward(
|
1876
1895
|
self,
|
@@ -1896,7 +1915,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1896
1915
|
hidden_states, residual, forward_batch
|
1897
1916
|
)
|
1898
1917
|
|
1899
|
-
|
1918
|
+
should_allreduce_fusion = (
|
1900
1919
|
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
|
1901
1920
|
and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
|
1902
1921
|
and not self.is_nextn
|
@@ -1907,13 +1926,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1907
1926
|
forward_batch
|
1908
1927
|
)
|
1909
1928
|
hidden_states = self.mlp(
|
1910
|
-
hidden_states, forward_batch,
|
1929
|
+
hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
|
1911
1930
|
)
|
1912
1931
|
|
1913
|
-
if
|
1932
|
+
if should_allreduce_fusion:
|
1914
1933
|
hidden_states._sglang_needs_allreduce_fusion = True
|
1915
1934
|
|
1916
|
-
if not
|
1935
|
+
if not should_allreduce_fusion:
|
1917
1936
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
1918
1937
|
hidden_states, residual, forward_batch
|
1919
1938
|
)
|
@@ -1990,6 +2009,26 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1990
2009
|
)
|
1991
2010
|
return output
|
1992
2011
|
|
2012
|
+
def _build_fuse_allreduce_lookup_table(self):
|
2013
|
+
static_conditions_met = (
|
2014
|
+
self.layer_id != self.config.num_hidden_layers - 1
|
2015
|
+
and get_tensor_model_parallel_world_size() > 1
|
2016
|
+
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
2017
|
+
and _is_sm100_supported
|
2018
|
+
and _is_flashinfer_available
|
2019
|
+
)
|
2020
|
+
|
2021
|
+
if not static_conditions_met:
|
2022
|
+
return {}
|
2023
|
+
|
2024
|
+
lookup_table = {}
|
2025
|
+
for batch_size in range(129): # 0 to 128
|
2026
|
+
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
|
2027
|
+
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
|
2028
|
+
lookup_table[batch_size] = should_fuse
|
2029
|
+
|
2030
|
+
return lookup_table
|
2031
|
+
|
1993
2032
|
|
1994
2033
|
class DeepseekV2Model(nn.Module):
|
1995
2034
|
fall_back_to_pt_during_load = False
|
sglang/srt/models/gemma2.py
CHANGED
@@ -432,40 +432,6 @@ class Gemma2ForCausalLM(nn.Module):
|
|
432
432
|
|
433
433
|
return result
|
434
434
|
|
435
|
-
def get_hidden_dim(self, module_name):
|
436
|
-
# return input_dim, output_dim
|
437
|
-
if module_name in ["q_proj", "qkv_proj"]:
|
438
|
-
return (
|
439
|
-
self.config.hidden_size,
|
440
|
-
self.config.head_dim * self.config.num_attention_heads,
|
441
|
-
)
|
442
|
-
elif module_name in ["o_proj"]:
|
443
|
-
return (
|
444
|
-
self.config.head_dim * self.config.num_attention_heads,
|
445
|
-
self.config.hidden_size,
|
446
|
-
)
|
447
|
-
elif module_name in ["kv_proj"]:
|
448
|
-
return (
|
449
|
-
self.config.hidden_size,
|
450
|
-
self.config.head_dim * self.config.num_key_value_heads,
|
451
|
-
)
|
452
|
-
elif module_name == "gate_up_proj":
|
453
|
-
return self.config.hidden_size, self.config.intermediate_size
|
454
|
-
elif module_name == "down_proj":
|
455
|
-
return self.config.intermediate_size, self.config.hidden_size
|
456
|
-
else:
|
457
|
-
raise NotImplementedError()
|
458
|
-
|
459
|
-
def get_module_name(self, name):
|
460
|
-
params_mapping = {
|
461
|
-
"q_proj": "qkv_proj",
|
462
|
-
"k_proj": "qkv_proj",
|
463
|
-
"v_proj": "qkv_proj",
|
464
|
-
"gate_proj": "gate_up_proj",
|
465
|
-
"up_proj": "gate_up_proj",
|
466
|
-
}
|
467
|
-
return params_mapping.get(name, name)
|
468
|
-
|
469
435
|
def get_attention_sliding_window_size(self):
|
470
436
|
return get_attention_sliding_window_size(self.config)
|
471
437
|
|
sglang/srt/models/gemma3n_mm.py
CHANGED
@@ -501,27 +501,26 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|
501
501
|
|
502
502
|
def get_hidden_dim(self, module_name):
|
503
503
|
# return input_dim, output_dim
|
504
|
-
if module_name
|
504
|
+
if module_name == "qkv_proj":
|
505
505
|
return (
|
506
506
|
self.config.hidden_size,
|
507
|
-
self.config.head_dim
|
507
|
+
self.config.head_dim
|
508
|
+
* (
|
509
|
+
self.config.num_attention_heads
|
510
|
+
+ self.config.num_key_value_heads * 2
|
511
|
+
),
|
508
512
|
)
|
509
|
-
elif module_name
|
513
|
+
elif module_name == "o_proj":
|
510
514
|
return (
|
511
515
|
self.config.head_dim * self.config.num_attention_heads,
|
512
516
|
self.config.hidden_size,
|
513
517
|
)
|
514
|
-
elif module_name in ["kv_proj"]:
|
515
|
-
return (
|
516
|
-
self.config.hidden_size,
|
517
|
-
self.config.head_dim * self.config.num_key_value_heads,
|
518
|
-
)
|
519
518
|
elif module_name == "gate_up_proj":
|
520
519
|
assert len(set(self.config.intermediate_size)) == 1, (
|
521
520
|
"Currently SGLang requires uniform intermediate size for all layers. "
|
522
521
|
"Please file an issue if you need support for non-uniform intermediate sizes."
|
523
522
|
)
|
524
|
-
return self.config.hidden_size, self.config.intermediate_size[0]
|
523
|
+
return self.config.hidden_size, self.config.intermediate_size[0] * 2
|
525
524
|
elif module_name == "down_proj":
|
526
525
|
assert len(set(self.config.intermediate_size)) == 1, (
|
527
526
|
"Currently SGLang requires uniform intermediate size for all layers. "
|
sglang/srt/models/glm4.py
CHANGED
@@ -218,6 +218,12 @@ class Glm4Model(nn.Module):
|
|
218
218
|
|
219
219
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
220
220
|
|
221
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
222
|
+
return self.embed_tokens
|
223
|
+
|
224
|
+
def dtype(self) -> torch.dtype:
|
225
|
+
return next(self.parameters()).dtype
|
226
|
+
|
221
227
|
@torch.no_grad()
|
222
228
|
def forward(
|
223
229
|
self,
|