sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +10 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +39 -674
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
- sglang/srt/layers/quantization/fp8.py +52 -18
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/managers/cache_controller.py +165 -67
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +0 -2
- sglang/srt/managers/scheduler.py +90 -671
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +123 -74
- sglang/srt/managers/tp_worker.py +4 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +60 -17
- sglang/srt/mem_cache/hiradix_cache.py +36 -8
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +418 -29
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/cuda_graph_runner.py +25 -1
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/glm4_moe.py +6 -4
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/step3_vl.py +991 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +49 -18
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/utils.py +1 -0
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,372 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
Support attention backend for TRTLLM MLA kernels from flashinfer.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import math
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from typing import TYPE_CHECKING, Optional, Union
|
10
|
+
|
11
|
+
import torch
|
12
|
+
import triton
|
13
|
+
|
14
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
15
|
+
from sglang.srt.layers.attention.utils import (
|
16
|
+
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
17
|
+
create_flashmla_kv_indices_triton,
|
18
|
+
)
|
19
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
20
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
21
|
+
from sglang.srt.utils import is_flashinfer_available
|
22
|
+
|
23
|
+
if is_flashinfer_available():
|
24
|
+
import flashinfer
|
25
|
+
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
28
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
29
|
+
from sglang.srt.speculative.spec_info import SpecInfo
|
30
|
+
|
31
|
+
# Constants
|
32
|
+
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
33
|
+
|
34
|
+
# Block constraint from flashinfer requirements
|
35
|
+
# From flashinfer.decode._check_trtllm_gen_mla_shape:
|
36
|
+
# block_num % (128 / block_size) == 0
|
37
|
+
# This imposes that the total number of blocks must be divisible by
|
38
|
+
# (128 / block_size). We capture the 128 constant here so we can
|
39
|
+
# compute the LCM with other padding constraints.
|
40
|
+
TRTLLM_BLOCK_CONSTRAINT = 128
|
41
|
+
|
42
|
+
|
43
|
+
@dataclass
|
44
|
+
class TRTLLMMLADecodeMetadata:
|
45
|
+
"""Metadata for TRTLLM MLA decode operations."""
|
46
|
+
|
47
|
+
workspace: Optional[torch.Tensor] = None
|
48
|
+
block_kv_indices: Optional[torch.Tensor] = None
|
49
|
+
|
50
|
+
|
51
|
+
class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
52
|
+
"""TRTLLM MLA attention kernel from flashinfer."""
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
model_runner: ModelRunner,
|
57
|
+
skip_prefill: bool = False,
|
58
|
+
kv_indptr_buf: Optional[torch.Tensor] = None,
|
59
|
+
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
60
|
+
):
|
61
|
+
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
|
62
|
+
|
63
|
+
config = model_runner.model_config
|
64
|
+
|
65
|
+
# Model parameters
|
66
|
+
self.num_q_heads = config.num_attention_heads // get_attention_tp_size()
|
67
|
+
self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size())
|
68
|
+
self.num_local_heads = config.num_attention_heads // get_attention_tp_size()
|
69
|
+
|
70
|
+
# MLA-specific dimensions
|
71
|
+
self.kv_lora_rank = config.kv_lora_rank
|
72
|
+
self.qk_nope_head_dim = config.qk_nope_head_dim
|
73
|
+
self.qk_rope_head_dim = config.qk_rope_head_dim
|
74
|
+
self.v_head_dim = config.v_head_dim
|
75
|
+
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
|
76
|
+
|
77
|
+
# Runtime parameters
|
78
|
+
self.scaling = config.scaling
|
79
|
+
self.data_type = model_runner.kv_cache_dtype
|
80
|
+
self.q_data_type = model_runner.dtype
|
81
|
+
self.page_size = model_runner.page_size
|
82
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
83
|
+
|
84
|
+
# Workspace allocation
|
85
|
+
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
86
|
+
self.workspace_buffer = torch.empty(
|
87
|
+
self.workspace_size, dtype=torch.int8, device=self.device
|
88
|
+
)
|
89
|
+
|
90
|
+
# CUDA graph state
|
91
|
+
self.decode_cuda_graph_metadata = {}
|
92
|
+
self.cuda_graph_kv_indices = None
|
93
|
+
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
94
|
+
|
95
|
+
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
96
|
+
"""
|
97
|
+
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
max_seq_len: Maximum sequence length in tokens
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
Number of blocks padded to satisfy all constraints
|
104
|
+
"""
|
105
|
+
blocks = triton.cdiv(max_seq_len, self.page_size)
|
106
|
+
|
107
|
+
# Apply dual constraints (take LCM to satisfy both):
|
108
|
+
# 1. TRT-LLM: block_num % (128 / page_size) == 0
|
109
|
+
# 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
|
110
|
+
trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
|
111
|
+
constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
|
112
|
+
|
113
|
+
if blocks % constraint_lcm != 0:
|
114
|
+
blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
|
115
|
+
return blocks
|
116
|
+
|
117
|
+
def _create_block_kv_indices(
|
118
|
+
self,
|
119
|
+
batch_size: int,
|
120
|
+
max_blocks: int,
|
121
|
+
req_pool_indices: torch.Tensor,
|
122
|
+
seq_lens: torch.Tensor,
|
123
|
+
device: torch.device,
|
124
|
+
) -> torch.Tensor:
|
125
|
+
"""
|
126
|
+
Create block KV indices tensor using Triton kernel.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
batch_size: Batch size
|
130
|
+
max_blocks: Maximum number of blocks per sequence
|
131
|
+
req_pool_indices: Request pool indices
|
132
|
+
seq_lens: Sequence lengths
|
133
|
+
device: Target device
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
Block KV indices tensor
|
137
|
+
"""
|
138
|
+
block_kv_indices = torch.full(
|
139
|
+
(batch_size, max_blocks), -1, dtype=torch.int32, device=device
|
140
|
+
)
|
141
|
+
|
142
|
+
create_flashmla_kv_indices_triton[(batch_size,)](
|
143
|
+
self.req_to_token,
|
144
|
+
req_pool_indices,
|
145
|
+
seq_lens,
|
146
|
+
None,
|
147
|
+
block_kv_indices,
|
148
|
+
self.req_to_token.stride(0),
|
149
|
+
max_blocks,
|
150
|
+
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
151
|
+
self.page_size,
|
152
|
+
)
|
153
|
+
|
154
|
+
return block_kv_indices
|
155
|
+
|
156
|
+
def init_cuda_graph_state(
|
157
|
+
self,
|
158
|
+
max_bs: int,
|
159
|
+
max_num_tokens: int,
|
160
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
161
|
+
):
|
162
|
+
"""Initialize CUDA graph state for TRTLLM MLA."""
|
163
|
+
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
164
|
+
|
165
|
+
self.cuda_graph_kv_indices = torch.full(
|
166
|
+
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
167
|
+
)
|
168
|
+
self.cuda_graph_workspace = torch.empty(
|
169
|
+
self.workspace_size, dtype=torch.int8, device=self.device
|
170
|
+
)
|
171
|
+
|
172
|
+
def init_forward_metadata_capture_cuda_graph(
|
173
|
+
self,
|
174
|
+
bs: int,
|
175
|
+
num_tokens: int,
|
176
|
+
req_pool_indices: torch.Tensor,
|
177
|
+
seq_lens: torch.Tensor,
|
178
|
+
encoder_lens: Optional[torch.Tensor],
|
179
|
+
forward_mode: ForwardMode,
|
180
|
+
spec_info: Optional[SpecInfo],
|
181
|
+
):
|
182
|
+
"""Initialize metadata for CUDA graph capture."""
|
183
|
+
# Delegate to parent for non-decode modes or when speculative execution is used.
|
184
|
+
if not (forward_mode.is_decode_or_idle() and spec_info is None):
|
185
|
+
return super().init_forward_metadata_capture_cuda_graph(
|
186
|
+
bs,
|
187
|
+
num_tokens,
|
188
|
+
req_pool_indices,
|
189
|
+
seq_lens,
|
190
|
+
encoder_lens,
|
191
|
+
forward_mode,
|
192
|
+
spec_info,
|
193
|
+
)
|
194
|
+
|
195
|
+
# Custom fast-path for decode/idle without speculative execution.
|
196
|
+
max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
|
197
|
+
block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
|
198
|
+
|
199
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
200
|
+
self.req_to_token,
|
201
|
+
req_pool_indices,
|
202
|
+
seq_lens,
|
203
|
+
None,
|
204
|
+
block_kv_indices,
|
205
|
+
self.req_to_token.stride(0),
|
206
|
+
max_seqlen_pad,
|
207
|
+
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
208
|
+
self.page_size,
|
209
|
+
)
|
210
|
+
|
211
|
+
metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
|
212
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
213
|
+
self.forward_metadata = metadata
|
214
|
+
|
215
|
+
def init_forward_metadata_replay_cuda_graph(
|
216
|
+
self,
|
217
|
+
bs: int,
|
218
|
+
req_pool_indices: torch.Tensor,
|
219
|
+
seq_lens: torch.Tensor,
|
220
|
+
seq_lens_sum: int,
|
221
|
+
encoder_lens: Optional[torch.Tensor],
|
222
|
+
forward_mode: ForwardMode,
|
223
|
+
spec_info: Optional[SpecInfo],
|
224
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
225
|
+
):
|
226
|
+
"""Replay CUDA graph with new inputs."""
|
227
|
+
# Delegate to parent for non-decode modes or when speculative execution is used.
|
228
|
+
if not (forward_mode.is_decode_or_idle() and spec_info is None):
|
229
|
+
return super().init_forward_metadata_replay_cuda_graph(
|
230
|
+
bs,
|
231
|
+
req_pool_indices,
|
232
|
+
seq_lens,
|
233
|
+
seq_lens_sum,
|
234
|
+
encoder_lens,
|
235
|
+
forward_mode,
|
236
|
+
spec_info,
|
237
|
+
seq_lens_cpu,
|
238
|
+
)
|
239
|
+
|
240
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
241
|
+
|
242
|
+
# Update block indices for new sequences.
|
243
|
+
create_flashmla_kv_indices_triton[(bs,)](
|
244
|
+
self.req_to_token,
|
245
|
+
req_pool_indices[:bs],
|
246
|
+
seq_lens[:bs],
|
247
|
+
None,
|
248
|
+
metadata.block_kv_indices,
|
249
|
+
self.req_to_token.stride(0),
|
250
|
+
metadata.block_kv_indices.shape[1],
|
251
|
+
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
252
|
+
self.page_size,
|
253
|
+
)
|
254
|
+
|
255
|
+
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
256
|
+
"""Get the fill value for sequence lengths in CUDA graph."""
|
257
|
+
return 1
|
258
|
+
|
259
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
260
|
+
"""Initialize the metadata for a forward pass."""
|
261
|
+
# Delegate to parent for non-decode modes or when speculative execution is used.
|
262
|
+
if not (
|
263
|
+
forward_batch.forward_mode.is_decode_or_idle()
|
264
|
+
and forward_batch.spec_info is None
|
265
|
+
):
|
266
|
+
return super().init_forward_metadata(forward_batch)
|
267
|
+
|
268
|
+
bs = forward_batch.batch_size
|
269
|
+
|
270
|
+
# Get maximum sequence length.
|
271
|
+
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
|
272
|
+
max_seq = forward_batch.seq_lens_cpu.max().item()
|
273
|
+
else:
|
274
|
+
max_seq = forward_batch.seq_lens.max().item()
|
275
|
+
|
276
|
+
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
277
|
+
block_kv_indices = self._create_block_kv_indices(
|
278
|
+
bs,
|
279
|
+
max_seqlen_pad,
|
280
|
+
forward_batch.req_pool_indices,
|
281
|
+
forward_batch.seq_lens,
|
282
|
+
forward_batch.seq_lens.device,
|
283
|
+
)
|
284
|
+
|
285
|
+
self.forward_metadata = TRTLLMMLADecodeMetadata(
|
286
|
+
self.workspace_buffer, block_kv_indices
|
287
|
+
)
|
288
|
+
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
289
|
+
|
290
|
+
def forward_decode(
|
291
|
+
self,
|
292
|
+
q: torch.Tensor,
|
293
|
+
k: torch.Tensor,
|
294
|
+
v: torch.Tensor,
|
295
|
+
layer: RadixAttention,
|
296
|
+
forward_batch: ForwardBatch,
|
297
|
+
save_kv_cache: bool = True,
|
298
|
+
q_rope: Optional[torch.Tensor] = None,
|
299
|
+
k_rope: Optional[torch.Tensor] = None,
|
300
|
+
) -> torch.Tensor:
|
301
|
+
"""Run forward for decode using TRTLLM MLA kernel."""
|
302
|
+
# Save KV cache if requested
|
303
|
+
if k is not None and save_kv_cache:
|
304
|
+
cache_loc = forward_batch.out_cache_loc
|
305
|
+
if k_rope is not None:
|
306
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
307
|
+
layer, cache_loc, k, k_rope
|
308
|
+
)
|
309
|
+
elif v is not None:
|
310
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
311
|
+
|
312
|
+
# Prepare query tensor inline
|
313
|
+
if q_rope is not None:
|
314
|
+
# q contains NOPE part (v_head_dim)
|
315
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
316
|
+
q_rope_reshaped = q_rope.view(
|
317
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
318
|
+
)
|
319
|
+
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
320
|
+
else:
|
321
|
+
# q already has both parts
|
322
|
+
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
323
|
+
|
324
|
+
# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
|
325
|
+
if query.dim() == 3:
|
326
|
+
query = query.unsqueeze(1)
|
327
|
+
|
328
|
+
# Prepare KV cache inline
|
329
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
330
|
+
pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
|
331
|
+
# TRT-LLM expects single KV data with extra dimension
|
332
|
+
kv_cache = pages.unsqueeze(1)
|
333
|
+
|
334
|
+
# Get metadata
|
335
|
+
metadata = (
|
336
|
+
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
337
|
+
or self.forward_metadata
|
338
|
+
)
|
339
|
+
|
340
|
+
# Scale computation for TRTLLM MLA kernel:
|
341
|
+
# - BMM1 scale = q_scale * k_scale * softmax_scale
|
342
|
+
# - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling
|
343
|
+
# - k_scale is read from model checkpoint if available
|
344
|
+
# TODO: Change once fp8 path is supported
|
345
|
+
q_scale = 1.0
|
346
|
+
k_scale = (
|
347
|
+
layer.k_scale_float
|
348
|
+
if getattr(layer, "k_scale_float", None) is not None
|
349
|
+
else 1.0
|
350
|
+
)
|
351
|
+
|
352
|
+
bmm1_scale = q_scale * k_scale * layer.scaling
|
353
|
+
|
354
|
+
# Call TRT-LLM kernel
|
355
|
+
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
356
|
+
query=query,
|
357
|
+
kv_cache=kv_cache,
|
358
|
+
workspace_buffer=metadata.workspace,
|
359
|
+
qk_nope_head_dim=self.qk_nope_head_dim,
|
360
|
+
kv_lora_rank=self.kv_lora_rank,
|
361
|
+
qk_rope_head_dim=self.qk_rope_head_dim,
|
362
|
+
block_tables=metadata.block_kv_indices,
|
363
|
+
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
364
|
+
max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
|
365
|
+
bmm1_scale=bmm1_scale,
|
366
|
+
)
|
367
|
+
|
368
|
+
# Extract value projection part and reshape
|
369
|
+
raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
|
370
|
+
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
371
|
+
|
372
|
+
return output
|
@@ -1,6 +1,11 @@
|
|
1
1
|
import triton
|
2
2
|
import triton.language as tl
|
3
3
|
|
4
|
+
# Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
|
5
|
+
# Number of pages that the kernel writes per iteration.
|
6
|
+
# Exposed here so other Python modules can import it instead of hard-coding 64.
|
7
|
+
TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
|
8
|
+
|
4
9
|
|
5
10
|
@triton.jit
|
6
11
|
def create_flashinfer_kv_indices_triton(
|
@@ -50,10 +55,10 @@ def create_flashmla_kv_indices_triton(
|
|
50
55
|
kv_indices_ptr,
|
51
56
|
req_to_token_ptr_stride: tl.constexpr,
|
52
57
|
kv_indices_ptr_stride: tl.constexpr,
|
58
|
+
NUM_PAGE_PER_BLOCK: tl.constexpr = TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
53
59
|
PAGED_SIZE: tl.constexpr = 64,
|
54
60
|
):
|
55
61
|
BLOCK_SIZE: tl.constexpr = 4096
|
56
|
-
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
|
57
62
|
pid = tl.program_id(axis=0)
|
58
63
|
|
59
64
|
# find the req pool idx, this is for batch to token
|
@@ -209,7 +209,8 @@ def cutlass_fused_experts_fp8(
|
|
209
209
|
)
|
210
210
|
|
211
211
|
result = torch.empty((m, k), device=device, dtype=out_dtype)
|
212
|
-
|
212
|
+
apply_shuffle_mul_sum(c2, result, c_map, topk_weights.to(out_dtype))
|
213
|
+
return result
|
213
214
|
|
214
215
|
|
215
216
|
FLOAT4_E2M1_MAX = 6.0
|