sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +0 -112
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +1 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +11 -0
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +35 -15
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/hf_transformers_utils.py +25 -10
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/attention/vision.py +27 -10
- sglang/srt/layers/communicator.py +14 -4
- sglang/srt/layers/linear.py +7 -1
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/ep_moe/layer.py +29 -68
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/utils.py +43 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp8.py +57 -1
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/layers/vocab_parallel_embedding.py +7 -1
- sglang/srt/lora/lora_registry.py +7 -0
- sglang/srt/managers/cache_controller.py +43 -39
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +3 -2
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +145 -6
- sglang/srt/managers/template_manager.py +25 -22
- sglang/srt/managers/tokenizer_manager.py +114 -62
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -12
- sglang/srt/mem_cache/hiradix_cache.py +21 -4
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +350 -33
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/model_executor/cuda_graph_runner.py +42 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -3
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/deepseek_v2.py +28 -23
- sglang/srt/models/glm4_moe.py +85 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2_moe.py +1 -4
- sglang/srt/models/qwen3_moe.py +7 -8
- sglang/srt/models/step3_vl.py +1 -4
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/server_args.py +115 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +6 -4
- sglang/srt/utils.py +4 -24
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -45,6 +45,7 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
|
45
45
|
|
46
46
|
from sglang.srt.disaggregation.utils import (
|
47
47
|
FAKE_BOOTSTRAP_HOST,
|
48
|
+
DisaggregationMode,
|
48
49
|
register_disaggregation_server,
|
49
50
|
)
|
50
51
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
@@ -88,7 +89,7 @@ from sglang.srt.managers.io_struct import (
|
|
88
89
|
VertexGenerateReqInput,
|
89
90
|
)
|
90
91
|
from sglang.srt.managers.template_manager import TemplateManager
|
91
|
-
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
92
|
+
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
|
92
93
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
93
94
|
from sglang.srt.reasoning_parser import ReasoningParser
|
94
95
|
from sglang.srt.server_args import ServerArgs
|
@@ -230,23 +231,28 @@ async def validate_json_request(raw_request: Request):
|
|
230
231
|
|
231
232
|
|
232
233
|
@app.get("/health")
|
233
|
-
async def health() -> Response:
|
234
|
-
"""Check the health of the http server."""
|
235
|
-
return Response(status_code=200)
|
236
|
-
|
237
|
-
|
238
234
|
@app.get("/health_generate")
|
239
235
|
async def health_generate(request: Request) -> Response:
|
240
|
-
"""
|
236
|
+
"""
|
237
|
+
Check the health of the inference server by sending a special request to generate one token.
|
238
|
+
|
239
|
+
If the server is running something, this request will be ignored, so it creates zero overhead.
|
240
|
+
If the server is not running anything, this request will be run, so we know whether the server is healthy.
|
241
|
+
"""
|
242
|
+
|
241
243
|
if _global_state.tokenizer_manager.gracefully_exit:
|
242
244
|
logger.info("Health check request received during shutdown. Returning 503.")
|
243
245
|
return Response(status_code=503)
|
244
246
|
|
247
|
+
if not _global_state.tokenizer_manager.server_status.is_healthy():
|
248
|
+
return Response(status_code=503)
|
249
|
+
|
245
250
|
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
|
246
251
|
rid = f"HEALTH_CHECK_{time.time()}"
|
247
252
|
|
248
253
|
if _global_state.tokenizer_manager.is_image_gen:
|
249
|
-
|
254
|
+
# Keep this branch for some internal use cases.
|
255
|
+
raise NotImplementedError("Image generation is not supported yet.")
|
250
256
|
elif _global_state.tokenizer_manager.is_generation:
|
251
257
|
gri = GenerateReqInput(
|
252
258
|
rid=rid,
|
@@ -254,6 +260,12 @@ async def health_generate(request: Request) -> Response:
|
|
254
260
|
sampling_params=sampling_params,
|
255
261
|
log_metrics=False,
|
256
262
|
)
|
263
|
+
if (
|
264
|
+
_global_state.tokenizer_manager.server_args.disaggregation_mode
|
265
|
+
!= DisaggregationMode.NULL
|
266
|
+
):
|
267
|
+
gri.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
268
|
+
gri.bootstrap_room = 0
|
257
269
|
else:
|
258
270
|
gri = EmbeddingReqInput(
|
259
271
|
rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
@@ -263,9 +275,6 @@ async def health_generate(request: Request) -> Response:
|
|
263
275
|
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
264
276
|
break
|
265
277
|
|
266
|
-
# This request is a special request.
|
267
|
-
# If the server already has something running, this request will be ignored, so it creates zero overhead.
|
268
|
-
# If the server is not running, this request will be run, so we know whether the server is healthy.
|
269
278
|
task = asyncio.create_task(gen())
|
270
279
|
|
271
280
|
# As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy.
|
@@ -1032,8 +1041,10 @@ def _execute_server_warmup(
|
|
1032
1041
|
timeout=600,
|
1033
1042
|
)
|
1034
1043
|
assert res.status_code == 200, f"{res}"
|
1044
|
+
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
1045
|
+
|
1035
1046
|
else:
|
1036
|
-
logger.info(f"Start of
|
1047
|
+
logger.info(f"Start of pd disaggregation warmup ...")
|
1037
1048
|
json_data = {
|
1038
1049
|
"sampling_params": {
|
1039
1050
|
"temperature": 0.0,
|
@@ -1055,9 +1066,18 @@ def _execute_server_warmup(
|
|
1055
1066
|
headers=headers,
|
1056
1067
|
timeout=1800, # because of deep gemm precache is very long if not precache.
|
1057
1068
|
)
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1069
|
+
if res.status_code == 200:
|
1070
|
+
logger.info(
|
1071
|
+
f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}"
|
1072
|
+
)
|
1073
|
+
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
1074
|
+
else:
|
1075
|
+
logger.info(
|
1076
|
+
"Prefill disaggregation mode warm Up Failed, status code: {}".format(
|
1077
|
+
res.status_code
|
1078
|
+
)
|
1079
|
+
)
|
1080
|
+
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
|
1061
1081
|
|
1062
1082
|
except Exception:
|
1063
1083
|
last_traceback = get_exception_traceback()
|
@@ -288,12 +288,14 @@ class _SinglePassGatherer(ABC):
|
|
288
288
|
)
|
289
289
|
|
290
290
|
if server_args.expert_distribution_recorder_mode == "stat_approx":
|
291
|
-
if server_args.
|
291
|
+
if server_args.moe_a2a_backend is not None and (
|
292
|
+
server_args.deepep_mode == "normal"
|
293
|
+
):
|
292
294
|
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
|
293
295
|
else:
|
294
296
|
raise NotImplementedError
|
295
297
|
|
296
|
-
if server_args.
|
298
|
+
if server_args.moe_a2a_backend is not None:
|
297
299
|
if server_args.deepep_mode == "normal":
|
298
300
|
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
299
301
|
elif server_args.deepep_mode == "low_latency":
|
@@ -14,7 +14,6 @@
|
|
14
14
|
"""Utilities for Huggingface Transformers."""
|
15
15
|
|
16
16
|
import contextlib
|
17
|
-
import logging
|
18
17
|
import os
|
19
18
|
import warnings
|
20
19
|
from pathlib import Path
|
@@ -45,7 +44,7 @@ from sglang.srt.configs import (
|
|
45
44
|
)
|
46
45
|
from sglang.srt.configs.internvl import InternVLChatConfig
|
47
46
|
from sglang.srt.connector import create_remote_connector
|
48
|
-
from sglang.srt.utils import is_remote_url, lru_cache_frozenset
|
47
|
+
from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset
|
49
48
|
|
50
49
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
51
50
|
ChatGLMConfig.model_type: ChatGLMConfig,
|
@@ -317,15 +316,31 @@ def get_processor(
|
|
317
316
|
|
318
317
|
if config.model_type not in {"llava", "clip"}:
|
319
318
|
kwargs["use_fast"] = use_fast
|
319
|
+
try:
|
320
|
+
processor = AutoProcessor.from_pretrained(
|
321
|
+
tokenizer_name,
|
322
|
+
*args,
|
323
|
+
trust_remote_code=trust_remote_code,
|
324
|
+
revision=revision,
|
325
|
+
**kwargs,
|
326
|
+
)
|
320
327
|
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
328
|
+
except ValueError as e:
|
329
|
+
error_message = str(e)
|
330
|
+
if "does not have a slow version" in error_message:
|
331
|
+
logger.info(
|
332
|
+
f"Processor {tokenizer_name} does not have a slow version. Automatically use fast version"
|
333
|
+
)
|
334
|
+
kwargs["use_fast"] = True
|
335
|
+
processor = AutoProcessor.from_pretrained(
|
336
|
+
tokenizer_name,
|
337
|
+
*args,
|
338
|
+
trust_remote_code=trust_remote_code,
|
339
|
+
revision=revision,
|
340
|
+
**kwargs,
|
341
|
+
)
|
342
|
+
else:
|
343
|
+
raise e
|
329
344
|
tokenizer = get_tokenizer_from_processor(processor)
|
330
345
|
|
331
346
|
attach_additional_stop_token_ids(tokenizer)
|
@@ -102,7 +102,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
102
102
|
block_kv_indices,
|
103
103
|
self.req_to_token.stride(0),
|
104
104
|
max_seqlen_pad,
|
105
|
-
PAGE_SIZE,
|
105
|
+
PAGED_SIZE=PAGE_SIZE,
|
106
106
|
)
|
107
107
|
workspace_size = cutlass_mla_get_workspace_size(
|
108
108
|
max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
|
@@ -165,7 +165,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
165
165
|
self.cuda_graph_kv_indices,
|
166
166
|
self.req_to_token.stride(0),
|
167
167
|
self.cuda_graph_kv_indices.stride(0),
|
168
|
-
PAGE_SIZE,
|
168
|
+
PAGED_SIZE=PAGE_SIZE,
|
169
169
|
)
|
170
170
|
self.forward_metadata = CutlassMLADecodeMetadata(
|
171
171
|
self.cuda_graph_mla_workspace,
|
@@ -206,7 +206,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
206
206
|
self.cuda_graph_kv_indices,
|
207
207
|
self.req_to_token.stride(0),
|
208
208
|
self.cuda_graph_kv_indices.stride(0),
|
209
|
-
PAGE_SIZE,
|
209
|
+
PAGED_SIZE=PAGE_SIZE,
|
210
210
|
)
|
211
211
|
else:
|
212
212
|
super().init_forward_metadata_replay_cuda_graph(
|
@@ -1406,7 +1406,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1406
1406
|
)
|
1407
1407
|
metadata.page_table = self.decode_cuda_graph_metadata[
|
1408
1408
|
"page_table_draft_decode"
|
1409
|
-
][
|
1409
|
+
][:bs, :]
|
1410
1410
|
self.decode_cuda_graph_metadata[bs] = metadata
|
1411
1411
|
else:
|
1412
1412
|
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
@@ -1424,7 +1424,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1424
1424
|
][: bs + 1]
|
1425
1425
|
metadata.page_table = self.draft_decode_metadata_topk_normal[
|
1426
1426
|
"page_table"
|
1427
|
-
][
|
1427
|
+
][:bs, :]
|
1428
1428
|
|
1429
1429
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1430
1430
|
metadata_expand.cache_seqlens_int32 = (
|
@@ -1461,7 +1461,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1461
1461
|
metadata.max_seq_len_k = seq_lens.max().item()
|
1462
1462
|
# Precompute page table
|
1463
1463
|
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
1464
|
-
|
1464
|
+
:bs, :
|
1465
1465
|
]
|
1466
1466
|
# Precompute cumulative sequence lengths
|
1467
1467
|
metadata.cu_seqlens_q = torch.arange(
|
@@ -1498,9 +1498,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1498
1498
|
: (bs + 1)
|
1499
1499
|
]
|
1500
1500
|
|
1501
|
-
metadata.page_table = self.target_verify_metadata["page_table"][
|
1502
|
-
req_pool_indices, :
|
1503
|
-
]
|
1501
|
+
metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
|
1504
1502
|
|
1505
1503
|
self.target_verify_metadata[bs] = metadata
|
1506
1504
|
else:
|
@@ -1519,7 +1517,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1519
1517
|
][: bs + 1]
|
1520
1518
|
metadata.page_table = self.target_verify_metadata_topk_normal[
|
1521
1519
|
"page_table"
|
1522
|
-
][
|
1520
|
+
][:bs, :]
|
1523
1521
|
|
1524
1522
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1525
1523
|
metadata_expand.cache_seqlens_int32 = (
|
@@ -1562,9 +1560,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1562
1560
|
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
|
1563
1561
|
: (bs + 1)
|
1564
1562
|
]
|
1565
|
-
metadata.page_table = self.draft_extend_metadata["page_table"][
|
1566
|
-
req_pool_indices, :
|
1567
|
-
]
|
1563
|
+
metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
|
1568
1564
|
|
1569
1565
|
self.draft_extend_metadata[bs] = metadata
|
1570
1566
|
|
@@ -1578,7 +1574,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1578
1574
|
][: (encoder_bs + 1)]
|
1579
1575
|
|
1580
1576
|
metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
|
1581
|
-
|
1577
|
+
:bs, :
|
1582
1578
|
]
|
1583
1579
|
|
1584
1580
|
self.forward_metadata = metadata
|
@@ -0,0 +1,372 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
Support attention backend for TRTLLM MLA kernels from flashinfer.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import math
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from typing import TYPE_CHECKING, Optional, Union
|
10
|
+
|
11
|
+
import torch
|
12
|
+
import triton
|
13
|
+
|
14
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
15
|
+
from sglang.srt.layers.attention.utils import (
|
16
|
+
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
17
|
+
create_flashmla_kv_indices_triton,
|
18
|
+
)
|
19
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
20
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
21
|
+
from sglang.srt.utils import is_flashinfer_available
|
22
|
+
|
23
|
+
if is_flashinfer_available():
|
24
|
+
import flashinfer
|
25
|
+
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
28
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
29
|
+
from sglang.srt.speculative.spec_info import SpecInfo
|
30
|
+
|
31
|
+
# Constants
|
32
|
+
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
33
|
+
|
34
|
+
# Block constraint from flashinfer requirements
|
35
|
+
# From flashinfer.decode._check_trtllm_gen_mla_shape:
|
36
|
+
# block_num % (128 / block_size) == 0
|
37
|
+
# This imposes that the total number of blocks must be divisible by
|
38
|
+
# (128 / block_size). We capture the 128 constant here so we can
|
39
|
+
# compute the LCM with other padding constraints.
|
40
|
+
TRTLLM_BLOCK_CONSTRAINT = 128
|
41
|
+
|
42
|
+
|
43
|
+
@dataclass
|
44
|
+
class TRTLLMMLADecodeMetadata:
|
45
|
+
"""Metadata for TRTLLM MLA decode operations."""
|
46
|
+
|
47
|
+
workspace: Optional[torch.Tensor] = None
|
48
|
+
block_kv_indices: Optional[torch.Tensor] = None
|
49
|
+
|
50
|
+
|
51
|
+
class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
52
|
+
"""TRTLLM MLA attention kernel from flashinfer."""
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
model_runner: ModelRunner,
|
57
|
+
skip_prefill: bool = False,
|
58
|
+
kv_indptr_buf: Optional[torch.Tensor] = None,
|
59
|
+
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
60
|
+
):
|
61
|
+
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
|
62
|
+
|
63
|
+
config = model_runner.model_config
|
64
|
+
|
65
|
+
# Model parameters
|
66
|
+
self.num_q_heads = config.num_attention_heads // get_attention_tp_size()
|
67
|
+
self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size())
|
68
|
+
self.num_local_heads = config.num_attention_heads // get_attention_tp_size()
|
69
|
+
|
70
|
+
# MLA-specific dimensions
|
71
|
+
self.kv_lora_rank = config.kv_lora_rank
|
72
|
+
self.qk_nope_head_dim = config.qk_nope_head_dim
|
73
|
+
self.qk_rope_head_dim = config.qk_rope_head_dim
|
74
|
+
self.v_head_dim = config.v_head_dim
|
75
|
+
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
|
76
|
+
|
77
|
+
# Runtime parameters
|
78
|
+
self.scaling = config.scaling
|
79
|
+
self.data_type = model_runner.kv_cache_dtype
|
80
|
+
self.q_data_type = model_runner.dtype
|
81
|
+
self.page_size = model_runner.page_size
|
82
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
83
|
+
|
84
|
+
# Workspace allocation
|
85
|
+
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
86
|
+
self.workspace_buffer = torch.empty(
|
87
|
+
self.workspace_size, dtype=torch.int8, device=self.device
|
88
|
+
)
|
89
|
+
|
90
|
+
# CUDA graph state
|
91
|
+
self.decode_cuda_graph_metadata = {}
|
92
|
+
self.cuda_graph_kv_indices = None
|
93
|
+
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
94
|
+
|
95
|
+
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
96
|
+
"""
|
97
|
+
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
max_seq_len: Maximum sequence length in tokens
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
Number of blocks padded to satisfy all constraints
|
104
|
+
"""
|
105
|
+
blocks = triton.cdiv(max_seq_len, self.page_size)
|
106
|
+
|
107
|
+
# Apply dual constraints (take LCM to satisfy both):
|
108
|
+
# 1. TRT-LLM: block_num % (128 / page_size) == 0
|
109
|
+
# 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
|
110
|
+
trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
|
111
|
+
constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
|
112
|
+
|
113
|
+
if blocks % constraint_lcm != 0:
|
114
|
+
blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
|
115
|
+
return blocks
|
116
|
+
|
117
|
+
def _create_block_kv_indices(
|
118
|
+
self,
|
119
|
+
batch_size: int,
|
120
|
+
max_blocks: int,
|
121
|
+
req_pool_indices: torch.Tensor,
|
122
|
+
seq_lens: torch.Tensor,
|
123
|
+
device: torch.device,
|
124
|
+
) -> torch.Tensor:
|
125
|
+
"""
|
126
|
+
Create block KV indices tensor using Triton kernel.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
batch_size: Batch size
|
130
|
+
max_blocks: Maximum number of blocks per sequence
|
131
|
+
req_pool_indices: Request pool indices
|
132
|
+
seq_lens: Sequence lengths
|
133
|
+
device: Target device
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
Block KV indices tensor
|
137
|
+
"""
|
138
|
+
block_kv_indices = torch.full(
|
139
|
+
(batch_size, max_blocks), -1, dtype=torch.int32, device=device
|
140
|
+
)
|
141
|
+
|
142
|
+
create_flashmla_kv_indices_triton[(batch_size,)](
|
143
|
+
self.req_to_token,
|
144
|
+
req_pool_indices,
|
145
|
+
seq_lens,
|
146
|
+
None,
|
147
|
+
block_kv_indices,
|
148
|
+
self.req_to_token.stride(0),
|
149
|
+
max_blocks,
|
150
|
+
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
151
|
+
PAGED_SIZE=self.page_size,
|
152
|
+
)
|
153
|
+
|
154
|
+
return block_kv_indices
|
155
|
+
|
156
|
+
def init_cuda_graph_state(
|
157
|
+
self,
|
158
|
+
max_bs: int,
|
159
|
+
max_num_tokens: int,
|
160
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
161
|
+
):
|
162
|
+
"""Initialize CUDA graph state for TRTLLM MLA."""
|
163
|
+
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
164
|
+
|
165
|
+
self.cuda_graph_kv_indices = torch.full(
|
166
|
+
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
167
|
+
)
|
168
|
+
self.cuda_graph_workspace = torch.empty(
|
169
|
+
self.workspace_size, dtype=torch.int8, device=self.device
|
170
|
+
)
|
171
|
+
|
172
|
+
def init_forward_metadata_capture_cuda_graph(
|
173
|
+
self,
|
174
|
+
bs: int,
|
175
|
+
num_tokens: int,
|
176
|
+
req_pool_indices: torch.Tensor,
|
177
|
+
seq_lens: torch.Tensor,
|
178
|
+
encoder_lens: Optional[torch.Tensor],
|
179
|
+
forward_mode: ForwardMode,
|
180
|
+
spec_info: Optional[SpecInfo],
|
181
|
+
):
|
182
|
+
"""Initialize metadata for CUDA graph capture."""
|
183
|
+
# Delegate to parent for non-decode modes or when speculative execution is used.
|
184
|
+
if not (forward_mode.is_decode_or_idle() and spec_info is None):
|
185
|
+
return super().init_forward_metadata_capture_cuda_graph(
|
186
|
+
bs,
|
187
|
+
num_tokens,
|
188
|
+
req_pool_indices,
|
189
|
+
seq_lens,
|
190
|
+
encoder_lens,
|
191
|
+
forward_mode,
|
192
|
+
spec_info,
|
193
|
+
)
|
194
|
+
|
195
|
+
# Custom fast-path for decode/idle without speculative execution.
|
196
|
+
max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
|
197
|
+
block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
|
198
|
+
|
199
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
200
|
+
self.req_to_token,
|
201
|
+
req_pool_indices,
|
202
|
+
seq_lens,
|
203
|
+
None,
|
204
|
+
block_kv_indices,
|
205
|
+
self.req_to_token.stride(0),
|
206
|
+
max_seqlen_pad,
|
207
|
+
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
208
|
+
PAGED_SIZE=self.page_size,
|
209
|
+
)
|
210
|
+
|
211
|
+
metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
|
212
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
213
|
+
self.forward_metadata = metadata
|
214
|
+
|
215
|
+
def init_forward_metadata_replay_cuda_graph(
|
216
|
+
self,
|
217
|
+
bs: int,
|
218
|
+
req_pool_indices: torch.Tensor,
|
219
|
+
seq_lens: torch.Tensor,
|
220
|
+
seq_lens_sum: int,
|
221
|
+
encoder_lens: Optional[torch.Tensor],
|
222
|
+
forward_mode: ForwardMode,
|
223
|
+
spec_info: Optional[SpecInfo],
|
224
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
225
|
+
):
|
226
|
+
"""Replay CUDA graph with new inputs."""
|
227
|
+
# Delegate to parent for non-decode modes or when speculative execution is used.
|
228
|
+
if not (forward_mode.is_decode_or_idle() and spec_info is None):
|
229
|
+
return super().init_forward_metadata_replay_cuda_graph(
|
230
|
+
bs,
|
231
|
+
req_pool_indices,
|
232
|
+
seq_lens,
|
233
|
+
seq_lens_sum,
|
234
|
+
encoder_lens,
|
235
|
+
forward_mode,
|
236
|
+
spec_info,
|
237
|
+
seq_lens_cpu,
|
238
|
+
)
|
239
|
+
|
240
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
241
|
+
|
242
|
+
# Update block indices for new sequences.
|
243
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
244
|
+
self.req_to_token,
|
245
|
+
req_pool_indices[:bs],
|
246
|
+
seq_lens[:bs],
|
247
|
+
None,
|
248
|
+
metadata.block_kv_indices,
|
249
|
+
self.req_to_token.stride(0),
|
250
|
+
metadata.block_kv_indices.shape[1],
|
251
|
+
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
252
|
+
PAGED_SIZE=self.page_size,
|
253
|
+
)
|
254
|
+
|
255
|
+
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
256
|
+
"""Get the fill value for sequence lengths in CUDA graph."""
|
257
|
+
return 1
|
258
|
+
|
259
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
260
|
+
"""Initialize the metadata for a forward pass."""
|
261
|
+
# Delegate to parent for non-decode modes or when speculative execution is used.
|
262
|
+
if not (
|
263
|
+
forward_batch.forward_mode.is_decode_or_idle()
|
264
|
+
and forward_batch.spec_info is None
|
265
|
+
):
|
266
|
+
return super().init_forward_metadata(forward_batch)
|
267
|
+
|
268
|
+
bs = forward_batch.batch_size
|
269
|
+
|
270
|
+
# Get maximum sequence length.
|
271
|
+
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
|
272
|
+
max_seq = forward_batch.seq_lens_cpu.max().item()
|
273
|
+
else:
|
274
|
+
max_seq = forward_batch.seq_lens.max().item()
|
275
|
+
|
276
|
+
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
277
|
+
block_kv_indices = self._create_block_kv_indices(
|
278
|
+
bs,
|
279
|
+
max_seqlen_pad,
|
280
|
+
forward_batch.req_pool_indices,
|
281
|
+
forward_batch.seq_lens,
|
282
|
+
forward_batch.seq_lens.device,
|
283
|
+
)
|
284
|
+
|
285
|
+
self.forward_metadata = TRTLLMMLADecodeMetadata(
|
286
|
+
self.workspace_buffer, block_kv_indices
|
287
|
+
)
|
288
|
+
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
289
|
+
|
290
|
+
def forward_decode(
|
291
|
+
self,
|
292
|
+
q: torch.Tensor,
|
293
|
+
k: torch.Tensor,
|
294
|
+
v: torch.Tensor,
|
295
|
+
layer: RadixAttention,
|
296
|
+
forward_batch: ForwardBatch,
|
297
|
+
save_kv_cache: bool = True,
|
298
|
+
q_rope: Optional[torch.Tensor] = None,
|
299
|
+
k_rope: Optional[torch.Tensor] = None,
|
300
|
+
) -> torch.Tensor:
|
301
|
+
"""Run forward for decode using TRTLLM MLA kernel."""
|
302
|
+
# Save KV cache if requested
|
303
|
+
if k is not None and save_kv_cache:
|
304
|
+
cache_loc = forward_batch.out_cache_loc
|
305
|
+
if k_rope is not None:
|
306
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
307
|
+
layer, cache_loc, k, k_rope
|
308
|
+
)
|
309
|
+
elif v is not None:
|
310
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
311
|
+
|
312
|
+
# Prepare query tensor inline
|
313
|
+
if q_rope is not None:
|
314
|
+
# q contains NOPE part (v_head_dim)
|
315
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
316
|
+
q_rope_reshaped = q_rope.view(
|
317
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
318
|
+
)
|
319
|
+
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
320
|
+
else:
|
321
|
+
# q already has both parts
|
322
|
+
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
323
|
+
|
324
|
+
# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
|
325
|
+
if query.dim() == 3:
|
326
|
+
query = query.unsqueeze(1)
|
327
|
+
|
328
|
+
# Prepare KV cache inline
|
329
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
330
|
+
pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
|
331
|
+
# TRT-LLM expects single KV data with extra dimension
|
332
|
+
kv_cache = pages.unsqueeze(1)
|
333
|
+
|
334
|
+
# Get metadata
|
335
|
+
metadata = (
|
336
|
+
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
337
|
+
or self.forward_metadata
|
338
|
+
)
|
339
|
+
|
340
|
+
# Scale computation for TRTLLM MLA kernel:
|
341
|
+
# - BMM1 scale = q_scale * k_scale * softmax_scale
|
342
|
+
# - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling
|
343
|
+
# - k_scale is read from model checkpoint if available
|
344
|
+
# TODO: Change once fp8 path is supported
|
345
|
+
q_scale = 1.0
|
346
|
+
k_scale = (
|
347
|
+
layer.k_scale_float
|
348
|
+
if getattr(layer, "k_scale_float", None) is not None
|
349
|
+
else 1.0
|
350
|
+
)
|
351
|
+
|
352
|
+
bmm1_scale = q_scale * k_scale * layer.scaling
|
353
|
+
|
354
|
+
# Call TRT-LLM kernel
|
355
|
+
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
356
|
+
query=query,
|
357
|
+
kv_cache=kv_cache,
|
358
|
+
workspace_buffer=metadata.workspace,
|
359
|
+
qk_nope_head_dim=self.qk_nope_head_dim,
|
360
|
+
kv_lora_rank=self.kv_lora_rank,
|
361
|
+
qk_rope_head_dim=self.qk_rope_head_dim,
|
362
|
+
block_tables=metadata.block_kv_indices,
|
363
|
+
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
364
|
+
max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
|
365
|
+
bmm1_scale=bmm1_scale,
|
366
|
+
)
|
367
|
+
|
368
|
+
# Extract value projection part and reshape
|
369
|
+
raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
|
370
|
+
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
371
|
+
|
372
|
+
return output
|
@@ -1,6 +1,11 @@
|
|
1
1
|
import triton
|
2
2
|
import triton.language as tl
|
3
3
|
|
4
|
+
# Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
|
5
|
+
# Number of pages that the kernel writes per iteration.
|
6
|
+
# Exposed here so other Python modules can import it instead of hard-coding 64.
|
7
|
+
TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
|
8
|
+
|
4
9
|
|
5
10
|
@triton.jit
|
6
11
|
def create_flashinfer_kv_indices_triton(
|
@@ -50,10 +55,10 @@ def create_flashmla_kv_indices_triton(
|
|
50
55
|
kv_indices_ptr,
|
51
56
|
req_to_token_ptr_stride: tl.constexpr,
|
52
57
|
kv_indices_ptr_stride: tl.constexpr,
|
58
|
+
NUM_PAGE_PER_BLOCK: tl.constexpr = TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
53
59
|
PAGED_SIZE: tl.constexpr = 64,
|
54
60
|
):
|
55
61
|
BLOCK_SIZE: tl.constexpr = 4096
|
56
|
-
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
|
57
62
|
pid = tl.program_id(axis=0)
|
58
63
|
|
59
64
|
# find the req pool idx, this is for batch to token
|