sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +1 -11
- sglang/bench_serving.py +149 -1
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +17 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +30 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +14 -2
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +5 -0
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/lora/lora_manager.py +10 -13
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/schedule_batch.py +19 -1
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +28 -13
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +9 -12
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/model_executor/model_runner.py +44 -33
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +55 -20
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +1 -1
- sglang/srt/models/llama4.py +53 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +24 -40
- sglang/srt/openai_api/protocol.py +28 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +30 -6
- sglang/srt/utils.py +35 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,7 @@ from torch.distributed import Backend, ProcessGroup
|
|
42
42
|
from sglang.srt.utils import (
|
43
43
|
direct_register_custom_op,
|
44
44
|
is_cuda_alike,
|
45
|
+
is_npu,
|
45
46
|
supports_custom_op,
|
46
47
|
)
|
47
48
|
|
@@ -206,6 +207,7 @@ class GroupCoordinator:
|
|
206
207
|
use_custom_allreduce: bool,
|
207
208
|
use_hpu_communicator: bool,
|
208
209
|
use_xpu_communicator: bool,
|
210
|
+
use_npu_communicator: bool,
|
209
211
|
use_message_queue_broadcaster: bool = False,
|
210
212
|
group_name: Optional[str] = None,
|
211
213
|
):
|
@@ -244,6 +246,7 @@ class GroupCoordinator:
|
|
244
246
|
self.use_custom_allreduce = use_custom_allreduce
|
245
247
|
self.use_hpu_communicator = use_hpu_communicator
|
246
248
|
self.use_xpu_communicator = use_xpu_communicator
|
249
|
+
self.use_npu_communicator = use_npu_communicator
|
247
250
|
self.use_message_queue_broadcaster = use_message_queue_broadcaster
|
248
251
|
|
249
252
|
# lazy import to avoid documentation build error
|
@@ -291,6 +294,14 @@ class GroupCoordinator:
|
|
291
294
|
if use_xpu_communicator and self.world_size > 1:
|
292
295
|
self.xpu_communicator = XpuCommunicator(group=self.device_group)
|
293
296
|
|
297
|
+
from sglang.srt.distributed.device_communicators.npu_communicator import (
|
298
|
+
NpuCommunicator,
|
299
|
+
)
|
300
|
+
|
301
|
+
self.npu_communicator: Optional[NpuCommunicator] = None
|
302
|
+
if use_npu_communicator and self.world_size > 1:
|
303
|
+
self.npu_communicator = NpuCommunicator(group=self.device_group)
|
304
|
+
|
294
305
|
from sglang.srt.distributed.device_communicators.shm_broadcast import (
|
295
306
|
MessageQueue,
|
296
307
|
)
|
@@ -418,6 +429,9 @@ class GroupCoordinator:
|
|
418
429
|
if self.xpu_communicator is not None and not self.xpu_communicator.disabled:
|
419
430
|
return self.xpu_communicator.all_reduce(input_)
|
420
431
|
|
432
|
+
if self.npu_communicator is not None and not self.npu_communicator.disabled:
|
433
|
+
return self.npu_communicator.all_reduce(input_)
|
434
|
+
|
421
435
|
if (
|
422
436
|
self.ca_comm is not None
|
423
437
|
and not self.ca_comm.disabled
|
@@ -497,6 +511,11 @@ class GroupCoordinator:
|
|
497
511
|
if hpu_comm is not None and not hpu_comm.disabled:
|
498
512
|
return hpu_comm.all_gather(input_, dim)
|
499
513
|
|
514
|
+
# For NPUs, use NPU communicator.
|
515
|
+
npu_comm = self.npu_communicator
|
516
|
+
if npu_comm is not None and not npu_comm.disabled:
|
517
|
+
return npu_comm.all_gather(input_, dim)
|
518
|
+
|
500
519
|
if dim < 0:
|
501
520
|
# Convert negative dim to positive.
|
502
521
|
dim += input_.dim()
|
@@ -941,6 +960,7 @@ def init_world_group(
|
|
941
960
|
use_custom_allreduce=False,
|
942
961
|
use_hpu_communicator=False,
|
943
962
|
use_xpu_communicator=False,
|
963
|
+
use_npu_communicator=False,
|
944
964
|
group_name="world",
|
945
965
|
)
|
946
966
|
|
@@ -959,10 +979,11 @@ def init_model_parallel_group(
|
|
959
979
|
group_ranks=group_ranks,
|
960
980
|
local_rank=local_rank,
|
961
981
|
torch_distributed_backend=backend,
|
962
|
-
use_pynccl=
|
982
|
+
use_pynccl=not is_npu(),
|
963
983
|
use_custom_allreduce=use_custom_allreduce,
|
964
984
|
use_hpu_communicator=True,
|
965
985
|
use_xpu_communicator=True,
|
986
|
+
use_npu_communicator=True,
|
966
987
|
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
967
988
|
group_name=group_name,
|
968
989
|
)
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -163,6 +163,9 @@ class Engine(EngineBase):
|
|
163
163
|
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
164
164
|
return_hidden_states: bool = False,
|
165
165
|
stream: bool = False,
|
166
|
+
bootstrap_host: Optional[Union[List[str], str]] = None,
|
167
|
+
bootstrap_port: Optional[Union[List[int], int]] = None,
|
168
|
+
bootstrap_room: Optional[Union[List[int], int]] = None,
|
166
169
|
) -> Union[Dict, Iterator[Dict]]:
|
167
170
|
"""
|
168
171
|
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
@@ -181,6 +184,9 @@ class Engine(EngineBase):
|
|
181
184
|
custom_logit_processor=custom_logit_processor,
|
182
185
|
return_hidden_states=return_hidden_states,
|
183
186
|
stream=stream,
|
187
|
+
bootstrap_host=bootstrap_host,
|
188
|
+
bootstrap_port=bootstrap_port,
|
189
|
+
bootstrap_room=bootstrap_room,
|
184
190
|
)
|
185
191
|
loop = asyncio.get_event_loop()
|
186
192
|
generator = self.tokenizer_manager.generate_request(obj, None)
|
@@ -227,6 +233,9 @@ class Engine(EngineBase):
|
|
227
233
|
lora_path: Optional[List[Optional[str]]] = None,
|
228
234
|
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
229
235
|
stream: bool = False,
|
236
|
+
bootstrap_host: Optional[Union[List[str], str]] = None,
|
237
|
+
bootstrap_port: Optional[Union[List[int], int]] = None,
|
238
|
+
bootstrap_room: Optional[Union[List[int], int]] = None,
|
230
239
|
) -> Union[Dict, AsyncIterator[Dict]]:
|
231
240
|
"""
|
232
241
|
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
@@ -244,6 +253,9 @@ class Engine(EngineBase):
|
|
244
253
|
lora_path=lora_path,
|
245
254
|
stream=stream,
|
246
255
|
custom_logit_processor=custom_logit_processor,
|
256
|
+
bootstrap_host=bootstrap_host,
|
257
|
+
bootstrap_port=bootstrap_port,
|
258
|
+
bootstrap_room=bootstrap_room,
|
247
259
|
)
|
248
260
|
generator = self.tokenizer_manager.generate_request(obj, None)
|
249
261
|
|
@@ -348,8 +360,8 @@ class Engine(EngineBase):
|
|
348
360
|
load_format: Optional[str] = None,
|
349
361
|
flush_cache: bool = True,
|
350
362
|
):
|
351
|
-
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be
|
352
|
-
to avoid duplicated
|
363
|
+
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false
|
364
|
+
to avoid duplicated cache cleaning operation."""
|
353
365
|
obj = UpdateWeightsFromTensorReqInput(
|
354
366
|
serialized_named_tensors=[
|
355
367
|
MultiprocessingSerializer.serialize(named_tensors)
|
@@ -42,7 +42,10 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
|
|
42
42
|
from fastapi.middleware.cors import CORSMiddleware
|
43
43
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
44
44
|
|
45
|
-
from sglang.srt.disaggregation.utils import
|
45
|
+
from sglang.srt.disaggregation.utils import (
|
46
|
+
FakeBootstrapHost,
|
47
|
+
register_disaggregation_server,
|
48
|
+
)
|
46
49
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
47
50
|
from sglang.srt.function_call_parser import FunctionCallParser
|
48
51
|
from sglang.srt.managers.io_struct import (
|
@@ -59,6 +62,7 @@ from sglang.srt.managers.io_struct import (
|
|
59
62
|
ResumeMemoryOccupationReqInput,
|
60
63
|
SeparateReasoningReqInput,
|
61
64
|
SetInternalStateReq,
|
65
|
+
SlowDownReqInput,
|
62
66
|
UpdateWeightFromDiskReqInput,
|
63
67
|
UpdateWeightsFromDistributedReqInput,
|
64
68
|
UpdateWeightsFromTensorReqInput,
|
@@ -491,6 +495,19 @@ async def resume_memory_occupation(
|
|
491
495
|
return _create_error_response(e)
|
492
496
|
|
493
497
|
|
498
|
+
@app.api_route("/slow_down", methods=["GET", "POST"])
|
499
|
+
async def slow_down(obj: SlowDownReqInput, request: Request):
|
500
|
+
"""Slow down the system deliberately. Only for testing. Example scenario:
|
501
|
+
when we want to test performance of D in large-scale PD disaggregation and have no enough nodes for P,
|
502
|
+
we can use this to slow down D to let it have enough running sequences, and then disable slowdown
|
503
|
+
to let it run in full batch size.
|
504
|
+
"""
|
505
|
+
try:
|
506
|
+
await _global_state.tokenizer_manager.slow_down(obj, request)
|
507
|
+
except Exception as e:
|
508
|
+
return _create_error_response(e)
|
509
|
+
|
510
|
+
|
494
511
|
@app.api_route("/open_session", methods=["GET", "POST"])
|
495
512
|
async def open_session(obj: OpenSessionReqInput, request: Request):
|
496
513
|
"""Open a session, and return its unique session id."""
|
@@ -675,6 +692,8 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
|
|
675
692
|
**(vertex_req.parameters or {}),
|
676
693
|
)
|
677
694
|
ret = await generate_request(req, raw_request)
|
695
|
+
if isinstance(ret, Response):
|
696
|
+
return ret
|
678
697
|
return ORJSONResponse({"predictions": ret})
|
679
698
|
|
680
699
|
|
@@ -869,5 +888,13 @@ def _wait_and_warmup(
|
|
869
888
|
if server_args.debug_tensor_dump_input_file:
|
870
889
|
kill_process_tree(os.getpid())
|
871
890
|
|
891
|
+
if server_args.pdlb_url is not None:
|
892
|
+
register_disaggregation_server(
|
893
|
+
server_args.disaggregation_mode,
|
894
|
+
server_args.port,
|
895
|
+
server_args.disaggregation_bootstrap_port,
|
896
|
+
server_args.pdlb_url,
|
897
|
+
)
|
898
|
+
|
872
899
|
if launch_callback is not None:
|
873
900
|
launch_callback()
|
@@ -37,6 +37,7 @@ class VerlEngine:
|
|
37
37
|
monkey_patch_torch_reductions()
|
38
38
|
self._device_mesh_cpu = device_mesh_cpu
|
39
39
|
self._tp_rank = device_mesh_cpu.get_local_rank()
|
40
|
+
self._rank = device_mesh_cpu.get_rank()
|
40
41
|
self._tp_size = device_mesh_cpu.size()
|
41
42
|
tp_size_per_node = self._tp_size // nnodes
|
42
43
|
node_rank = self._tp_rank // tp_size_per_node
|
@@ -114,7 +115,7 @@ class VerlEngine:
|
|
114
115
|
# Most naive implementation, can extract tensor and send via gloo if too slow
|
115
116
|
[output] = broadcast_pyobj(
|
116
117
|
data=[output],
|
117
|
-
rank=self.
|
118
|
+
rank=self._rank,
|
118
119
|
dist_group=self._device_mesh_cpu.get_group(),
|
119
120
|
src=self._device_mesh_cpu.mesh[0].item(),
|
120
121
|
force_cpu_device=False,
|
@@ -157,7 +158,7 @@ class VerlEngine:
|
|
157
158
|
)
|
158
159
|
|
159
160
|
if self._tp_rank == 0:
|
160
|
-
self._engine.
|
161
|
+
self._engine.flush_cache()
|
161
162
|
|
162
163
|
def release_memory_occupation(self):
|
163
164
|
if self._tp_rank == 0:
|
@@ -19,6 +19,7 @@ import warnings
|
|
19
19
|
from pathlib import Path
|
20
20
|
from typing import Dict, Optional, Type, Union
|
21
21
|
|
22
|
+
import transformers
|
22
23
|
from huggingface_hub import snapshot_download
|
23
24
|
from transformers import (
|
24
25
|
AutoConfig,
|
@@ -26,6 +27,7 @@ from transformers import (
|
|
26
27
|
AutoTokenizer,
|
27
28
|
PretrainedConfig,
|
28
29
|
PreTrainedTokenizer,
|
30
|
+
PreTrainedTokenizerBase,
|
29
31
|
PreTrainedTokenizerFast,
|
30
32
|
)
|
31
33
|
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
@@ -38,6 +40,7 @@ from sglang.srt.configs import (
|
|
38
40
|
KimiVLConfig,
|
39
41
|
MultiModalityConfig,
|
40
42
|
)
|
43
|
+
from sglang.srt.configs.internvl import InternVLChatConfig
|
41
44
|
from sglang.srt.connector import create_remote_connector
|
42
45
|
from sglang.srt.utils import is_remote_url
|
43
46
|
|
@@ -48,6 +51,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|
48
51
|
DeepseekVL2Config.model_type: DeepseekVL2Config,
|
49
52
|
MultiModalityConfig.model_type: MultiModalityConfig,
|
50
53
|
KimiVLConfig.model_type: KimiVLConfig,
|
54
|
+
InternVLChatConfig.model_type: InternVLChatConfig,
|
51
55
|
}
|
52
56
|
|
53
57
|
for name, cls in _CONFIG_REGISTRY.items():
|
@@ -90,6 +94,12 @@ def get_config(
|
|
90
94
|
config = config_class.from_pretrained(model, revision=revision)
|
91
95
|
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
|
92
96
|
setattr(config, "_name_or_path", model)
|
97
|
+
|
98
|
+
if isinstance(model, str) and config.model_type == "internvl_chat":
|
99
|
+
for key, val in config.llm_config.__dict__.items():
|
100
|
+
if not hasattr(config, key):
|
101
|
+
setattr(config, key, val)
|
102
|
+
|
93
103
|
if model_override_args:
|
94
104
|
config.update(model_override_args)
|
95
105
|
|
@@ -211,6 +221,13 @@ def get_tokenizer(
|
|
211
221
|
return tokenizer
|
212
222
|
|
213
223
|
|
224
|
+
# Some models doesn't have an available processor, e.g.: InternVL
|
225
|
+
def get_tokenizer_from_processor(processor):
|
226
|
+
if isinstance(processor, PreTrainedTokenizerBase):
|
227
|
+
return processor
|
228
|
+
return processor.tokenizer
|
229
|
+
|
230
|
+
|
214
231
|
def get_processor(
|
215
232
|
tokenizer_name: str,
|
216
233
|
*args,
|
@@ -246,7 +263,9 @@ def get_processor(
|
|
246
263
|
**kwargs,
|
247
264
|
)
|
248
265
|
|
249
|
-
|
266
|
+
tokenizer = get_tokenizer_from_processor(processor)
|
267
|
+
|
268
|
+
attach_additional_stop_token_ids(tokenizer)
|
250
269
|
return processor
|
251
270
|
|
252
271
|
|
@@ -338,7 +338,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
338
338
|
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
339
339
|
metadata = FlashAttentionMetadata()
|
340
340
|
seqlens_in_batch = forward_batch.seq_lens
|
341
|
-
batch_size =
|
341
|
+
batch_size = forward_batch.batch_size
|
342
342
|
device = seqlens_in_batch.device
|
343
343
|
|
344
344
|
if forward_batch.forward_mode.is_decode_or_idle():
|
@@ -913,8 +913,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
913
913
|
# Use precomputed metadata across all layers
|
914
914
|
metadata = self.forward_metadata
|
915
915
|
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
|
916
|
-
|
917
|
-
self.attention_chunk_size is not None
|
916
|
+
use_local_attn = (
|
917
|
+
self.attention_chunk_size is not None
|
918
|
+
and local_attn_metadata is not None
|
919
|
+
and (hasattr(layer, "use_irope") and layer.use_irope)
|
918
920
|
)
|
919
921
|
# We do cascade attention for Draft Decode with topk > 1
|
920
922
|
use_cascade_attn = self.topk > 1
|
@@ -970,7 +972,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
970
972
|
k_descale=k_descale,
|
971
973
|
v_descale=v_descale,
|
972
974
|
)
|
973
|
-
elif
|
975
|
+
elif use_local_attn:
|
974
976
|
# Use chunked (local) attention batching for self-attention
|
975
977
|
o = flash_attn_with_kvcache(
|
976
978
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
@@ -979,7 +981,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
979
981
|
page_table=local_attn_metadata.local_block_table,
|
980
982
|
cache_seqlens=local_attn_metadata.local_seqused_k,
|
981
983
|
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
|
982
|
-
cu_seqlens_k_new=
|
984
|
+
cu_seqlens_k_new=None,
|
983
985
|
max_seqlen_q=local_attn_metadata.local_max_query_len,
|
984
986
|
softmax_scale=layer.scaling,
|
985
987
|
causal=True,
|
@@ -1127,7 +1129,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1127
1129
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
1128
1130
|
to avoid memory allocations.
|
1129
1131
|
"""
|
1130
|
-
|
1131
1132
|
# This is being used by normal decode and draft decode when topk == 1
|
1132
1133
|
self.decode_cuda_graph_metadata = {
|
1133
1134
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
@@ -1154,6 +1155,34 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1154
1155
|
),
|
1155
1156
|
}
|
1156
1157
|
|
1158
|
+
# Only allocate local attention buffers if local attention is enabled
|
1159
|
+
# This prevents OOM errors when local attention is not being used
|
1160
|
+
if self.attention_chunk_size is not None:
|
1161
|
+
# Estimate maximum sizes for local attention metadata
|
1162
|
+
max_seq_len = self.max_context_len
|
1163
|
+
page_size = self.page_size or 1
|
1164
|
+
attn_chunk_size = self.attention_chunk_size
|
1165
|
+
max_virtual_batches = max_bs * (
|
1166
|
+
(max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
1167
|
+
)
|
1168
|
+
max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
1169
|
+
max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
|
1170
|
+
|
1171
|
+
self.decode_cuda_graph_local_attn_metadata = {
|
1172
|
+
"local_query_start_loc": torch.zeros(
|
1173
|
+
max_virtual_batches + 1, dtype=torch.int32, device=self.device
|
1174
|
+
),
|
1175
|
+
"local_seqused_k": torch.zeros(
|
1176
|
+
max_virtual_batches, dtype=torch.int32, device=self.device
|
1177
|
+
),
|
1178
|
+
"local_block_table": torch.zeros(
|
1179
|
+
max_virtual_batches,
|
1180
|
+
max_blocks_per_seq * max_pages_per_block,
|
1181
|
+
dtype=torch.int32,
|
1182
|
+
device=self.device,
|
1183
|
+
),
|
1184
|
+
}
|
1185
|
+
|
1157
1186
|
# This is used by draft decode's first half of metadata when topk > 1
|
1158
1187
|
if self.topk > 1:
|
1159
1188
|
self.draft_decode_metadata_topk_normal = {
|
@@ -1405,6 +1434,21 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1405
1434
|
)
|
1406
1435
|
self.decode_cuda_graph_metadata[bs] = metadata
|
1407
1436
|
|
1437
|
+
if self.attention_chunk_size is not None:
|
1438
|
+
metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
1439
|
+
local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[
|
1440
|
+
"local_query_start_loc"
|
1441
|
+
],
|
1442
|
+
local_seqused_k=self.decode_cuda_graph_local_attn_metadata[
|
1443
|
+
"local_seqused_k"
|
1444
|
+
],
|
1445
|
+
local_block_table=self.decode_cuda_graph_local_attn_metadata[
|
1446
|
+
"local_block_table"
|
1447
|
+
],
|
1448
|
+
local_max_query_len=1,
|
1449
|
+
local_max_seq_len=1,
|
1450
|
+
)
|
1451
|
+
|
1408
1452
|
elif forward_mode.is_target_verify():
|
1409
1453
|
if self.topk <= 1:
|
1410
1454
|
metadata.cache_seqlens_int32 = self.target_verify_metadata[
|
@@ -1525,12 +1569,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1525
1569
|
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
1526
1570
|
self.speculative_step_id + 1
|
1527
1571
|
)
|
1528
|
-
metadata.cu_seqlens_k.copy_(
|
1529
|
-
torch.
|
1530
|
-
torch.
|
1531
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1532
|
-
),
|
1533
|
-
(1, 0),
|
1572
|
+
metadata.cu_seqlens_k[1:].copy_(
|
1573
|
+
torch.cumsum(
|
1574
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1534
1575
|
)
|
1535
1576
|
)
|
1536
1577
|
|
@@ -1554,12 +1595,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1554
1595
|
# metadata.max_seq_len_q = self.topk, already set in capture
|
1555
1596
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1556
1597
|
# metadata.cu_seqlens_q already set in capture
|
1557
|
-
metadata.cu_seqlens_k.copy_(
|
1558
|
-
torch.
|
1559
|
-
torch.
|
1560
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1561
|
-
),
|
1562
|
-
(1, 0),
|
1598
|
+
metadata.cu_seqlens_k[1:].copy_(
|
1599
|
+
torch.cumsum(
|
1600
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1563
1601
|
)
|
1564
1602
|
)
|
1565
1603
|
|
@@ -1578,8 +1616,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1578
1616
|
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
1579
1617
|
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
1580
1618
|
)
|
1581
|
-
# TODO:
|
1582
|
-
self._init_local_attn_metadata(metadata, device)
|
1619
|
+
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
|
1583
1620
|
else:
|
1584
1621
|
metadata = self.decode_cuda_graph_metadata[bs]
|
1585
1622
|
# Normal Decode
|
@@ -1587,8 +1624,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1587
1624
|
metadata.max_seq_len_k = max_len
|
1588
1625
|
|
1589
1626
|
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
1590
|
-
|
1591
|
-
|
1627
|
+
# Optimize cumulative sequence length calculation
|
1628
|
+
metadata.cu_seqlens_k[1:].copy_(
|
1629
|
+
torch.cumsum(seq_lens, dim=0, dtype=torch.int32)
|
1592
1630
|
)
|
1593
1631
|
|
1594
1632
|
max_seq_pages = (
|
@@ -1604,7 +1642,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1604
1642
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
1605
1643
|
metadata.page_table[:, max_seq_pages:].fill_(0)
|
1606
1644
|
|
1607
|
-
self.
|
1645
|
+
self._update_local_attn_metadata_for_replay(metadata, bs)
|
1608
1646
|
elif forward_mode.is_target_verify():
|
1609
1647
|
if self.topk <= 1:
|
1610
1648
|
metadata = self.target_verify_metadata[bs]
|
@@ -1615,13 +1653,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1615
1653
|
metadata.max_seq_len_k = (
|
1616
1654
|
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
1617
1655
|
)
|
1618
|
-
metadata.cu_seqlens_k.copy_(
|
1619
|
-
torch.
|
1620
|
-
torch.cumsum(
|
1621
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1622
|
-
),
|
1623
|
-
(1, 0),
|
1624
|
-
)
|
1656
|
+
metadata.cu_seqlens_k[1:].copy_(
|
1657
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
1625
1658
|
)
|
1626
1659
|
max_seq_pages = (
|
1627
1660
|
metadata.max_seq_len_k + self.page_size - 1
|
@@ -1640,13 +1673,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1640
1673
|
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
|
1641
1674
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1642
1675
|
# metadata.cu_seqlens_q already set in capture
|
1643
|
-
metadata.cu_seqlens_k.copy_(
|
1644
|
-
torch.
|
1645
|
-
torch.cumsum(
|
1646
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1647
|
-
),
|
1648
|
-
(1, 0),
|
1649
|
-
)
|
1676
|
+
metadata.cu_seqlens_k[1:].copy_(
|
1677
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
1650
1678
|
)
|
1651
1679
|
page_table = self.req_to_token[
|
1652
1680
|
req_pool_indices, : metadata.max_seq_len_k
|
@@ -1704,14 +1732,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1704
1732
|
metadata_expand.cache_seqlens_int32.copy_(
|
1705
1733
|
mask.sum(dim=1).to(torch.int32)
|
1706
1734
|
)
|
1707
|
-
metadata_expand.cu_seqlens_k.copy_(
|
1708
|
-
torch.
|
1709
|
-
|
1710
|
-
|
1711
|
-
|
1712
|
-
dtype=torch.int32,
|
1713
|
-
),
|
1714
|
-
(1, 0),
|
1735
|
+
metadata_expand.cu_seqlens_k[1:].copy_(
|
1736
|
+
torch.cumsum(
|
1737
|
+
metadata_expand.cache_seqlens_int32,
|
1738
|
+
dim=0,
|
1739
|
+
dtype=torch.int32,
|
1715
1740
|
)
|
1716
1741
|
)
|
1717
1742
|
metadata_expand.max_seq_len_k = (
|
@@ -1722,11 +1747,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1722
1747
|
# Only support encoder size 1 for now
|
1723
1748
|
metadata.encoder_max_seq_len_k = encoder_lens[0]
|
1724
1749
|
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
|
1725
|
-
metadata.encoder_cu_seqlens_k.copy_(
|
1726
|
-
torch.
|
1727
|
-
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
|
1728
|
-
(1, 0),
|
1729
|
-
)
|
1750
|
+
metadata.encoder_cu_seqlens_k[1:].copy_(
|
1751
|
+
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32)
|
1730
1752
|
)
|
1731
1753
|
|
1732
1754
|
metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
|
@@ -1776,6 +1798,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1776
1798
|
page_table,
|
1777
1799
|
self.page_size,
|
1778
1800
|
)
|
1801
|
+
|
1779
1802
|
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
1780
1803
|
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
|
1781
1804
|
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
|
@@ -1785,6 +1808,79 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1785
1808
|
)
|
1786
1809
|
metadata.local_attn_metadata = local_metadata
|
1787
1810
|
|
1811
|
+
def _update_local_attn_metadata_for_replay(
|
1812
|
+
self, metadata: FlashAttentionMetadata, bs: int
|
1813
|
+
):
|
1814
|
+
"""Update preallocated local attention metadata in-place before CUDA graph replay."""
|
1815
|
+
if self.attention_chunk_size is None:
|
1816
|
+
return
|
1817
|
+
|
1818
|
+
# Access preallocated buffers
|
1819
|
+
local_q_buf = self.decode_cuda_graph_local_attn_metadata[
|
1820
|
+
"local_query_start_loc"
|
1821
|
+
]
|
1822
|
+
local_k_buf = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"]
|
1823
|
+
local_block_buf = self.decode_cuda_graph_local_attn_metadata[
|
1824
|
+
"local_block_table"
|
1825
|
+
]
|
1826
|
+
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"]
|
1827
|
+
|
1828
|
+
# Create a modified version for local attention that only processes the last token
|
1829
|
+
# This mimics the normal decode pattern
|
1830
|
+
cu_seqlens_q = torch.arange(
|
1831
|
+
bs + 1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype
|
1832
|
+
)
|
1833
|
+
seqlens = metadata.cache_seqlens_int32[:bs]
|
1834
|
+
# Slice the page_table to match the batch size and actual sequence length
|
1835
|
+
# This serves three important purposes:
|
1836
|
+
# 1. Ensures we only process the actual batch size (bs) and not the maximum batch size
|
1837
|
+
# 2. Limits the sequence length to prevent processing padding tokens or garbage values
|
1838
|
+
# 3. Prevents zeros in the block table which can cause garbage output during replay
|
1839
|
+
#
|
1840
|
+
# Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
|
1841
|
+
# beyond the actual sequence length, leading to incorrect attention calculations
|
1842
|
+
max_seq_len = int(seqlens.max().item())
|
1843
|
+
sliced_page_table = metadata.page_table[:bs, :max_seq_len]
|
1844
|
+
|
1845
|
+
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
|
1846
|
+
seqlens_np = seqlens.cpu().numpy()
|
1847
|
+
(
|
1848
|
+
seqlens_q_local_np,
|
1849
|
+
cu_seqlens_q_local_np,
|
1850
|
+
seqlens_k_local_np,
|
1851
|
+
block_table_local,
|
1852
|
+
) = make_local_attention_virtual_batches(
|
1853
|
+
self.attention_chunk_size,
|
1854
|
+
cu_seqlens_q_np,
|
1855
|
+
seqlens_np,
|
1856
|
+
sliced_page_table,
|
1857
|
+
self.page_size,
|
1858
|
+
)
|
1859
|
+
|
1860
|
+
# Convert back to tensors
|
1861
|
+
device = local_q_buf.device
|
1862
|
+
cu_seqlens_q_local = torch.from_numpy(cu_seqlens_q_local_np).to(device)
|
1863
|
+
seqlens_k_local = torch.from_numpy(seqlens_k_local_np).to(device)
|
1864
|
+
block_table_local = block_table_local.to(device)
|
1865
|
+
# Get sizes
|
1866
|
+
q_len = cu_seqlens_q_local.shape[0]
|
1867
|
+
k_len = seqlens_k_local.shape[0]
|
1868
|
+
b0, b1 = block_table_local.shape
|
1869
|
+
|
1870
|
+
# In-place updates into preallocated tensors and zero out the unused space
|
1871
|
+
local_q_buf[:q_len].copy_(cu_seqlens_q_local)
|
1872
|
+
local_q_buf[q_len:].fill_(0)
|
1873
|
+
local_k_buf[:k_len].copy_(seqlens_k_local)
|
1874
|
+
local_k_buf[k_len:].fill_(0)
|
1875
|
+
local_block_buf[:b0, :b1].copy_(block_table_local)
|
1876
|
+
local_block_buf[b0:, :].fill_(0)
|
1877
|
+
local_block_buf[:b0, b1:].fill_(0)
|
1878
|
+
|
1879
|
+
if metadata.local_attn_metadata is not None:
|
1880
|
+
lam = metadata.local_attn_metadata
|
1881
|
+
lam.local_max_query_len = int(seqlens_q_local_np.max())
|
1882
|
+
lam.local_max_seq_len = int(seqlens_k_local_np.max())
|
1883
|
+
|
1788
1884
|
|
1789
1885
|
class FlashAttentionMultiStepBackend:
|
1790
1886
|
|
@@ -16,8 +16,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
|
16
16
|
import torch
|
17
17
|
|
18
18
|
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
19
|
-
import
|
19
|
+
import logging
|
20
20
|
|
21
|
+
torch._logging.set_logs(dynamo=logging.ERROR)
|
21
22
|
torch._dynamo.config.suppress_errors = True
|
22
23
|
|
23
24
|
from sglang.global_config import global_config
|
@@ -107,6 +108,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
107
108
|
if (
|
108
109
|
"Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures
|
109
110
|
or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
|
111
|
+
or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures
|
110
112
|
):
|
111
113
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
112
114
|
|
@@ -416,6 +418,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
416
418
|
|
417
419
|
logits_soft_cap = layer.logit_cap
|
418
420
|
|
421
|
+
q = q.contiguous()
|
419
422
|
if not self.forward_metadata.use_ragged:
|
420
423
|
if k is not None:
|
421
424
|
assert v is not None
|
@@ -425,7 +428,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
425
428
|
)
|
426
429
|
|
427
430
|
o = prefill_wrapper_paged.forward(
|
428
|
-
q.
|
431
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
429
432
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
430
433
|
causal=not layer.is_cross_attention,
|
431
434
|
sm_scale=layer.scaling,
|
@@ -435,20 +438,27 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
435
438
|
v_scale=layer.v_scale,
|
436
439
|
)
|
437
440
|
else:
|
438
|
-
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
439
|
-
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
440
|
-
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
441
|
-
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
442
|
-
causal=True,
|
443
|
-
sm_scale=layer.scaling,
|
444
|
-
logits_soft_cap=logits_soft_cap,
|
445
|
-
)
|
446
|
-
|
447
441
|
if self.forward_metadata.extend_no_prefix:
|
448
|
-
o =
|
442
|
+
o = self.prefill_wrapper_ragged.forward(
|
443
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
444
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
445
|
+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
446
|
+
causal=True,
|
447
|
+
sm_scale=layer.scaling,
|
448
|
+
logits_soft_cap=logits_soft_cap,
|
449
|
+
)
|
450
|
+
|
449
451
|
else:
|
452
|
+
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
453
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
454
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
455
|
+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
456
|
+
causal=True,
|
457
|
+
sm_scale=layer.scaling,
|
458
|
+
logits_soft_cap=logits_soft_cap,
|
459
|
+
)
|
450
460
|
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
451
|
-
q.
|
461
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
452
462
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
453
463
|
causal=False,
|
454
464
|
sm_scale=layer.scaling,
|