sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -121,6 +121,10 @@ from sglang.srt.utils import (
|
|
121
121
|
set_cpu_offload_max_bytes,
|
122
122
|
set_cuda_arch,
|
123
123
|
)
|
124
|
+
from sglang.srt.weight_sync.tensor_bucket import (
|
125
|
+
FlattenedTensorBucket,
|
126
|
+
FlattenedTensorMetadata,
|
127
|
+
)
|
124
128
|
|
125
129
|
_is_hip = is_hip()
|
126
130
|
_is_npu = is_npu()
|
@@ -378,6 +382,25 @@ class ModelRunner:
|
|
378
382
|
)
|
379
383
|
server_args.attention_backend = "torch_native"
|
380
384
|
|
385
|
+
if server_args.prefill_attention_backend is not None and (
|
386
|
+
server_args.prefill_attention_backend
|
387
|
+
== server_args.decode_attention_backend
|
388
|
+
): # override the default attention backend
|
389
|
+
server_args.attention_backend = server_args.prefill_attention_backend
|
390
|
+
|
391
|
+
if (
|
392
|
+
getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
|
393
|
+
is not None
|
394
|
+
):
|
395
|
+
if server_args.attention_backend is None:
|
396
|
+
server_args.attention_backend = "dual_chunk_flash_attn"
|
397
|
+
logger.info("Dual chunk attention is turned on by default.")
|
398
|
+
elif server_args.attention_backend != "dual_chunk_flash_attn":
|
399
|
+
raise ValueError(
|
400
|
+
"Dual chunk attention is enabled, but attention backend is set to "
|
401
|
+
f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
|
402
|
+
)
|
403
|
+
|
381
404
|
if server_args.attention_backend is None:
|
382
405
|
"""
|
383
406
|
Auto select the fastest attention backend.
|
@@ -397,7 +420,6 @@ class ModelRunner:
|
|
397
420
|
is_hopper_with_cuda_12_3()
|
398
421
|
and is_no_spec_infer_or_topk_one(server_args)
|
399
422
|
and is_fa3_default_architecture(self.model_config.hf_config)
|
400
|
-
and (not server_args.enable_hierarchical_cache)
|
401
423
|
):
|
402
424
|
server_args.attention_backend = "fa3"
|
403
425
|
elif _is_hip:
|
@@ -410,9 +432,7 @@ class ModelRunner:
|
|
410
432
|
)
|
411
433
|
else:
|
412
434
|
# MLA architecture
|
413
|
-
if is_hopper_with_cuda_12_3()
|
414
|
-
not server_args.enable_hierarchical_cache
|
415
|
-
):
|
435
|
+
if is_hopper_with_cuda_12_3():
|
416
436
|
server_args.attention_backend = "fa3"
|
417
437
|
elif is_sm100_supported():
|
418
438
|
server_args.attention_backend = "flashinfer"
|
@@ -500,6 +520,27 @@ class ModelRunner:
|
|
500
520
|
if self.model_config.context_len > 8192:
|
501
521
|
self.mem_fraction_static *= 0.85
|
502
522
|
|
523
|
+
if (
|
524
|
+
server_args.enable_hierarchical_cache
|
525
|
+
and server_args.hicache_io_backend == "kernel"
|
526
|
+
):
|
527
|
+
# fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
|
528
|
+
if server_args.decode_attention_backend is None:
|
529
|
+
if not self.use_mla_backend:
|
530
|
+
server_args.decode_attention_backend = (
|
531
|
+
"flashinfer" if is_flashinfer_available() else "triton"
|
532
|
+
)
|
533
|
+
else:
|
534
|
+
server_args.decode_attention_backend = (
|
535
|
+
"flashinfer" if is_sm100_supported() else "triton"
|
536
|
+
)
|
537
|
+
elif server_args.decode_attention_backend == "fa3":
|
538
|
+
server_args.hicache_io_backend = "direct"
|
539
|
+
logger.warning(
|
540
|
+
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
|
541
|
+
f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
542
|
+
)
|
543
|
+
|
503
544
|
def init_torch_distributed(self):
|
504
545
|
logger.info("Init torch distributed begin.")
|
505
546
|
|
@@ -871,8 +912,18 @@ class ModelRunner:
|
|
871
912
|
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
|
872
913
|
load_format: Optional[str] = None,
|
873
914
|
):
|
915
|
+
monkey_patch_torch_reductions()
|
916
|
+
if load_format == "flattened_bucket":
|
917
|
+
# Handle flattened bucket format
|
918
|
+
return self._update_weights_from_flattened_bucket(
|
919
|
+
flattened_tensor_bucket_dict=named_tensors
|
920
|
+
)
|
921
|
+
|
922
|
+
# We need to get device after patch otherwise the device would be wrong
|
923
|
+
infered_device = torch.cuda.current_device()
|
924
|
+
|
874
925
|
named_tensors = [
|
875
|
-
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
|
926
|
+
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
|
876
927
|
for name, tensor in named_tensors
|
877
928
|
]
|
878
929
|
if load_format == "direct":
|
@@ -886,6 +937,38 @@ class ModelRunner:
|
|
886
937
|
raise NotImplementedError(f"Unknown load_format={load_format}")
|
887
938
|
return True, "Success"
|
888
939
|
|
940
|
+
def _update_weights_from_flattened_bucket(
|
941
|
+
self,
|
942
|
+
flattened_tensor_bucket_dict,
|
943
|
+
):
|
944
|
+
"""Handle flattened bucket format for weight updates"""
|
945
|
+
flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
|
946
|
+
metadata = flattened_tensor_bucket_dict["metadata"]
|
947
|
+
|
948
|
+
# Convert metadata dict to our format
|
949
|
+
converted_metadata = []
|
950
|
+
for meta in metadata:
|
951
|
+
converted_meta = FlattenedTensorMetadata(
|
952
|
+
name=meta.name,
|
953
|
+
shape=meta.shape,
|
954
|
+
dtype=meta.dtype,
|
955
|
+
start_idx=meta.start_idx,
|
956
|
+
end_idx=meta.end_idx,
|
957
|
+
numel=meta.numel,
|
958
|
+
)
|
959
|
+
converted_metadata.append(converted_meta)
|
960
|
+
|
961
|
+
# Create bucket and reconstruct tensors
|
962
|
+
bucket = FlattenedTensorBucket(
|
963
|
+
flattened_tensor=flattened_tensor, metadata=converted_metadata
|
964
|
+
)
|
965
|
+
reconstructed_tensors = bucket.reconstruct_tensors()
|
966
|
+
|
967
|
+
# Load the reconstructed tensors using the standard method
|
968
|
+
self.model.load_weights(reconstructed_tensors)
|
969
|
+
|
970
|
+
return True, "Success"
|
971
|
+
|
889
972
|
def get_weights_by_name(
|
890
973
|
self, name: str, truncate_size: int = 100
|
891
974
|
) -> Optional[torch.Tensor]:
|
@@ -1181,30 +1264,33 @@ class ModelRunner:
|
|
1181
1264
|
# Draft worker shares req_to_token_pool with the target worker.
|
1182
1265
|
assert self.is_draft_worker
|
1183
1266
|
|
1184
|
-
if self.server_args.attention_backend == "ascend"
|
1185
|
-
self.
|
1186
|
-
self.
|
1187
|
-
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1204
|
-
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1267
|
+
if self.server_args.attention_backend == "ascend":
|
1268
|
+
if self.use_mla_backend:
|
1269
|
+
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
|
1270
|
+
self.max_total_num_tokens,
|
1271
|
+
page_size=self.page_size,
|
1272
|
+
dtype=self.kv_cache_dtype,
|
1273
|
+
kv_lora_rank=self.model_config.kv_lora_rank,
|
1274
|
+
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
1275
|
+
layer_num=self.num_effective_layers,
|
1276
|
+
device=self.device,
|
1277
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
1278
|
+
start_layer=self.start_layer,
|
1279
|
+
end_layer=self.end_layer,
|
1280
|
+
)
|
1281
|
+
else:
|
1282
|
+
self.token_to_kv_pool = AscendTokenToKVPool(
|
1283
|
+
self.max_total_num_tokens,
|
1284
|
+
page_size=self.page_size,
|
1285
|
+
dtype=self.kv_cache_dtype,
|
1286
|
+
head_num=self.model_config.get_num_kv_heads(
|
1287
|
+
get_attention_tp_size()
|
1288
|
+
),
|
1289
|
+
head_dim=self.model_config.head_dim,
|
1290
|
+
layer_num=self.model_config.num_hidden_layers,
|
1291
|
+
device=self.device,
|
1292
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
1293
|
+
)
|
1208
1294
|
elif self.use_mla_backend:
|
1209
1295
|
self.token_to_kv_pool = MLATokenToKVPool(
|
1210
1296
|
self.max_total_num_tokens,
|
@@ -1263,6 +1349,7 @@ class ModelRunner:
|
|
1263
1349
|
end_layer=self.end_layer,
|
1264
1350
|
)
|
1265
1351
|
|
1352
|
+
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
1266
1353
|
if self.token_to_kv_pool_allocator is None:
|
1267
1354
|
if self.page_size == 1:
|
1268
1355
|
if self.is_hybrid:
|
@@ -1272,6 +1359,7 @@ class ModelRunner:
|
|
1272
1359
|
dtype=self.kv_cache_dtype,
|
1273
1360
|
device=self.device,
|
1274
1361
|
kvcache=self.token_to_kv_pool,
|
1362
|
+
need_sort=need_sort,
|
1275
1363
|
)
|
1276
1364
|
else:
|
1277
1365
|
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
@@ -1279,23 +1367,26 @@ class ModelRunner:
|
|
1279
1367
|
dtype=self.kv_cache_dtype,
|
1280
1368
|
device=self.device,
|
1281
1369
|
kvcache=self.token_to_kv_pool,
|
1370
|
+
need_sort=need_sort,
|
1282
1371
|
)
|
1283
1372
|
else:
|
1284
|
-
if _is_npu:
|
1285
|
-
self.token_to_kv_pool_allocator =
|
1373
|
+
if not _is_npu:
|
1374
|
+
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
1286
1375
|
self.max_total_num_tokens,
|
1287
1376
|
page_size=self.page_size,
|
1288
1377
|
dtype=self.kv_cache_dtype,
|
1289
1378
|
device=self.device,
|
1290
1379
|
kvcache=self.token_to_kv_pool,
|
1380
|
+
need_sort=need_sort,
|
1291
1381
|
)
|
1292
1382
|
else:
|
1293
|
-
self.token_to_kv_pool_allocator =
|
1383
|
+
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
1294
1384
|
self.max_total_num_tokens,
|
1295
1385
|
page_size=self.page_size,
|
1296
1386
|
dtype=self.kv_cache_dtype,
|
1297
1387
|
device=self.device,
|
1298
1388
|
kvcache=self.token_to_kv_pool,
|
1389
|
+
need_sort=need_sort,
|
1299
1390
|
)
|
1300
1391
|
else:
|
1301
1392
|
assert self.is_draft_worker
|
@@ -1396,6 +1487,10 @@ class ModelRunner:
|
|
1396
1487
|
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
1397
1488
|
|
1398
1489
|
return AiterAttnBackend(self)
|
1490
|
+
elif self.server_args.attention_backend == "wave":
|
1491
|
+
from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
|
1492
|
+
|
1493
|
+
return WaveAttnBackend(self)
|
1399
1494
|
elif backend_str == "ascend":
|
1400
1495
|
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
1401
1496
|
|
@@ -1443,19 +1538,36 @@ class ModelRunner:
|
|
1443
1538
|
)
|
1444
1539
|
|
1445
1540
|
return CutlassMLABackend(self)
|
1446
|
-
elif
|
1541
|
+
elif backend_str == "trtllm_mla":
|
1447
1542
|
if not self.use_mla_backend:
|
1448
1543
|
raise ValueError("trtllm_mla backend can only be used with MLA models.")
|
1449
1544
|
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
1450
1545
|
|
1451
1546
|
return TRTLLMMLABackend(self)
|
1452
|
-
elif
|
1547
|
+
elif backend_str == "trtllm_mha":
|
1548
|
+
if self.use_mla_backend:
|
1549
|
+
raise ValueError(
|
1550
|
+
"trtllm_mha backend can only be used with non-MLA models."
|
1551
|
+
)
|
1552
|
+
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
1553
|
+
TRTLLMHAAttnBackend,
|
1554
|
+
)
|
1555
|
+
|
1556
|
+
return TRTLLMHAAttnBackend(self)
|
1557
|
+
|
1558
|
+
elif backend_str == "intel_amx":
|
1453
1559
|
from sglang.srt.layers.attention.intel_amx_backend import (
|
1454
1560
|
IntelAMXAttnBackend,
|
1455
1561
|
)
|
1456
1562
|
|
1457
1563
|
logger.info(f"Intel AMX attention backend is enabled.")
|
1458
1564
|
return IntelAMXAttnBackend(self)
|
1565
|
+
elif self.server_args.attention_backend == "dual_chunk_flash_attn":
|
1566
|
+
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
1567
|
+
DualChunkFlashAttentionBackend,
|
1568
|
+
)
|
1569
|
+
|
1570
|
+
return DualChunkFlashAttentionBackend(self)
|
1459
1571
|
else:
|
1460
1572
|
raise ValueError(f"Invalid attention backend: {backend_str}")
|
1461
1573
|
|
@@ -1768,11 +1880,10 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
|
|
1768
1880
|
default_weight_loader(params_dict[name], tensor)
|
1769
1881
|
|
1770
1882
|
|
1771
|
-
def _unwrap_tensor(tensor, tp_rank):
|
1883
|
+
def _unwrap_tensor(tensor, tp_rank, device):
|
1772
1884
|
if isinstance(tensor, LocalSerializedTensor):
|
1773
|
-
monkey_patch_torch_reductions()
|
1774
1885
|
tensor = tensor.get(tp_rank)
|
1775
|
-
return tensor.to(
|
1886
|
+
return tensor.to(device)
|
1776
1887
|
|
1777
1888
|
|
1778
1889
|
@dataclass
|
@@ -162,12 +162,24 @@ def _initialize_model(
|
|
162
162
|
model_class, _ = get_model_architecture(model_config)
|
163
163
|
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
|
164
164
|
if _is_npu:
|
165
|
-
packed_modules_mapping
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
165
|
+
packed_modules_mapping.update(
|
166
|
+
{
|
167
|
+
"visual": {"qkv_proj": ["qkv"]},
|
168
|
+
"vision_model": {
|
169
|
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
170
|
+
"proj": ["out_proj"],
|
171
|
+
},
|
172
|
+
"model": {
|
173
|
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
174
|
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
175
|
+
"fused_qkv_a_proj_with_mqa": [
|
176
|
+
"q_a_proj",
|
177
|
+
"kv_a_proj_with_mqa",
|
178
|
+
],
|
179
|
+
},
|
180
|
+
}
|
181
|
+
)
|
182
|
+
|
171
183
|
quant_config = _get_quantization_config(
|
172
184
|
model_config, load_config, packed_modules_mapping
|
173
185
|
)
|
@@ -843,6 +843,16 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|
843
843
|
return None
|
844
844
|
return remapped_name
|
845
845
|
|
846
|
+
quark_scale_names = {
|
847
|
+
".q_proj.output_scale": ".attn.q_scale",
|
848
|
+
".k_proj.output_scale": ".attn.k_scale",
|
849
|
+
".v_proj.output_scale": ".attn.v_scale",
|
850
|
+
"self_attn.prob_output_scale": ".attn.prob_scale",
|
851
|
+
}
|
852
|
+
for quark_scale_name, sglang_scale_name in quark_scale_names.items():
|
853
|
+
if name.endswith(quark_scale_name):
|
854
|
+
return name.replace(quark_scale_name, sglang_scale_name)
|
855
|
+
|
846
856
|
# If there were no matches, return the untouched param name
|
847
857
|
return name
|
848
858
|
|