sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,30 @@
1
- from typing import Callable, List, Optional, Tuple, Union
1
+ from typing import Callable, List, Optional, Tuple
2
2
 
3
3
  import torch
4
4
  import torch.nn as nn
5
5
 
6
- from sglang.srt.configs.model_config import ModelConfig
7
- from sglang.srt.custom_op import CustomOp
6
+ from sglang.srt.configs.mamba_utils import (
7
+ Mamba2CacheParams,
8
+ extra_groups_for_head_shards,
9
+ )
8
10
  from sglang.srt.distributed import (
11
+ divide,
9
12
  get_tensor_model_parallel_rank,
10
13
  get_tensor_model_parallel_world_size,
11
- tensor_model_parallel_all_gather,
12
- tensor_model_parallel_all_reduce,
13
14
  )
14
15
  from sglang.srt.distributed.utils import divide
15
- from sglang.srt.layers.attention.fla.layernorm_gated import layernorm_fn
16
16
  from sglang.srt.layers.attention.mamba.causal_conv1d import (
17
17
  causal_conv1d_fn,
18
18
  causal_conv1d_update,
19
19
  )
20
- from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
20
+ from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
21
+ causal_conv1d_fn as causal_conv1d_fn_triton,
22
+ )
23
+ from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
24
+ causal_conv1d_update as causal_conv1d_update_triton,
25
+ )
26
+ from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
27
+ from sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated import Mixer2RMSNormGated
21
28
  from sglang.srt.layers.attention.mamba.ops import (
22
29
  mamba_chunk_scan_combined,
23
30
  selective_state_update,
@@ -28,7 +35,7 @@ from sglang.srt.layers.linear import (
28
35
  RowParallelLinear,
29
36
  )
30
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
31
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
38
+ from sglang.srt.mem_cache.memory_pool import MambaPool
32
39
  from sglang.srt.model_loader.weight_utils import (
33
40
  composed_weight_loader,
34
41
  sharded_weight_loader,
@@ -97,110 +104,6 @@ def mamba_v2_sharded_weight_loader(
97
104
  return loader
98
105
 
99
106
 
100
- class Mixer2RMSNormGated(CustomOp):
101
-
102
- def __init__(
103
- self,
104
- full_hidden_size: int,
105
- full_n_groups: int,
106
- use_rms_norm: bool = True,
107
- eps: float = 1e-6,
108
- ):
109
- super().__init__()
110
- self.tp_size = get_tensor_model_parallel_world_size()
111
- self.tp_rank = get_tensor_model_parallel_rank()
112
- self.full_hidden_size = full_hidden_size
113
- self.group_size = full_hidden_size // full_n_groups
114
- self.per_rank_hidden_size = full_hidden_size // self.tp_size
115
- self.n_groups = full_hidden_size // self.group_size
116
-
117
- self.variance_epsilon = eps
118
- self.use_rms_norm = use_rms_norm
119
- if self.use_rms_norm:
120
- # Register norm weight only if we're actually applying RMSNorm
121
- self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
122
- set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
123
- else:
124
- # Avoid checkpoint mismatch by skipping unused parameter
125
- self.register_parameter("weight", None)
126
- assert (
127
- self.full_hidden_size % self.tp_size == 0
128
- ), "Tensor parallel world size must divide hidden size."
129
-
130
- def forward_native(
131
- self,
132
- x: torch.Tensor,
133
- gate: torch.Tensor,
134
- ):
135
- # Three tensor-parallel cases:
136
- # 1. n_groups is 1
137
- # In this case we parallelize along the reduction dim.
138
- # Each rank computes a local sum of squares followed by AllReduce
139
- # 2. tp_size divides n_groups
140
- # Each rank only reduces within its local group(s).
141
- # No collective ops necessary.
142
- # 3. The general case can be pretty complicated so we AllGather
143
- # the input and then redundantly compute the RMSNorm.
144
- input_dtype = x.dtype
145
- x = x * nn.functional.silu(gate.to(torch.float32))
146
- if not self.use_rms_norm:
147
- return x.to(input_dtype)
148
-
149
- if self.n_groups == 1:
150
- if self.tp_size > 1:
151
- # Compute local sum and then reduce to obtain global sum
152
- local_sums = x.pow(2).sum(dim=-1, keepdim=True)
153
- global_sums = tensor_model_parallel_all_reduce(local_sums)
154
- # Calculate the variance
155
- count = self.tp_size * x.shape[-1]
156
- variance = global_sums / count
157
-
158
- else:
159
- variance = x.pow(2).mean(-1, keepdim=True)
160
- x = x * torch.rsqrt(variance + self.variance_epsilon)
161
- else:
162
- redundant_tp: bool = self.n_groups % self.tp_size != 0
163
- if redundant_tp:
164
- # To handle the general case, redundantly apply the variance
165
- x = tensor_model_parallel_all_gather(x, -1)
166
-
167
- *prefix_dims, hidden_dim = x.shape
168
- group_count = hidden_dim // self.group_size
169
- x_grouped = x.view(*prefix_dims, group_count, self.group_size)
170
- variance = x_grouped.pow(2).mean(-1, keepdim=True)
171
- x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
172
- x = x_grouped.view(*prefix_dims, hidden_dim)
173
-
174
- if redundant_tp:
175
- start = self.per_rank_hidden_size * self.tp_rank
176
- end = start + self.per_rank_hidden_size
177
- x = x[..., start:end]
178
-
179
- return self.weight * x.to(input_dtype)
180
-
181
- def forward_cuda(
182
- self,
183
- x: torch.Tensor,
184
- gate: torch.Tensor,
185
- ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
186
- input_dtype = x.dtype
187
- if not self.use_rms_norm:
188
- # Keep gate in float32 for numerical stability during silu
189
- return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
190
-
191
- if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
192
- return self.forward_native(x, gate)
193
-
194
- return layernorm_fn(
195
- x,
196
- self.weight.data,
197
- bias=None,
198
- z=gate,
199
- eps=self.variance_epsilon,
200
- norm_before_gate=False,
201
- )
202
-
203
-
204
107
  class MambaMixer2(torch.nn.Module):
205
108
  """
206
109
  Compute ∆, A, B, C, and D the state space parameters and compute
@@ -214,22 +117,14 @@ class MambaMixer2(torch.nn.Module):
214
117
 
215
118
  def __init__(
216
119
  self,
120
+ cache_params: Mamba2CacheParams,
217
121
  hidden_size: int,
218
- ssm_state_size: int,
219
- conv_kernel_size: int,
220
- intermediate_size: int,
221
122
  use_conv_bias: bool,
222
123
  use_bias: bool,
223
- chunk_size: int,
224
- layer_id: int,
225
124
  n_groups: int = 1,
226
- num_heads: int = 128,
227
- head_dim: int = 64,
228
125
  rms_norm_eps: float = 1e-5,
229
126
  activation: str = "silu",
230
127
  use_rms_norm: bool = True,
231
- model_config: Optional[ModelConfig] = None,
232
- # cache_config: Optional[CacheConfig] = None,
233
128
  quant_config: Optional[QuantizationConfig] = None,
234
129
  prefix: str = "",
235
130
  ):
@@ -252,6 +147,9 @@ class MambaMixer2(torch.nn.Module):
252
147
  self.tp_size = get_tensor_model_parallel_world_size()
253
148
  self.tp_rank = get_tensor_model_parallel_rank()
254
149
 
150
+ self.num_heads = num_heads = cache_params.shape.num_heads
151
+ self.head_dim = cache_params.shape.head_dim
152
+
255
153
  assert (
256
154
  num_heads % self.tp_size == 0
257
155
  ), "Tensor parallel world size must divide num heads."
@@ -261,57 +159,76 @@ class MambaMixer2(torch.nn.Module):
261
159
  "then num_groups must equal 1."
262
160
  )
263
161
 
264
- self.ssm_state_size = ssm_state_size
265
- self.conv_kernel_size = conv_kernel_size
266
- self.activation = activation
267
- self.layer_id = layer_id
162
+ assert (
163
+ (n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
164
+ ), (
165
+ "Tensor parallel currently supported for quantized models only "
166
+ "if tensor parallel world size divides num groups."
167
+ )
268
168
 
269
- self.intermediate_size = intermediate_size
270
- self.head_dim = head_dim
271
- self.num_heads = num_heads
272
- self.chunk_size = chunk_size
169
+ self.ssm_state_size = cache_params.shape.ssm_state_size
170
+ self.activation = activation
273
171
 
172
+ conv_kernel_size = cache_params.shape.conv_kernel
173
+ self.intermediate_size = intermediate_size = (
174
+ cache_params.shape.intermediate_size
175
+ )
274
176
  self.n_groups = n_groups
275
177
  if n_groups % self.tp_size != 0:
276
178
  # - for TP we shard conv_dim by sharding on n_groups,
277
179
  # - but if n_groups cannot divide tp_size, we need to
278
180
  # extend some extra groups
279
- groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
280
- n_groups, self.tp_size
281
- )
181
+ groups = extra_groups_for_head_shards(n_groups, self.tp_size)
282
182
  self.n_groups = n_groups + groups
283
-
284
183
  self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
285
- self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
286
-
287
- self.conv1d = MergedColumnParallelLinear(
288
- input_size=conv_kernel_size,
289
- output_sizes=[
290
- intermediate_size,
291
- self.groups_ssm_state_size,
292
- self.groups_ssm_state_size,
293
- ],
294
- bias=use_conv_bias,
295
- quant_config=None,
296
- prefix=f"{prefix}.conv1d",
297
- )
184
+ self.conv_dim = cache_params.shape.conv_dim
185
+
186
+ if n_groups % self.tp_size == 0:
187
+ self.conv1d = MergedColumnParallelLinear(
188
+ input_size=conv_kernel_size,
189
+ output_sizes=[
190
+ intermediate_size,
191
+ self.groups_ssm_state_size,
192
+ self.groups_ssm_state_size,
193
+ ],
194
+ bias=use_conv_bias,
195
+ quant_config=None,
196
+ prefix=f"{prefix}.conv1d",
197
+ )
298
198
 
299
- self.in_proj = MergedColumnParallelLinear(
300
- input_size=hidden_size,
301
- output_sizes=[
302
- intermediate_size,
303
- intermediate_size,
304
- self.groups_ssm_state_size,
305
- self.groups_ssm_state_size,
306
- self.num_heads,
307
- ],
308
- bias=use_bias,
309
- prefix=f"{prefix}.in_proj",
310
- )
311
- if n_groups % self.tp_size != 0:
199
+ self.in_proj = MergedColumnParallelLinear(
200
+ input_size=hidden_size,
201
+ output_sizes=[
202
+ intermediate_size,
203
+ intermediate_size,
204
+ self.groups_ssm_state_size,
205
+ self.groups_ssm_state_size,
206
+ self.num_heads,
207
+ ],
208
+ bias=use_bias,
209
+ quant_config=quant_config,
210
+ prefix=f"{prefix}.in_proj",
211
+ )
212
+ else:
312
213
  # This is the n_groups == 1 case,
313
214
  # where we need to duplicate groups if TP>1.
314
215
 
216
+ self.conv1d = ColumnParallelLinear(
217
+ input_size=conv_kernel_size,
218
+ output_size=self.conv_dim,
219
+ bias=use_conv_bias,
220
+ quant_config=None,
221
+ prefix=f"{prefix}.conv1d",
222
+ )
223
+
224
+ self.in_proj = ColumnParallelLinear(
225
+ input_size=hidden_size,
226
+ output_size=intermediate_size + self.conv_dim + self.num_heads,
227
+ bias=use_bias,
228
+ quant_config=quant_config,
229
+ prefix=f"{prefix}.in_proj",
230
+ )
231
+
315
232
  # - because in_proj is a concatenation of 3 weights, we
316
233
  # need to interleave them before sharding
317
234
  # - use the custom weight loader mamba_v2_sharded_weight_loader
@@ -421,47 +338,27 @@ class MambaMixer2(torch.nn.Module):
421
338
  intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
422
339
  )
423
340
 
424
- # The tuple is (conv_state, ssm_state)
425
- self.kv_cache = (torch.tensor([]), torch.tensor([]))
426
-
427
- self.model_config = model_config
428
341
  self.prefix = prefix
429
342
 
430
- def forward_native(
431
- self,
432
- hidden_states: torch.Tensor,
433
- output: torch.Tensor,
434
- mup_vector: Optional[torch.Tensor] = None,
435
- ):
436
- pass
437
-
438
343
  def forward(
439
344
  self,
345
+ *,
440
346
  hidden_states: torch.Tensor,
441
347
  output: torch.Tensor,
442
- forward_batch: ForwardBatch,
348
+ layer_cache: MambaPool.State,
349
+ metadata: Mamba2Metadata,
443
350
  mup_vector: Optional[torch.Tensor] = None,
351
+ use_triton_causal_conv: bool = False,
444
352
  ):
445
- # attn_backend_list[-1] gives access to MambaAttnBackend
446
- mamba_backend = forward_batch.attn_backend.attn_backend_list[-1]
447
- attn_metadata = mamba_backend.forward_metadata
448
- state_indices_tensor = attn_metadata.mamba_cache_indices
449
- chunk_size = self.chunk_size
450
-
451
- conv_state, ssm_state, *rest = mamba_backend.req_to_token_pool.get_mamba_params(
452
- self.layer_id
453
- )
353
+ # metadata contains metadata necessary for the mamba2 triton
354
+ # kernels to operate in continuous batching and in chunked prefill
355
+ # modes; they are computed at top-level model forward since they
356
+ # stay the same and reused for all mamba layers in the same iteration
357
+ state_indices_tensor = metadata.mamba_cache_indices
358
+ conv_state = layer_cache.conv
359
+ ssm_state = layer_cache.temporal
454
360
 
455
- assert (
456
- ssm_state.size(1) == self.ssm_state_size
457
- ), f"dstate must be {self.ssm_state_size}, got {ssm_state.size(1)}"
458
-
459
- query_start_loc = attn_metadata.query_start_loc
460
-
461
- chunk_size = self.chunk_size
462
-
463
- # TODO: properly support this
464
- prep_initial_states = False
361
+ query_start_loc = metadata.query_start_loc
465
362
 
466
363
  # 1. Gated MLP's linear projection
467
364
  projected_states, _ = self.in_proj(hidden_states)
@@ -493,6 +390,38 @@ class MambaMixer2(torch.nn.Module):
493
390
  dim=-1,
494
391
  )
495
392
 
393
+ num_prefills = metadata.num_prefills # request count
394
+ num_decodes = metadata.num_decodes # token count (=request)
395
+ num_prefill_tokens = metadata.num_prefill_tokens # token count
396
+ has_prefill = num_prefills > 0
397
+ has_decode = num_decodes > 0
398
+ num_actual_tokens = num_prefill_tokens + num_decodes
399
+ assert num_actual_tokens == projected_states.shape[0]
400
+
401
+ # NOTE: V0 put prefill before decode
402
+ # Separate prefill and decode by splitting varlen input
403
+ # Split along token dimension
404
+ hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
405
+ hidden_states_B_C,
406
+ [num_prefill_tokens, num_decodes],
407
+ dim=0,
408
+ )
409
+ dt_p, dt_d = torch.split(
410
+ dt,
411
+ [num_prefill_tokens, num_decodes],
412
+ dim=0,
413
+ )
414
+ # Split along batch dimension
415
+ state_indices_tensor_p, state_indices_tensor_d = torch.split(
416
+ state_indices_tensor,
417
+ [num_prefills, num_decodes],
418
+ dim=0,
419
+ )
420
+ query_start_loc_p = query_start_loc[: num_prefills + 1] if has_prefill else None
421
+
422
+ # Preallocate output tensor to avoid memcpy cost for merging prefill
423
+ # and decode outputs
424
+
496
425
  preallocated_ssm_out = torch.empty(
497
426
  [
498
427
  projected_states.shape[0],
@@ -501,128 +430,147 @@ class MambaMixer2(torch.nn.Module):
501
430
  dtype=hidden_states.dtype,
502
431
  device=hidden_states.device,
503
432
  )
433
+ preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
434
+ preallocated_ssm_out,
435
+ [num_prefill_tokens, num_decodes],
436
+ dim=0,
437
+ )
504
438
 
505
439
  # Process prefill requests
506
- if forward_batch.forward_mode.is_extend():
440
+ if has_prefill:
441
+ mixed_metadata = metadata.mixed_metadata
442
+ assert mixed_metadata is not None
507
443
  # 2. Convolution sequence transformation
508
444
  # - "cache_indices" updates the conv_state cache in positions
509
445
  # pointed to by "state_indices_tensor"
510
- num_prefill_tokens = forward_batch.extend_num_tokens or 0
511
- has_initial_states = forward_batch.extend_prefix_lens > 0
512
- cache_indices = attn_metadata.mamba_cache_indices
513
-
514
- x = hidden_states_B_C.transpose(
446
+ has_initial_states_p = mixed_metadata.has_initial_states
447
+ prep_initial_states = mixed_metadata.prep_initial_states
448
+ cache_indices = state_indices_tensor_p
449
+ x = hidden_states_B_C_p.transpose(
515
450
  0, 1
516
451
  ) # this is the form that causal-conv see
517
- hidden_states_B_C = causal_conv1d_fn(
452
+ ccfn = (
453
+ causal_conv1d_fn
454
+ if not use_triton_causal_conv
455
+ else causal_conv1d_fn_triton
456
+ )
457
+ hidden_states_B_C_p = ccfn(
518
458
  x,
519
459
  conv_weights,
520
460
  self.conv1d.bias,
521
461
  activation=self.activation,
522
462
  conv_states=conv_state,
523
- has_initial_state=has_initial_states,
463
+ has_initial_state=has_initial_states_p,
524
464
  cache_indices=cache_indices,
525
- query_start_loc=query_start_loc,
526
- seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
527
- ).transpose(0, 1)
465
+ query_start_loc=query_start_loc_p,
466
+ seq_lens_cpu=mixed_metadata.extend_seq_lens_cpu,
467
+ ).transpose(0, 1)[:num_prefill_tokens]
528
468
 
529
- hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
469
+ hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
530
470
 
531
471
  # 3. State Space Model sequence transformation
532
472
  initial_states = None
533
-
534
- if has_initial_states is not None and prep_initial_states:
473
+ if has_initial_states_p is not None and prep_initial_states:
535
474
  initial_states = torch.where(
536
- has_initial_states[:, None, None, None],
537
- ssm_state[state_indices_tensor],
475
+ has_initial_states_p[:, None, None, None],
476
+ ssm_state[state_indices_tensor_p],
538
477
  0,
539
478
  )
540
479
 
541
480
  # NOTE: final output is an in-place update of out tensor
542
481
  varlen_state = mamba_chunk_scan_combined(
543
- hidden_states.view(
482
+ hidden_states_p.view(
544
483
  1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
545
484
  ),
546
- dt.unsqueeze(0),
485
+ dt_p.unsqueeze(0),
547
486
  self.A,
548
- B.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
549
- C.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
550
- chunk_size=chunk_size,
487
+ B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
488
+ C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
489
+ chunk_size=mixed_metadata.chunk_size,
551
490
  D=self.D,
552
491
  z=None,
553
492
  dt_bias=self.dt_bias,
554
- cu_seqlens=query_start_loc,
493
+ seq_idx=mixed_metadata.seq_idx,
494
+ chunk_indices=mixed_metadata.chunk_indices,
495
+ chunk_offsets=mixed_metadata.chunk_offsets,
496
+ cu_seqlens=query_start_loc_p,
555
497
  initial_states=initial_states,
556
498
  return_varlen_states=True,
557
499
  return_final_states=False,
558
500
  dt_softplus=True,
559
501
  dt_limit=(0.0, float("inf")),
560
- out=preallocated_ssm_out.view(1, num_prefill_tokens, -1, self.head_dim),
502
+ out=preallocated_ssm_out_p.view(
503
+ 1, num_prefill_tokens, -1, self.head_dim
504
+ ),
561
505
  state_dtype=ssm_state.dtype,
562
506
  )
563
507
 
564
508
  # update ssm states
565
509
  # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
566
- ssm_state[state_indices_tensor] = varlen_state.permute(0, 3, 2, 1)
567
- elif forward_batch.forward_mode.is_decode():
568
- num_decodes = len(query_start_loc) - 1
510
+ ssm_state[state_indices_tensor_p] = varlen_state
511
+
512
+ # Process decode requests
513
+ if has_decode:
569
514
  # 2. Convolution sequence transformation
570
- hidden_states_B_C = causal_conv1d_update(
571
- hidden_states_B_C,
515
+ ccu = (
516
+ causal_conv1d_update
517
+ if not use_triton_causal_conv
518
+ else causal_conv1d_update_triton
519
+ )
520
+ hidden_states_B_C_d = ccu(
521
+ hidden_states_B_C_d,
572
522
  conv_state,
573
523
  conv_weights,
574
524
  self.conv1d.bias,
575
525
  self.activation,
576
- conv_state_indices=state_indices_tensor,
526
+ conv_state_indices=state_indices_tensor_d,
577
527
  )
578
528
 
579
- hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
529
+ hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
580
530
 
581
531
  # 3. State Space Model sequence transformation
582
532
  n_groups = self.n_groups // self.tp_size
583
- A = (
533
+ A_d = (
584
534
  self.A[:, None, ...][:, :, None]
585
535
  .expand(-1, self.head_dim, self.ssm_state_size)
586
536
  .to(dtype=torch.float32)
587
537
  )
588
- dt = dt[:, :, None].expand(-1, -1, self.head_dim)
538
+ dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
589
539
  dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
590
- D = self.D[:, None, ...].expand(-1, self.head_dim)
591
- B = B.view(-1, n_groups, B.shape[1] // n_groups)
592
- C = C.view(-1, n_groups, C.shape[1] // n_groups)
593
- hidden_states = hidden_states.view(
540
+ D_d = self.D[:, None, ...].expand(-1, self.head_dim)
541
+ B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
542
+ C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
543
+ hidden_states_d = hidden_states_d.view(
594
544
  -1, self.num_heads // self.tp_size, self.head_dim
595
545
  )
596
546
 
597
547
  # - the hidden is reshaped into (bs, num_heads, head_dim)
598
- # - mamba_cache_params.ssm_state's slots will be selected
548
+ # - layer_state.ssm_state's slots will be selected
599
549
  # using state_indices_tensor_d
600
550
  # NOTE: final output is an in-place update of out tensor
601
551
  selective_state_update(
602
- ssm_state.permute(0, 3, 2, 1),
603
- hidden_states,
604
- dt,
605
- A,
606
- B,
607
- C,
608
- D,
552
+ ssm_state,
553
+ hidden_states_d,
554
+ dt_d,
555
+ A_d,
556
+ B_d,
557
+ C_d,
558
+ D_d,
609
559
  z=None,
610
560
  dt_bias=dt_bias,
611
561
  dt_softplus=True,
612
- state_batch_indices=state_indices_tensor,
613
- out=preallocated_ssm_out.view(num_decodes, -1, self.head_dim),
562
+ state_batch_indices=state_indices_tensor_d,
563
+ out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
614
564
  )
615
- elif forward_batch.forward_mode.is_idle():
616
- preallocated_ssm_out = preallocated_ssm_out
617
565
 
618
566
  # 4. gated MLP
619
567
  # GatedRMSNorm internally applying SiLU to the gate
620
568
  # SiLU is applied internally before normalization, unlike standard
621
569
  # norm usage
622
- hidden_states = self.norm(preallocated_ssm_out, gate)
570
+ hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
623
571
 
624
572
  # 5. Final linear projection
625
- output[:], _ = self.out_proj(hidden_states)
573
+ output[:num_actual_tokens], _ = self.out_proj(hidden_states)
626
574
 
627
575
  @property
628
576
  def mamba_type(self) -> str: