sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -14,14 +14,21 @@ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
14
14
  fused_sigmoid_gating_delta_rule_update,
15
15
  )
16
16
  from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
17
+ PAD_SLOT_ID,
17
18
  causal_conv1d_fn,
18
19
  causal_conv1d_update,
19
20
  )
21
+ from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
22
+ from sglang.srt.layers.attention.mamba.mamba2_metadata import (
23
+ ForwardMetadata,
24
+ Mamba2Metadata,
25
+ )
20
26
  from sglang.srt.layers.radix_attention import RadixAttention
21
- from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
27
+ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool
22
28
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
29
  from sglang.srt.model_executor.model_runner import ModelRunner
24
30
  from sglang.srt.models.qwen3_next import fused_gdn_gating
31
+ from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
25
32
  from sglang.srt.speculative.spec_info import SpecInput
26
33
  from sglang.srt.utils import is_cuda, is_npu
27
34
 
@@ -47,18 +54,10 @@ elif is_npu():
47
54
  causal_conv1d_update = causal_conv1d_update_npu
48
55
 
49
56
 
50
- @dataclass
51
- class ForwardMetadata:
52
- query_start_loc: Optional[torch.Tensor]
53
- mamba_cache_indices: torch.Tensor
54
-
55
-
56
- class MambaAttnBackend(AttentionBackend):
57
- """Attention backend using Mamba kernel."""
58
-
57
+ class MambaAttnBackendBase(AttentionBackend):
59
58
  def __init__(self, model_runner: ModelRunner):
60
59
  super().__init__()
61
- self.pad_slot_id = -1 # Default pad slot id
60
+ self.pad_slot_id = PAD_SLOT_ID
62
61
  self.device = model_runner.device
63
62
  self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool
64
63
  self.forward_metadata: ForwardMetadata = None
@@ -67,7 +66,7 @@ class MambaAttnBackend(AttentionBackend):
67
66
  self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
68
67
  self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
69
68
 
70
- def init_forward_metadata(self, forward_batch: ForwardBatch):
69
+ def _forward_metadata(self, forward_batch: ForwardBatch):
71
70
  bs = forward_batch.batch_size
72
71
 
73
72
  if forward_batch.forward_mode.is_decode_or_idle():
@@ -97,11 +96,43 @@ class MambaAttnBackend(AttentionBackend):
97
96
  mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(
98
97
  forward_batch.req_pool_indices
99
98
  )
