sglang 0.4.10__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/moe/ep_moe/layer.py +19 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -2
- sglang/srt/layers/quantization/fp8.py +52 -0
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/managers/cache_controller.py +35 -35
- sglang/srt/managers/scheduler.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +15 -6
- sglang/srt/mem_cache/hiradix_cache.py +21 -4
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +350 -33
- sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
- sglang/srt/model_executor/cuda_graph_runner.py +25 -1
- sglang/srt/model_executor/model_runner.py +8 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/step3_vl.py +0 -3
- sglang/srt/server_args.py +40 -6
- sglang/srt/utils.py +1 -0
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/version.py +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +35 -30
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -418,6 +418,26 @@ if __name__ == "__main__":
|
|
418
418
|
ServerArgs.add_cli_args(parser)
|
419
419
|
BenchArgs.add_cli_args(parser)
|
420
420
|
args = parser.parse_args()
|
421
|
+
|
422
|
+
# handling ModelScope model downloads
|
423
|
+
if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() in ("true", "1"):
|
424
|
+
if os.path.exists(args.model_path):
|
425
|
+
print(f"Using local model path: {args.model_path}")
|
426
|
+
else:
|
427
|
+
try:
|
428
|
+
from modelscope import snapshot_download
|
429
|
+
|
430
|
+
print(f"Using ModelScope to download model: {args.model_path}")
|
431
|
+
|
432
|
+
# download the model and replace args.model_path
|
433
|
+
args.model_path = snapshot_download(
|
434
|
+
args.model_path,
|
435
|
+
)
|
436
|
+
print(f"Model downloaded to: {args.model_path}")
|
437
|
+
except Exception as e:
|
438
|
+
print(f"ModelScope download failed: {str(e)}")
|
439
|
+
raise e
|
440
|
+
|
421
441
|
server_args = ServerArgs.from_cli_args(args)
|
422
442
|
bench_args = BenchArgs.from_cli_args(args)
|
423
443
|
|
@@ -112,6 +112,7 @@ class ModelConfig:
|
|
112
112
|
mm_disabled_models = [
|
113
113
|
"Gemma3ForConditionalGeneration",
|
114
114
|
"Llama4ForConditionalGeneration",
|
115
|
+
"Step3VLForConditionalGeneration",
|
115
116
|
]
|
116
117
|
if self.hf_config.architectures[0] in mm_disabled_models:
|
117
118
|
enable_multimodal = False
|
@@ -1,6 +1,8 @@
|
|
1
1
|
import argparse
|
2
2
|
import dataclasses
|
3
3
|
|
4
|
+
from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
|
5
|
+
|
4
6
|
|
5
7
|
@dataclasses.dataclass
|
6
8
|
class LBArgs:
|
@@ -18,7 +20,7 @@ class LBArgs:
|
|
18
20
|
parser.add_argument(
|
19
21
|
"--rust-lb",
|
20
22
|
action="store_true",
|
21
|
-
help="
|
23
|
+
help="Deprecated, please use SGLang Router instead, this argument will have no effect.",
|
22
24
|
)
|
23
25
|
parser.add_argument(
|
24
26
|
"--host",
|
@@ -115,25 +117,8 @@ def main():
|
|
115
117
|
args = parser.parse_args()
|
116
118
|
lb_args = LBArgs.from_cli_args(args)
|
117
119
|
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
RustLB(
|
122
|
-
host=lb_args.host,
|
123
|
-
port=lb_args.port,
|
124
|
-
policy=lb_args.policy,
|
125
|
-
prefill_infos=lb_args.prefill_infos,
|
126
|
-
decode_infos=lb_args.decode_infos,
|
127
|
-
log_interval=lb_args.log_interval,
|
128
|
-
timeout=lb_args.timeout,
|
129
|
-
).start()
|
130
|
-
else:
|
131
|
-
from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
|
132
|
-
|
133
|
-
prefill_configs = [
|
134
|
-
PrefillConfig(url, port) for url, port in lb_args.prefill_infos
|
135
|
-
]
|
136
|
-
run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)
|
120
|
+
prefill_configs = [PrefillConfig(url, port) for url, port in lb_args.prefill_infos]
|
121
|
+
run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)
|
137
122
|
|
138
123
|
|
139
124
|
if __name__ == "__main__":
|
@@ -37,6 +37,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
37
37
|
from sglang.srt.server_args import ServerArgs
|
38
38
|
from sglang.srt.utils import (
|
39
39
|
format_tcp_address,
|
40
|
+
get_bool_env_var,
|
40
41
|
get_free_port,
|
41
42
|
get_int_env_var,
|
42
43
|
get_ip,
|
@@ -198,6 +199,10 @@ class MooncakeKVManager(BaseKVManager):
|
|
198
199
|
self.bootstrap_timeout = get_int_env_var(
|
199
200
|
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
|
200
201
|
)
|
202
|
+
|
203
|
+
self.enable_custom_mem_pool = get_bool_env_var(
|
204
|
+
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
205
|
+
)
|
201
206
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
202
207
|
self.heartbeat_failures = {}
|
203
208
|
self.session_pool = defaultdict(requests.Session)
|
@@ -258,6 +263,26 @@ class MooncakeKVManager(BaseKVManager):
|
|
258
263
|
socket.connect(endpoint)
|
259
264
|
return socket
|
260
265
|
|
266
|
+
def _transfer_data(self, mooncake_session_id, transfer_blocks):
|
267
|
+
if not transfer_blocks:
|
268
|
+
return 0
|
269
|
+
|
270
|
+
# TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
|
271
|
+
if self.enable_custom_mem_pool:
|
272
|
+
# batch_transfer_sync has a higher chance to trigger an accuracy drop for MNNVL, fallback to transfer_sync temporarily
|
273
|
+
for src_addr, dst_addr, length in transfer_blocks:
|
274
|
+
status = self.engine.transfer_sync(
|
275
|
+
mooncake_session_id, src_addr, dst_addr, length
|
276
|
+
)
|
277
|
+
if status != 0:
|
278
|
+
return status
|
279
|
+
return 0
|
280
|
+
else:
|
281
|
+
src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
|
282
|
+
return self.engine.batch_transfer_sync(
|
283
|
+
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
|
284
|
+
)
|
285
|
+
|
261
286
|
def send_kvcache(
|
262
287
|
self,
|
263
288
|
mooncake_session_id: str,
|
@@ -283,17 +308,14 @@ class MooncakeKVManager(BaseKVManager):
|
|
283
308
|
|
284
309
|
# Worker function for processing a single layer
|
285
310
|
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
311
|
+
transfer_blocks = []
|
286
312
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
287
313
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
288
314
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
289
315
|
length = item_len * len(prefill_index)
|
316
|
+
transfer_blocks.append((src_addr, dst_addr, length))
|
290
317
|
|
291
|
-
|
292
|
-
mooncake_session_id, src_addr, dst_addr, length
|
293
|
-
)
|
294
|
-
if status != 0:
|
295
|
-
return status
|
296
|
-
return 0
|
318
|
+
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
297
319
|
|
298
320
|
futures = [
|
299
321
|
executor.submit(
|
@@ -465,21 +487,17 @@ class MooncakeKVManager(BaseKVManager):
|
|
465
487
|
dst_aux_ptrs: list[int],
|
466
488
|
dst_aux_index: int,
|
467
489
|
):
|
468
|
-
|
469
|
-
dst_addr_list = []
|
470
|
-
length_list = []
|
490
|
+
transfer_blocks = []
|
471
491
|
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
472
492
|
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
493
|
+
|
473
494
|
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
|
474
495
|
length = prefill_aux_item_lens[i]
|
475
496
|
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
476
497
|
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
return self.engine.batch_transfer_sync(
|
481
|
-
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
482
|
-
)
|
498
|
+
transfer_blocks.append((src_addr, dst_addr, length))
|
499
|
+
|
500
|
+
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
483
501
|
|
484
502
|
def sync_status_to_decode_endpoint(
|
485
503
|
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
@@ -0,0 +1,372 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
Support attention backend for TRTLLM MLA kernels from flashinfer.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import math
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from typing import TYPE_CHECKING, Optional, Union
|
10
|
+
|
11
|
+
import torch
|
12
|
+
import triton
|
13
|
+
|
14
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
15
|
+
from sglang.srt.layers.attention.utils import (
|
16
|
+
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
17
|
+
create_flashmla_kv_indices_triton,
|
18
|
+
)
|
19
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
20
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
21
|
+
from sglang.srt.utils import is_flashinfer_available
|
22
|
+
|
23
|
+
if is_flashinfer_available():
|
24
|
+
import flashinfer
|
25
|
+
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
28
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
29
|
+
from sglang.srt.speculative.spec_info import SpecInfo
|
30
|
+
|
31
|
+
# Constants
|
32
|
+
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
33
|
+
|
34
|
+
# Block constraint from flashinfer requirements
|
35
|
+
# From flashinfer.decode._check_trtllm_gen_mla_shape:
|
36
|
+
# block_num % (128 / block_size) == 0
|
37
|
+
# This imposes that the total number of blocks must be divisible by
|
38
|
+
# (128 / block_size). We capture the 128 constant here so we can
|
39
|
+
# compute the LCM with other padding constraints.
|
40
|
+
TRTLLM_BLOCK_CONSTRAINT = 128
|
41
|
+
|
42
|
+
|
43
|
+
@dataclass
|
44
|
+
class TRTLLMMLADecodeMetadata:
|
45
|
+
"""Metadata for TRTLLM MLA decode operations."""
|
46
|
+
|
47
|
+
workspace: Optional[torch.Tensor] = None
|
48
|
+
block_kv_indices: Optional[torch.Tensor] = None
|
49
|
+
|
50
|
+
|
51
|
+
class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
52
|
+
"""TRTLLM MLA attention kernel from flashinfer."""
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
model_runner: ModelRunner,
|
57
|
+
skip_prefill: bool = False,
|
58
|
+
kv_indptr_buf: Optional[torch.Tensor] = None,
|
59
|
+
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
60
|
+
):
|
61
|
+
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
|
62
|
+
|
63
|
+
config = model_runner.model_config
|
64
|
+
|
65
|
+
# Model parameters
|
66
|
+
self.num_q_heads = config.num_attention_heads // get_attention_tp_size()
|
67
|
+
self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size())
|
68
|
+
self.num_local_heads = config.num_attention_heads // get_attention_tp_size()
|
69
|
+
|
70
|
+
# MLA-specific dimensions
|
71
|
+
self.kv_lora_rank = config.kv_lora_rank
|
72
|
+
self.qk_nope_head_dim = config.qk_nope_head_dim
|
73
|
+
self.qk_rope_head_dim = config.qk_rope_head_dim
|
74
|
+
self.v_head_dim = config.v_head_dim
|
75
|
+
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
|
76
|
+
|
77
|
+
# Runtime parameters
|
78
|
+
self.scaling = config.scaling
|
79
|
+
self.data_type = model_runner.kv_cache_dtype
|
80
|
+
self.q_data_type = model_runner.dtype
|
81
|
+
self.page_size = model_runner.page_size
|
82
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
83
|
+
|
84
|
+
# Workspace allocation
|
85
|
+
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
86
|
+
self.workspace_buffer = torch.empty(
|
87
|
+
self.workspace_size, dtype=torch.int8, device=self.device
|
88
|
+
)
|
89
|
+
|
90
|
+
# CUDA graph state
|
91
|
+
self.decode_cuda_graph_metadata = {}
|
92
|
+
self.cuda_graph_kv_indices = None
|
93
|
+
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
94
|
+
|
95
|
+
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
96
|
+
"""
|
97
|
+
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
max_seq_len: Maximum sequence length in tokens
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
Number of blocks padded to satisfy all constraints
|
104
|
+
"""
|
105
|
+
blocks = triton.cdiv(max_seq_len, self.page_size)
|
106
|
+
|
107
|
+
# Apply dual constraints (take LCM to satisfy both):
|
108
|
+
# 1. TRT-LLM: block_num % (128 / page_size) == 0
|
109
|
+
# 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
|
110
|
+
trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
|
111
|
+
constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
|
112
|
+
|
113
|
+
if blocks % constraint_lcm != 0:
|
114
|
+
blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
|
115
|
+
return blocks
|
116
|
+
|
117
|
+
def _create_block_kv_indices(
|
118
|
+
self,
|
119
|
+
batch_size: int,
|
120
|
+
max_blocks: int,
|
121
|
+
req_pool_indices: torch.Tensor,
|
122
|
+
seq_lens: torch.Tensor,
|
123
|
+
device: torch.device,
|
124
|
+
) -> torch.Tensor:
|
125
|
+
"""
|
126
|
+
Create block KV indices tensor using Triton kernel.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
batch_size: Batch size
|
130
|
+
max_blocks: Maximum number of blocks per sequence
|
131
|
+
req_pool_indices: Request pool indices
|
132
|
+
seq_lens: Sequence lengths
|
133
|
+
device: Target device
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
Block KV indices tensor
|
137
|
+
"""
|
138
|
+
block_kv_indices = torch.full(
|
139
|
+
(batch_size, max_blocks), -1, dtype=torch.int32, device=device
|
140
|
+
)
|
141
|
+
|
142
|
+
create_flashmla_kv_indices_triton[(batch_size,)](
|
143
|
+
self.req_to_token,
|
144
|
+
req_pool_indices,
|
145
|
+
seq_lens,
|
146
|
+
None,
|
147
|
+
block_kv_indices,
|
148
|
+
self.req_to_token.stride(0),
|
149
|
+
max_blocks,
|
150
|
+
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
151
|
+
self.page_size,
|
152
|
+
)
|
153
|
+
|
154
|
+
return block_kv_indices
|
155
|
+
|
156
|
+
def init_cuda_graph_state(
|
157
|
+
self,
|
158
|
+
max_bs: int,
|
159
|
+
max_num_tokens: int,
|
160
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
161
|
+
):
|
162
|
+
"""Initialize CUDA graph state for TRTLLM MLA."""
|
163
|
+
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
164
|
+
|
165
|
+
self.cuda_graph_kv_indices = torch.full(
|
166
|
+
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
167
|
+
)
|
168
|
+
self.cuda_graph_workspace = torch.empty(
|
169
|
+
self.workspace_size, dtype=torch.int8, device=self.device
|
170
|
+
)
|
171
|
+
|
172
|
+
def init_forward_metadata_capture_cuda_graph(
|
173
|
+
self,
|
174
|
+
bs: int,
|
175
|
+
num_tokens: int,
|
176
|
+
req_pool_indices: torch.Tensor,
|
177
|
+
seq_lens: torch.Tensor,
|
178
|
+
encoder_lens: Optional[torch.Tensor],
|
179
|
+
forward_mode: ForwardMode,
|
180
|
+
spec_info: Optional[SpecInfo],
|
181
|
+
):
|
182
|
+
"""Initialize metadata for CUDA graph capture."""
|
183
|
+
# Delegate to parent for non-decode modes or when speculative execution is used.
|
184
|
+
if not (forward_mode.is_decode_or_idle() and spec_info is None):
|
185
|
+
return super().init_forward_metadata_capture_cuda_graph(
|
186
|
+
bs,
|
187
|
+
num_tokens,
|
188
|
+
req_pool_indices,
|
189
|
+
seq_lens,
|
190
|
+
encoder_lens,
|
191
|
+
forward_mode,
|
192
|
+
spec_info,
|
193
|
+
)
|
194
|
+
|
195
|
+
# Custom fast-path for decode/idle without speculative execution.
|
196
|
+
max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
|
197
|
+
block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
|
198
|
+
|
199
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
200
|
+
self.req_to_token,
|
201
|
+
req_pool_indices,
|
202
|
+
seq_lens,
|
203
|
+
None,
|
204
|
+
block_kv_indices,
|
205
|
+
self.req_to_token.stride(0),
|
206
|
+
max_seqlen_pad,
|
207
|
+
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
208
|
+
self.page_size,
|
209
|
+
)
|
210
|
+
|
211
|
+
metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
|
212
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
213
|
+
self.forward_metadata = metadata
|
214
|
+
|
215
|
+
def init_forward_metadata_replay_cuda_graph(
|
216
|
+
self,
|
217
|
+
bs: int,
|
218
|
+
req_pool_indices: torch.Tensor,
|
219
|
+
seq_lens: torch.Tensor,
|
220
|
+
seq_lens_sum: int,
|
221
|
+
encoder_lens: Optional[torch.Tensor],
|
222
|
+
forward_mode: ForwardMode,
|
223
|
+
spec_info: Optional[SpecInfo],
|
224
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
225
|
+
):
|
226
|
+
"""Replay CUDA graph with new inputs."""
|
227
|
+
# Delegate to parent for non-decode modes or when speculative execution is used.
|
228
|
+
if not (forward_mode.is_decode_or_idle() and spec_info is None):
|
229
|
+
return super().init_forward_metadata_replay_cuda_graph(
|
230
|
+
bs,
|
231
|
+
req_pool_indices,
|
232
|
+
seq_lens,
|
233
|
+
seq_lens_sum,
|
234
|
+
encoder_lens,
|
235
|
+
forward_mode,
|
236
|
+
spec_info,
|
237
|
+
seq_lens_cpu,
|
238
|
+
)
|
239
|
+
|
240
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
241
|
+
|
242
|
+
# Update block indices for new sequences.
|
243
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
244
|
+
self.req_to_token,
|
245
|
+
req_pool_indices[:bs],
|
246
|
+
seq_lens[:bs],
|
247
|
+
None,
|
248
|
+
metadata.block_kv_indices,
|
249
|
+
self.req_to_token.stride(0),
|
250
|
+
metadata.block_kv_indices.shape[1],
|
251
|
+
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
252
|
+
self.page_size,
|
253
|
+
)
|
254
|
+
|
255
|
+
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
256
|
+
"""Get the fill value for sequence lengths in CUDA graph."""
|
257
|
+
return 1
|
258
|
+
|
259
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
260
|
+
"""Initialize the metadata for a forward pass."""
|
261
|
+
# Delegate to parent for non-decode modes or when speculative execution is used.
|
262
|
+
if not (
|
263
|
+
forward_batch.forward_mode.is_decode_or_idle()
|
264
|
+
and forward_batch.spec_info is None
|
265
|
+
):
|
266
|
+
return super().init_forward_metadata(forward_batch)
|
267
|
+
|
268
|
+
bs = forward_batch.batch_size
|
269
|
+
|
270
|
+
# Get maximum sequence length.
|
271
|
+
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
|
272
|
+
max_seq = forward_batch.seq_lens_cpu.max().item()
|
273
|
+
else:
|
274
|
+
max_seq = forward_batch.seq_lens.max().item()
|
275
|
+
|
276
|
+
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
277
|
+
block_kv_indices = self._create_block_kv_indices(
|
278
|
+
bs,
|
279
|
+
max_seqlen_pad,
|
280
|
+
forward_batch.req_pool_indices,
|
281
|
+
forward_batch.seq_lens,
|
282
|
+
forward_batch.seq_lens.device,
|
283
|
+
)
|
284
|
+
|
285
|
+
self.forward_metadata = TRTLLMMLADecodeMetadata(
|
286
|
+
self.workspace_buffer, block_kv_indices
|
287
|
+
)
|
288
|
+
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
289
|
+
|
290
|
+
def forward_decode(
|
291
|
+
self,
|
292
|
+
q: torch.Tensor,
|
293
|
+
k: torch.Tensor,
|
294
|
+
v: torch.Tensor,
|
295
|
+
layer: RadixAttention,
|
296
|
+
forward_batch: ForwardBatch,
|
297
|
+
save_kv_cache: bool = True,
|
298
|
+
q_rope: Optional[torch.Tensor] = None,
|
299
|
+
k_rope: Optional[torch.Tensor] = None,
|
300
|
+
) -> torch.Tensor:
|
301
|
+
"""Run forward for decode using TRTLLM MLA kernel."""
|
302
|
+
# Save KV cache if requested
|
303
|
+
if k is not None and save_kv_cache:
|
304
|
+
cache_loc = forward_batch.out_cache_loc
|
305
|
+
if k_rope is not None:
|
306
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
307
|
+
layer, cache_loc, k, k_rope
|
308
|
+
)
|
309
|
+
elif v is not None:
|
310
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
311
|
+
|
312
|
+
# Prepare query tensor inline
|
313
|
+
if q_rope is not None:
|
314
|
+
# q contains NOPE part (v_head_dim)
|
315
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
316
|
+
q_rope_reshaped = q_rope.view(
|
317
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
318
|
+
)
|
319
|
+
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
320
|
+
else:
|
321
|
+
# q already has both parts
|
322
|
+
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
323
|
+
|
324
|
+
# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
|
325
|
+
if query.dim() == 3:
|
326
|
+
query = query.unsqueeze(1)
|
327
|
+
|
328
|
+
# Prepare KV cache inline
|
329
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
330
|
+
pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
|
331
|
+
# TRT-LLM expects single KV data with extra dimension
|
332
|
+
kv_cache = pages.unsqueeze(1)
|
333
|
+
|
334
|
+
# Get metadata
|
335
|
+
metadata = (
|
336
|
+
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
337
|
+
or self.forward_metadata
|
338
|
+
)
|
339
|
+
|
340
|
+
# Scale computation for TRTLLM MLA kernel:
|
341
|
+
# - BMM1 scale = q_scale * k_scale * softmax_scale
|
342
|
+
# - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling
|
343
|
+
# - k_scale is read from model checkpoint if available
|
344
|
+
# TODO: Change once fp8 path is supported
|
345
|
+
q_scale = 1.0
|
346
|
+
k_scale = (
|
347
|
+
layer.k_scale_float
|
348
|
+
if getattr(layer, "k_scale_float", None) is not None
|
349
|
+
else 1.0
|
350
|
+
)
|
351
|
+
|
352
|
+
bmm1_scale = q_scale * k_scale * layer.scaling
|
353
|
+
|
354
|
+
# Call TRT-LLM kernel
|
355
|
+
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
356
|
+
query=query,
|
357
|
+
kv_cache=kv_cache,
|
358
|
+
workspace_buffer=metadata.workspace,
|
359
|
+
qk_nope_head_dim=self.qk_nope_head_dim,
|
360
|
+
kv_lora_rank=self.kv_lora_rank,
|
361
|
+
qk_rope_head_dim=self.qk_rope_head_dim,
|
362
|
+
block_tables=metadata.block_kv_indices,
|
363
|
+
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
364
|
+
max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
|
365
|
+
bmm1_scale=bmm1_scale,
|
366
|
+
)
|
367
|
+
|
368
|
+
# Extract value projection part and reshape
|
369
|
+
raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
|
370
|
+
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
371
|
+
|
372
|
+
return output
|
@@ -1,6 +1,11 @@
|
|
1
1
|
import triton
|
2
2
|
import triton.language as tl
|
3
3
|
|
4
|
+
# Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
|
5
|
+
# Number of pages that the kernel writes per iteration.
|
6
|
+
# Exposed here so other Python modules can import it instead of hard-coding 64.
|
7
|
+
TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
|
8
|
+
|
4
9
|
|
5
10
|
@triton.jit
|
6
11
|
def create_flashinfer_kv_indices_triton(
|
@@ -50,10 +55,10 @@ def create_flashmla_kv_indices_triton(
|
|
50
55
|
kv_indices_ptr,
|
51
56
|
req_to_token_ptr_stride: tl.constexpr,
|
52
57
|
kv_indices_ptr_stride: tl.constexpr,
|
58
|
+
NUM_PAGE_PER_BLOCK: tl.constexpr = TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
53
59
|
PAGED_SIZE: tl.constexpr = 64,
|
54
60
|
):
|
55
61
|
BLOCK_SIZE: tl.constexpr = 4096
|
56
|
-
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
|
57
62
|
pid = tl.program_id(axis=0)
|
58
63
|
|
59
64
|
# find the req pool idx, this is for batch to token
|
@@ -25,14 +25,22 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
25
25
|
silu_and_mul_triton_kernel,
|
26
26
|
tma_align_input_scale,
|
27
27
|
)
|
28
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import
|
28
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
29
|
+
FlashInferFusedMoE,
|
30
|
+
FusedMoE,
|
31
|
+
should_use_flashinfer_trtllm_moe,
|
32
|
+
)
|
29
33
|
from sglang.srt.layers.moe.topk import TopKOutput
|
30
34
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
31
35
|
from sglang.srt.layers.quantization.base_config import (
|
32
36
|
QuantizationConfig,
|
33
37
|
QuantizeMethodBase,
|
34
38
|
)
|
35
|
-
from sglang.srt.layers.quantization.fp8 import
|
39
|
+
from sglang.srt.layers.quantization.fp8 import (
|
40
|
+
Fp8Config,
|
41
|
+
Fp8MoEMethod,
|
42
|
+
get_tile_tokens_dim,
|
43
|
+
)
|
36
44
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
37
45
|
is_fp8_fnuz,
|
38
46
|
sglang_per_token_group_quant_fp8,
|
@@ -49,7 +57,6 @@ from sglang.srt.utils import (
|
|
49
57
|
get_bool_env_var,
|
50
58
|
is_hip,
|
51
59
|
is_npu,
|
52
|
-
next_power_of_2,
|
53
60
|
)
|
54
61
|
|
55
62
|
if TYPE_CHECKING:
|
@@ -63,10 +70,7 @@ _is_hip = is_hip()
|
|
63
70
|
_is_npu = is_npu()
|
64
71
|
_is_fp8_fnuz = is_fp8_fnuz()
|
65
72
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
66
|
-
|
67
|
-
global_server_args_dict["enable_flashinfer_trtllm_moe"]
|
68
|
-
and global_server_args_dict["enable_ep_moe"]
|
69
|
-
)
|
73
|
+
|
70
74
|
|
71
75
|
if not (_is_npu or _is_hip):
|
72
76
|
from sgl_kernel import silu_and_mul
|
@@ -76,26 +80,9 @@ if _use_aiter:
|
|
76
80
|
from aiter.fused_moe import fused_moe
|
77
81
|
from aiter.ops.shuffle import shuffle_weight
|
78
82
|
|
79
|
-
if use_flashinfer_trtllm_moe:
|
80
|
-
try:
|
81
|
-
import flashinfer.fused_moe as fi_fused_moe
|
82
|
-
except ImportError:
|
83
|
-
fi_fused_moe = None
|
84
|
-
use_flashinfer_trtllm_moe = False
|
85
|
-
|
86
83
|
logger = logging.getLogger(__name__)
|
87
84
|
|
88
85
|
|
89
|
-
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
90
|
-
# Guess tokens per expert assuming perfect expert distribution first.
|
91
|
-
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
92
|
-
# And pad the number to the next power of 2.
|
93
|
-
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
94
|
-
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
95
|
-
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
96
|
-
return tile_tokens_dim
|
97
|
-
|
98
|
-
|
99
86
|
class EPMoE(FusedMoE):
|
100
87
|
"""
|
101
88
|
MoE Expert Parallel Impl
|
@@ -731,10 +718,10 @@ class FlashInferEPMoE(EPMoE):
|
|
731
718
|
self.num_expert_group = num_expert_group
|
732
719
|
self.topk_group = topk_group
|
733
720
|
self.correction_bias = correction_bias
|
734
|
-
self.use_flashinfer_trtllm_moe =
|
721
|
+
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
735
722
|
|
736
723
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
737
|
-
assert use_flashinfer_trtllm_moe
|
724
|
+
assert self.use_flashinfer_trtllm_moe
|
738
725
|
assert (
|
739
726
|
self.activation == "silu"
|
740
727
|
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
@@ -747,8 +734,9 @@ class FlashInferEPMoE(EPMoE):
|
|
747
734
|
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
|
748
735
|
# NOTE: scales of hidden states have to be transposed!
|
749
736
|
a_sf_t = a_sf.t().contiguous()
|
750
|
-
|
751
|
-
|
737
|
+
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
738
|
+
|
739
|
+
return trtllm_fp8_block_scale_moe(
|
752
740
|
routing_logits=router_logits.to(torch.float32),
|
753
741
|
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
754
742
|
hidden_states=a_q,
|
@@ -765,7 +753,7 @@ class FlashInferEPMoE(EPMoE):
|
|
765
753
|
local_expert_offset=self.start_expert_id,
|
766
754
|
local_num_experts=self.num_local_experts,
|
767
755
|
routed_scaling_factor=self.routed_scaling_factor,
|
768
|
-
tile_tokens_dim=
|
756
|
+
tile_tokens_dim=get_tile_tokens_dim(
|
769
757
|
hidden_states.shape[0], self.top_k, self.num_experts
|
770
758
|
),
|
771
759
|
routing_method_type=2, # DeepSeek-styled routing method
|
@@ -779,9 +767,6 @@ def get_moe_impl_class():
|
|
779
767
|
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
780
768
|
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
781
769
|
return FusedMoE
|
782
|
-
if use_flashinfer_trtllm_moe:
|
783
|
-
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
784
|
-
return FlashInferEPMoE
|
785
770
|
if global_server_args_dict["enable_ep_moe"]:
|
786
|
-
return EPMoE
|
787
|
-
return FusedMoE
|
771
|
+
return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
|
772
|
+
return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
|