sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +7 -0
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +16 -1
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mooncake/conn.py +16 -0
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +13 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +132 -79
- sglang/srt/function_call/ebnf_composer.py +10 -3
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/qwen3_coder_detector.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +14 -3
- sglang/srt/layers/moe/ep_moe/layer.py +323 -242
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +90 -24
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +27 -10
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/lora/lora_registry.py +93 -29
- sglang/srt/managers/cache_controller.py +9 -7
- sglang/srt/managers/data_parallel_controller.py +4 -0
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +14 -8
- sglang/srt/managers/scheduler.py +64 -1
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/tokenizer_manager.py +80 -15
- sglang/srt/managers/tp_worker.py +8 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +83 -27
- sglang/srt/models/deepseek_v2.py +75 -84
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/qwen2_moe.py +2 -2
- sglang/srt/models/qwen3_moe.py +17 -71
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +65 -6
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/srt/utils.py +96 -1
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +118 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -27,6 +27,7 @@ import threading
|
|
27
27
|
import time
|
28
28
|
import uuid
|
29
29
|
from collections import deque
|
30
|
+
from contextlib import nullcontext
|
30
31
|
from datetime import datetime
|
31
32
|
from http import HTTPStatus
|
32
33
|
from typing import (
|
@@ -69,6 +70,7 @@ from sglang.srt.managers.io_struct import (
|
|
69
70
|
BatchMultimodalOut,
|
70
71
|
BatchStrOut,
|
71
72
|
BatchTokenIDOut,
|
73
|
+
BlockReqType,
|
72
74
|
CloseSessionReqInput,
|
73
75
|
ConfigureLoggingReq,
|
74
76
|
EmbeddingReqInput,
|
@@ -112,7 +114,9 @@ from sglang.srt.managers.io_struct import (
|
|
112
114
|
UpdateWeightsFromTensorReqInput,
|
113
115
|
UpdateWeightsFromTensorReqOutput,
|
114
116
|
)
|
117
|
+
from sglang.srt.managers.mm_utils import TensorTransportMode
|
115
118
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
119
|
+
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
116
120
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
117
121
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
118
122
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
@@ -166,6 +170,16 @@ class ReqState:
|
|
166
170
|
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
167
171
|
|
168
172
|
|
173
|
+
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
174
|
+
is_cross_node = server_args.dist_init_addr
|
175
|
+
|
176
|
+
if is_cross_node:
|
177
|
+
# Fallback to default CPU transport for multi-node
|
178
|
+
return "default"
|
179
|
+
else:
|
180
|
+
return "cuda_ipc"
|
181
|
+
|
182
|
+
|
169
183
|
class TokenizerManager:
|
170
184
|
"""TokenizerManager is a process that tokenizes the text."""
|
171
185
|
|
@@ -216,12 +230,13 @@ class TokenizerManager:
|
|
216
230
|
revision=server_args.revision,
|
217
231
|
use_fast=not server_args.disable_fast_image_processor,
|
218
232
|
)
|
233
|
+
transport_mode = _determine_tensor_transport_mode(self.server_args)
|
219
234
|
|
220
235
|
# We want to parallelize the image pre-processing so we create an executor for it
|
221
236
|
# We create mm_processor for any skip_tokenizer_init to make sure we still encode
|
222
237
|
# images even with skip_tokenizer_init=False.
|
223
238
|
self.mm_processor = get_mm_processor(
|
224
|
-
self.model_config.hf_config, server_args, _processor
|
239
|
+
self.model_config.hf_config, server_args, _processor, transport_mode
|
225
240
|
)
|
226
241
|
|
227
242
|
if server_args.skip_tokenizer_init:
|
@@ -270,6 +285,11 @@ class TokenizerManager:
|
|
270
285
|
None
|
271
286
|
)
|
272
287
|
|
288
|
+
# Lock to serialize LoRA update operations.
|
289
|
+
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
|
290
|
+
# LoRA updates and inference to overlap.
|
291
|
+
self.lora_update_lock = asyncio.Lock()
|
292
|
+
|
273
293
|
# For pd disaggregtion
|
274
294
|
self.disaggregation_mode = DisaggregationMode(
|
275
295
|
self.server_args.disaggregation_mode
|
@@ -525,7 +545,8 @@ class TokenizerManager:
|
|
525
545
|
mm_inputs = None
|
526
546
|
|
527
547
|
if self.server_args.enable_lora and obj.lora_path:
|
528
|
-
#
|
548
|
+
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
|
549
|
+
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
|
529
550
|
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
|
530
551
|
|
531
552
|
self._validate_one_request(obj, input_ids)
|
@@ -735,6 +756,10 @@ class TokenizerManager:
|
|
735
756
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
736
757
|
logger.info(msg)
|
737
758
|
|
759
|
+
# Mark ongoing LoRA request as finished.
|
760
|
+
if self.server_args.enable_lora and obj.lora_path:
|
761
|
+
await self.lora_registry.release(obj.lora_path)
|
762
|
+
|
738
763
|
# Check if this was an abort/error created by scheduler
|
739
764
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
740
765
|
finish_reason = out["meta_info"]["finish_reason"]
|
@@ -744,6 +769,19 @@ class TokenizerManager:
|
|
744
769
|
):
|
745
770
|
raise ValueError(finish_reason["message"])
|
746
771
|
|
772
|
+
if (
|
773
|
+
finish_reason.get("type") == "abort"
|
774
|
+
and finish_reason.get("status_code")
|
775
|
+
== HTTPStatus.SERVICE_UNAVAILABLE
|
776
|
+
):
|
777
|
+
# This is an abort request initiated by scheduler.
|
778
|
+
# Delete the key to prevent resending abort request to the scheduler and
|
779
|
+
# to ensure aborted request state is cleaned up.
|
780
|
+
del self.rid_to_state[state.obj.rid]
|
781
|
+
raise fastapi.HTTPException(
|
782
|
+
status_code=finish_reason["status_code"],
|
783
|
+
detail=finish_reason["message"],
|
784
|
+
)
|
747
785
|
yield out
|
748
786
|
break
|
749
787
|
|
@@ -784,12 +822,21 @@ class TokenizerManager:
|
|
784
822
|
rids.append(tmp_obj.rid)
|
785
823
|
else:
|
786
824
|
# Sequential tokenization and processing
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
825
|
+
with (
|
826
|
+
input_blocker_guard_region(send_to_scheduler=self.send_to_scheduler)
|
827
|
+
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
|
828
|
+
else nullcontext()
|
829
|
+
):
|
830
|
+
for i in range(batch_size):
|
831
|
+
tmp_obj = obj[i]
|
832
|
+
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
833
|
+
state = self._send_one_request(
|
834
|
+
tmp_obj, tokenized_obj, created_time
|
835
|
+
)
|
836
|
+
generators.append(
|
837
|
+
self._wait_one_response(tmp_obj, state, request)
|
838
|
+
)
|
839
|
+
rids.append(tmp_obj.rid)
|
793
840
|
else:
|
794
841
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
795
842
|
if batch_size > 128:
|
@@ -1041,16 +1088,18 @@ class TokenizerManager:
|
|
1041
1088
|
obj.lora_path,
|
1042
1089
|
)
|
1043
1090
|
|
1044
|
-
async with self.
|
1091
|
+
async with self.lora_update_lock:
|
1045
1092
|
# Generate new uniquely identifiable LoRARef object.
|
1046
1093
|
new_adapter = LoRARef(
|
1047
1094
|
lora_name=obj.lora_name,
|
1048
1095
|
lora_path=obj.lora_path,
|
1049
1096
|
)
|
1050
1097
|
|
1051
|
-
#
|
1098
|
+
# Trigger the actual loading operation at the backend processes.
|
1052
1099
|
obj.lora_id = new_adapter.lora_id
|
1053
1100
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1101
|
+
|
1102
|
+
# Register the LoRA adapter only after loading is successful.
|
1054
1103
|
if result.success:
|
1055
1104
|
await self.lora_registry.register(new_adapter)
|
1056
1105
|
|
@@ -1081,8 +1130,15 @@ class TokenizerManager:
|
|
1081
1130
|
obj.lora_name,
|
1082
1131
|
)
|
1083
1132
|
|
1084
|
-
async with self.
|
1085
|
-
|
1133
|
+
async with self.lora_update_lock:
|
1134
|
+
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
1135
|
+
# from being started.
|
1136
|
+
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
1137
|
+
obj.lora_id = lora_id
|
1138
|
+
|
1139
|
+
# Initiate the actual unloading operation at the backend processes only after all
|
1140
|
+
# ongoing requests using this LoRA adapter are finished.
|
1141
|
+
await self.lora_registry.wait_for_unload(lora_id)
|
1086
1142
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1087
1143
|
|
1088
1144
|
return result
|
@@ -1674,8 +1730,15 @@ class TokenizerManager:
|
|
1674
1730
|
def _handle_abort_req(self, recv_obj):
|
1675
1731
|
state = self.rid_to_state[recv_obj.rid]
|
1676
1732
|
state.finished = True
|
1677
|
-
|
1678
|
-
{
|
1733
|
+
if recv_obj.finished_reason:
|
1734
|
+
out = {
|
1735
|
+
"meta_info": {
|
1736
|
+
"id": recv_obj.rid,
|
1737
|
+
"finish_reason": recv_obj.finished_reason,
|
1738
|
+
},
|
1739
|
+
}
|
1740
|
+
else:
|
1741
|
+
out = {
|
1679
1742
|
"text": "",
|
1680
1743
|
"meta_info": {
|
1681
1744
|
"id": recv_obj.rid,
|
@@ -1687,7 +1750,7 @@ class TokenizerManager:
|
|
1687
1750
|
"completion_tokens": 0,
|
1688
1751
|
},
|
1689
1752
|
}
|
1690
|
-
)
|
1753
|
+
state.out_list.append(out)
|
1691
1754
|
state.event.set()
|
1692
1755
|
|
1693
1756
|
def _handle_open_session_req_output(self, recv_obj):
|
@@ -1879,8 +1942,10 @@ class _Communicator(Generic[T]):
|
|
1879
1942
|
#
|
1880
1943
|
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
|
1881
1944
|
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
|
1945
|
+
# | http | yes | validation | background task | fast api | del in _handle_abort_req |
|
1882
1946
|
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
|
1883
1947
|
# | http | yes | running | background task | fast api | del in _handle_batch_output |
|
1948
|
+
# | http | no | validation | http exception | http exception | del in _handle_abort_req |
|
1884
1949
|
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
|
1885
1950
|
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
|
1886
1951
|
#
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -41,6 +41,7 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|
41
41
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
42
42
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
43
43
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
44
|
+
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
44
45
|
from sglang.srt.server_args import ServerArgs
|
45
46
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
46
47
|
|
@@ -129,6 +130,10 @@ class TpModelWorker:
|
|
129
130
|
self.model_runner.req_to_token_pool.size,
|
130
131
|
)
|
131
132
|
assert self.max_running_requests > 0, "max_running_request is zero"
|
133
|
+
self.max_queued_requests = server_args.max_queued_requests
|
134
|
+
assert (
|
135
|
+
self.max_running_requests > 0
|
136
|
+
), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
|
132
137
|
self.max_req_len = min(
|
133
138
|
self.model_config.context_len - 1,
|
134
139
|
self.max_total_num_tokens - 1,
|
@@ -164,6 +169,7 @@ class TpModelWorker:
|
|
164
169
|
self.max_total_num_tokens,
|
165
170
|
self.max_prefill_tokens,
|
166
171
|
self.max_running_requests,
|
172
|
+
self.max_queued_requests,
|
167
173
|
self.max_req_len,
|
168
174
|
self.max_req_input_len,
|
169
175
|
self.random_seed,
|
@@ -278,6 +284,8 @@ class TpModelWorker:
|
|
278
284
|
return success, message
|
279
285
|
|
280
286
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
287
|
+
|
288
|
+
monkey_patch_torch_reductions()
|
281
289
|
success, message = self.model_runner.update_weights_from_tensor(
|
282
290
|
named_tensors=MultiprocessingSerializer.deserialize(
|
283
291
|
recv_req.serialized_named_tensors[self.tp_rank]
|
@@ -365,10 +365,12 @@ class HiRadixCache(RadixCache):
|
|
365
365
|
for _ in range(queue_size.item()):
|
366
366
|
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
367
367
|
if req_id in self.ongoing_prefetch:
|
368
|
-
last_host_node, _,
|
368
|
+
last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
|
369
369
|
last_host_node.release_host()
|
370
|
-
self.cache_controller.mem_pool_host.free(host_indices)
|
371
370
|
del self.ongoing_prefetch[req_id]
|
371
|
+
else:
|
372
|
+
# the revoked operation already got terminated
|
373
|
+
pass
|
372
374
|
|
373
375
|
def check_backup_progress(self):
|
374
376
|
queue_size = torch.tensor(
|
@@ -403,6 +405,7 @@ class HiRadixCache(RadixCache):
|
|
403
405
|
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
|
404
406
|
req_id
|
405
407
|
]
|
408
|
+
|
406
409
|
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
407
410
|
operation
|
408
411
|
)
|
@@ -285,11 +285,21 @@ class ModelRunner:
|
|
285
285
|
if architectures and not any("Llama4" in arch for arch in architectures):
|
286
286
|
self.is_hybrid = self.model_config.is_hybrid = True
|
287
287
|
|
288
|
-
|
289
|
-
|
290
|
-
|
288
|
+
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
|
289
|
+
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
|
290
|
+
# determine the number of layers.
|
291
|
+
model_has_mtp_layers = self.model_config.num_nextn_predict_layers is not None
|
292
|
+
model_num_layers = (
|
293
|
+
self.model_config.num_nextn_predict_layers
|
294
|
+
if self.is_draft_worker and model_has_mtp_layers
|
295
|
+
else self.model_config.num_hidden_layers
|
291
296
|
)
|
297
|
+
self.start_layer = getattr(self.model, "start_layer", 0)
|
298
|
+
self.end_layer = getattr(self.model, "end_layer", model_num_layers)
|
292
299
|
self.num_effective_layers = self.end_layer - self.start_layer
|
300
|
+
assert (not model_has_mtp_layers) or (
|
301
|
+
self.num_effective_layers == model_num_layers
|
302
|
+
), "PP is not compatible with MTP models."
|
293
303
|
|
294
304
|
# Apply torchao quantization
|
295
305
|
torchao_applied = getattr(self.model, "torchao_applied", False)
|
@@ -1178,11 +1188,7 @@ class ModelRunner:
|
|
1178
1188
|
dtype=self.kv_cache_dtype,
|
1179
1189
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
1180
1190
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
1181
|
-
layer_num=
|
1182
|
-
self.model_config.num_hidden_layers
|
1183
|
-
if not self.is_draft_worker
|
1184
|
-
else self.model_config.hf_config.num_nextn_predict_layers
|
1185
|
-
), # PP is not compatible with mla backend
|
1191
|
+
layer_num=self.num_effective_layers,
|
1186
1192
|
device=self.device,
|
1187
1193
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1188
1194
|
start_layer=self.start_layer,
|
@@ -1195,11 +1201,7 @@ class ModelRunner:
|
|
1195
1201
|
dtype=self.kv_cache_dtype,
|
1196
1202
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
1197
1203
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
1198
|
-
layer_num=
|
1199
|
-
self.model_config.num_hidden_layers
|
1200
|
-
if not self.is_draft_worker
|
1201
|
-
else self.model_config.hf_config.num_nextn_predict_layers
|
1202
|
-
), # PP is not compatible with mla backend
|
1204
|
+
layer_num=self.num_effective_layers,
|
1203
1205
|
device=self.device,
|
1204
1206
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1205
1207
|
start_layer=self.start_layer,
|
@@ -1308,9 +1310,58 @@ class ModelRunner:
|
|
1308
1310
|
else:
|
1309
1311
|
self.attn_backend = self._get_attention_backend()
|
1310
1312
|
|
1311
|
-
# TODO unify with 6338
|
1312
1313
|
def _get_attention_backend(self):
|
1313
|
-
|
1314
|
+
"""Init attention kernel backend."""
|
1315
|
+
self.decode_attention_backend_str = (
|
1316
|
+
self.server_args.decode_attention_backend
|
1317
|
+
if self.server_args.decode_attention_backend
|
1318
|
+
else self.server_args.attention_backend
|
1319
|
+
)
|
1320
|
+
self.prefill_attention_backend_str = (
|
1321
|
+
self.server_args.prefill_attention_backend
|
1322
|
+
if self.server_args.prefill_attention_backend
|
1323
|
+
else self.server_args.attention_backend
|
1324
|
+
)
|
1325
|
+
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
1326
|
+
assert (
|
1327
|
+
self.server_args.speculative_algorithm is None
|
1328
|
+
), "Currently HybridAttentionBackend does not support speculative decoding."
|
1329
|
+
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
1330
|
+
HybridAttnBackend,
|
1331
|
+
)
|
1332
|
+
|
1333
|
+
attn_backend = HybridAttnBackend(
|
1334
|
+
decode_backend=self._get_attention_backend_from_str(
|
1335
|
+
self.decode_attention_backend_str
|
1336
|
+
),
|
1337
|
+
prefill_backend=self._get_attention_backend_from_str(
|
1338
|
+
self.prefill_attention_backend_str
|
1339
|
+
),
|
1340
|
+
)
|
1341
|
+
logger.info(
|
1342
|
+
f"Using hybrid attention backend for decode and prefill: "
|
1343
|
+
f"decode_backend={self.decode_attention_backend_str}, "
|
1344
|
+
f"prefill_backend={self.prefill_attention_backend_str}."
|
1345
|
+
)
|
1346
|
+
logger.warning(
|
1347
|
+
f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
|
1348
|
+
f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
|
1349
|
+
)
|
1350
|
+
else:
|
1351
|
+
attn_backend = self._get_attention_backend_from_str(
|
1352
|
+
self.server_args.attention_backend
|
1353
|
+
)
|
1354
|
+
|
1355
|
+
global_server_args_dict.update(
|
1356
|
+
{
|
1357
|
+
"decode_attention_backend": self.decode_attention_backend_str,
|
1358
|
+
"prefill_attention_backend": self.prefill_attention_backend_str,
|
1359
|
+
}
|
1360
|
+
)
|
1361
|
+
return attn_backend
|
1362
|
+
|
1363
|
+
def _get_attention_backend_from_str(self, backend_str: str):
|
1364
|
+
if backend_str == "flashinfer":
|
1314
1365
|
if not self.use_mla_backend:
|
1315
1366
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
1316
1367
|
FlashInferAttnBackend,
|
@@ -1318,7 +1369,11 @@ class ModelRunner:
|
|
1318
1369
|
|
1319
1370
|
# Init streams
|
1320
1371
|
if self.server_args.speculative_algorithm == "EAGLE":
|
1321
|
-
|
1372
|
+
if (
|
1373
|
+
not hasattr(self, "plan_stream_for_flashinfer")
|
1374
|
+
or not self.plan_stream_for_flashinfer
|
1375
|
+
):
|
1376
|
+
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
1322
1377
|
return FlashInferAttnBackend(self)
|
1323
1378
|
else:
|
1324
1379
|
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
@@ -1326,15 +1381,15 @@ class ModelRunner:
|
|
1326
1381
|
)
|
1327
1382
|
|
1328
1383
|
return FlashInferMLAAttnBackend(self)
|
1329
|
-
elif
|
1384
|
+
elif backend_str == "aiter":
|
1330
1385
|
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
1331
1386
|
|
1332
1387
|
return AiterAttnBackend(self)
|
1333
|
-
elif
|
1388
|
+
elif backend_str == "ascend":
|
1334
1389
|
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
1335
1390
|
|
1336
1391
|
return AscendAttnBackend(self)
|
1337
|
-
elif
|
1392
|
+
elif backend_str == "triton":
|
1338
1393
|
assert not self.model_config.is_encoder_decoder, (
|
1339
1394
|
"Cross attention is not supported in the triton attention backend. "
|
1340
1395
|
"Please use `--attention-backend flashinfer`."
|
@@ -1349,17 +1404,17 @@ class ModelRunner:
|
|
1349
1404
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
1350
1405
|
|
1351
1406
|
return TritonAttnBackend(self)
|
1352
|
-
elif
|
1407
|
+
elif backend_str == "torch_native":
|
1353
1408
|
from sglang.srt.layers.attention.torch_native_backend import (
|
1354
1409
|
TorchNativeAttnBackend,
|
1355
1410
|
)
|
1356
1411
|
|
1357
1412
|
return TorchNativeAttnBackend(self)
|
1358
|
-
elif
|
1413
|
+
elif backend_str == "flashmla":
|
1359
1414
|
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
1360
1415
|
|
1361
1416
|
return FlashMLABackend(self)
|
1362
|
-
elif
|
1417
|
+
elif backend_str == "fa3":
|
1363
1418
|
assert (
|
1364
1419
|
torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
|
1365
1420
|
) or torch.cuda.get_device_capability()[0] == 9, (
|
@@ -1371,7 +1426,7 @@ class ModelRunner:
|
|
1371
1426
|
)
|
1372
1427
|
|
1373
1428
|
return FlashAttentionBackend(self)
|
1374
|
-
elif
|
1429
|
+
elif backend_str == "cutlass_mla":
|
1375
1430
|
from sglang.srt.layers.attention.cutlass_mla_backend import (
|
1376
1431
|
CutlassMLABackend,
|
1377
1432
|
)
|
@@ -1385,9 +1440,7 @@ class ModelRunner:
|
|
1385
1440
|
logger.info(f"Intel AMX attention backend is enabled.")
|
1386
1441
|
return IntelAMXAttnBackend(self)
|
1387
1442
|
else:
|
1388
|
-
raise ValueError(
|
1389
|
-
f"Invalid attention backend: {self.server_args.attention_backend}"
|
1390
|
-
)
|
1443
|
+
raise ValueError(f"Invalid attention backend: {backend_str}")
|
1391
1444
|
|
1392
1445
|
def init_double_sparsity_channel_config(self, selected_channel):
|
1393
1446
|
selected_channel = "." + selected_channel + "_proj"
|
@@ -1475,7 +1528,10 @@ class ModelRunner:
|
|
1475
1528
|
if self.support_pp:
|
1476
1529
|
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
1477
1530
|
return self.model.forward(
|
1478
|
-
forward_batch.input_ids,
|
1531
|
+
forward_batch.input_ids,
|
1532
|
+
forward_batch.positions,
|
1533
|
+
forward_batch,
|
1534
|
+
**kwargs,
|
1479
1535
|
)
|
1480
1536
|
|
1481
1537
|
def forward_extend(
|