sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/bench_serving.py +56 -12
  2. sglang/launch_server.py +2 -0
  3. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
  4. sglang/srt/compilation/backend.py +1 -1
  5. sglang/srt/configs/model_config.py +5 -5
  6. sglang/srt/distributed/parallel_state.py +0 -7
  7. sglang/srt/entrypoints/engine.py +18 -15
  8. sglang/srt/entrypoints/grpc_server.py +0 -1
  9. sglang/srt/entrypoints/http_server.py +75 -94
  10. sglang/srt/environ.py +16 -2
  11. sglang/srt/eplb/expert_distribution.py +30 -0
  12. sglang/srt/function_call/function_call_parser.py +2 -0
  13. sglang/srt/function_call/minimax_m2.py +367 -0
  14. sglang/srt/layers/activation.py +6 -0
  15. sglang/srt/layers/attention/flashattention_backend.py +12 -2
  16. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  17. sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
  18. sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
  19. sglang/srt/layers/attention/utils.py +78 -0
  20. sglang/srt/layers/communicator.py +1 -0
  21. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  22. sglang/srt/layers/layernorm.py +19 -4
  23. sglang/srt/layers/logits_processor.py +5 -0
  24. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  25. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  26. sglang/srt/layers/moe/ep_moe/layer.py +79 -272
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  29. sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
  30. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  31. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  32. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  33. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  34. sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
  35. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  36. sglang/srt/layers/moe/topk.py +4 -4
  37. sglang/srt/layers/moe/utils.py +3 -4
  38. sglang/srt/layers/quantization/__init__.py +3 -5
  39. sglang/srt/layers/quantization/awq.py +0 -3
  40. sglang/srt/layers/quantization/base_config.py +7 -0
  41. sglang/srt/layers/quantization/fp8.py +68 -63
  42. sglang/srt/layers/quantization/gguf.py +566 -0
  43. sglang/srt/layers/quantization/mxfp4.py +30 -38
  44. sglang/srt/layers/quantization/unquant.py +23 -45
  45. sglang/srt/layers/quantization/w4afp8.py +38 -2
  46. sglang/srt/layers/radix_attention.py +5 -2
  47. sglang/srt/layers/rotary_embedding.py +13 -1
  48. sglang/srt/layers/sampler.py +12 -1
  49. sglang/srt/managers/io_struct.py +3 -0
  50. sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
  51. sglang/srt/managers/scheduler.py +21 -15
  52. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  53. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  54. sglang/srt/managers/tokenizer_manager.py +11 -19
  55. sglang/srt/mem_cache/hicache_storage.py +7 -1
  56. sglang/srt/mem_cache/memory_pool.py +82 -0
  57. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  58. sglang/srt/model_executor/forward_batch_info.py +44 -3
  59. sglang/srt/model_executor/model_runner.py +1 -149
  60. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  61. sglang/srt/models/deepseek_v2.py +147 -44
  62. sglang/srt/models/glm4_moe.py +322 -354
  63. sglang/srt/models/glm4_moe_nextn.py +4 -14
  64. sglang/srt/models/glm4v_moe.py +29 -196
  65. sglang/srt/models/minimax_m2.py +922 -0
  66. sglang/srt/models/nvila.py +355 -0
  67. sglang/srt/models/nvila_lite.py +184 -0
  68. sglang/srt/models/qwen2.py +22 -1
  69. sglang/srt/models/qwen3.py +34 -4
  70. sglang/srt/models/qwen3_moe.py +2 -4
  71. sglang/srt/multimodal/processors/base_processor.py +1 -0
  72. sglang/srt/multimodal/processors/glm4v.py +1 -1
  73. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  74. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  75. sglang/srt/parser/reasoning_parser.py +28 -1
  76. sglang/srt/server_args.py +365 -186
  77. sglang/srt/single_batch_overlap.py +2 -7
  78. sglang/srt/utils/common.py +87 -42
  79. sglang/srt/utils/hf_transformers_utils.py +7 -3
  80. sglang/test/test_deterministic.py +235 -12
  81. sglang/test/test_deterministic_utils.py +2 -1
  82. sglang/version.py +1 -1
  83. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
  84. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
  85. sglang/srt/models/vila.py +0 -306
  86. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  87. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  88. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,922 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ # Adapted from DeepSeek and Mixtral implementation
