sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,211 @@
|
|
1
|
+
# Copyright 2025 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/vllm/model_executor/layers/mamba/mamba2_metadata.py
|
15
|
+
|
16
|
+
|
17
|
+
import math
|
18
|
+
from dataclasses import dataclass
|
19
|
+
|
20
|
+
import torch
|
21
|
+
|
22
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass(kw_only=True)
|
26
|
+
class ForwardMetadata:
|
27
|
+
query_start_loc: torch.Tensor
|
28
|
+
mamba_cache_indices: torch.Tensor
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass(kw_only=True)
|
32
|
+
class Mamba2Metadata(ForwardMetadata):
|
33
|
+
"""stable metadata across all mamba2 layers in the forward pass"""
|
34
|
+
|
35
|
+
num_prefills: int
|
36
|
+
num_prefill_tokens: int
|
37
|
+
num_decodes: int
|
38
|
+
|
39
|
+
@dataclass(kw_only=True, frozen=True)
|
40
|
+
class MixedMetadata:
|
41
|
+
has_initial_states: torch.Tensor
|
42
|
+
prep_initial_states: bool
|
43
|
+
|
44
|
+
chunk_size: int
|
45
|
+
seq_idx: torch.Tensor
|
46
|
+
chunk_indices: torch.Tensor
|
47
|
+
chunk_offsets: torch.Tensor
|
48
|
+
|
49
|
+
extend_seq_lens_cpu: list[int]
|
50
|
+
|
51
|
+
mixed_metadata: MixedMetadata | None = None
|
52
|
+
"""`mixed_metadata` is used for extend/mixed requests"""
|
53
|
+
|
54
|
+
@staticmethod
|
55
|
+
def _query_start_loc_to_chunk_indices_offsets(
|
56
|
+
query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int
|
57
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
58
|
+
"""
|
59
|
+
Args:
|
60
|
+
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
|
61
|
+
lengths, shape (num_seqs + 1,).
|
62
|
+
The first element should be 0. Each entry represents the starting
|
63
|
+
index of a sequence in the flattened token array.
|
64
|
+
chunk_size (int): The size of each physical mamba chunk
|
65
|
+
(number of tokens per chunk).
|
66
|
+
total_seqlens (int): The total number of tokens in the batch.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
70
|
+
- chunk_indices (torch.Tensor): 1D tensor of indices
|
71
|
+
indicating the physical chunk for each logical chunk.
|
72
|
+
- chunk_offsets (torch.Tensor): 1D tensor of offsets
|
73
|
+
indicating the starting index of each logical chunk within
|
74
|
+
its physical chunk.
|
75
|
+
|
76
|
+
This function computes the chunk indices and offsets for the given
|
77
|
+
query_start_loc and chunk_size. Both are tensors of integers with length N,
|
78
|
+
where N is the number of logical (pseudo) chunks.
|
79
|
+
A logical chunk is a sequence of tokens that are all part of the same
|
80
|
+
sequence and are all in the same physical mamba chunk.
|
81
|
+
In other words, a logical chunk changes every time we cross a sequence
|
82
|
+
boundary or a physical mamba chunk boundary.
|
83
|
+
Logical chunks are needed to handle batched requests with initial states
|
84
|
+
(see _state_passing_fwd and _chunk_scan_fwd).
|
85
|
+
The chunk_indices tensor contains the index of the physical chunk for each
|
86
|
+
logical chunk.
|
87
|
+
The chunk_offsets tensor contains the offset (AKA starting index) of the
|
88
|
+
logical chunk in the physical chunk.
|
89
|
+
|
90
|
+
Example:
|
91
|
+
query_start_loc = [0, 5, 10]
|
92
|
+
chunk_size = 8
|
93
|
+
total_seqlens = 10
|
94
|
+
-> chunk_indices = [0, 0, 1]
|
95
|
+
-> chunk_offsets = [0, 5, 0]
|
96
|
+
|
97
|
+
In this example, we have 2 sequences, each with 5 tokens. The physical
|
98
|
+
chunk size is 8 tokens.
|
99
|
+
We have three logical chunks:
|
100
|
+
- the first logical chunk starts at token 0 in the first physical chunk
|
101
|
+
and contains all 5 tokens from the first sequence
|
102
|
+
- the second logical chunk starts at token 5 in the first physical chunk
|
103
|
+
and contains first 3 tokens from the second sequence
|
104
|
+
- the third logical chunk starts at token 0 in the second physical chunk
|
105
|
+
and contains the remaining 2 tokens from the second sequence
|
106
|
+
"""
|
107
|
+
|
108
|
+
cu_seqlens = query_start_loc[1:] # remove prepended 0
|
109
|
+
|
110
|
+
# outputs will have length expansion of chunks that do not divide
|
111
|
+
# chunk_size
|
112
|
+
N = (
|
113
|
+
math.ceil(total_seqlens / chunk_size)
|
114
|
+
+ (cu_seqlens[:-1] % chunk_size > 0).sum()
|
115
|
+
)
|
116
|
+
chunk_indices = torch.arange(N, dtype=torch.int, device=query_start_loc.device)
|
117
|
+
chunk_offsets = torch.zeros(
|
118
|
+
(N,), dtype=torch.int, device=query_start_loc.device
|
119
|
+
)
|
120
|
+
|
121
|
+
p = 0 # num of insertions
|
122
|
+
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
123
|
+
|
124
|
+
# if does not divide chunk_size, then there is one chunk insertion
|
125
|
+
p += s % chunk_size > 0
|
126
|
+
|
127
|
+
# get the dimensions
|
128
|
+
# - the + 1 for _e is to shift the boundary by one chunk
|
129
|
+
# - this shifting is not needed if chunk_size divides e
|
130
|
+
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0)
|
131
|
+
|
132
|
+
# adjust indices and offsets
|
133
|
+
chunk_indices[_s:_e] -= p
|
134
|
+
chunk_offsets[_s] = s % chunk_size
|
135
|
+
|
136
|
+
return chunk_indices, chunk_offsets
|
137
|
+
|
138
|
+
@staticmethod
|
139
|
+
def prepare_decode(
|
140
|
+
query_start_loc: torch.Tensor,
|
141
|
+
mamba_cache_indices: torch.Tensor,
|
142
|
+
seq_lens: torch.Tensor,
|
143
|
+
) -> "Mamba2Metadata":
|
144
|
+
"""This path is run during CUDA graph capture, i.e. decode only, so `num_prefills` is 0"""
|
145
|
+
return Mamba2Metadata(
|
146
|
+
query_start_loc=query_start_loc,
|
147
|
+
mamba_cache_indices=mamba_cache_indices,
|
148
|
+
num_decodes=len(seq_lens),
|
149
|
+
num_prefills=0,
|
150
|
+
num_prefill_tokens=0,
|
151
|
+
)
|
152
|
+
|
153
|
+
@classmethod
|
154
|
+
def prepare_mixed(
|
155
|
+
cls,
|
156
|
+
query_start_loc: torch.Tensor,
|
157
|
+
mamba_cache_indices: torch.Tensor,
|
158
|
+
chunk_size: int,
|
159
|
+
forward_batch: ForwardBatch,
|
160
|
+
) -> "Mamba2Metadata":
|
161
|
+
"""This path cannot run with CUDA graph, as it contains extend requests."""
|
162
|
+
if forward_batch.extend_num_tokens is None:
|
163
|
+
return cls.prepare_decode(
|
164
|
+
query_start_loc, mamba_cache_indices, forward_batch.seq_lens
|
165
|
+
)
|
166
|
+
num_prefills = len(forward_batch.extend_seq_lens)
|
167
|
+
num_prefill_tokens = forward_batch.extend_num_tokens
|
168
|
+
num_decodes = len(forward_batch.seq_lens) - num_prefills
|
169
|
+
context_lens_tensor = forward_batch.extend_prefix_lens
|
170
|
+
assert context_lens_tensor is not None
|
171
|
+
# precompute flag to avoid device syncs later
|
172
|
+
has_initial_states = context_lens_tensor > 0
|
173
|
+
prep_initial_states = torch.any(has_initial_states[:num_prefills]).item()
|
174
|
+
|
175
|
+
query_start_loc = query_start_loc[: num_prefills + 1]
|
176
|
+
seq_idx = torch.repeat_interleave(
|
177
|
+
torch.arange(
|
178
|
+
num_prefills, dtype=torch.int32, device=query_start_loc.device
|
179
|
+
),
|
180
|
+
query_start_loc.diff(),
|
181
|
+
output_size=num_prefill_tokens,
|
182
|
+
)
|
183
|
+
seq_idx.unsqueeze_(0)
|
184
|
+
|
185
|
+
# We compute metadata for chunked prefill once at the top level model
|
186
|
+
# forward and reuse them in mamba layers. If not needed, they will be
|
187
|
+
# ignored inside mamba kernels.
|
188
|
+
chunk_offsets, chunk_indices = None, None
|
189
|
+
if prep_initial_states:
|
190
|
+
chunk_indices, chunk_offsets = (
|
191
|
+
cls._query_start_loc_to_chunk_indices_offsets(
|
192
|
+
query_start_loc, chunk_size, num_prefill_tokens
|
193
|
+
)
|
194
|
+
)
|
195
|
+
|
196
|
+
return Mamba2Metadata(
|
197
|
+
query_start_loc=query_start_loc,
|
198
|
+
mamba_cache_indices=mamba_cache_indices,
|
199
|
+
num_prefills=num_prefills,
|
200
|
+
num_prefill_tokens=num_prefill_tokens,
|
201
|
+
num_decodes=num_decodes,
|
202
|
+
mixed_metadata=cls.MixedMetadata(
|
203
|
+
has_initial_states=has_initial_states,
|
204
|
+
prep_initial_states=prep_initial_states,
|
205
|
+
chunk_size=chunk_size,
|
206
|
+
seq_idx=seq_idx,
|
207
|
+
chunk_indices=chunk_indices,
|
208
|
+
chunk_offsets=chunk_offsets,
|
209
|
+
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
210
|
+
),
|
211
|
+
)
|
@@ -0,0 +1,120 @@
|
|
1
|
+
from typing import Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.custom_op import CustomOp
|
6
|
+
from sglang.srt.distributed.communication_op import (
|
7
|
+
tensor_model_parallel_all_gather,
|
8
|
+
tensor_model_parallel_all_reduce,
|
9
|
+
)
|
10
|
+
from sglang.srt.distributed.parallel_state import (
|
11
|
+
get_tensor_model_parallel_rank,
|
12
|
+
get_tensor_model_parallel_world_size,
|
13
|
+
)
|
14
|
+
from sglang.srt.layers.attention.fla.layernorm_gated import rms_norm_gated
|
15
|
+
from sglang.srt.model_loader.weight_utils import sharded_weight_loader
|
16
|
+
from sglang.srt.utils.common import set_weight_attrs
|
17
|
+
|
18
|
+
|
19
|
+
class Mixer2RMSNormGated(CustomOp):
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
full_hidden_size: int,
|
23
|
+
full_n_groups: int,
|
24
|
+
use_rms_norm: bool = True,
|
25
|
+
eps: float = 1e-6,
|
26
|
+
):
|
27
|
+
super().__init__()
|
28
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
29
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
30
|
+
self.full_hidden_size = full_hidden_size
|
31
|
+
self.group_size = full_hidden_size // full_n_groups
|
32
|
+
self.per_rank_hidden_size = full_hidden_size // self.tp_size
|
33
|
+
self.n_groups = full_hidden_size // self.group_size
|
34
|
+
|
35
|
+
self.variance_epsilon = eps
|
36
|
+
self.use_rms_norm = use_rms_norm
|
37
|
+
if self.use_rms_norm:
|
38
|
+
# Register norm weight only if we're actually applying RMSNorm
|
39
|
+
self.weight = torch.nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
40
|
+
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
|
41
|
+
else:
|
42
|
+
# Avoid checkpoint mismatch by skipping unused parameter
|
43
|
+
self.register_parameter("weight", None)
|
44
|
+
assert (
|
45
|
+
self.full_hidden_size % self.tp_size == 0
|
46
|
+
), "Tensor parallel world size must divide hidden size."
|
47
|
+
|
48
|
+
def forward_native(
|
49
|
+
self,
|
50
|
+
x: torch.Tensor,
|
51
|
+
gate: torch.Tensor,
|
52
|
+
):
|
53
|
+
# Three tensor-parallel cases:
|
54
|
+
# 1. n_groups is 1
|
55
|
+
# In this case we parallelize along the reduction dim.
|
56
|
+
# Each rank computes a local sum of squares followed by AllReduce
|
57
|
+
# 2. tp_size divides n_groups
|
58
|
+
# Each rank only reduces within its local group(s).
|
59
|
+
# No collective ops necessary.
|
60
|
+
# 3. The general case can be pretty complicated so we AllGather
|
61
|
+
# the input and then redundantly compute the RMSNorm.
|
62
|
+
input_dtype = x.dtype
|
63
|
+
x = x * torch.nn.functional.silu(gate.to(torch.float32))
|
64
|
+
if not self.use_rms_norm:
|
65
|
+
return x.to(input_dtype)
|
66
|
+
|
67
|
+
if self.n_groups == 1:
|
68
|
+
if self.tp_size > 1:
|
69
|
+
# Compute local sum and then reduce to obtain global sum
|
70
|
+
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
|
71
|
+
global_sums = tensor_model_parallel_all_reduce(local_sums)
|
72
|
+
# Calculate the variance
|
73
|
+
count = self.tp_size * x.shape[-1]
|
74
|
+
variance = global_sums / count
|
75
|
+
|
76
|
+
else:
|
77
|
+
variance = x.pow(2).mean(-1, keepdim=True)
|
78
|
+
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
79
|
+
else:
|
80
|
+
redundant_tp: bool = self.n_groups % self.tp_size != 0
|
81
|
+
if redundant_tp:
|
82
|
+
# To handle the general case, redundantly apply the variance
|
83
|
+
x = tensor_model_parallel_all_gather(x, -1)
|
84
|
+
|
85
|
+
*prefix_dims, hidden_dim = x.shape
|
86
|
+
group_count = hidden_dim // self.group_size
|
87
|
+
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
|
88
|
+
variance = x_grouped.pow(2).mean(-1, keepdim=True)
|
89
|
+
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
|
90
|
+
x = x_grouped.view(*prefix_dims, hidden_dim)
|
91
|
+
|
92
|
+
if redundant_tp:
|
93
|
+
start = self.per_rank_hidden_size * self.tp_rank
|
94
|
+
end = start + self.per_rank_hidden_size
|
95
|
+
x = x[..., start:end]
|
96
|
+
|
97
|
+
return self.weight * x.to(input_dtype)
|
98
|
+
|
99
|
+
def forward_cuda(
|
100
|
+
self,
|
101
|
+
x: torch.Tensor,
|
102
|
+
gate: torch.Tensor,
|
103
|
+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
104
|
+
input_dtype = x.dtype
|
105
|
+
if not self.use_rms_norm:
|
106
|
+
# Keep gate in float32 for numerical stability during silu
|
107
|
+
return x * torch.nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
|
108
|
+
|
109
|
+
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
|
110
|
+
return self.forward_native(x, gate)
|
111
|
+
|
112
|
+
return rms_norm_gated(
|
113
|
+
x=x,
|
114
|
+
weight=self.weight.data,
|
115
|
+
bias=None,
|
116
|
+
z=gate,
|
117
|
+
eps=self.variance_epsilon,
|
118
|
+
norm_before_gate=False,
|
119
|
+
is_rms_norm=True,
|
120
|
+
)
|
@@ -15,56 +15,6 @@ import triton
|
|
15
15
|
import triton.language as tl
|
16
16
|
|
17
17
|
|
18
|
-
# @triton.autotune(
|
19
|
-
# configs=[
|
20
|
-
# triton.Config(
|
21
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
22
|
-
# num_stages=3,
|
23
|
-
# num_warps=8,
|
24
|
-
# ),
|
25
|
-
# triton.Config(
|
26
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
27
|
-
# num_stages=4,
|
28
|
-
# num_warps=4,
|
29
|
-
# ),
|
30
|
-
# triton.Config(
|
31
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
32
|
-
# num_stages=4,
|
33
|
-
# num_warps=4,
|
34
|
-
# ),
|
35
|
-
# triton.Config(
|
36
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
37
|
-
# num_stages=4,
|
38
|
-
# num_warps=4,
|
39
|
-
# ),
|
40
|
-
# triton.Config(
|
41
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
42
|
-
# num_stages=4,
|
43
|
-
# num_warps=4,
|
44
|
-
# ),
|
45
|
-
# triton.Config(
|
46
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
47
|
-
# num_stages=4,
|
48
|
-
# num_warps=4,
|
49
|
-
# ),
|
50
|
-
# triton.Config(
|
51
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
52
|
-
# num_stages=5,
|
53
|
-
# num_warps=2,
|
54
|
-
# ),
|
55
|
-
# triton.Config(
|
56
|
-
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
57
|
-
# num_stages=5,
|
58
|
-
# num_warps=2,
|
59
|
-
# ),
|
60
|
-
# triton.Config(
|
61
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
62
|
-
# num_stages=4,
|
63
|
-
# num_warps=2,
|
64
|
-
# ),
|
65
|
-
# ],
|
66
|
-
# key=["chunk_size", "K", "IS_CAUSAL"],
|
67
|
-
# )
|
68
18
|
@triton.jit
|
69
19
|
def _bmm_chunk_fwd_kernel(
|
70
20
|
# Pointers to matrices
|
@@ -16,66 +16,6 @@ from packaging import version
|
|
16
16
|
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
17
17
|
|
18
18
|
|
19
|
-
# @triton.autotune(
|
20
|
-
# configs=[
|
21
|
-
# triton.Config(
|
22
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
23
|
-
# num_stages=3,
|
24
|
-
# num_warps=8,
|
25
|
-
# ),
|
26
|
-
# triton.Config(
|
27
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
28
|
-
# num_stages=4,
|
29
|
-
# num_warps=4,
|
30
|
-
# ),
|
31
|
-
# triton.Config(
|
32
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
33
|
-
# num_stages=4,
|
34
|
-
# num_warps=4,
|
35
|
-
# ),
|
36
|
-
# triton.Config(
|
37
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
38
|
-
# num_stages=4,
|
39
|
-
# num_warps=4,
|
40
|
-
# ),
|
41
|
-
# triton.Config(
|
42
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
43
|
-
# num_stages=4,
|
44
|
-
# num_warps=4,
|
45
|
-
# ),
|
46
|
-
# triton.Config(
|
47
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
|
48
|
-
# num_stages=4,
|
49
|
-
# num_warps=4,
|
50
|
-
# ),
|
51
|
-
# triton.Config(
|
52
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
53
|
-
# num_stages=4,
|
54
|
-
# num_warps=4,
|
55
|
-
# ),
|
56
|
-
# triton.Config(
|
57
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
58
|
-
# num_stages=4,
|
59
|
-
# num_warps=4,
|
60
|
-
# ),
|
61
|
-
# triton.Config(
|
62
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
63
|
-
# num_stages=5,
|
64
|
-
# num_warps=2,
|
65
|
-
# ),
|
66
|
-
# triton.Config(
|
67
|
-
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
68
|
-
# num_stages=5,
|
69
|
-
# num_warps=2,
|
70
|
-
# ),
|
71
|
-
# triton.Config(
|
72
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
73
|
-
# num_stages=4,
|
74
|
-
# num_warps=2,
|
75
|
-
# ),
|
76
|
-
# ],
|
77
|
-
# key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
|
78
|
-
# )
|
79
19
|
@triton.jit
|
80
20
|
def _chunk_scan_fwd_kernel(
|
81
21
|
# Pointers to matrices
|
@@ -17,17 +17,6 @@ import triton.language as tl
|
|
17
17
|
from .mamba_ssm import softplus
|
18
18
|
|
19
19
|
|
20
|
-
# @triton.autotune(
|
21
|
-
# configs=[
|
22
|
-
# triton.Config({"BLOCK_SIZE_H": 2}),
|
23
|
-
# triton.Config({"BLOCK_SIZE_H": 4}),
|
24
|
-
# triton.Config({"BLOCK_SIZE_H": 8}),
|
25
|
-
# triton.Config({"BLOCK_SIZE_H": 16}),
|
26
|
-
# triton.Config({"BLOCK_SIZE_H": 32}),
|
27
|
-
# triton.Config({"BLOCK_SIZE_H": 64}),
|
28
|
-
# ],
|
29
|
-
# key=["chunk_size", "nheads"],
|
30
|
-
# )
|
31
20
|
@triton.jit
|
32
21
|
def _chunk_cumsum_fwd_kernel(
|
33
22
|
# Pointers to matrices
|
@@ -120,56 +109,6 @@ def _chunk_cumsum_fwd_kernel(
|
|
120
109
|
)
|
121
110
|
|
122
111
|
|
123
|
-
# @triton.autotune(
|
124
|
-
# configs=[
|
125
|
-
# triton.Config(
|
126
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
127
|
-
# num_stages=3,
|
128
|
-
# num_warps=8,
|
129
|
-
# ),
|
130
|
-
# triton.Config(
|
131
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
132
|
-
# num_stages=4,
|
133
|
-
# num_warps=4,
|
134
|
-
# ),
|
135
|
-
# triton.Config(
|
136
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
137
|
-
# num_stages=4,
|
138
|
-
# num_warps=4,
|
139
|
-
# ),
|
140
|
-
# triton.Config(
|
141
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
142
|
-
# num_stages=4,
|
143
|
-
# num_warps=4,
|
144
|
-
# ),
|
145
|
-
# triton.Config(
|
146
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
147
|
-
# num_stages=4,
|
148
|
-
# num_warps=4,
|
149
|
-
# ),
|
150
|
-
# triton.Config(
|
151
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
152
|
-
# num_stages=4,
|
153
|
-
# num_warps=4,
|
154
|
-
# ),
|
155
|
-
# triton.Config(
|
156
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
157
|
-
# num_stages=5,
|
158
|
-
# num_warps=2,
|
159
|
-
# ),
|
160
|
-
# triton.Config(
|
161
|
-
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
162
|
-
# num_stages=5,
|
163
|
-
# num_warps=2,
|
164
|
-
# ),
|
165
|
-
# triton.Config(
|
166
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
167
|
-
# num_stages=4,
|
168
|
-
# num_warps=2,
|
169
|
-
# ),
|
170
|
-
# ],
|
171
|
-
# key=["hdim", "dstate", "chunk_size"],
|
172
|
-
# )
|
173
112
|
@triton.jit
|
174
113
|
def _chunk_state_fwd_kernel(
|
175
114
|
# Pointers to matrices
|
@@ -320,56 +259,6 @@ def _chunk_state_fwd_kernel(
|
|
320
259
|
tl.store(states_ptrs, states, mask=c_mask)
|
321
260
|
|
322
261
|
|
323
|
-
# @triton.autotune(
|
324
|
-
# configs=[
|
325
|
-
# triton.Config(
|
326
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
327
|
-
# num_stages=3,
|
328
|
-
# num_warps=8,
|
329
|
-
# ),
|
330
|
-
# triton.Config(
|
331
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
332
|
-
# num_stages=4,
|
333
|
-
# num_warps=4,
|
334
|
-
# ),
|
335
|
-
# triton.Config(
|
336
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
337
|
-
# num_stages=4,
|
338
|
-
# num_warps=4,
|
339
|
-
# ),
|
340
|
-
# triton.Config(
|
341
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
342
|
-
# num_stages=4,
|
343
|
-
# num_warps=4,
|
344
|
-
# ),
|
345
|
-
# triton.Config(
|
346
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
347
|
-
# num_stages=4,
|
348
|
-
# num_warps=4,
|
349
|
-
# ),
|
350
|
-
# triton.Config(
|
351
|
-
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
352
|
-
# num_stages=4,
|
353
|
-
# num_warps=4,
|
354
|
-
# ),
|
355
|
-
# triton.Config(
|
356
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
357
|
-
# num_stages=5,
|
358
|
-
# num_warps=2,
|
359
|
-
# ),
|
360
|
-
# triton.Config(
|
361
|
-
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
362
|
-
# num_stages=5,
|
363
|
-
# num_warps=2,
|
364
|
-
# ),
|
365
|
-
# triton.Config(
|
366
|
-
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
367
|
-
# num_stages=4,
|
368
|
-
# num_warps=2,
|
369
|
-
# ),
|
370
|
-
# ],
|
371
|
-
# key=["hdim", "dstate", "chunk_size"],
|
372
|
-
# )
|
373
262
|
@triton.jit
|
374
263
|
def _chunk_state_varlen_kernel(
|
375
264
|
# Pointers to matrices
|
@@ -13,17 +13,6 @@ import triton
|
|
13
13
|
import triton.language as tl
|
14
14
|
|
15
15
|
|
16
|
-
# @triton.autotune(
|
17
|
-
# configs=[
|
18
|
-
# triton.Config({"BLOCK_SIZE": 64}),
|
19
|
-
# triton.Config({"BLOCK_SIZE": 128}),
|
20
|
-
# triton.Config({"BLOCK_SIZE": 256}),
|
21
|
-
# triton.Config({"BLOCK_SIZE": 512}),
|
22
|
-
# triton.Config({"BLOCK_SIZE": 1024}),
|
23
|
-
# triton.Config({"BLOCK_SIZE": 2048}),
|
24
|
-
# ],
|
25
|
-
# key=["dim"],
|
26
|
-
# )
|
27
16
|
@triton.jit
|
28
17
|
def _state_passing_fwd_kernel(
|
29
18
|
# Pointers to matrices
|
@@ -85,7 +85,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
85
85
|
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
86
86
|
get_attention_tp_size()
|
87
87
|
)
|
88
|
-
if model_runner.
|
88
|
+
if model_runner.hybrid_gdn_config is not None:
|
89
89
|
# For hybrid linear models, layer_id = 0 may not be full attention
|
90
90
|
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
|
91
91
|
else:
|