sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -14,14 +14,21 @@ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
|
|
14
14
|
fused_sigmoid_gating_delta_rule_update,
|
15
15
|
)
|
16
16
|
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
|
17
|
+
PAD_SLOT_ID,
|
17
18
|
causal_conv1d_fn,
|
18
19
|
causal_conv1d_update,
|
19
20
|
)
|
21
|
+
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
|
22
|
+
from sglang.srt.layers.attention.mamba.mamba2_metadata import (
|
23
|
+
ForwardMetadata,
|
24
|
+
Mamba2Metadata,
|
25
|
+
)
|
20
26
|
from sglang.srt.layers.radix_attention import RadixAttention
|
21
|
-
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
|
27
|
+
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool
|
22
28
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
23
29
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
24
30
|
from sglang.srt.models.qwen3_next import fused_gdn_gating
|
31
|
+
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
25
32
|
from sglang.srt.speculative.spec_info import SpecInput
|
26
33
|
from sglang.srt.utils import is_cuda, is_npu
|
27
34
|
|
@@ -47,18 +54,10 @@ elif is_npu():
|
|
47
54
|
causal_conv1d_update = causal_conv1d_update_npu
|
48
55
|
|
49
56
|
|
50
|
-
|
51
|
-
class ForwardMetadata:
|
52
|
-
query_start_loc: Optional[torch.Tensor]
|
53
|
-
mamba_cache_indices: torch.Tensor
|
54
|
-
|
55
|
-
|
56
|
-
class MambaAttnBackend(AttentionBackend):
|
57
|
-
"""Attention backend using Mamba kernel."""
|
58
|
-
|
57
|
+
class MambaAttnBackendBase(AttentionBackend):
|
59
58
|
def __init__(self, model_runner: ModelRunner):
|
60
59
|
super().__init__()
|
61
|
-
self.pad_slot_id =
|
60
|
+
self.pad_slot_id = PAD_SLOT_ID
|
62
61
|
self.device = model_runner.device
|
63
62
|
self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool
|
64
63
|
self.forward_metadata: ForwardMetadata = None
|
@@ -67,7 +66,7 @@ class MambaAttnBackend(AttentionBackend):
|
|
67
66
|
self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
|
68
67
|
self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
|
69
68
|
|
70
|
-
def
|
69
|
+
def _forward_metadata(self, forward_batch: ForwardBatch):
|
71
70
|
bs = forward_batch.batch_size
|
72
71
|
|
73
72
|
if forward_batch.forward_mode.is_decode_or_idle():
|
@@ -97,11 +96,43 @@ class MambaAttnBackend(AttentionBackend):
|
|
97
96
|
mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(
|
98
97
|
forward_batch.req_pool_indices
|
99
98
|
)
|
100
|
-
|
99
|
+
return ForwardMetadata(
|
101
100
|
query_start_loc=query_start_loc,
|
102
101
|
mamba_cache_indices=mamba_cache_indices,
|
103
102
|
)
|
104
103
|
|
104
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
105
|
+
self.forward_metadata = self._forward_metadata(forward_batch)
|
106
|
+
|
107
|
+
def init_forward_metadata_capture_cuda_graph(
|
108
|
+
self,
|
109
|
+
bs: int,
|
110
|
+
num_tokens: int,
|
111
|
+
req_pool_indices: torch.Tensor,
|
112
|
+
seq_lens: torch.Tensor,
|
113
|
+
encoder_lens: Optional[torch.Tensor],
|
114
|
+
forward_mode: ForwardMode,
|
115
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
116
|
+
):
|
117
|
+
self.forward_metadata = self._capture_metadata(
|
118
|
+
bs, req_pool_indices, forward_mode
|
119
|
+
)
|
120
|
+
|
121
|
+
def init_forward_metadata_replay_cuda_graph(
|
122
|
+
self,
|
123
|
+
bs: int,
|
124
|
+
req_pool_indices: torch.Tensor,
|
125
|
+
seq_lens: torch.Tensor,
|
126
|
+
seq_lens_sum: int,
|
127
|
+
encoder_lens: Optional[torch.Tensor],
|
128
|
+
forward_mode: ForwardMode,
|
129
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
130
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
131
|
+
):
|
132
|
+
self.forward_metadata = self._replay_metadata(
|
133
|
+
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
|
134
|
+
)
|
135
|
+
|
105
136
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
106
137
|
assert (
|
107
138
|
max_num_tokens % max_bs == 0
|
@@ -127,15 +158,8 @@ class MambaAttnBackend(AttentionBackend):
|
|
127
158
|
device=self.device,
|
128
159
|
)
|
129
160
|
|
130
|
-
def
|
131
|
-
self,
|
132
|
-
bs: int,
|
133
|
-
num_tokens: int,
|
134
|
-
req_pool_indices: torch.Tensor,
|
135
|
-
seq_lens: torch.Tensor,
|
136
|
-
encoder_lens: Optional[torch.Tensor],
|
137
|
-
forward_mode: ForwardMode,
|
138
|
-
spec_info: Optional[SpecInput],
|
161
|
+
def _capture_metadata(
|
162
|
+
self, bs: int, req_pool_indices: torch.Tensor, forward_mode: ForwardMode
|
139
163
|
):
|
140
164
|
if forward_mode.is_decode_or_idle():
|
141
165
|
self.query_start_loc_list[bs - 1].copy_(
|
@@ -149,18 +173,15 @@ class MambaAttnBackend(AttentionBackend):
|
|
149
173
|
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
150
174
|
mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
|
151
175
|
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
|
152
|
-
|
176
|
+
return ForwardMetadata(
|
153
177
|
query_start_loc=self.query_start_loc_list[bs - 1],
|
154
178
|
mamba_cache_indices=self.state_indices_list[bs - 1],
|
155
179
|
)
|
156
180
|
|
157
|
-
def
|
181
|
+
def _replay_metadata(
|
158
182
|
self,
|
159
183
|
bs: int,
|
160
184
|
req_pool_indices: torch.Tensor,
|
161
|
-
seq_lens: torch.Tensor,
|
162
|
-
seq_lens_sum: int,
|
163
|
-
encoder_lens: Optional[torch.Tensor],
|
164
185
|
forward_mode: ForwardMode,
|
165
186
|
spec_info: Optional[SpecInput],
|
166
187
|
seq_lens_cpu: Optional[torch.Tensor],
|
@@ -200,7 +221,7 @@ class MambaAttnBackend(AttentionBackend):
|
|
200
221
|
else:
|
201
222
|
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
202
223
|
|
203
|
-
|
224
|
+
return ForwardMetadata(
|
204
225
|
query_start_loc=self.query_start_loc_list[bs - 1],
|
205
226
|
mamba_cache_indices=self.state_indices_list[bs - 1],
|
206
227
|
)
|
@@ -208,6 +229,10 @@ class MambaAttnBackend(AttentionBackend):
|
|
208
229
|
def get_cuda_graph_seq_len_fill_value(self):
|
209
230
|
return 1 # Mamba attn does not use seq lens to index kv cache
|
210
231
|
|
232
|
+
|
233
|
+
class GDNAttnBackend(MambaAttnBackendBase):
|
234
|
+
"""Attention backend using Mamba kernel."""
|
235
|
+
|
211
236
|
def forward_decode(
|
212
237
|
self,
|
213
238
|
q: torch.Tensor,
|
@@ -233,9 +258,9 @@ class MambaAttnBackend(AttentionBackend):
|
|
233
258
|
dt_bias = kwargs["dt_bias"]
|
234
259
|
layer_id = kwargs["layer_id"]
|
235
260
|
|
236
|
-
|
237
|
-
|
238
|
-
|
261
|
+
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
|
262
|
+
conv_states = layer_cache.conv
|
263
|
+
ssm_states = layer_cache.temporal
|
239
264
|
query_start_loc = self.forward_metadata.query_start_loc
|
240
265
|
cache_indices = self.forward_metadata.mamba_cache_indices
|
241
266
|
|
@@ -313,13 +338,13 @@ class MambaAttnBackend(AttentionBackend):
|
|
313
338
|
query_start_loc = self.forward_metadata.query_start_loc
|
314
339
|
cache_indices = self.forward_metadata.mamba_cache_indices
|
315
340
|
|
341
|
+
mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
|
342
|
+
conv_states = mamba_cache_params.conv
|
343
|
+
ssm_states = mamba_cache_params.temporal
|
316
344
|
if is_target_verify:
|
317
|
-
(
|
318
|
-
|
319
|
-
|
320
|
-
intermediate_state_cache,
|
321
|
-
intermediate_conv_window_cache,
|
322
|
-
) = self.req_to_token_pool.get_mamba_params(layer_id)
|
345
|
+
assert isinstance(mamba_cache_params, MambaPool.SpeculativeState)
|
346
|
+
intermediate_state_cache = mamba_cache_params.intermediate_ssm
|
347
|
+
intermediate_conv_window_cache = mamba_cache_params.intermediate_conv_window
|
323
348
|
has_initial_states = torch.ones(
|
324
349
|
seq_len // forward_batch.spec_info.draft_token_num,
|
325
350
|
dtype=torch.bool,
|
@@ -327,9 +352,6 @@ class MambaAttnBackend(AttentionBackend):
|
|
327
352
|
)
|
328
353
|
conv_states_to_use = conv_states.clone()
|
329
354
|
else:
|
330
|
-
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
|
331
|
-
layer_id
|
332
|
-
)
|
333
355
|
has_initial_states = forward_batch.extend_prefix_lens > 0
|
334
356
|
conv_states_to_use = conv_states
|
335
357
|
|
@@ -424,16 +446,100 @@ class MambaAttnBackend(AttentionBackend):
|
|
424
446
|
return core_attn_out
|
425
447
|
|
426
448
|
|
449
|
+
class Mamba2AttnBackend(MambaAttnBackendBase):
|
450
|
+
"""Attention backend wrapper for Mamba2Mixer kernels."""
|
451
|
+
|
452
|
+
def __init__(self, model_runner: ModelRunner):
|
453
|
+
super().__init__(model_runner)
|
454
|
+
config = model_runner.mamba2_config
|
455
|
+
assert config is not None
|
456
|
+
self.mamba_chunk_size = config.mamba_chunk_size
|
457
|
+
|
458
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
459
|
+
metadata = self._forward_metadata(forward_batch)
|
460
|
+
self.forward_metadata = Mamba2Metadata.prepare_mixed(
|
461
|
+
metadata.query_start_loc,
|
462
|
+
metadata.mamba_cache_indices,
|
463
|
+
self.mamba_chunk_size,
|
464
|
+
forward_batch,
|
465
|
+
)
|
466
|
+
|
467
|
+
def init_forward_metadata_capture_cuda_graph(
|
468
|
+
self,
|
469
|
+
bs: int,
|
470
|
+
num_tokens: int,
|
471
|
+
req_pool_indices: torch.Tensor,
|
472
|
+
seq_lens: torch.Tensor,
|
473
|
+
encoder_lens: Optional[torch.Tensor],
|
474
|
+
forward_mode: ForwardMode,
|
475
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
476
|
+
):
|
477
|
+
metadata = self._capture_metadata(bs, req_pool_indices, forward_mode)
|
478
|
+
self.forward_metadata = Mamba2Metadata.prepare_decode(
|
479
|
+
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
|
480
|
+
)
|
481
|
+
|
482
|
+
def init_forward_metadata_replay_cuda_graph(
|
483
|
+
self,
|
484
|
+
bs: int,
|
485
|
+
req_pool_indices: torch.Tensor,
|
486
|
+
seq_lens: torch.Tensor,
|
487
|
+
seq_lens_sum: int,
|
488
|
+
encoder_lens: Optional[torch.Tensor],
|
489
|
+
forward_mode: ForwardMode,
|
490
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
491
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
492
|
+
):
|
493
|
+
metadata = self._replay_metadata(
|
494
|
+
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
|
495
|
+
)
|
496
|
+
self.forward_metadata = Mamba2Metadata.prepare_decode(
|
497
|
+
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
|
498
|
+
)
|
499
|
+
|
500
|
+
def forward(
|
501
|
+
self,
|
502
|
+
mixer: MambaMixer2,
|
503
|
+
hidden_states: torch.Tensor,
|
504
|
+
output: torch.Tensor,
|
505
|
+
layer_id: int,
|
506
|
+
mup_vector: Optional[torch.Tensor] = None,
|
507
|
+
use_triton_causal_conv: bool = False,
|
508
|
+
):
|
509
|
+
assert isinstance(self.forward_metadata, Mamba2Metadata)
|
510
|
+
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
|
511
|
+
return mixer.forward(
|
512
|
+
hidden_states=hidden_states,
|
513
|
+
output=output,
|
514
|
+
layer_cache=layer_cache,
|
515
|
+
metadata=self.forward_metadata,
|
516
|
+
mup_vector=mup_vector,
|
517
|
+
use_triton_causal_conv=use_triton_causal_conv,
|
518
|
+
)
|
519
|
+
|
520
|
+
def forward_decode(self, *args, **kwargs):
|
521
|
+
raise NotImplementedError(
|
522
|
+
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
|
523
|
+
)
|
524
|
+
|
525
|
+
def forward_extend(self, *args, **kwargs):
|
526
|
+
raise NotImplementedError(
|
527
|
+
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
|
528
|
+
)
|
529
|
+
|
530
|
+
|
427
531
|
class HybridLinearAttnBackend(AttentionBackend):
|
428
|
-
"""
|
532
|
+
"""Manages a full and linear attention backend"""
|
429
533
|
|
430
534
|
def __init__(
|
431
535
|
self,
|
432
536
|
full_attn_backend: AttentionBackend,
|
433
|
-
linear_attn_backend:
|
537
|
+
linear_attn_backend: MambaAttnBackendBase,
|
434
538
|
full_attn_layers: list[int],
|
435
539
|
):
|
436
540
|
self.full_attn_layers = full_attn_layers
|
541
|
+
self.full_attn_backend = full_attn_backend
|
542
|
+
self.linear_attn_backend = linear_attn_backend
|
437
543
|
self.attn_backend_list = [full_attn_backend, linear_attn_backend]
|
438
544
|
|
439
545
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
@@ -489,7 +595,7 @@ class HybridLinearAttnBackend(AttentionBackend):
|
|
489
595
|
)
|
490
596
|
|
491
597
|
def get_cuda_graph_seq_len_fill_value(self):
|
492
|
-
return self.
|
598
|
+
return self.full_attn_backend.get_cuda_graph_seq_len_fill_value()
|
493
599
|
|
494
600
|
def forward_decode(
|
495
601
|
self,
|
@@ -503,10 +609,10 @@ class HybridLinearAttnBackend(AttentionBackend):
|
|
503
609
|
):
|
504
610
|
layer_id = layer.layer_id if layer else kwargs["layer_id"]
|
505
611
|
if layer_id in self.full_attn_layers:
|
506
|
-
return self.
|
612
|
+
return self.full_attn_backend.forward_decode(
|
507
613
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
508
614
|
)
|
509
|
-
return self.
|
615
|
+
return self.linear_attn_backend.forward_decode(
|
510
616
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
511
617
|
)
|
512
618
|
|
@@ -522,10 +628,10 @@ class HybridLinearAttnBackend(AttentionBackend):
|
|
522
628
|
):
|
523
629
|
layer_id = layer.layer_id if layer else kwargs["layer_id"]
|
524
630
|
if layer_id in self.full_attn_layers:
|
525
|
-
return self.
|
631
|
+
return self.full_attn_backend.forward_extend(
|
526
632
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
527
633
|
)
|
528
|
-
return self.
|
634
|
+
return self.linear_attn_backend.forward_extend(
|
529
635
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
530
636
|
)
|
531
637
|
|
@@ -568,20 +674,20 @@ class HybridLinearAttnBackend(AttentionBackend):
|
|
568
674
|
def update_mamba_state_after_mtp_verify(self, accepted_length, model):
|
569
675
|
request_number = accepted_length.shape[0]
|
570
676
|
|
571
|
-
state_indices_tensor =
|
572
|
-
|
573
|
-
|
677
|
+
state_indices_tensor = (
|
678
|
+
self.linear_attn_backend.forward_metadata.mamba_cache_indices[
|
679
|
+
:request_number
|
680
|
+
]
|
681
|
+
)
|
574
682
|
|
575
|
-
mamba_caches =
|
576
|
-
|
577
|
-
|
683
|
+
mamba_caches = (
|
684
|
+
self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers()
|
685
|
+
)
|
578
686
|
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
intermediate_conv_window_cache,
|
584
|
-
) = mamba_caches
|
687
|
+
conv_states = mamba_caches.conv
|
688
|
+
ssm_states = mamba_caches.temporal
|
689
|
+
intermediate_state_cache = mamba_caches.intermediate_ssm
|
690
|
+
intermediate_conv_window_cache = mamba_caches.intermediate_conv_window
|
585
691
|
|
586
692
|
# SSM state updates (chunked to reduce peak memory)
|
587
693
|
valid_mask = accepted_length > 0
|
@@ -6,11 +6,11 @@ from typing import List, Optional, Union
|
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
import torch
|
9
|
-
|
10
|
-
PAD_SLOT_ID = -1
|
11
9
|
import triton
|
12
10
|
import triton.language as tl
|
13
11
|
|
12
|
+
PAD_SLOT_ID = -1
|
13
|
+
|
14
14
|
|
15
15
|
@triton.jit()
|
16
16
|
def _causal_conv1d_fwd_kernel( # continuous batching
|
@@ -672,7 +672,9 @@ def _causal_conv1d_update_kernel(
|
|
672
672
|
+ (conv_state_batch_coord * stride_conv_state_seq)
|
673
673
|
+ conv_state_token_offset * stride_conv_state_tok
|
674
674
|
+ (idx_feats * stride_conv_state_dim)[None, :]
|
675
|
-
+ ((idx_tokens + 1) * stride_conv_state_tok)[
|
675
|
+
+ ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[
|
676
|
+
:, None
|
677
|
+
]
|
676
678
|
) # [BLOCK_M, BLOCK_N]
|
677
679
|
mask = (
|
678
680
|
(conv_state_batch_coord < num_cache_lines)
|
@@ -897,7 +899,10 @@ def causal_conv1d_update(
|
|
897
899
|
stride_state_indices = (
|
898
900
|
conv_state_indices.stride(0) if conv_state_indices is not None else 0
|
899
901
|
)
|
900
|
-
|
902
|
+
if num_accepted_tokens is not None:
|
903
|
+
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
904
|
+
else:
|
905
|
+
state_len = width - 1
|
901
906
|
np2_statelen = triton.next_power_of_2(state_len)
|
902
907
|
|
903
908
|
def grid(META):
|