16
+ """Inference-only MiniMax M2 model compatible with HuggingFace weights."""
17
+
18
+ import logging
19
+ from typing import Iterable, Optional, Set, Tuple, Union
20
+
21
+ import torch
22
+ from torch import nn
23
+ from transformers import PretrainedConfig
24
+
25
+ from sglang.srt.distributed import (
26
+ get_moe_expert_parallel_world_size,
27
+ get_pp_group,
28
+ get_tensor_model_parallel_rank,
29
+ get_tensor_model_parallel_world_size,
30
+ tensor_model_parallel_all_reduce,
31
+ )
32
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
33
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
34
+ from sglang.srt.layers.activation import SiluAndMul
35
+ from sglang.srt.layers.communicator import (
36
+ LayerCommunicator,
37
+ LayerScatterModes,
38
+ ScatterMode,
39
+ )
40
+ from sglang.srt.layers.layernorm import RMSNorm
41
+ from sglang.srt.layers.linear import (
42
+ MergedColumnParallelLinear,
43
+ QKVParallelLinear,
44
+ ReplicatedLinear,
45
+ RowParallelLinear,
46
+ )
47
+ from sglang.srt.layers.logits_processor import LogitsProcessor
48
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
49
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
50
+ from sglang.srt.layers.moe.topk import TopK
51
+ from sglang.srt.layers.moe.utils import get_moe_a2a_backend
52
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
53
+ from sglang.srt.layers.radix_attention import RadixAttention
54
+ from sglang.srt.layers.rotary_embedding import get_rope
55
+ from sglang.srt.layers.utils import PPMissingLayer
56
+ from sglang.srt.layers.vocab_parallel_embedding import (
57
+ ParallelLMHead,
58
+ VocabParallelEmbedding,
59
+ )
60
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
61
+ from sglang.srt.model_loader.weight_utils import (
62
+ default_weight_loader,
63
+ maybe_remap_kv_scale_name,
64
+ )
65
+ from sglang.srt.server_args import get_global_server_args
66
+ from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
67
+ from sglang.srt.utils import (
68
+ BumpAllocator,
69
+ add_prefix,
70
+ get_compiler_backend,
71
+ is_non_idle_and_non_empty,
72
+ make_layers,
73
+ )
74
+
75
+ logger = logging.getLogger(__name__)
76
+
77
+
78
+ class MiniMaxM2RMSNormTP(nn.Module):
79
+ """RMSNorm with Tensor Parallel support for QK normalization."""
80
+
81
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
82
+ super().__init__()
83
+ self.tp_world = get_tensor_model_parallel_world_size()
84
+ self.tp_rank = get_tensor_model_parallel_rank()
85
+
86
+ # Weight parameter is sharded across TP ranks
87
+ self.weight = nn.Parameter(torch.ones(int(hidden_size / self.tp_world)))
88
+ self.weight.weight_loader = self.weight_loader
89
+ self.variance_epsilon = eps
90
+
91
+ @staticmethod
92
+ def weight_loader(
93
+ param: nn.Parameter,
94
+ loaded_weight: torch.Tensor,
95
+ ) -> None:
96
+ """Custom weight loader that handles TP sharding."""
97
+ tp_world = get_tensor_model_parallel_world_size()
98
+ tp_rank = get_tensor_model_parallel_rank()
99
+
100
+ shard_size = loaded_weight.shape[0] // tp_world
101
+ shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
102
+ param.data.copy_(loaded_weight[shard])
103
+
104
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
105
+ def forward(
106
+ self,
107
+ x: torch.Tensor,
108
+ residual: Optional[torch.Tensor] = None,
109
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
110
+ """Forward pass with TP-aware variance computation."""
111
+ assert residual is None, "RMSNormTP does not support residual connection."
112
+
113
+ orig_dtype = x.dtype
114
+ x = x.to(torch.float32)
115
+
116
+ # Compute variance across the full dimension (not just local shard)
117
+ variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
118
+
119
+ if self.tp_world > 1:
120
+ # All-reduce variance across TP ranks to get global variance
121
+ variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
122
+
123
+ # Normalize and apply local weight shard
124
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
125
+ x = x.to(orig_dtype) * self.weight
126
+
127
+ return x
128
+
129
+
130
+ class MiniMaxM2MLP(nn.Module):
131
+ def __init__(
132
+ self,
133
+ hidden_size: int,
134
+ intermediate_size: int,
135
+ quant_config: Optional[QuantizationConfig] = None,
136
+ prefix: str = "mlp",
137
+ ) -> None:
138
+ super().__init__()
139
+
140
+ self.gate_up_proj = MergedColumnParallelLinear(
141
+ hidden_size,
142
+ [intermediate_size] * 2,
143
+ bias=False,
144
+ quant_config=quant_config,
145
+ prefix=add_prefix("gate_up_proj", prefix),
146
+ )
147
+ self.down_proj = RowParallelLinear(
148
+ intermediate_size,
149
+ hidden_size,
150
+ bias=False,
151
+ quant_config=quant_config,
152
+ prefix=add_prefix("down_proj", prefix),
153
+ )
154
+ self.act_fn = SiluAndMul()
155
+ return
156
+
157
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
158
+ gate_up, _ = self.gate_up_proj(x)
159
+ x = self.act_fn(gate_up)
160
+ x, _ = self.down_proj(x)
161
+ return x
162
+
163
+
164
+ class MiniMaxM2MoE(nn.Module):
165
+ """MiniMax MoE implementation using DeepEP for Expert Parallel support."""
166
+
167
+ def __init__(
168
+ self,
169
+ config: PretrainedConfig,
170
+ layer_id: int,
171
+ quant_config: Optional[QuantizationConfig] = None,
172
+ prefix: str = "",
173
+ ):
174
+ super().__init__()
175
+ self.tp_size = get_tensor_model_parallel_world_size()
176
+ if self.tp_size > config.num_local_experts:
177
+ raise ValueError(
178
+ f"Tensor parallel size {self.tp_size} is greater than "
179
+ f"the number of experts {config.num_local_experts}."
180
+ )
181
+ self.use_routing_bias = getattr(config, "use_routing_bias", False)
182
+ if self.use_routing_bias:
183
+ self.e_score_correction_bias = nn.Parameter(
184
+ torch.empty(config.num_local_experts, dtype=torch.float32)
185
+ )
186
+ self.e_score_correction_bias.weight_loader = (
187
+ MiniMaxM2MoE.ebias_weight_loader
188
+ )
189
+ else:
190
+ self.e_score_correction_bias = None
191
+
192
+ self.experts = get_moe_impl_class(quant_config)(
193
+ num_experts=config.num_local_experts
194
+ + get_global_server_args().ep_num_redundant_experts,
195
+ top_k=config.num_experts_per_tok,
196
+ hidden_size=config.hidden_size,
197
+ intermediate_size=config.intermediate_size,
198
+ layer_id=layer_id,
199
+ quant_config=quant_config,
200
+ prefix=add_prefix("experts", prefix),
201
+ )
202
+ self.topk = TopK(
203
+ top_k=config.num_experts_per_tok,
204
+ renormalize=True,
205
+ scoring_func=config.scoring_func,
206
+ use_grouped_topk=True, # TODO: Use "grouped top-k" flag only for hardcoded sigmoid scoring
207
+ num_expert_group=1,
208
+ topk_group=1,
209
+ correction_bias=self.e_score_correction_bias,
210
+ routed_scaling_factor=1.0,
211
+ )
212
+
213
+ self.gate = ReplicatedLinear(
214
+ config.hidden_size,
215
+ config.num_local_experts,
216
+ bias=False,
217
+ params_dtype=torch.float32,
218
+ quant_config=None,
219
+ prefix=add_prefix("gate", prefix),
220
+ )
221
+
222
+ self.layer_id = layer_id
223
+
224
+ if get_moe_a2a_backend().is_deepep():
225
+ self.ep_size = get_moe_expert_parallel_world_size()
226
+ self.top_k = config.num_experts_per_tok
227
+
228
+ @staticmethod
229
+ def ebias_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None:
230
+ assert param.size() == loaded_weight.size()
231
+ param.data.copy_(loaded_weight.to(torch.float32))
232
+
233
+ def forward(
234
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
235
+ ) -> torch.Tensor:
236
+ if get_moe_a2a_backend().is_deepep():
237
+ return self.forward_deepep(hidden_states, forward_batch)
238
+ else:
239
+ return self.forward_normal(hidden_states)
240
+
241
+ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
242
+ num_tokens, hidden_dim = hidden_states.shape
243
+ hidden_states = hidden_states.view(-1, hidden_dim)
244
+
245
+ # router_logits: (num_tokens, n_experts)
246
+ router_logits, _ = self.gate(hidden_states.to(torch.float32))
247
+ topk_output = self.topk(hidden_states, router_logits)
248
+
249
+ final_hidden_states = self.experts(hidden_states, topk_output)
250
+ if self.tp_size > 1:
251
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
252
+
253
+ return final_hidden_states.view(num_tokens, hidden_dim)
254
+
255
+ def forward_deepep(
256
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
257
+ ) -> torch.Tensor:
258
+ if hidden_states.shape[0] > 0:
259
+ # router_logits: (num_tokens, n_experts)
260
+ router_logits, _ = self.gate(hidden_states.to(torch.float32))
261
+ topk_weights, topk_idx, _ = self.topk(
262
+ hidden_states,
263
+ router_logits,
264
+ num_token_non_padded=forward_batch.num_token_non_padded,
265
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
266
+ layer_id=self.layer_id,
267
+ ),
268
+ )
269
+ else:
270
+ topk_weights, topk_idx, _ = self.topk.empty_topk_output(
271
+ hidden_states.shape[0], self.top_k
272
+ )
273
+ final_hidden_states = self.experts(
274
+ hidden_states=hidden_states,
275
+ topk_idx=topk_idx,
276
+ topk_weights=topk_weights,
277
+ forward_batch=forward_batch,
278
+ )
279
+
280
+ return final_hidden_states
281
+
282
+ # TBO Operations for MiniMax MoE
283
+ def op_gate(self, state):
284
+ """Gate operation for TBO - compute router logits"""
285
+ if is_non_idle_and_non_empty(
286
+ state.forward_batch.forward_mode, state.hidden_states_mlp_input
287
+ ): # router_logits: (num_tokens, num_experts)
288
+ state.router_logits, _ = self.gate(state.hidden_states_mlp_input)
289
+ else:
290
+ state.router_logits = None
291
+
292
+ def op_select_experts(self, state):
293
+ """Expert selection operation for TBO"""
294
+ router_logits = state.pop("router_logits")
295
+ hidden_states = state.hidden_states_mlp_input
296
+
297
+ if router_logits is not None:
298
+ with get_global_expert_distribution_recorder().with_current_layer(
299
+ self.layer_id
300
+ ):
301
+ state.topk_weights_local, state.topk_idx_local, _ = self.topk(
302
+ hidden_states=hidden_states,
303
+ router_logits=router_logits,
304
+ num_token_non_padded=state.forward_batch.num_token_non_padded,
305
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
306
+ layer_id=self.layer_id,
307
+ ),
308
+ )
309
+ else:
310
+ state.topk_idx_local = torch.full(
311
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
312
+ )
313
+ state.topk_weights_local = torch.empty(
314
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
315
+ )
316
+
317
+ def op_dispatch_a(self, state):
318
+ """Dispatch A operation for TBO - start async dispatch"""
319
+ if self.ep_size > 1:
320
+ self.experts.deepep_dispatcher.dispatch_a(
321
+ hidden_states=state.pop("hidden_states_mlp_input"),
322
+ topk_idx=state.pop("topk_idx_local"),
323
+ topk_weights=state.pop("topk_weights_local"),
324
+ forward_batch=state.forward_batch,
325
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
326
+ )
327
+
328
+ def op_dispatch_b(self, state):
329
+ """Dispatch B operation for TBO - complete async dispatch"""
330
+ if self.ep_size > 1:
331
+ with get_global_expert_distribution_recorder().with_current_layer(
332
+ self.layer_id
333
+ ):
334
+ state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
335
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
336
+ )
337
+
338
+ def op_experts(self, state):
339
+ """Expert computation for TBO"""
340
+ state.hidden_states_experts_output = self.experts.moe_impl(
341
+ dispatch_output=state.dispatch_output,
342
+ )
343
+
344
+ def op_combine_a(self, state):
345
+ """Combine A operation for TBO - start async combine"""
346
+ if self.ep_size > 1:
347
+ self.experts.deepep_dispatcher.combine_a(
348
+ hidden_states=state.pop("hidden_states_experts_output"),
349
+ topk_idx=state.dispatch_output.topk_idx,
350
+ topk_weights=state.dispatch_output.topk_weights,
351
+ forward_batch=state.forward_batch,
352
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
353
+ )
354
+ state.pop("dispatch_output")
355
+
356
+ def op_combine_b(self, state):
357
+ """Combine B operation for TBO - complete async combine"""
358
+ if self.ep_size > 1:
359
+ state.hidden_states_after_combine = (
360
+ self.experts.deepep_dispatcher.combine_b(
361
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
362
+ )
363
+ )
364
+
365
+ def op_output(self, state):
366
+ """Output operation for TBO - final MLP output"""
367
+ final_hidden_states = state.pop("hidden_states_after_combine")
368
+ # MiniMax doesn't have shared experts like DeepSeek, so no need to add them
369
+ state.hidden_states_mlp_output = final_hidden_states
370
+
371
+
372
+ class MiniMaxM2Attention(nn.Module):
373
+ """MiniMax Attention implementation with QK normalization and partial RoPE."""
374
+
375
+ def __init__(
376
+ self,
377
+ config: PretrainedConfig,
378
+ layer_id: int = 0,
379
+ quant_config: Optional[QuantizationConfig] = None,
380
+ prefix: str = "",
381
+ ) -> None:
382
+ super().__init__()
383
+ self.hidden_size = config.hidden_size
384
+ tp_size = get_tensor_model_parallel_world_size()
385
+
386
+ # Get dimensions from config
387
+ self.total_num_heads = config.num_attention_heads
388
+ assert self.total_num_heads % tp_size == 0
389
+ self.num_heads = self.total_num_heads // tp_size
390
+ self.total_num_kv_heads = config.num_key_value_heads
391
+
392
+ if self.total_num_kv_heads >= tp_size:
393
+ # Number of KV heads is greater than TP size, so we partition
394
+ # the KV heads across multiple tensor parallel GPUs.
395
+ assert self.total_num_kv_heads % tp_size == 0
396
+ else:
397
+ # Number of KV heads is less than TP size, so we replicate
398
+ # the KV heads across multiple tensor parallel GPUs.
399
+ assert tp_size % self.total_num_kv_heads == 0
400
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
401
+
402
+ # Use head_dim from config if available, otherwise calculate
403
+ self.head_dim = getattr(
404
+ config, "head_dim", self.hidden_size // self.total_num_heads
405
+ )
406
+ self.q_size = self.num_heads * self.head_dim
407
+ self.kv_size = self.num_kv_heads * self.head_dim
408
+ self.scaling = self.head_dim**-0.5
409
+
410
+ # RoPE settings - support partial RoPE
411
+ self.rope_theta = getattr(config, "rope_theta", 10000)
412
+ self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
413
+ self.rotary_dim = getattr(
414
+ config, "rotary_dim", self.head_dim
415
+ ) # MiniMax uses rotary_dim=64
416
+
417
+ # QK Normalization settings
418
+ self.use_qk_norm = getattr(config, "use_qk_norm", False)
419
+ self.qk_norm_type = getattr(config, "qk_norm_type", "per_layer")
420
+
421
+ self.qkv_proj = QKVParallelLinear(
422
+ self.hidden_size,
423
+ self.head_dim,
424
+ self.total_num_heads,
425
+ self.total_num_kv_heads,
426
+ bias=False,
427
+ quant_config=quant_config,
428
+ prefix=add_prefix("qkv_proj", prefix),
429
+ )
430
+
431
+ self.o_proj = RowParallelLinear(
432
+ self.total_num_heads * self.head_dim,
433
+ self.hidden_size,
434
+ bias=False,
435
+ reduce_results=False,
436
+ quant_config=quant_config,
437
+ prefix=add_prefix("o_proj", prefix),
438
+ )
439
+
440
+ # Setup RoPE with partial rotary dimension
441
+ rope_scaling = getattr(config, "rope_scaling", None)
442
+ self.rotary_emb = get_rope(
443
+ self.head_dim,
444
+ rotary_dim=self.rotary_dim, # Use partial rotary dimension
445
+ max_position=self.max_position_embeddings,
446
+ base=self.rope_theta,
447
+ rope_scaling=rope_scaling,
448
+ )
449
+
450
+ # QK Normalization layers
451
+ if self.use_qk_norm:
452
+ if self.qk_norm_type == "per_layer":
453
+ # Use RMSNormTP for proper tensor parallel support
454
+ # Use total dimensions (before TP sharding) for correct normalization
455
+ self.q_norm = MiniMaxM2RMSNormTP(
456
+ self.total_num_heads * self.head_dim, eps=config.rms_norm_eps
457
+ )
458
+ self.k_norm = MiniMaxM2RMSNormTP(
459
+ self.total_num_kv_heads * self.head_dim, eps=config.rms_norm_eps
460
+ )
461
+ else:
462
+ raise ValueError(f"Unsupported qk_norm_type: {self.qk_norm_type}")
463
+
464
+ self.attn = RadixAttention(
465
+ self.num_heads,
466
+ self.head_dim,
467
+ self.scaling,
468
+ num_kv_heads=self.num_kv_heads,
469
+ layer_id=layer_id,
470
+ quant_config=quant_config,
471
+ prefix=add_prefix("attn", prefix),
472
+ )
473
+
474
+ def forward_prepare(
475
+ self,
476
+ positions: torch.Tensor,
477
+ hidden_states: torch.Tensor,
478
+ forward_batch: ForwardBatch,
479
+ ):
480
+ qkv, _ = self.qkv_proj(hidden_states)
481
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
482
+ if self.use_qk_norm:
483
+ q = self.q_norm(q.contiguous())
484
+ k = self.k_norm(k.contiguous())
485
+ else:
486
+ q, k = q.contiguous(), k.contiguous()
487
+ q, k = self.rotary_emb(positions, q, k)
488
+ inner_state = q, k, v, forward_batch
489
+ return None, forward_batch, inner_state
490
+
491
+ def forward_core(self, intermediate_state):
492
+ _, _, inner_state = intermediate_state
493
+ attn_output = self.attn(*inner_state)
494
+ output, _ = self.o_proj(attn_output)
495
+ return output
496
+
497
+ def forward(
498
+ self,
499
+ positions: torch.Tensor,
500
+ hidden_states: torch.Tensor,
501
+ forward_batch: ForwardBatch,
502
+ ) -> torch.Tensor:
503
+ s = self.forward_prepare(
504
+ positions=positions,
505
+ hidden_states=hidden_states,
506
+ forward_batch=forward_batch,
507
+ )
508
+ return self.forward_core(s)
509
+
510
+ def op_prepare(self, state):
511
+ state.attn_intermediate_state = self.forward_prepare(
512
+ positions=state.positions,
513
+ hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
514
+ forward_batch=state.forward_batch,
515
+ )
516
+
517
+ def op_core(self, state):
518
+ state.hidden_states_after_attn = self.forward_core(
519
+ state.pop("attn_intermediate_state")
520
+ )
521
+
522
+
523
+ class MiniMaxM2DecoderLayer(nn.Module):
524
+ """MiniMax Decoder Layer implementation with MoE support."""
525
+
526
+ def __init__(
527
+ self,
528
+ config: PretrainedConfig,
529
+ layer_id: int,
530
+ quant_config: Optional[QuantizationConfig] = None,
531
+ prefix: str = "",
532
+ ) -> None:
533
+ super().__init__()
534
+ self.hidden_size = config.hidden_size
535
+ self.layer_id = layer_id
536
+
537
+ # TBO support: All MiniMax layers are sparse (MoE)
538
+ self.is_layer_sparse = True
539
+
540
+ self.self_attn = MiniMaxM2Attention(
541
+ config=config,
542
+ layer_id=layer_id,
543
+ quant_config=quant_config,
544
+ prefix=add_prefix("self_attn", prefix),
545
+ )
546
+
547
+ self.block_sparse_moe = MiniMaxM2MoE(
548
+ config=config,
549
+ layer_id=layer_id,
550
+ quant_config=quant_config,
551
+ prefix=add_prefix("mlp", prefix),
552
+ )
553
+
554
+ self.input_layernorm = RMSNorm(
555
+ config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-6)
556
+ )
557
+ self.post_attention_layernorm = RMSNorm(
558
+ config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-6)
559
+ )
560
+
561
+ is_previous_layer_sparse = True
562
+ self.layer_scatter_modes = LayerScatterModes.init_new(
563
+ layer_id=layer_id,
564
+ num_layers=config.num_hidden_layers,
565
+ is_layer_sparse=self.is_layer_sparse,
566
+ is_previous_layer_sparse=is_previous_layer_sparse,
567
+ )
568
+
569
+ self.layer_communicator = LayerCommunicator(
570
+ layer_scatter_modes=self.layer_scatter_modes,
571
+ input_layernorm=self.input_layernorm,
572
+ post_attention_layernorm=self.post_attention_layernorm,
573
+ allow_reduce_scatter=True,
574
+ )
575
+
576
+ def forward(
577
+ self,
578
+ positions: torch.Tensor,
579
+ hidden_states: torch.Tensor,
580
+ forward_batch: ForwardBatch,
581
+ residual: Optional[torch.Tensor],
582
+ ) -> torch.Tensor:
583
+ # Self Attention
584
+ hidden_states, residual = self.layer_communicator.prepare_attn(
585
+ hidden_states, residual, forward_batch
586
+ )
587
+
588
+ hidden_states = self.self_attn(
589
+ positions=positions,
590
+ hidden_states=hidden_states,
591
+ forward_batch=forward_batch,
592
+ )
593
+
594
+ # Fully Connected (MLP or MoE)
595
+
596
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
597
+ hidden_states, residual, forward_batch
598
+ )
599
+
600
+ hidden_states = self.block_sparse_moe(hidden_states, forward_batch)
601
+
602
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
603
+ hidden_states, residual, forward_batch
604
+ )
605
+
606
+ return hidden_states, residual
607
+
608
+ # TBO Operations for MiniMax Decoder Layer
609
+ def op_comm_prepare_attn(
610
+ self,
611
+ state,
612
+ positions: torch.Tensor,
613
+ hidden_states: torch.Tensor,
614
+ forward_batch: ForwardBatch,
615
+ residual: Optional[torch.Tensor],
616
+ zero_allocator: BumpAllocator,
617
+ tbo_subbatch_index: Optional[int] = None,
618
+ ):
619
+ """Communication prepare for attention - TBO operation"""
620
+ state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
621
+ self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
622
+ )
623
+ state.update(
624
+ dict(
625
+ forward_batch=forward_batch,
626
+ positions=positions,
627
+ zero_allocator=zero_allocator,
628
+ tbo_subbatch_index=tbo_subbatch_index,
629
+ )
630
+ )
631
+
632
+ def op_comm_prepare_mlp(self, state):
633
+ """Communication prepare for MLP - TBO operation"""
634
+ state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
635
+ self.layer_communicator.prepare_mlp(
636
+ state.pop("hidden_states_after_attn"),
637
+ state.pop("residual_after_input_ln"),
638
+ state.forward_batch,
639
+ )
640
+ )
641
+
642
+ def op_mlp(self, state):
643
+ hidden_states = state.pop("hidden_states_mlp_input")
644
+ state.hidden_states_mlp_output = self.block_sparse_moe(
645
+ hidden_states, state.forward_batch
646
+ )
647
+
648
+ def op_comm_postprocess_layer(self, state):
649
+ """Communication postprocess for layer - TBO operation"""
650
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
651
+ state.pop("hidden_states_mlp_output"),
652
+ state.pop("residual_after_comm_pre_mlp"),
653
+ state.forward_batch,
654
+ )
655
+
656
+ output = dict(
657
+ positions=state.positions,
658
+ hidden_states=hidden_states,
659
+ residual=residual,
660
+ forward_batch=state.forward_batch,
661
+ zero_allocator=state.zero_allocator,
662
+ tbo_subbatch_index=state.tbo_subbatch_index,
663
+ )
664
+ return output
665
+
666
+
667
+ class MiniMaxM2Model(nn.Module):
668
+ """MiniMax Model implementation."""
669
+
670
+ fall_back_to_pt_during_load = False
671
+
672
+ def __init__(
673
+ self,
674
+ config: PretrainedConfig,
675
+ quant_config: Optional[QuantizationConfig] = None,
676
+ prefix: str = "",
677
+ ) -> None:
678
+ super().__init__()
679
+
680
+ self.padding_idx = getattr(config, "pad_token_id", 0)
681
+ self.vocab_size = config.vocab_size
682
+ self.pp_group = get_pp_group()
683
+
684
+ self.embed_tokens = VocabParallelEmbedding(
685
+ config.vocab_size,
686
+ config.hidden_size,
687
+ )
688
+
689
+ def layer_fn(idx, prefix: str) -> nn.Module:
690
+ return MiniMaxM2DecoderLayer(
691
+ config=config,
692
+ layer_id=idx,
693
+ quant_config=quant_config,
694
+ prefix=prefix,
695
+ )
696
+
697
+ self.layers, self.start_layer, self.end_layer = make_layers(
698
+ config.num_hidden_layers,
699
+ layer_fn,
700
+ pp_rank=self.pp_group.rank_in_group,
701
+ pp_size=self.pp_group.world_size,
702
+ prefix=add_prefix("layers", prefix),
703
+ )
704
+ if self.pp_group.is_last_rank:
705
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
706
+ else:
707
+ self.norm = PPMissingLayer(return_tuple=True)
708
+
709
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
710
+ return self.embed_tokens(input_ids)
711
+
712
+ def forward(
713
+ self,
714
+ input_ids: torch.Tensor,
715
+ positions: torch.Tensor,
716
+ forward_batch: ForwardBatch,
717
+ input_embeds: torch.Tensor = None,
718
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
719
+ ) -> Union[torch.Tensor, PPProxyTensors]:
720
+ if self.pp_group.is_first_rank:
721
+ if input_embeds is None:
722
+ hidden_states = self.get_input_embeddings(input_ids)
723
+ else:
724
+ hidden_states = input_embeds
725
+ residual = None
726
+ else:
727
+ assert pp_proxy_tensors is not None
728
+ hidden_states = pp_proxy_tensors["hidden_states"]
729
+ residual = pp_proxy_tensors["residual"]
730
+
731
+ if forward_batch.can_run_tbo:
732
+ hidden_states, residual = model_forward_maybe_tbo(
733
+ layers=self.layers,
734
+ enable_tbo=True,
735
+ input_data_scatter_mode=ScatterMode.model_input_output(),
736
+ positions=positions,
737
+ forward_batch=forward_batch,
738
+ hidden_states=hidden_states,
739
+ residual=residual,
740
+ )
741
+ else:
742
+ for i in range(self.start_layer, self.end_layer):
743
+ with get_global_expert_distribution_recorder().with_current_layer(i):
744
+ layer = self.layers[i]
745
+ hidden_states, residual = layer(
746
+ positions=positions,
747
+ forward_batch=forward_batch,
748
+ hidden_states=hidden_states,
749
+ residual=residual,
750
+ )
751
+
752
+ if not self.pp_group.is_last_rank:
753
+ return PPProxyTensors(
754
+ {"hidden_states": hidden_states, "residual": residual}
755
+ )
756
+
757
+ if residual is not None:
758
+ hidden_states, _ = self.norm(hidden_states, residual)
759
+ else:
760
+ hidden_states = self.norm(hidden_states)
761
+
762
+ return hidden_states
763
+
764
+
765
+ class MiniMaxM2ForCausalLM(nn.Module):
766
+ """MiniMax M2 model for causal language modeling."""
767
+
768
+ def __init__(
769
+ self,
770
+ config: PretrainedConfig,
771
+ quant_config: Optional[QuantizationConfig] = None,
772
+ prefix: str = "",
773
+ ) -> None:
774
+ super().__init__()
775
+
776
+ self.config = config
777
+ self.quant_config = quant_config
778
+
779
+ self.model = MiniMaxM2Model(
780
+ config, quant_config, prefix=add_prefix("model", prefix)
781
+ )
782
+
783
+ if get_pp_group().is_last_rank:
784
+ self.lm_head = ParallelLMHead(
785
+ config.vocab_size,
786
+ config.hidden_size,
787
+ quant_config=None,
788
+ prefix=add_prefix("lm_head", prefix),
789
+ )
790
+ else:
791
+ self.lm_head = PPMissingLayer()
792
+
793
+ self.logits_processor = LogitsProcessor(config)
794
+
795
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
796
+ return self.model.get_input_embeddings(input_ids)
797
+
798
+ @torch.no_grad()
799
+ def forward(
800
+ self,
801
+ input_ids: torch.Tensor,
802
+ positions: torch.Tensor,
803
+ forward_batch: ForwardBatch,
804
+ input_embeds: torch.Tensor = None,
805
+ ) -> torch.Tensor:
806
+ # _print_tensor_info(input_ids, "input_ids")
807
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
808
+ return self.logits_processor(
809
+ input_ids, hidden_states, self.lm_head, forward_batch
810
+ )
811
+
812
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
813
+ """Load model weights with proper mapping for MiniMax architecture."""
814
+
815
+ stacked_params_mapping = [
816
+ # (param_name, shard_name, shard_id)
817
+ ("qkv_proj", "q_proj", "q"),
818
+ ("qkv_proj", "k_proj", "k"),
819
+ ("qkv_proj", "v_proj", "v"),
820
+ ("gate_up_proj", "gate_proj", 0),
821
+ ("gate_up_proj", "up_proj", 1),
822
+ ]
823
+
824
+ # Params for weights, fp8 weight scales, fp8 activation scales
825
+ # (param_name, weight_name, expert_id, shard_id)
826
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
827
+ ckpt_gate_proj_name="w1",
828
+ ckpt_down_proj_name="w2",
829
+ ckpt_up_proj_name="w3",
830
+ num_experts=self.config.num_local_experts,
831
+ )
832
+
833
+ params_dict = dict(self.named_parameters())
834
+ loaded_params: Set[str] = set()
835
+ for name, loaded_weight in weights:
836
+ if "rotary_emb.inv_freq" in name:
837
+ continue
838
+
839
+ spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
840
+ if spec_layer is not None:
841
+ continue # skip spec decode layers for main model
842
+
843
+ for param_name, weight_name, shard_id in stacked_params_mapping:
844
+ # Skip non-stacked layers and experts (experts handled below).
845
+ if weight_name not in name:
846
+ continue
847
+ # We have mlp.experts[0].gate_proj in the checkpoint.
848
+ # Since we handle the experts below in expert_params_mapping,
849
+ # we need to skip here BEFORE we update the name, otherwise
850
+ # name will be updated to mlp.experts[0].gate_up_proj, which
851
+ # will then be updated below in expert_params_mapping
852
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
853
+ if ("mlp.experts." in name) and name not in params_dict:
854
+ continue
855
+ name = name.replace(weight_name, param_name)
856
+ # Skip loading extra bias for GPTQ models.
857
+ if name.endswith(".bias") and name not in params_dict:
858
+ continue
859
+
860
+ param = params_dict[name]
861
+ weight_loader = param.weight_loader
862
+ weight_loader(param, loaded_weight, shard_id)
863
+ break
864
+ else:
865
+ for mapping in expert_params_mapping:
866
+ param_name, weight_name, expert_id, shard_id = mapping
867
+ if weight_name not in name:
868
+ continue
869
+ name = name.replace(weight_name, param_name)
870
+
871
+ param = params_dict[name]
872
+ weight_loader = param.weight_loader
873
+ weight_loader(
874
+ param,
875
+ loaded_weight,
876
+ name,
877
+ shard_id=shard_id,
878
+ expert_id=expert_id,
879
+ )
880
+ break
881
+ else:
882
+ # Skip loading extra bias for GPTQ models.
883
+ if name.endswith(".bias") and name not in params_dict:
884
+ continue
885
+
886
+ # Remapping the name of FP8 kv-scale.
887
+ name = maybe_remap_kv_scale_name(name, params_dict)
888
+ if name is None:
889
+ continue
890
+
891
+ param = params_dict[name]
892
+ weight_loader = getattr(
893
+ param, "weight_loader", default_weight_loader
894
+ )
895
+ weight_loader(param, loaded_weight)
896
+ loaded_params.add(name)
897
+ return loaded_params
898
+
899
+ @classmethod
900
+ def get_model_config_for_expert_location(cls, config):
901
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
902
+
903
+ return ModelConfigForExpertLocation(
904
+ num_layers=config.num_hidden_layers,
905
+ num_logical_experts=config.num_local_experts,
906
+ num_groups=None,
907
+ )
908
+
909
+
910
+ def get_spec_layer_idx_from_weight_name(
911
+ config: PretrainedConfig, weight_name: str
912
+ ) -> Optional[int]:
913
+ if hasattr(config, "num_mtp_modules") and (config.num_mtp_modules > 0):
914
+ layer_idx = config.num_hidden_layers
915
+ for i in range(config.num_mtp_modules):
916
+ if weight_name.startswith(f"model.layers.{layer_idx + i}."):
917
+ return layer_idx + i
918
+ return None
919
+
920
+
921
+ # Entry class for model registration
922
+ EntryClass = MiniMaxM2ForCausalLM