sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,30 @@
|
|
1
|
-
from typing import Callable, List, Optional, Tuple
|
1
|
+
from typing import Callable, List, Optional, Tuple
|
2
2
|
|
3
3
|
import torch
|
4
4
|
import torch.nn as nn
|
5
5
|
|
6
|
-
from sglang.srt.configs.
|
7
|
-
|
6
|
+
from sglang.srt.configs.mamba_utils import (
|
7
|
+
Mamba2CacheParams,
|
8
|
+
extra_groups_for_head_shards,
|
9
|
+
)
|
8
10
|
from sglang.srt.distributed import (
|
11
|
+
divide,
|
9
12
|
get_tensor_model_parallel_rank,
|
10
13
|
get_tensor_model_parallel_world_size,
|
11
|
-
tensor_model_parallel_all_gather,
|
12
|
-
tensor_model_parallel_all_reduce,
|
13
14
|
)
|
14
15
|
from sglang.srt.distributed.utils import divide
|
15
|
-
from sglang.srt.layers.attention.fla.layernorm_gated import layernorm_fn
|
16
16
|
from sglang.srt.layers.attention.mamba.causal_conv1d import (
|
17
17
|
causal_conv1d_fn,
|
18
18
|
causal_conv1d_update,
|
19
19
|
)
|
20
|
-
from sglang.srt.layers.attention.mamba.
|
20
|
+
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
|
21
|
+
causal_conv1d_fn as causal_conv1d_fn_triton,
|
22
|
+
)
|
23
|
+
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
|
24
|
+
causal_conv1d_update as causal_conv1d_update_triton,
|
25
|
+
)
|
26
|
+
from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
|
27
|
+
from sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated import Mixer2RMSNormGated
|
21
28
|
from sglang.srt.layers.attention.mamba.ops import (
|
22
29
|
mamba_chunk_scan_combined,
|
23
30
|
selective_state_update,
|
@@ -28,7 +35,7 @@ from sglang.srt.layers.linear import (
|
|
28
35
|
RowParallelLinear,
|
29
36
|
)
|
30
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
31
|
-
from sglang.srt.
|
38
|
+
from sglang.srt.mem_cache.memory_pool import MambaPool
|
32
39
|
from sglang.srt.model_loader.weight_utils import (
|
33
40
|
composed_weight_loader,
|
34
41
|
sharded_weight_loader,
|
@@ -97,110 +104,6 @@ def mamba_v2_sharded_weight_loader(
|
|
97
104
|
return loader
|
98
105
|
|
99
106
|
|
100
|
-
class Mixer2RMSNormGated(CustomOp):
|
101
|
-
|
102
|
-
def __init__(
|
103
|
-
self,
|
104
|
-
full_hidden_size: int,
|
105
|
-
full_n_groups: int,
|
106
|
-
use_rms_norm: bool = True,
|
107
|
-
eps: float = 1e-6,
|
108
|
-
):
|
109
|
-
super().__init__()
|
110
|
-
self.tp_size = get_tensor_model_parallel_world_size()
|
111
|
-
self.tp_rank = get_tensor_model_parallel_rank()
|
112
|
-
self.full_hidden_size = full_hidden_size
|
113
|
-
self.group_size = full_hidden_size // full_n_groups
|
114
|
-
self.per_rank_hidden_size = full_hidden_size // self.tp_size
|
115
|
-
self.n_groups = full_hidden_size // self.group_size
|
116
|
-
|
117
|
-
self.variance_epsilon = eps
|
118
|
-
self.use_rms_norm = use_rms_norm
|
119
|
-
if self.use_rms_norm:
|
120
|
-
# Register norm weight only if we're actually applying RMSNorm
|
121
|
-
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
122
|
-
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
|
123
|
-
else:
|
124
|
-
# Avoid checkpoint mismatch by skipping unused parameter
|
125
|
-
self.register_parameter("weight", None)
|
126
|
-
assert (
|
127
|
-
self.full_hidden_size % self.tp_size == 0
|
128
|
-
), "Tensor parallel world size must divide hidden size."
|
129
|
-
|
130
|
-
def forward_native(
|
131
|
-
self,
|
132
|
-
x: torch.Tensor,
|
133
|
-
gate: torch.Tensor,
|
134
|
-
):
|
135
|
-
# Three tensor-parallel cases:
|
136
|
-
# 1. n_groups is 1
|
137
|
-
# In this case we parallelize along the reduction dim.
|
138
|
-
# Each rank computes a local sum of squares followed by AllReduce
|
139
|
-
# 2. tp_size divides n_groups
|
140
|
-
# Each rank only reduces within its local group(s).
|
141
|
-
# No collective ops necessary.
|
142
|
-
# 3. The general case can be pretty complicated so we AllGather
|
143
|
-
# the input and then redundantly compute the RMSNorm.
|
144
|
-
input_dtype = x.dtype
|
145
|
-
x = x * nn.functional.silu(gate.to(torch.float32))
|
146
|
-
if not self.use_rms_norm:
|
147
|
-
return x.to(input_dtype)
|
148
|
-
|
149
|
-
if self.n_groups == 1:
|
150
|
-
if self.tp_size > 1:
|
151
|
-
# Compute local sum and then reduce to obtain global sum
|
152
|
-
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
|
153
|
-
global_sums = tensor_model_parallel_all_reduce(local_sums)
|
154
|
-
# Calculate the variance
|
155
|
-
count = self.tp_size * x.shape[-1]
|
156
|
-
variance = global_sums / count
|
157
|
-
|
158
|
-
else:
|
159
|
-
variance = x.pow(2).mean(-1, keepdim=True)
|
160
|
-
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
161
|
-
else:
|
162
|
-
redundant_tp: bool = self.n_groups % self.tp_size != 0
|
163
|
-
if redundant_tp:
|
164
|
-
# To handle the general case, redundantly apply the variance
|
165
|
-
x = tensor_model_parallel_all_gather(x, -1)
|
166
|
-
|
167
|
-
*prefix_dims, hidden_dim = x.shape
|
168
|
-
group_count = hidden_dim // self.group_size
|
169
|
-
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
|
170
|
-
variance = x_grouped.pow(2).mean(-1, keepdim=True)
|
171
|
-
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
|
172
|
-
x = x_grouped.view(*prefix_dims, hidden_dim)
|
173
|
-
|
174
|
-
if redundant_tp:
|
175
|
-
start = self.per_rank_hidden_size * self.tp_rank
|
176
|
-
end = start + self.per_rank_hidden_size
|
177
|
-
x = x[..., start:end]
|
178
|
-
|
179
|
-
return self.weight * x.to(input_dtype)
|
180
|
-
|
181
|
-
def forward_cuda(
|
182
|
-
self,
|
183
|
-
x: torch.Tensor,
|
184
|
-
gate: torch.Tensor,
|
185
|
-
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
186
|
-
input_dtype = x.dtype
|
187
|
-
if not self.use_rms_norm:
|
188
|
-
# Keep gate in float32 for numerical stability during silu
|
189
|
-
return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
|
190
|
-
|
191
|
-
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
|
192
|
-
return self.forward_native(x, gate)
|
193
|
-
|
194
|
-
return layernorm_fn(
|
195
|
-
x,
|
196
|
-
self.weight.data,
|
197
|
-
bias=None,
|
198
|
-
z=gate,
|
199
|
-
eps=self.variance_epsilon,
|
200
|
-
norm_before_gate=False,
|
201
|
-
)
|
202
|
-
|
203
|
-
|
204
107
|
class MambaMixer2(torch.nn.Module):
|
205
108
|
"""
|
206
109
|
Compute ∆, A, B, C, and D the state space parameters and compute
|
@@ -214,22 +117,14 @@ class MambaMixer2(torch.nn.Module):
|
|
214
117
|
|
215
118
|
def __init__(
|
216
119
|
self,
|
120
|
+
cache_params: Mamba2CacheParams,
|
217
121
|
hidden_size: int,
|
218
|
-
ssm_state_size: int,
|
219
|
-
conv_kernel_size: int,
|
220
|
-
intermediate_size: int,
|
221
122
|
use_conv_bias: bool,
|
222
123
|
use_bias: bool,
|
223
|
-
chunk_size: int,
|
224
|
-
layer_id: int,
|
225
124
|
n_groups: int = 1,
|
226
|
-
num_heads: int = 128,
|
227
|
-
head_dim: int = 64,
|
228
125
|
rms_norm_eps: float = 1e-5,
|
229
126
|
activation: str = "silu",
|
230
127
|
use_rms_norm: bool = True,
|
231
|
-
model_config: Optional[ModelConfig] = None,
|
232
|
-
# cache_config: Optional[CacheConfig] = None,
|
233
128
|
quant_config: Optional[QuantizationConfig] = None,
|
234
129
|
prefix: str = "",
|
235
130
|
):
|
@@ -252,6 +147,9 @@ class MambaMixer2(torch.nn.Module):
|
|
252
147
|
self.tp_size = get_tensor_model_parallel_world_size()
|
253
148
|
self.tp_rank = get_tensor_model_parallel_rank()
|
254
149
|
|
150
|
+
self.num_heads = num_heads = cache_params.shape.num_heads
|
151
|
+
self.head_dim = cache_params.shape.head_dim
|
152
|
+
|
255
153
|
assert (
|
256
154
|
num_heads % self.tp_size == 0
|
257
155
|
), "Tensor parallel world size must divide num heads."
|
@@ -261,57 +159,76 @@ class MambaMixer2(torch.nn.Module):
|
|
261
159
|
"then num_groups must equal 1."
|
262
160
|
)
|
263
161
|
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
162
|
+
assert (
|
163
|
+
(n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
|
164
|
+
), (
|
165
|
+
"Tensor parallel currently supported for quantized models only "
|
166
|
+
"if tensor parallel world size divides num groups."
|
167
|
+
)
|
268
168
|
|
269
|
-
self.
|
270
|
-
self.
|
271
|
-
self.num_heads = num_heads
|
272
|
-
self.chunk_size = chunk_size
|
169
|
+
self.ssm_state_size = cache_params.shape.ssm_state_size
|
170
|
+
self.activation = activation
|
273
171
|
|
172
|
+
conv_kernel_size = cache_params.shape.conv_kernel
|
173
|
+
self.intermediate_size = intermediate_size = (
|
174
|
+
cache_params.shape.intermediate_size
|
175
|
+
)
|
274
176
|
self.n_groups = n_groups
|
275
177
|
if n_groups % self.tp_size != 0:
|
276
178
|
# - for TP we shard conv_dim by sharding on n_groups,
|
277
179
|
# - but if n_groups cannot divide tp_size, we need to
|
278
180
|
# extend some extra groups
|
279
|
-
groups =
|
280
|
-
n_groups, self.tp_size
|
281
|
-
)
|
181
|
+
groups = extra_groups_for_head_shards(n_groups, self.tp_size)
|
282
182
|
self.n_groups = n_groups + groups
|
283
|
-
|
284
183
|
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
|
285
|
-
self.conv_dim =
|
286
|
-
|
287
|
-
self.
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
184
|
+
self.conv_dim = cache_params.shape.conv_dim
|
185
|
+
|
186
|
+
if n_groups % self.tp_size == 0:
|
187
|
+
self.conv1d = MergedColumnParallelLinear(
|
188
|
+
input_size=conv_kernel_size,
|
189
|
+
output_sizes=[
|
190
|
+
intermediate_size,
|
191
|
+
self.groups_ssm_state_size,
|
192
|
+
self.groups_ssm_state_size,
|
193
|
+
],
|
194
|
+
bias=use_conv_bias,
|
195
|
+
quant_config=None,
|
196
|
+
prefix=f"{prefix}.conv1d",
|
197
|
+
)
|
298
198
|
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
199
|
+
self.in_proj = MergedColumnParallelLinear(
|
200
|
+
input_size=hidden_size,
|
201
|
+
output_sizes=[
|
202
|
+
intermediate_size,
|
203
|
+
intermediate_size,
|
204
|
+
self.groups_ssm_state_size,
|
205
|
+
self.groups_ssm_state_size,
|
206
|
+
self.num_heads,
|
207
|
+
],
|
208
|
+
bias=use_bias,
|
209
|
+
quant_config=quant_config,
|
210
|
+
prefix=f"{prefix}.in_proj",
|
211
|
+
)
|
212
|
+
else:
|
312
213
|
# This is the n_groups == 1 case,
|
313
214
|
# where we need to duplicate groups if TP>1.
|
314
215
|
|
216
|
+
self.conv1d = ColumnParallelLinear(
|
217
|
+
input_size=conv_kernel_size,
|
218
|
+
output_size=self.conv_dim,
|
219
|
+
bias=use_conv_bias,
|
220
|
+
quant_config=None,
|
221
|
+
prefix=f"{prefix}.conv1d",
|
222
|
+
)
|
223
|
+
|
224
|
+
self.in_proj = ColumnParallelLinear(
|
225
|
+
input_size=hidden_size,
|
226
|
+
output_size=intermediate_size + self.conv_dim + self.num_heads,
|
227
|
+
bias=use_bias,
|
228
|
+
quant_config=quant_config,
|
229
|
+
prefix=f"{prefix}.in_proj",
|
230
|
+
)
|
231
|
+
|
315
232
|
# - because in_proj is a concatenation of 3 weights, we
|
316
233
|
# need to interleave them before sharding
|
317
234
|
# - use the custom weight loader mamba_v2_sharded_weight_loader
|
@@ -421,47 +338,27 @@ class MambaMixer2(torch.nn.Module):
|
|
421
338
|
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
|
422
339
|
)
|
423
340
|
|
424
|
-
# The tuple is (conv_state, ssm_state)
|
425
|
-
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
426
|
-
|
427
|
-
self.model_config = model_config
|
428
341
|
self.prefix = prefix
|
429
342
|
|
430
|
-
def forward_native(
|
431
|
-
self,
|
432
|
-
hidden_states: torch.Tensor,
|
433
|
-
output: torch.Tensor,
|
434
|
-
mup_vector: Optional[torch.Tensor] = None,
|
435
|
-
):
|
436
|
-
pass
|
437
|
-
|
438
343
|
def forward(
|
439
344
|
self,
|
345
|
+
*,
|
440
346
|
hidden_states: torch.Tensor,
|
441
347
|
output: torch.Tensor,
|
442
|
-
|
348
|
+
layer_cache: MambaPool.State,
|
349
|
+
metadata: Mamba2Metadata,
|
443
350
|
mup_vector: Optional[torch.Tensor] = None,
|
351
|
+
use_triton_causal_conv: bool = False,
|
444
352
|
):
|
445
|
-
#
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
self.layer_id
|
453
|
-
)
|
353
|
+
# metadata contains metadata necessary for the mamba2 triton
|
354
|
+
# kernels to operate in continuous batching and in chunked prefill
|
355
|
+
# modes; they are computed at top-level model forward since they
|
356
|
+
# stay the same and reused for all mamba layers in the same iteration
|
357
|
+
state_indices_tensor = metadata.mamba_cache_indices
|
358
|
+
conv_state = layer_cache.conv
|
359
|
+
ssm_state = layer_cache.temporal
|
454
360
|
|
455
|
-
|
456
|
-
ssm_state.size(1) == self.ssm_state_size
|
457
|
-
), f"dstate must be {self.ssm_state_size}, got {ssm_state.size(1)}"
|
458
|
-
|
459
|
-
query_start_loc = attn_metadata.query_start_loc
|
460
|
-
|
461
|
-
chunk_size = self.chunk_size
|
462
|
-
|
463
|
-
# TODO: properly support this
|
464
|
-
prep_initial_states = False
|
361
|
+
query_start_loc = metadata.query_start_loc
|
465
362
|
|
466
363
|
# 1. Gated MLP's linear projection
|
467
364
|
projected_states, _ = self.in_proj(hidden_states)
|
@@ -493,6 +390,38 @@ class MambaMixer2(torch.nn.Module):
|
|
493
390
|
dim=-1,
|
494
391
|
)
|
495
392
|
|
393
|
+
num_prefills = metadata.num_prefills # request count
|
394
|
+
num_decodes = metadata.num_decodes # token count (=request)
|
395
|
+
num_prefill_tokens = metadata.num_prefill_tokens # token count
|
396
|
+
has_prefill = num_prefills > 0
|
397
|
+
has_decode = num_decodes > 0
|
398
|
+
num_actual_tokens = num_prefill_tokens + num_decodes
|
399
|
+
assert num_actual_tokens == projected_states.shape[0]
|
400
|
+
|
401
|
+
# NOTE: V0 put prefill before decode
|
402
|
+
# Separate prefill and decode by splitting varlen input
|
403
|
+
# Split along token dimension
|
404
|
+
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
|
405
|
+
hidden_states_B_C,
|
406
|
+
[num_prefill_tokens, num_decodes],
|
407
|
+
dim=0,
|
408
|
+
)
|
409
|
+
dt_p, dt_d = torch.split(
|
410
|
+
dt,
|
411
|
+
[num_prefill_tokens, num_decodes],
|
412
|
+
dim=0,
|
413
|
+
)
|
414
|
+
# Split along batch dimension
|
415
|
+
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
416
|
+
state_indices_tensor,
|
417
|
+
[num_prefills, num_decodes],
|
418
|
+
dim=0,
|
419
|
+
)
|
420
|
+
query_start_loc_p = query_start_loc[: num_prefills + 1] if has_prefill else None
|
421
|
+
|
422
|
+
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
423
|
+
# and decode outputs
|
424
|
+
|
496
425
|
preallocated_ssm_out = torch.empty(
|
497
426
|
[
|
498
427
|
projected_states.shape[0],
|
@@ -501,128 +430,147 @@ class MambaMixer2(torch.nn.Module):
|
|
501
430
|
dtype=hidden_states.dtype,
|
502
431
|
device=hidden_states.device,
|
503
432
|
)
|
433
|
+
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
434
|
+
preallocated_ssm_out,
|
435
|
+
[num_prefill_tokens, num_decodes],
|
436
|
+
dim=0,
|
437
|
+
)
|
504
438
|
|
505
439
|
# Process prefill requests
|
506
|
-
if
|
440
|
+
if has_prefill:
|
441
|
+
mixed_metadata = metadata.mixed_metadata
|
442
|
+
assert mixed_metadata is not None
|
507
443
|
# 2. Convolution sequence transformation
|
508
444
|
# - "cache_indices" updates the conv_state cache in positions
|
509
445
|
# pointed to by "state_indices_tensor"
|
510
|
-
|
511
|
-
|
512
|
-
cache_indices =
|
513
|
-
|
514
|
-
x = hidden_states_B_C.transpose(
|
446
|
+
has_initial_states_p = mixed_metadata.has_initial_states
|
447
|
+
prep_initial_states = mixed_metadata.prep_initial_states
|
448
|
+
cache_indices = state_indices_tensor_p
|
449
|
+
x = hidden_states_B_C_p.transpose(
|
515
450
|
0, 1
|
516
451
|
) # this is the form that causal-conv see
|
517
|
-
|
452
|
+
ccfn = (
|
453
|
+
causal_conv1d_fn
|
454
|
+
if not use_triton_causal_conv
|
455
|
+
else causal_conv1d_fn_triton
|
456
|
+
)
|
457
|
+
hidden_states_B_C_p = ccfn(
|
518
458
|
x,
|
519
459
|
conv_weights,
|
520
460
|
self.conv1d.bias,
|
521
461
|
activation=self.activation,
|
522
462
|
conv_states=conv_state,
|
523
|
-
has_initial_state=
|
463
|
+
has_initial_state=has_initial_states_p,
|
524
464
|
cache_indices=cache_indices,
|
525
|
-
query_start_loc=
|
526
|
-
seq_lens_cpu=
|
527
|
-
).transpose(0, 1)
|
465
|
+
query_start_loc=query_start_loc_p,
|
466
|
+
seq_lens_cpu=mixed_metadata.extend_seq_lens_cpu,
|
467
|
+
).transpose(0, 1)[:num_prefill_tokens]
|
528
468
|
|
529
|
-
|
469
|
+
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
|
530
470
|
|
531
471
|
# 3. State Space Model sequence transformation
|
532
472
|
initial_states = None
|
533
|
-
|
534
|
-
if has_initial_states is not None and prep_initial_states:
|
473
|
+
if has_initial_states_p is not None and prep_initial_states:
|
535
474
|
initial_states = torch.where(
|
536
|
-
|
537
|
-
ssm_state[
|
475
|
+
has_initial_states_p[:, None, None, None],
|
476
|
+
ssm_state[state_indices_tensor_p],
|
538
477
|
0,
|
539
478
|
)
|
540
479
|
|
541
480
|
# NOTE: final output is an in-place update of out tensor
|
542
481
|
varlen_state = mamba_chunk_scan_combined(
|
543
|
-
|
482
|
+
hidden_states_p.view(
|
544
483
|
1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
|
545
484
|
),
|
546
|
-
|
485
|
+
dt_p.unsqueeze(0),
|
547
486
|
self.A,
|
548
|
-
|
549
|
-
|
550
|
-
chunk_size=chunk_size,
|
487
|
+
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
488
|
+
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
489
|
+
chunk_size=mixed_metadata.chunk_size,
|
551
490
|
D=self.D,
|
552
491
|
z=None,
|
553
492
|
dt_bias=self.dt_bias,
|
554
|
-
|
493
|
+
seq_idx=mixed_metadata.seq_idx,
|
494
|
+
chunk_indices=mixed_metadata.chunk_indices,
|
495
|
+
chunk_offsets=mixed_metadata.chunk_offsets,
|
496
|
+
cu_seqlens=query_start_loc_p,
|
555
497
|
initial_states=initial_states,
|
556
498
|
return_varlen_states=True,
|
557
499
|
return_final_states=False,
|
558
500
|
dt_softplus=True,
|
559
501
|
dt_limit=(0.0, float("inf")),
|
560
|
-
out=
|
502
|
+
out=preallocated_ssm_out_p.view(
|
503
|
+
1, num_prefill_tokens, -1, self.head_dim
|
504
|
+
),
|
561
505
|
state_dtype=ssm_state.dtype,
|
562
506
|
)
|
563
507
|
|
564
508
|
# update ssm states
|
565
509
|
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
566
|
-
ssm_state[
|
567
|
-
|
568
|
-
|
510
|
+
ssm_state[state_indices_tensor_p] = varlen_state
|
511
|
+
|
512
|
+
# Process decode requests
|
513
|
+
if has_decode:
|
569
514
|
# 2. Convolution sequence transformation
|
570
|
-
|
571
|
-
|
515
|
+
ccu = (
|
516
|
+
causal_conv1d_update
|
517
|
+
if not use_triton_causal_conv
|
518
|
+
else causal_conv1d_update_triton
|
519
|
+
)
|
520
|
+
hidden_states_B_C_d = ccu(
|
521
|
+
hidden_states_B_C_d,
|
572
522
|
conv_state,
|
573
523
|
conv_weights,
|
574
524
|
self.conv1d.bias,
|
575
525
|
self.activation,
|
576
|
-
conv_state_indices=
|
526
|
+
conv_state_indices=state_indices_tensor_d,
|
577
527
|
)
|
578
528
|
|
579
|
-
|
529
|
+
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
|
580
530
|
|
581
531
|
# 3. State Space Model sequence transformation
|
582
532
|
n_groups = self.n_groups // self.tp_size
|
583
|
-
|
533
|
+
A_d = (
|
584
534
|
self.A[:, None, ...][:, :, None]
|
585
535
|
.expand(-1, self.head_dim, self.ssm_state_size)
|
586
536
|
.to(dtype=torch.float32)
|
587
537
|
)
|
588
|
-
|
538
|
+
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
|
589
539
|
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
540
|
+
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
|
541
|
+
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
|
542
|
+
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
|
543
|
+
hidden_states_d = hidden_states_d.view(
|
594
544
|
-1, self.num_heads // self.tp_size, self.head_dim
|
595
545
|
)
|
596
546
|
|
597
547
|
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
598
|
-
# -
|
548
|
+
# - layer_state.ssm_state's slots will be selected
|
599
549
|
# using state_indices_tensor_d
|
600
550
|
# NOTE: final output is an in-place update of out tensor
|
601
551
|
selective_state_update(
|
602
|
-
ssm_state
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
552
|
+
ssm_state,
|
553
|
+
hidden_states_d,
|
554
|
+
dt_d,
|
555
|
+
A_d,
|
556
|
+
B_d,
|
557
|
+
C_d,
|
558
|
+
D_d,
|
609
559
|
z=None,
|
610
560
|
dt_bias=dt_bias,
|
611
561
|
dt_softplus=True,
|
612
|
-
state_batch_indices=
|
613
|
-
out=
|
562
|
+
state_batch_indices=state_indices_tensor_d,
|
563
|
+
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
|
614
564
|
)
|
615
|
-
elif forward_batch.forward_mode.is_idle():
|
616
|
-
preallocated_ssm_out = preallocated_ssm_out
|
617
565
|
|
618
566
|
# 4. gated MLP
|
619
567
|
# GatedRMSNorm internally applying SiLU to the gate
|
620
568
|
# SiLU is applied internally before normalization, unlike standard
|
621
569
|
# norm usage
|
622
|
-
hidden_states = self.norm(preallocated_ssm_out, gate)
|
570
|
+
hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
|
623
571
|
|
624
572
|
# 5. Final linear projection
|
625
|
-
output[:], _ = self.out_proj(hidden_states)
|
573
|
+
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
626
574
|
|
627
575
|
@property
|
628
576
|
def mamba_type(self) -> str:
|