100
- self.forward_metadata = ForwardMetadata(
99
+ return ForwardMetadata(
101
100
  query_start_loc=query_start_loc,
102
101
  mamba_cache_indices=mamba_cache_indices,
103
102
  )
104
103
 
104
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
105
+ self.forward_metadata = self._forward_metadata(forward_batch)
106
+
107
+ def init_forward_metadata_capture_cuda_graph(
108
+ self,
109
+ bs: int,
110
+ num_tokens: int,
111
+ req_pool_indices: torch.Tensor,
112
+ seq_lens: torch.Tensor,
113
+ encoder_lens: Optional[torch.Tensor],
114
+ forward_mode: ForwardMode,
115
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
116
+ ):
117
+ self.forward_metadata = self._capture_metadata(
118
+ bs, req_pool_indices, forward_mode
119
+ )
120
+
121
+ def init_forward_metadata_replay_cuda_graph(
122
+ self,
123
+ bs: int,
124
+ req_pool_indices: torch.Tensor,
125
+ seq_lens: torch.Tensor,
126
+ seq_lens_sum: int,
127
+ encoder_lens: Optional[torch.Tensor],
128
+ forward_mode: ForwardMode,
129
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
130
+ seq_lens_cpu: Optional[torch.Tensor],
131
+ ):
132
+ self.forward_metadata = self._replay_metadata(
133
+ bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
134
+ )
135
+
105
136
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
106
137
  assert (
107
138
  max_num_tokens % max_bs == 0
@@ -127,15 +158,8 @@ class MambaAttnBackend(AttentionBackend):
127
158
  device=self.device,
128
159
  )
129
160
 
130
- def init_forward_metadata_capture_cuda_graph(
131
- self,
132
- bs: int,
133
- num_tokens: int,
134
- req_pool_indices: torch.Tensor,
135
- seq_lens: torch.Tensor,
136
- encoder_lens: Optional[torch.Tensor],
137
- forward_mode: ForwardMode,
138
- spec_info: Optional[SpecInput],
161
+ def _capture_metadata(
162
+ self, bs: int, req_pool_indices: torch.Tensor, forward_mode: ForwardMode
139
163
  ):
140
164
  if forward_mode.is_decode_or_idle():
141
165
  self.query_start_loc_list[bs - 1].copy_(
@@ -149,18 +173,15 @@ class MambaAttnBackend(AttentionBackend):
149
173
  raise ValueError(f"Invalid forward mode: {forward_mode=}")
150
174
  mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
151
175
  self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
152
- self.forward_metadata = ForwardMetadata(
176
+ return ForwardMetadata(
153
177
  query_start_loc=self.query_start_loc_list[bs - 1],
154
178
  mamba_cache_indices=self.state_indices_list[bs - 1],
155
179
  )
156
180
 
157
- def init_forward_metadata_replay_cuda_graph(
181
+ def _replay_metadata(
158
182
  self,
159
183
  bs: int,
160
184
  req_pool_indices: torch.Tensor,
161
- seq_lens: torch.Tensor,
162
- seq_lens_sum: int,
163
- encoder_lens: Optional[torch.Tensor],
164
185
  forward_mode: ForwardMode,
165
186
  spec_info: Optional[SpecInput],
166
187
  seq_lens_cpu: Optional[torch.Tensor],
@@ -200,7 +221,7 @@ class MambaAttnBackend(AttentionBackend):
200
221
  else:
201
222
  raise ValueError(f"Invalid forward mode: {forward_mode=}")
202
223
 
203
- self.forward_metadata = ForwardMetadata(
224
+ return ForwardMetadata(
204
225
  query_start_loc=self.query_start_loc_list[bs - 1],
205
226
  mamba_cache_indices=self.state_indices_list[bs - 1],
206
227
  )
@@ -208,6 +229,10 @@ class MambaAttnBackend(AttentionBackend):
208
229
  def get_cuda_graph_seq_len_fill_value(self):
209
230
  return 1 # Mamba attn does not use seq lens to index kv cache
210
231
 
232
+
233
+ class GDNAttnBackend(MambaAttnBackendBase):
234
+ """Attention backend using Mamba kernel."""
235
+
211
236
  def forward_decode(
212
237
  self,
213
238
  q: torch.Tensor,
@@ -233,9 +258,9 @@ class MambaAttnBackend(AttentionBackend):
233
258
  dt_bias = kwargs["dt_bias"]
234
259
  layer_id = kwargs["layer_id"]
235
260
 
236
- conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
237
- layer_id
238
- )
261
+ layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
262
+ conv_states = layer_cache.conv
263
+ ssm_states = layer_cache.temporal
239
264
  query_start_loc = self.forward_metadata.query_start_loc
240
265
  cache_indices = self.forward_metadata.mamba_cache_indices
241
266
 
@@ -313,13 +338,13 @@ class MambaAttnBackend(AttentionBackend):
313
338
  query_start_loc = self.forward_metadata.query_start_loc
314
339
  cache_indices = self.forward_metadata.mamba_cache_indices
315
340
 
341
+ mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
342
+ conv_states = mamba_cache_params.conv
343
+ ssm_states = mamba_cache_params.temporal
316
344
  if is_target_verify:
317
- (
318
- conv_states,
319
- ssm_states,
320
- intermediate_state_cache,
321
- intermediate_conv_window_cache,
322
- ) = self.req_to_token_pool.get_mamba_params(layer_id)
345
+ assert isinstance(mamba_cache_params, MambaPool.SpeculativeState)
346
+ intermediate_state_cache = mamba_cache_params.intermediate_ssm
347
+ intermediate_conv_window_cache = mamba_cache_params.intermediate_conv_window
323
348
  has_initial_states = torch.ones(
324
349
  seq_len // forward_batch.spec_info.draft_token_num,
325
350
  dtype=torch.bool,
@@ -327,9 +352,6 @@ class MambaAttnBackend(AttentionBackend):
327
352
  )
328
353
  conv_states_to_use = conv_states.clone()
329
354
  else:
330
- conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
331
- layer_id
332
- )
333
355
  has_initial_states = forward_batch.extend_prefix_lens > 0
334
356
  conv_states_to_use = conv_states
335
357
 
@@ -424,16 +446,100 @@ class MambaAttnBackend(AttentionBackend):
424
446
  return core_attn_out
425
447
 
426
448
 
449
+ class Mamba2AttnBackend(MambaAttnBackendBase):
450
+ """Attention backend wrapper for Mamba2Mixer kernels."""
451
+
452
+ def __init__(self, model_runner: ModelRunner):
453
+ super().__init__(model_runner)
454
+ config = model_runner.mamba2_config
455
+ assert config is not None
456
+ self.mamba_chunk_size = config.mamba_chunk_size
457
+
458
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
459
+ metadata = self._forward_metadata(forward_batch)
460
+ self.forward_metadata = Mamba2Metadata.prepare_mixed(
461
+ metadata.query_start_loc,
462
+ metadata.mamba_cache_indices,
463
+ self.mamba_chunk_size,
464
+ forward_batch,
465
+ )
466
+
467
+ def init_forward_metadata_capture_cuda_graph(
468
+ self,
469
+ bs: int,
470
+ num_tokens: int,
471
+ req_pool_indices: torch.Tensor,
472
+ seq_lens: torch.Tensor,
473
+ encoder_lens: Optional[torch.Tensor],
474
+ forward_mode: ForwardMode,
475
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
476
+ ):
477
+ metadata = self._capture_metadata(bs, req_pool_indices, forward_mode)
478
+ self.forward_metadata = Mamba2Metadata.prepare_decode(
479
+ metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
480
+ )
481
+
482
+ def init_forward_metadata_replay_cuda_graph(
483
+ self,
484
+ bs: int,
485
+ req_pool_indices: torch.Tensor,
486
+ seq_lens: torch.Tensor,
487
+ seq_lens_sum: int,
488
+ encoder_lens: Optional[torch.Tensor],
489
+ forward_mode: ForwardMode,
490
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
491
+ seq_lens_cpu: Optional[torch.Tensor],
492
+ ):
493
+ metadata = self._replay_metadata(
494
+ bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
495
+ )
496
+ self.forward_metadata = Mamba2Metadata.prepare_decode(
497
+ metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
498
+ )
499
+
500
+ def forward(
501
+ self,
502
+ mixer: MambaMixer2,
503
+ hidden_states: torch.Tensor,
504
+ output: torch.Tensor,
505
+ layer_id: int,
506
+ mup_vector: Optional[torch.Tensor] = None,
507
+ use_triton_causal_conv: bool = False,
508
+ ):
509
+ assert isinstance(self.forward_metadata, Mamba2Metadata)
510
+ layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
511
+ return mixer.forward(
512
+ hidden_states=hidden_states,
513
+ output=output,
514
+ layer_cache=layer_cache,
515
+ metadata=self.forward_metadata,
516
+ mup_vector=mup_vector,
517
+ use_triton_causal_conv=use_triton_causal_conv,
518
+ )
519
+
520
+ def forward_decode(self, *args, **kwargs):
521
+ raise NotImplementedError(
522
+ "Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
523
+ )
524
+
525
+ def forward_extend(self, *args, **kwargs):
526
+ raise NotImplementedError(
527
+ "Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
528
+ )
529
+
530
+
427
531
  class HybridLinearAttnBackend(AttentionBackend):
428
- """Support different backends for prefill and decode."""
532
+ """Manages a full and linear attention backend"""
429
533
 
430
534
  def __init__(
431
535
  self,
432
536
  full_attn_backend: AttentionBackend,
433
- linear_attn_backend: AttentionBackend,
537
+ linear_attn_backend: MambaAttnBackendBase,
434
538
  full_attn_layers: list[int],
435
539
  ):
436
540
  self.full_attn_layers = full_attn_layers
541
+ self.full_attn_backend = full_attn_backend
542
+ self.linear_attn_backend = linear_attn_backend
437
543
  self.attn_backend_list = [full_attn_backend, linear_attn_backend]
438
544
 
439
545
  def init_forward_metadata(self, forward_batch: ForwardBatch):
@@ -489,7 +595,7 @@ class HybridLinearAttnBackend(AttentionBackend):
489
595
  )
490
596
 
491
597
  def get_cuda_graph_seq_len_fill_value(self):
492
- return self.attn_backend_list[0].get_cuda_graph_seq_len_fill_value()
598
+ return self.full_attn_backend.get_cuda_graph_seq_len_fill_value()
493
599
 
494
600
  def forward_decode(
495
601
  self,
@@ -503,10 +609,10 @@ class HybridLinearAttnBackend(AttentionBackend):
503
609
  ):
504
610
  layer_id = layer.layer_id if layer else kwargs["layer_id"]
505
611
  if layer_id in self.full_attn_layers:
506
- return self.attn_backend_list[0].forward_decode(
612
+ return self.full_attn_backend.forward_decode(
507
613
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
508
614
  )
509
- return self.attn_backend_list[1].forward_decode(
615
+ return self.linear_attn_backend.forward_decode(
510
616
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
511
617
  )
512
618
 
@@ -522,10 +628,10 @@ class HybridLinearAttnBackend(AttentionBackend):
522
628
  ):
523
629
  layer_id = layer.layer_id if layer else kwargs["layer_id"]
524
630
  if layer_id in self.full_attn_layers:
525
- return self.attn_backend_list[0].forward_extend(
631
+ return self.full_attn_backend.forward_extend(
526
632
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
527
633
  )
528
- return self.attn_backend_list[1].forward_extend(
634
+ return self.linear_attn_backend.forward_extend(
529
635
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
530
636
  )
531
637
 
@@ -568,20 +674,20 @@ class HybridLinearAttnBackend(AttentionBackend):
568
674
  def update_mamba_state_after_mtp_verify(self, accepted_length, model):
569
675
  request_number = accepted_length.shape[0]
570
676
 
571
- state_indices_tensor = self.attn_backend_list[
572
- 1
573
- ].forward_metadata.mamba_cache_indices[:request_number]
677
+ state_indices_tensor = (
678
+ self.linear_attn_backend.forward_metadata.mamba_cache_indices[
679
+ :request_number
680
+ ]
681
+ )
574
682
 
575
- mamba_caches = self.attn_backend_list[
576
- 1
577
- ].req_to_token_pool.get_mamba_params_all_layers()
683
+ mamba_caches = (
684
+ self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers()
685
+ )
578
686
 
579
- (
580
- conv_states,
581
- ssm_states,
582
- intermediate_state_cache,
583
- intermediate_conv_window_cache,
584
- ) = mamba_caches
687
+ conv_states = mamba_caches.conv
688
+ ssm_states = mamba_caches.temporal
689
+ intermediate_state_cache = mamba_caches.intermediate_ssm
690
+ intermediate_conv_window_cache = mamba_caches.intermediate_conv_window
585
691
 
586
692
  # SSM state updates (chunked to reduce peak memory)
587
693
  valid_mask = accepted_length > 0
@@ -10,7 +10,7 @@ import torch
10
10
  from sgl_kernel import causal_conv1d_fwd
11
11
  from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel
12
12
 
13
- PAD_SLOT_ID = -1
13
+ from .causal_conv1d_triton import PAD_SLOT_ID
14
14
 
15
15
 
16
16
  def causal_conv1d_fn(
@@ -6,11 +6,11 @@ from typing import List, Optional, Union
6
6
 
7
7
  import numpy as np
8
8
  import torch
9
-
10
- PAD_SLOT_ID = -1
11
9
  import triton
12
10
  import triton.language as tl
13
11
 
12
+ PAD_SLOT_ID = -1
13
+
14
14
 
15
15
  @triton.jit()
16
16
  def _causal_conv1d_fwd_kernel( # continuous batching
@@ -672,7 +672,9 @@ def _causal_conv1d_update_kernel(
672
672
  + (conv_state_batch_coord * stride_conv_state_seq)
673
673
  + conv_state_token_offset * stride_conv_state_tok
674
674
  + (idx_feats * stride_conv_state_dim)[None, :]
675
- + ((idx_tokens + 1) * stride_conv_state_tok)[:, None]
675
+ + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[
676
+ :, None
677
+ ]
676
678
  ) # [BLOCK_M, BLOCK_N]
677
679
  mask = (
678
680
  (conv_state_batch_coord < num_cache_lines)
@@ -897,7 +899,10 @@ def causal_conv1d_update(
897
899
  stride_state_indices = (
898
900
  conv_state_indices.stride(0) if conv_state_indices is not None else 0
899
901
  )
900
- state_len = width - 1 + (seqlen - 1) # effective state_len needed
902
+ if num_accepted_tokens is not None:
903
+ state_len = width - 1 + (seqlen - 1) # effective state_len needed
904
+ else:
905
+ state_len = width - 1
901
906
  np2_statelen = triton.next_power_of_2(state_len)
902
907
 
903
908
  def grid(META):