sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,211 @@
1
+ # Copyright 2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ # Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/vllm/model_executor/layers/mamba/mamba2_metadata.py
15
+
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+
20
+ import torch
21
+
22
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
23
+
24
+
25
+ @dataclass(kw_only=True)
26
+ class ForwardMetadata:
27
+ query_start_loc: torch.Tensor
28
+ mamba_cache_indices: torch.Tensor
29
+
30
+
31
+ @dataclass(kw_only=True)
32
+ class Mamba2Metadata(ForwardMetadata):
33
+ """stable metadata across all mamba2 layers in the forward pass"""
34
+
35
+ num_prefills: int
36
+ num_prefill_tokens: int
37
+ num_decodes: int
38
+
39
+ @dataclass(kw_only=True, frozen=True)
40
+ class MixedMetadata:
41
+ has_initial_states: torch.Tensor
42
+ prep_initial_states: bool
43
+
44
+ chunk_size: int
45
+ seq_idx: torch.Tensor
46
+ chunk_indices: torch.Tensor
47
+ chunk_offsets: torch.Tensor
48
+
49
+ extend_seq_lens_cpu: list[int]
50
+
51
+ mixed_metadata: MixedMetadata | None = None
52
+ """`mixed_metadata` is used for extend/mixed requests"""
53
+
54
+ @staticmethod
55
+ def _query_start_loc_to_chunk_indices_offsets(
56
+ query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int
57
+ ) -> tuple[torch.Tensor, torch.Tensor]:
58
+ """
59
+ Args:
60
+ query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
61
+ lengths, shape (num_seqs + 1,).
62
+ The first element should be 0. Each entry represents the starting
63
+ index of a sequence in the flattened token array.
64
+ chunk_size (int): The size of each physical mamba chunk
65
+ (number of tokens per chunk).
66
+ total_seqlens (int): The total number of tokens in the batch.
67
+
68
+ Returns:
69
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
70
+ - chunk_indices (torch.Tensor): 1D tensor of indices
71
+ indicating the physical chunk for each logical chunk.
72
+ - chunk_offsets (torch.Tensor): 1D tensor of offsets
73
+ indicating the starting index of each logical chunk within
74
+ its physical chunk.
75
+
76
+ This function computes the chunk indices and offsets for the given
77
+ query_start_loc and chunk_size. Both are tensors of integers with length N,
78
+ where N is the number of logical (pseudo) chunks.
79
+ A logical chunk is a sequence of tokens that are all part of the same
80
+ sequence and are all in the same physical mamba chunk.
81
+ In other words, a logical chunk changes every time we cross a sequence
82
+ boundary or a physical mamba chunk boundary.
83
+ Logical chunks are needed to handle batched requests with initial states
84
+ (see _state_passing_fwd and _chunk_scan_fwd).
85
+ The chunk_indices tensor contains the index of the physical chunk for each
86
+ logical chunk.
87
+ The chunk_offsets tensor contains the offset (AKA starting index) of the
88
+ logical chunk in the physical chunk.
89
+
90
+ Example:
91
+ query_start_loc = [0, 5, 10]
92
+ chunk_size = 8
93
+ total_seqlens = 10
94
+ -> chunk_indices = [0, 0, 1]
95
+ -> chunk_offsets = [0, 5, 0]
96
+
97
+ In this example, we have 2 sequences, each with 5 tokens. The physical
98
+ chunk size is 8 tokens.
99
+ We have three logical chunks:
100
+ - the first logical chunk starts at token 0 in the first physical chunk
101
+ and contains all 5 tokens from the first sequence
102
+ - the second logical chunk starts at token 5 in the first physical chunk
103
+ and contains first 3 tokens from the second sequence
104
+ - the third logical chunk starts at token 0 in the second physical chunk
105
+ and contains the remaining 2 tokens from the second sequence
106
+ """
107
+
108
+ cu_seqlens = query_start_loc[1:] # remove prepended 0
109
+
110
+ # outputs will have length expansion of chunks that do not divide
111
+ # chunk_size
112
+ N = (
113
+ math.ceil(total_seqlens / chunk_size)
114
+ + (cu_seqlens[:-1] % chunk_size > 0).sum()
115
+ )
116
+ chunk_indices = torch.arange(N, dtype=torch.int, device=query_start_loc.device)
117
+ chunk_offsets = torch.zeros(
118
+ (N,), dtype=torch.int, device=query_start_loc.device
119
+ )
120
+
121
+ p = 0 # num of insertions
122
+ for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
123
+
124
+ # if does not divide chunk_size, then there is one chunk insertion
125
+ p += s % chunk_size > 0
126
+
127
+ # get the dimensions
128
+ # - the + 1 for _e is to shift the boundary by one chunk
129
+ # - this shifting is not needed if chunk_size divides e
130
+ _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0)
131
+
132
+ # adjust indices and offsets
133
+ chunk_indices[_s:_e] -= p
134
+ chunk_offsets[_s] = s % chunk_size
135
+
136
+ return chunk_indices, chunk_offsets
137
+
138
+ @staticmethod
139
+ def prepare_decode(
140
+ query_start_loc: torch.Tensor,
141
+ mamba_cache_indices: torch.Tensor,
142
+ seq_lens: torch.Tensor,
143
+ ) -> "Mamba2Metadata":
144
+ """This path is run during CUDA graph capture, i.e. decode only, so `num_prefills` is 0"""
145
+ return Mamba2Metadata(
146
+ query_start_loc=query_start_loc,
147
+ mamba_cache_indices=mamba_cache_indices,
148
+ num_decodes=len(seq_lens),
149
+ num_prefills=0,
150
+ num_prefill_tokens=0,
151
+ )
152
+
153
+ @classmethod
154
+ def prepare_mixed(
155
+ cls,
156
+ query_start_loc: torch.Tensor,
157
+ mamba_cache_indices: torch.Tensor,
158
+ chunk_size: int,
159
+ forward_batch: ForwardBatch,
160
+ ) -> "Mamba2Metadata":
161
+ """This path cannot run with CUDA graph, as it contains extend requests."""
162
+ if forward_batch.extend_num_tokens is None:
163
+ return cls.prepare_decode(
164
+ query_start_loc, mamba_cache_indices, forward_batch.seq_lens
165
+ )
166
+ num_prefills = len(forward_batch.extend_seq_lens)
167
+ num_prefill_tokens = forward_batch.extend_num_tokens
168
+ num_decodes = len(forward_batch.seq_lens) - num_prefills
169
+ context_lens_tensor = forward_batch.extend_prefix_lens
170
+ assert context_lens_tensor is not None
171
+ # precompute flag to avoid device syncs later
172
+ has_initial_states = context_lens_tensor > 0
173
+ prep_initial_states = torch.any(has_initial_states[:num_prefills]).item()
174
+
175
+ query_start_loc = query_start_loc[: num_prefills + 1]
176
+ seq_idx = torch.repeat_interleave(
177
+ torch.arange(
178
+ num_prefills, dtype=torch.int32, device=query_start_loc.device
179
+ ),
180
+ query_start_loc.diff(),
181
+ output_size=num_prefill_tokens,
182
+ )
183
+ seq_idx.unsqueeze_(0)
184
+
185
+ # We compute metadata for chunked prefill once at the top level model
186
+ # forward and reuse them in mamba layers. If not needed, they will be
187
+ # ignored inside mamba kernels.
188
+ chunk_offsets, chunk_indices = None, None
189
+ if prep_initial_states:
190
+ chunk_indices, chunk_offsets = (
191
+ cls._query_start_loc_to_chunk_indices_offsets(
192
+ query_start_loc, chunk_size, num_prefill_tokens
193
+ )
194
+ )
195
+
196
+ return Mamba2Metadata(
197
+ query_start_loc=query_start_loc,
198
+ mamba_cache_indices=mamba_cache_indices,
199
+ num_prefills=num_prefills,
200
+ num_prefill_tokens=num_prefill_tokens,
201
+ num_decodes=num_decodes,
202
+ mixed_metadata=cls.MixedMetadata(
203
+ has_initial_states=has_initial_states,
204
+ prep_initial_states=prep_initial_states,
205
+ chunk_size=chunk_size,
206
+ seq_idx=seq_idx,
207
+ chunk_indices=chunk_indices,
208
+ chunk_offsets=chunk_offsets,
209
+ extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
210
+ ),
211
+ )
@@ -0,0 +1,120 @@
1
+ from typing import Union
2
+
3
+ import torch
4
+
5
+ from sglang.srt.custom_op import CustomOp
6
+ from sglang.srt.distributed.communication_op import (
7
+ tensor_model_parallel_all_gather,
8
+ tensor_model_parallel_all_reduce,
9
+ )
10
+ from sglang.srt.distributed.parallel_state import (
11
+ get_tensor_model_parallel_rank,
12
+ get_tensor_model_parallel_world_size,
13
+ )
14
+ from sglang.srt.layers.attention.fla.layernorm_gated import rms_norm_gated
15
+ from sglang.srt.model_loader.weight_utils import sharded_weight_loader
16
+ from sglang.srt.utils.common import set_weight_attrs
17
+
18
+
19
+ class Mixer2RMSNormGated(CustomOp):
20
+ def __init__(
21
+ self,
22
+ full_hidden_size: int,
23
+ full_n_groups: int,
24
+ use_rms_norm: bool = True,
25
+ eps: float = 1e-6,
26
+ ):
27
+ super().__init__()
28
+ self.tp_size = get_tensor_model_parallel_world_size()
29
+ self.tp_rank = get_tensor_model_parallel_rank()
30
+ self.full_hidden_size = full_hidden_size
31
+ self.group_size = full_hidden_size // full_n_groups
32
+ self.per_rank_hidden_size = full_hidden_size // self.tp_size
33
+ self.n_groups = full_hidden_size // self.group_size
34
+
35
+ self.variance_epsilon = eps
36
+ self.use_rms_norm = use_rms_norm
37
+ if self.use_rms_norm:
38
+ # Register norm weight only if we're actually applying RMSNorm
39
+ self.weight = torch.nn.Parameter(torch.ones(self.per_rank_hidden_size))
40
+ set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
41
+ else:
42
+ # Avoid checkpoint mismatch by skipping unused parameter
43
+ self.register_parameter("weight", None)
44
+ assert (
45
+ self.full_hidden_size % self.tp_size == 0
46
+ ), "Tensor parallel world size must divide hidden size."
47
+
48
+ def forward_native(
49
+ self,
50
+ x: torch.Tensor,
51
+ gate: torch.Tensor,
52
+ ):
53
+ # Three tensor-parallel cases:
54
+ # 1. n_groups is 1
55
+ # In this case we parallelize along the reduction dim.
56
+ # Each rank computes a local sum of squares followed by AllReduce
57
+ # 2. tp_size divides n_groups
58
+ # Each rank only reduces within its local group(s).
59
+ # No collective ops necessary.
60
+ # 3. The general case can be pretty complicated so we AllGather
61
+ # the input and then redundantly compute the RMSNorm.
62
+ input_dtype = x.dtype
63
+ x = x * torch.nn.functional.silu(gate.to(torch.float32))
64
+ if not self.use_rms_norm:
65
+ return x.to(input_dtype)
66
+
67
+ if self.n_groups == 1:
68
+ if self.tp_size > 1:
69
+ # Compute local sum and then reduce to obtain global sum
70
+ local_sums = x.pow(2).sum(dim=-1, keepdim=True)
71
+ global_sums = tensor_model_parallel_all_reduce(local_sums)
72
+ # Calculate the variance
73
+ count = self.tp_size * x.shape[-1]
74
+ variance = global_sums / count
75
+
76
+ else:
77
+ variance = x.pow(2).mean(-1, keepdim=True)
78
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
79
+ else:
80
+ redundant_tp: bool = self.n_groups % self.tp_size != 0
81
+ if redundant_tp:
82
+ # To handle the general case, redundantly apply the variance
83
+ x = tensor_model_parallel_all_gather(x, -1)
84
+
85
+ *prefix_dims, hidden_dim = x.shape
86
+ group_count = hidden_dim // self.group_size
87
+ x_grouped = x.view(*prefix_dims, group_count, self.group_size)
88
+ variance = x_grouped.pow(2).mean(-1, keepdim=True)
89
+ x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
90
+ x = x_grouped.view(*prefix_dims, hidden_dim)
91
+
92
+ if redundant_tp:
93
+ start = self.per_rank_hidden_size * self.tp_rank
94
+ end = start + self.per_rank_hidden_size
95
+ x = x[..., start:end]
96
+
97
+ return self.weight * x.to(input_dtype)
98
+
99
+ def forward_cuda(
100
+ self,
101
+ x: torch.Tensor,
102
+ gate: torch.Tensor,
103
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
104
+ input_dtype = x.dtype
105
+ if not self.use_rms_norm:
106
+ # Keep gate in float32 for numerical stability during silu
107
+ return x * torch.nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
108
+
109
+ if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
110
+ return self.forward_native(x, gate)
111
+
112
+ return rms_norm_gated(
113
+ x=x,
114
+ weight=self.weight.data,
115
+ bias=None,
116
+ z=gate,
117
+ eps=self.variance_epsilon,
118
+ norm_before_gate=False,
119
+ is_rms_norm=True,
120
+ )
@@ -15,56 +15,6 @@ import triton
15
15
  import triton.language as tl
16
16
 
17
17
 
18
- # @triton.autotune(
19
- # configs=[
20
- # triton.Config(
21
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
22
- # num_stages=3,
23
- # num_warps=8,
24
- # ),
25
- # triton.Config(
26
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
27
- # num_stages=4,
28
- # num_warps=4,
29
- # ),
30
- # triton.Config(
31
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
32
- # num_stages=4,
33
- # num_warps=4,
34
- # ),
35
- # triton.Config(
36
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
37
- # num_stages=4,
38
- # num_warps=4,
39
- # ),
40
- # triton.Config(
41
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
42
- # num_stages=4,
43
- # num_warps=4,
44
- # ),
45
- # triton.Config(
46
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
47
- # num_stages=4,
48
- # num_warps=4,
49
- # ),
50
- # triton.Config(
51
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
52
- # num_stages=5,
53
- # num_warps=2,
54
- # ),
55
- # triton.Config(
56
- # {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
57
- # num_stages=5,
58
- # num_warps=2,
59
- # ),
60
- # triton.Config(
61
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
62
- # num_stages=4,
63
- # num_warps=2,
64
- # ),
65
- # ],
66
- # key=["chunk_size", "K", "IS_CAUSAL"],
67
- # )
68
18
  @triton.jit
69
19
  def _bmm_chunk_fwd_kernel(
70
20
  # Pointers to matrices
@@ -16,66 +16,6 @@ from packaging import version
16
16
  TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
17
17
 
18
18
 
19
- # @triton.autotune(
20
- # configs=[
21
- # triton.Config(
22
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
23
- # num_stages=3,
24
- # num_warps=8,
25
- # ),
26
- # triton.Config(
27
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
28
- # num_stages=4,
29
- # num_warps=4,
30
- # ),
31
- # triton.Config(
32
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
33
- # num_stages=4,
34
- # num_warps=4,
35
- # ),
36
- # triton.Config(
37
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
38
- # num_stages=4,
39
- # num_warps=4,
40
- # ),
41
- # triton.Config(
42
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
43
- # num_stages=4,
44
- # num_warps=4,
45
- # ),
46
- # triton.Config(
47
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
48
- # num_stages=4,
49
- # num_warps=4,
50
- # ),
51
- # triton.Config(
52
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
53
- # num_stages=4,
54
- # num_warps=4,
55
- # ),
56
- # triton.Config(
57
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
58
- # num_stages=4,
59
- # num_warps=4,
60
- # ),
61
- # triton.Config(
62
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
63
- # num_stages=5,
64
- # num_warps=2,
65
- # ),
66
- # triton.Config(
67
- # {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
68
- # num_stages=5,
69
- # num_warps=2,
70
- # ),
71
- # triton.Config(
72
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
73
- # num_stages=4,
74
- # num_warps=2,
75
- # ),
76
- # ],
77
- # key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
78
- # )
79
19
  @triton.jit
80
20
  def _chunk_scan_fwd_kernel(
81
21
  # Pointers to matrices
@@ -17,17 +17,6 @@ import triton.language as tl
17
17
  from .mamba_ssm import softplus
18
18
 
19
19
 
20
- # @triton.autotune(
21
- # configs=[
22
- # triton.Config({"BLOCK_SIZE_H": 2}),
23
- # triton.Config({"BLOCK_SIZE_H": 4}),
24
- # triton.Config({"BLOCK_SIZE_H": 8}),
25
- # triton.Config({"BLOCK_SIZE_H": 16}),
26
- # triton.Config({"BLOCK_SIZE_H": 32}),
27
- # triton.Config({"BLOCK_SIZE_H": 64}),
28
- # ],
29
- # key=["chunk_size", "nheads"],
30
- # )
31
20
  @triton.jit
32
21
  def _chunk_cumsum_fwd_kernel(
33
22
  # Pointers to matrices
@@ -120,56 +109,6 @@ def _chunk_cumsum_fwd_kernel(
120
109
  )
121
110
 
122
111
 
123
- # @triton.autotune(
124
- # configs=[
125
- # triton.Config(
126
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
127
- # num_stages=3,
128
- # num_warps=8,
129
- # ),
130
- # triton.Config(
131
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
132
- # num_stages=4,
133
- # num_warps=4,
134
- # ),
135
- # triton.Config(
136
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
137
- # num_stages=4,
138
- # num_warps=4,
139
- # ),
140
- # triton.Config(
141
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
142
- # num_stages=4,
143
- # num_warps=4,
144
- # ),
145
- # triton.Config(
146
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
147
- # num_stages=4,
148
- # num_warps=4,
149
- # ),
150
- # triton.Config(
151
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
152
- # num_stages=4,
153
- # num_warps=4,
154
- # ),
155
- # triton.Config(
156
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
157
- # num_stages=5,
158
- # num_warps=2,
159
- # ),
160
- # triton.Config(
161
- # {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
162
- # num_stages=5,
163
- # num_warps=2,
164
- # ),
165
- # triton.Config(
166
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
167
- # num_stages=4,
168
- # num_warps=2,
169
- # ),
170
- # ],
171
- # key=["hdim", "dstate", "chunk_size"],
172
- # )
173
112
  @triton.jit
174
113
  def _chunk_state_fwd_kernel(
175
114
  # Pointers to matrices
@@ -320,56 +259,6 @@ def _chunk_state_fwd_kernel(
320
259
  tl.store(states_ptrs, states, mask=c_mask)
321
260
 
322
261
 
323
- # @triton.autotune(
324
- # configs=[
325
- # triton.Config(
326
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
327
- # num_stages=3,
328
- # num_warps=8,
329
- # ),
330
- # triton.Config(
331
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
332
- # num_stages=4,
333
- # num_warps=4,
334
- # ),
335
- # triton.Config(
336
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
337
- # num_stages=4,
338
- # num_warps=4,
339
- # ),
340
- # triton.Config(
341
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
342
- # num_stages=4,
343
- # num_warps=4,
344
- # ),
345
- # triton.Config(
346
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
347
- # num_stages=4,
348
- # num_warps=4,
349
- # ),
350
- # triton.Config(
351
- # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
352
- # num_stages=4,
353
- # num_warps=4,
354
- # ),
355
- # triton.Config(
356
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
357
- # num_stages=5,
358
- # num_warps=2,
359
- # ),
360
- # triton.Config(
361
- # {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
362
- # num_stages=5,
363
- # num_warps=2,
364
- # ),
365
- # triton.Config(
366
- # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
367
- # num_stages=4,
368
- # num_warps=2,
369
- # ),
370
- # ],
371
- # key=["hdim", "dstate", "chunk_size"],
372
- # )
373
262
  @triton.jit
374
263
  def _chunk_state_varlen_kernel(
375
264
  # Pointers to matrices
@@ -13,17 +13,6 @@ import triton
13
13
  import triton.language as tl
14
14
 
15
15
 
16
- # @triton.autotune(
17
- # configs=[
18
- # triton.Config({"BLOCK_SIZE": 64}),
19
- # triton.Config({"BLOCK_SIZE": 128}),
20
- # triton.Config({"BLOCK_SIZE": 256}),
21
- # triton.Config({"BLOCK_SIZE": 512}),
22
- # triton.Config({"BLOCK_SIZE": 1024}),
23
- # triton.Config({"BLOCK_SIZE": 2048}),
24
- # ],
25
- # key=["dim"],
26
- # )
27
16
  @triton.jit
28
17
  def _state_passing_fwd_kernel(
29
18
  # Pointers to matrices
@@ -85,7 +85,7 @@ class TritonAttnBackend(AttentionBackend):
85
85
  self.num_kv_head = model_runner.model_config.get_num_kv_heads(
86
86
  get_attention_tp_size()
87
87
  )
88
- if model_runner.is_hybrid_gdn:
88
+ if model_runner.hybrid_gdn_config is not None:
89
89
  # For hybrid linear models, layer_id = 0 may not be full attention
90
90
  self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
91
91
  else: