sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1251 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+
16
+ """Inference-only GptOss model compatible with HuggingFace weights."""
17
+
18
+ import logging
19
+ from collections.abc import Iterable
20
+ from functools import partial
21
+ from typing import Any, Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ from torch import nn
25
+ from transformers import PretrainedConfig
26
+
27
+ from sglang.srt.distributed import (
28
+ get_moe_expert_parallel_rank,
29
+ get_moe_expert_parallel_world_size,
30
+ get_moe_tensor_parallel_rank,
31
+ get_moe_tensor_parallel_world_size,
32
+ get_pp_group,
33
+ get_tensor_model_parallel_rank,
34
+ get_tensor_model_parallel_world_size,
35
+ tensor_model_parallel_all_reduce,
36
+ )
37
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
38
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
39
+ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
40
+ from sglang.srt.layers.dp_attention import (
41
+ get_attention_tp_rank,
42
+ get_attention_tp_size,
43
+ get_local_attention_dp_size,
44
+ )
45
+ from sglang.srt.layers.layernorm import RMSNorm
46
+ from sglang.srt.layers.linear import (
47
+ QKVParallelLinear,
48
+ ReplicatedLinear,
49
+ RowParallelLinear,
50
+ )
51
+ from sglang.srt.layers.logits_processor import LogitsProcessor
52
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
53
+ from sglang.srt.layers.moe.topk import TopK
54
+ from sglang.srt.layers.moe.utils import DeepEPMode
55
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
56
+ from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
57
+ from sglang.srt.layers.radix_attention import RadixAttention
58
+ from sglang.srt.layers.rotary_embedding import get_rope
59
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
60
+ from sglang.srt.layers.vocab_parallel_embedding import (
61
+ ParallelLMHead,
62
+ VocabParallelEmbedding,
63
+ )
64
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
65
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
66
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
67
+ from sglang.srt.utils import (
68
+ LazyValue,
69
+ add_prefix,
70
+ is_cuda,
71
+ is_flashinfer_available,
72
+ make_layers,
73
+ )
74
+
75
+ _is_cuda = is_cuda()
76
+ _is_flashinfer_available = is_flashinfer_available()
77
+ _is_sm100_supported = is_cuda() and is_sm100_supported()
78
+
79
+
80
+ if _is_cuda:
81
+ from sgl_kernel import FusedSetKVBufferArg
82
+
83
+
84
+ class GptOssConfig(PretrainedConfig):
85
+ model_type = "gpt_oss"
86
+
87
+ def __init__(self, **kwargs):
88
+ super().__init__(**kwargs)
89
+
90
+
91
+ logger = logging.getLogger(__name__)
92
+
93
+
94
+ # Aligned with HF's implementation, using sliding window inclusive with the last token
95
+ # SGLang assumes exclusive
96
+ def get_attention_sliding_window_size(config):
97
+ return config.sliding_window - 1
98
+
99
+
100
+ class GptOssSparseMoeBlock(nn.Module):
101
+ def __init__(
102
+ self,
103
+ layer_id: int,
104
+ config: GptOssConfig,
105
+ quant_config: Optional[QuantizationConfig] = None,
106
+ prefix: str = "",
107
+ ):
108
+ super().__init__()
109
+ self.tp_size = get_tensor_model_parallel_world_size()
110
+ self.layer_id = layer_id
111
+ self.activation = config.hidden_act
112
+ self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
113
+ self.swiglu_limit = config.swiglu_limit
114
+
115
+ if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
116
+ self.topk = None
117
+ else:
118
+ self.topk = TopK(
119
+ top_k=config.num_experts_per_tok,
120
+ renormalize=True,
121
+ )
122
+
123
+ self.top_k = config.num_experts_per_tok
124
+ experts_type = get_moe_impl_class()
125
+ extra_kwargs = {}
126
+ if experts_type.__name__ == "FusedMoE":
127
+ quant_config_name = (
128
+ quant_config.get_name() if quant_config is not None else None
129
+ )
130
+ extra_kwargs = {
131
+ "enable_flashinfer_cutlass_moe": global_server_args_dict[
132
+ "enable_flashinfer_cutlass_moe"
133
+ ],
134
+ # for moe gate_up_proj and down_proj and their bias loading
135
+ "use_weight_loader_fused": quant_config_name != "mxfp4",
136
+ }
137
+ self.experts = experts_type(
138
+ num_experts=config.num_local_experts
139
+ + global_server_args_dict["ep_num_redundant_experts"],
140
+ top_k=config.num_experts_per_tok,
141
+ layer_id=layer_id,
142
+ hidden_size=config.hidden_size,
143
+ intermediate_size=config.intermediate_size,
144
+ quant_config=quant_config,
145
+ activation=self.activation,
146
+ activation_alpha=self.activation_alpha,
147
+ swiglu_limit=self.swiglu_limit,
148
+ with_bias=True,
149
+ prefix=add_prefix("experts", prefix),
150
+ **(
151
+ dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
152
+ if global_server_args_dict["moe_a2a_backend"].is_deepep()
153
+ else {}
154
+ ),
155
+ **extra_kwargs,
156
+ )
157
+
158
+ self.router = ReplicatedLinear(
159
+ config.hidden_size,
160
+ config.num_local_experts,
161
+ bias=True,
162
+ quant_config=None,
163
+ prefix=add_prefix("gate", prefix),
164
+ params_dtype=config.torch_dtype,
165
+ )
166
+
167
+ def forward(
168
+ self,
169
+ hidden_states: torch.Tensor,
170
+ forward_batch: Optional[ForwardBatch] = None,
171
+ should_allreduce_fusion: bool = False,
172
+ ) -> torch.Tensor:
173
+ if not global_server_args_dict["moe_a2a_backend"].is_deepep():
174
+ return self.forward_normal(hidden_states, should_allreduce_fusion)
175
+ else:
176
+ raise Exception("forward_deepep branch not implemented yet")
177
+
178
+ def get_moe_weights(self):
179
+ return [
180
+ x.data
181
+ for name, x in self.experts.named_parameters()
182
+ if name not in ["correction_bias"]
183
+ ]
184
+
185
+ def forward_normal(
186
+ self,
187
+ hidden_states: torch.Tensor,
188
+ should_allreduce_fusion: bool = False,
189
+ ) -> torch.Tensor:
190
+ num_tokens, hidden_dim = hidden_states.shape
191
+ hidden_states = hidden_states.view(-1, hidden_dim)
192
+
193
+ # router_logits: (num_tokens, n_experts)
194
+ router_logits, _ = self.router(hidden_states)
195
+
196
+ kwargs = {"hidden_states": hidden_states}
197
+ if self.topk is not None:
198
+ kwargs["topk_output"] = self.topk(hidden_states, router_logits)
199
+ else:
200
+ kwargs["topk_output"] = (self.top_k, router_logits)
201
+ final_hidden_states = self.experts(**kwargs)
202
+
203
+ if self.tp_size > 1 and not should_allreduce_fusion:
204
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
205
+
206
+ ans = final_hidden_states.view(num_tokens, hidden_dim)
207
+ return ans
208
+
209
+
210
+ def _enable_fused_set_kv_buffer():
211
+ return _is_cuda
212
+
213
+
214
+ # TODO maybe move to a model-common utils
215
+ def _create_fused_set_kv_buffer_arg(
216
+ value: torch.Tensor,
217
+ layer: RadixAttention,
218
+ forward_batch: ForwardBatch,
219
+ ):
220
+ layer_id = layer.layer_id
221
+ token_to_kv_pool = forward_batch.token_to_kv_pool
222
+
223
+ k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
224
+ v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
225
+
226
+ return FusedSetKVBufferArg(
227
+ value=value,
228
+ k_buffer=k_buffer.view(k_buffer.shape[0], -1),
229
+ v_buffer=v_buffer.view(v_buffer.shape[0], -1),
230
+ k_scale=layer.k_scale,
231
+ v_scale=layer.v_scale,
232
+ cache_loc=forward_batch.out_cache_loc,
233
+ )
234
+
235
+
236
+ class GptOssAttention(nn.Module):
237
+ def __init__(
238
+ self,
239
+ hidden_size: int,
240
+ num_heads: int,
241
+ num_kv_heads: int,
242
+ layer_id: int = 0,
243
+ rope_theta: float = 10000,
244
+ rope_scaling: Optional[Dict[str, Any]] = None,
245
+ max_position_embeddings: int = 8192,
246
+ head_dim: Optional[int] = None,
247
+ rms_norm_eps: float = 1e-06,
248
+ attention_bias: bool = False,
249
+ quant_config: Optional[QuantizationConfig] = None,
250
+ prefix: str = "",
251
+ sliding_window_size: int = -1, # if -1, normal attention, else, window attention.
252
+ layer_type: str = "",
253
+ params_dtype: torch.dtype = torch.bfloat16,
254
+ ) -> None:
255
+ super().__init__()
256
+ self.hidden_size = hidden_size
257
+ self.sliding_window_size = sliding_window_size
258
+
259
+ attn_tp_rank = get_attention_tp_rank()
260
+ attn_tp_size = get_attention_tp_size()
261
+
262
+ self.total_num_heads = num_heads
263
+ assert self.total_num_heads % attn_tp_size == 0
264
+ self.num_heads = self.total_num_heads // attn_tp_size
265
+ self.total_num_kv_heads = num_kv_heads
266
+ if self.total_num_kv_heads >= attn_tp_size:
267
+ # Number of KV heads is greater than TP size, so we partition
268
+ # the KV heads across multiple tensor parallel GPUs.
269
+ assert self.total_num_kv_heads % attn_tp_size == 0
270
+ else:
271
+ # Number of KV heads is less than TP size, so we replicate
272
+ # the KV heads across multiple tensor parallel GPUs.
273
+ assert attn_tp_size % self.total_num_kv_heads == 0
274
+ self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
275
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
276
+ self.q_size = self.num_heads * self.head_dim
277
+ self.kv_size = self.num_kv_heads * self.head_dim
278
+ self.scaling = self.head_dim**-0.5
279
+ self.rope_theta = rope_theta
280
+ self.max_position_embeddings = max_position_embeddings
281
+ self.tp_rank = get_tensor_model_parallel_rank()
282
+
283
+ self.qkv_proj = QKVParallelLinear(
284
+ hidden_size,
285
+ self.head_dim,
286
+ self.total_num_heads,
287
+ self.total_num_kv_heads,
288
+ bias=attention_bias,
289
+ params_dtype=params_dtype,
290
+ quant_config=quant_config,
291
+ tp_rank=attn_tp_rank,
292
+ tp_size=attn_tp_size,
293
+ prefix=add_prefix("qkv_proj", prefix),
294
+ )
295
+
296
+ self.sinks = nn.Parameter(
297
+ torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False
298
+ )
299
+
300
+ self.o_proj = RowParallelLinear(
301
+ self.total_num_heads * self.head_dim,
302
+ hidden_size,
303
+ bias=attention_bias,
304
+ quant_config=quant_config,
305
+ tp_rank=attn_tp_rank,
306
+ tp_size=attn_tp_size,
307
+ reduce_results=False,
308
+ params_dtype=params_dtype,
309
+ prefix=add_prefix("o_proj", prefix),
310
+ )
311
+
312
+ self.rotary_emb = get_rope(
313
+ self.head_dim,
314
+ rotary_dim=self.head_dim,
315
+ max_position=max_position_embeddings,
316
+ base=rope_theta,
317
+ rope_scaling=rope_scaling,
318
+ )
319
+
320
+ assert layer_type in {"sliding_attention", "full_attention"}
321
+ use_sliding_window = layer_type == "sliding_attention"
322
+ self.attn = RadixAttention(
323
+ self.num_heads,
324
+ self.head_dim,
325
+ self.scaling,
326
+ num_kv_heads=self.num_kv_heads,
327
+ layer_id=layer_id,
328
+ prefix=add_prefix("attn", prefix),
329
+ sliding_window_size=(sliding_window_size if use_sliding_window else -1),
330
+ )
331
+ self.layer_id = layer_id
332
+
333
+ def forward_prepare(
334
+ self,
335
+ positions: torch.Tensor,
336
+ hidden_states: torch.Tensor,
337
+ forward_batch: ForwardBatch,
338
+ ):
339
+ if hidden_states.shape[0] == 0:
340
+ return hidden_states, forward_batch, None
341
+ qkv, _ = self.qkv_proj(hidden_states)
342
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
343
+
344
+ q, k = self.rotary_emb(
345
+ positions,
346
+ q,
347
+ k,
348
+ fused_set_kv_buffer_arg=(
349
+ _create_fused_set_kv_buffer_arg(
350
+ value=v,
351
+ layer=self.attn,
352
+ forward_batch=forward_batch,
353
+ )
354
+ if _enable_fused_set_kv_buffer()
355
+ else None
356
+ ),
357
+ )
358
+ inner_state = q, k, v, forward_batch
359
+ return None, forward_batch, inner_state
360
+
361
+ def forward_core(self, intermediate_state):
362
+ hidden_states, forward_batch, inner_state = intermediate_state
363
+ if inner_state is None:
364
+ return hidden_states
365
+ attn_output = self.attn(
366
+ *inner_state,
367
+ sinks=self.sinks,
368
+ save_kv_cache=not _enable_fused_set_kv_buffer(),
369
+ )
370
+ output, _ = self.o_proj(attn_output)
371
+ return output
372
+
373
+ def forward(
374
+ self,
375
+ positions: torch.Tensor,
376
+ hidden_states: torch.Tensor,
377
+ forward_batch: ForwardBatch,
378
+ ) -> torch.Tensor:
379
+ s = self.forward_prepare(
380
+ positions=positions,
381
+ hidden_states=hidden_states,
382
+ forward_batch=forward_batch,
383
+ )
384
+ return self.forward_core(s)
385
+
386
+
387
+ class GptOssDecoderLayer(nn.Module):
388
+ def __init__(
389
+ self,
390
+ config: GptOssConfig,
391
+ layer_id: int,
392
+ quant_config: Optional[QuantizationConfig] = None,
393
+ prefix: str = "",
394
+ sliding_window_size: int | None = None,
395
+ ) -> None:
396
+ super().__init__()
397
+ self.config = config
398
+ self.hidden_size = config.hidden_size
399
+ rope_theta = getattr(config, "rope_theta", 10000)
400
+ rope_scaling = getattr(config, "rope_scaling", None)
401
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
402
+ head_dim = getattr(
403
+ config, "head_dim", config.hidden_size // config.num_attention_heads
404
+ )
405
+ rms_norm_eps = config.rms_norm_eps
406
+ attention_bias = config.attention_bias
407
+
408
+ if sliding_window_size is None:
409
+ self.sliding_window_size = get_attention_sliding_window_size(self.config)
410
+ else:
411
+ self.sliding_window_size = sliding_window_size
412
+
413
+ self.self_attn = GptOssAttention(
414
+ hidden_size=self.hidden_size,
415
+ num_heads=config.num_attention_heads,
416
+ num_kv_heads=config.num_key_value_heads,
417
+ layer_id=layer_id,
418
+ rope_theta=rope_theta,
419
+ rope_scaling=rope_scaling,
420
+ max_position_embeddings=max_position_embeddings,
421
+ head_dim=head_dim,
422
+ rms_norm_eps=rms_norm_eps,
423
+ attention_bias=attention_bias,
424
+ prefix=add_prefix("self_attn", prefix),
425
+ sliding_window_size=self.sliding_window_size,
426
+ layer_type=config.layer_types[layer_id],
427
+ params_dtype=config.torch_dtype,
428
+ )
429
+
430
+ self.layer_id = layer_id
431
+
432
+ self.attn_tp_size = get_attention_tp_size()
433
+ self.attn_tp_rank = get_attention_tp_rank()
434
+ self.local_dp_size = get_local_attention_dp_size()
435
+
436
+ # GptOss all layers are sparse and have no nextn now
437
+ self.is_layer_sparse = True
438
+ self.is_nextn = False
439
+ is_previous_layer_sparse = True
440
+
441
+ self.layer_scatter_modes = LayerScatterModes.init_new(
442
+ layer_id=layer_id,
443
+ num_layers=config.num_hidden_layers,
444
+ is_layer_sparse=self.is_layer_sparse,
445
+ is_previous_layer_sparse=is_previous_layer_sparse,
446
+ )
447
+
448
+ if self.is_layer_sparse:
449
+ self.mlp = GptOssSparseMoeBlock(
450
+ layer_id=self.layer_id,
451
+ config=config,
452
+ quant_config=quant_config,
453
+ prefix=add_prefix("mlp", prefix),
454
+ )
455
+ else:
456
+ raise NotImplementedError(
457
+ "Dense MLP is not implemented for GptOssDecoderLayer. "
458
+ "Please use GptOssSparseMoeBlock instead."
459
+ )
460
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
461
+ self.post_attention_layernorm = RMSNorm(
462
+ config.hidden_size, eps=config.rms_norm_eps
463
+ )
464
+
465
+ self.layer_communicator = LayerCommunicator(
466
+ layer_scatter_modes=self.layer_scatter_modes,
467
+ input_layernorm=self.input_layernorm,
468
+ post_attention_layernorm=self.post_attention_layernorm,
469
+ )
470
+
471
+ self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
472
+
473
+ def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
474
+ """Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
475
+
476
+ batch_size = (
477
+ forward_batch.input_ids.shape[0]
478
+ if hasattr(forward_batch, "input_ids")
479
+ else 0
480
+ )
481
+
482
+ if batch_size > 128:
483
+ return False
484
+
485
+ return self._fuse_allreduce_lookup_table.get(batch_size, False)
486
+
487
+ def _build_fuse_allreduce_lookup_table(self):
488
+ static_conditions_met = (
489
+ self.layer_id != self.config.num_hidden_layers - 1
490
+ and get_tensor_model_parallel_world_size() > 1
491
+ and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
492
+ and _is_sm100_supported
493
+ and _is_flashinfer_available
494
+ )
495
+
496
+ if not static_conditions_met:
497
+ return {}
498
+
499
+ lookup_table = {}
500
+ for batch_size in range(129): # 0 to 128
501
+ is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
502
+ should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
503
+ lookup_table[batch_size] = should_fuse
504
+
505
+ return lookup_table
506
+
507
+ def forward(
508
+ self,
509
+ positions: torch.Tensor,
510
+ hidden_states: torch.Tensor,
511
+ forward_batch: ForwardBatch,
512
+ residual: Optional[torch.Tensor],
513
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
514
+ hidden_states, residual = self.layer_communicator.prepare_attn(
515
+ hidden_states, residual, forward_batch
516
+ )
517
+
518
+ if hidden_states.shape[0] != 0:
519
+ hidden_states = self.self_attn(
520
+ positions=positions,
521
+ hidden_states=hidden_states,
522
+ forward_batch=forward_batch,
523
+ )
524
+
525
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
526
+ hidden_states, residual, forward_batch
527
+ )
528
+
529
+ should_allreduce_fusion = (
530
+ self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
531
+ and not self.is_nextn
532
+ )
533
+
534
+ hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
535
+
536
+ if should_allreduce_fusion:
537
+ hidden_states._sglang_needs_allreduce_fusion = True
538
+
539
+ if not should_allreduce_fusion:
540
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
541
+ hidden_states, residual, forward_batch
542
+ )
543
+
544
+ return hidden_states, residual
545
+
546
+
547
+ class GptOssModel(nn.Module):
548
+ def __init__(
549
+ self,
550
+ config: PretrainedConfig,
551
+ quant_config: Optional[QuantizationConfig] = None,
552
+ prefix: str = "",
553
+ decoder_layer_type: type[nn.Module] = GptOssDecoderLayer,
554
+ ) -> None:
555
+ super().__init__()
556
+ self.padding_idx = config.pad_token_id
557
+ self.vocab_size = config.vocab_size
558
+ self.pp_group = get_pp_group()
559
+
560
+ if self.pp_group.is_first_rank:
561
+ self.embed_tokens = VocabParallelEmbedding(
562
+ config.vocab_size,
563
+ config.hidden_size,
564
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
565
+ prefix=add_prefix("embed_tokens", prefix),
566
+ )
567
+ else:
568
+ self.embed_tokens = PPMissingLayer()
569
+
570
+ # Use the provided decoder layer type or default to GptOssDecoderLayer
571
+ decoder_layer_type = decoder_layer_type or GptOssDecoderLayer
572
+ self.layers, self.start_layer, self.end_layer = make_layers(
573
+ config.num_hidden_layers,
574
+ lambda idx, prefix: decoder_layer_type(
575
+ layer_id=idx,
576
+ config=config,
577
+ quant_config=quant_config,
578
+ prefix=prefix,
579
+ ),
580
+ pp_rank=self.pp_group.rank_in_group,
581
+ pp_size=self.pp_group.world_size,
582
+ prefix=add_prefix("layers", prefix),
583
+ )
584
+ if self.pp_group.is_last_rank:
585
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
586
+ else:
587
+ self.norm = PPMissingLayer(return_tuple=True)
588
+
589
+ self.layers_to_capture = []
590
+
591
+ def forward(
592
+ self,
593
+ input_ids: torch.Tensor,
594
+ positions: torch.Tensor,
595
+ forward_batch: ForwardBatch,
596
+ input_embeds: torch.Tensor = None,
597
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
598
+ ) -> Union[torch.Tensor, PPProxyTensors]:
599
+ if self.pp_group.is_first_rank:
600
+ if input_embeds is None:
601
+ hidden_states = self.embed_tokens(input_ids)
602
+ else:
603
+ hidden_states = input_embeds
604
+ residual = None
605
+ else:
606
+ assert pp_proxy_tensors is not None
607
+ hidden_states = pp_proxy_tensors["hidden_states"]
608
+ residual = pp_proxy_tensors["residual"]
609
+
610
+ aux_hidden_states = []
611
+ for i in range(self.start_layer, self.end_layer):
612
+ with get_global_expert_distribution_recorder().with_current_layer(i):
613
+ if i in self.layers_to_capture:
614
+ aux_hidden_states.append(hidden_states + residual)
615
+ layer = self.layers[i]
616
+ hidden_states, residual = layer(
617
+ positions, hidden_states, forward_batch, residual
618
+ )
619
+ if not self.pp_group.is_last_rank:
620
+ return PPProxyTensors(
621
+ {
622
+ "hidden_states": hidden_states,
623
+ "residual": residual,
624
+ }
625
+ )
626
+ else:
627
+ if hidden_states.shape[0] != 0:
628
+ if residual is None:
629
+ hidden_states = self.norm(hidden_states)
630
+ else:
631
+ hidden_states, _ = self.norm(hidden_states, residual)
632
+ if len(aux_hidden_states) == 0:
633
+ return hidden_states
634
+
635
+ return hidden_states, aux_hidden_states
636
+
637
+
638
+ class GptOssForCausalLM(nn.Module):
639
+ fall_back_to_pt_during_load = False
640
+
641
+ def __init__(
642
+ self,
643
+ config: GptOssConfig,
644
+ quant_config: Optional[QuantizationConfig] = None,
645
+ prefix: str = "",
646
+ ) -> None:
647
+ super().__init__()
648
+ self.pp_group = get_pp_group()
649
+ self.config = config
650
+ self.quant_config = quant_config
651
+ self.model = GptOssModel(
652
+ config, quant_config, prefix=add_prefix("model", prefix)
653
+ )
654
+ self.lm_head = ParallelLMHead(
655
+ config.vocab_size,
656
+ config.hidden_size,
657
+ # quant_config=quant_config,
658
+ prefix=add_prefix("lm_head", prefix),
659
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
660
+ )
661
+ self.logits_processor = LogitsProcessor(config)
662
+ self.capture_aux_hidden_states = False
663
+
664
+ self._routed_experts_weights_of_layer = LazyValue(
665
+ lambda: {
666
+ layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
667
+ for layer_id in range(self.start_layer, self.end_layer)
668
+ if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock)
669
+ }
670
+ )
671
+
672
+ @property
673
+ def routed_experts_weights_of_layer(self):
674
+ return self._routed_experts_weights_of_layer.value
675
+
676
+ @torch.no_grad()
677
+ def forward(
678
+ self,
679
+ input_ids: torch.Tensor,
680
+ positions: torch.Tensor,
681
+ forward_batch: ForwardBatch,
682
+ input_embeds: torch.Tensor = None,
683
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
684
+ ) -> torch.Tensor:
685
+ hidden_states = self.model(
686
+ input_ids,
687
+ positions,
688
+ forward_batch,
689
+ input_embeds,
690
+ pp_proxy_tensors=pp_proxy_tensors,
691
+ )
692
+
693
+ aux_hidden_states = None
694
+ if self.capture_aux_hidden_states:
695
+ hidden_states, aux_hidden_states = hidden_states
696
+
697
+ if self.pp_group.is_last_rank:
698
+ return self.logits_processor(
699
+ input_ids,
700
+ hidden_states,
701
+ self.lm_head,
702
+ forward_batch,
703
+ aux_hidden_states,
704
+ )
705
+ else:
706
+ return hidden_states
707
+
708
+ @property
709
+ def start_layer(self):
710
+ return self.model.start_layer
711
+
712
+ @property
713
+ def end_layer(self):
714
+ return self.model.end_layer
715
+
716
+ def _get_default_weight_mapping(self):
717
+ """Generate default weight name mapping for GptOss safetensors."""
718
+ weight_mapping = {}
719
+
720
+ # Map router weights to gate
721
+ weight_mapping["embedding.weight"] = "model.embed_tokens.weight"
722
+ weight_mapping["unembedding.weight"] = "lm_head.weight"
723
+ weight_mapping["norm.scale"] = "model.norm.weight"
724
+ for layer_id in range(self.config.num_hidden_layers):
725
+ weight_mapping[f"block.{layer_id}.attn.q_proj.weight"] = (
726
+ f"model.layers.{layer_id}.self_attn.q_proj.weight"
727
+ )
728
+ weight_mapping[f"block.{layer_id}.attn.q_proj.bias"] = (
729
+ f"model.layers.{layer_id}.self_attn.q_proj.bias"
730
+ )
731
+
732
+ weight_mapping[f"block.{layer_id}.attn.k_proj.weight"] = (
733
+ f"model.layers.{layer_id}.self_attn.k_proj.weight"
734
+ )
735
+ weight_mapping[f"block.{layer_id}.attn.k_proj.bias"] = (
736
+ f"model.layers.{layer_id}.self_attn.k_proj.bias"
737
+ )
738
+
739
+ weight_mapping[f"block.{layer_id}.attn.v_proj.weight"] = (
740
+ f"model.layers.{layer_id}.self_attn.v_proj.weight"
741
+ )
742
+ weight_mapping[f"block.{layer_id}.attn.v_proj.bias"] = (
743
+ f"model.layers.{layer_id}.self_attn.v_proj.bias"
744
+ )
745
+
746
+ weight_mapping[f"block.{layer_id}.attn.out.weight"] = (
747
+ f"model.layers.{layer_id}.self_attn.o_proj.weight"
748
+ )
749
+ weight_mapping[f"block.{layer_id}.attn.out.bias"] = (
750
+ f"model.layers.{layer_id}.self_attn.o_proj.bias"
751
+ )
752
+ weight_mapping[f"block.{layer_id}.attn.sinks"] = (
753
+ f"model.layers.{layer_id}.self_attn.sinks"
754
+ )
755
+ weight_mapping[f"block.{layer_id}.attn.norm.scale"] = (
756
+ f"model.layers.{layer_id}.input_layernorm.weight"
757
+ )
758
+
759
+ weight_mapping[f"block.{layer_id}.mlp.gate.weight"] = (
760
+ f"model.layers.{layer_id}.mlp.router.weight"
761
+ )
762
+ weight_mapping[f"block.{layer_id}.mlp.gate.bias"] = (
763
+ f"model.layers.{layer_id}.mlp.router.bias"
764
+ )
765
+ weight_mapping[f"block.{layer_id}.mlp.norm.scale"] = (
766
+ f"model.layers.{layer_id}.post_attention_layernorm.weight"
767
+ )
768
+ weight_mapping[f"block.{layer_id}.mlp.experts.gate_up_proj"] = (
769
+ f"model.layers.{layer_id}.mlp.experts.gate_up_proj"
770
+ )
771
+ weight_mapping[f"block.{layer_id}.mlp.gate_up_proj_bias"] = (
772
+ f"model.layers.{layer_id}.mlp.experts.gate_up_proj_bias"
773
+ )
774
+ weight_mapping[f"block.{layer_id}.mlp.down_proj"] = (
775
+ f"model.layers.{layer_id}.mlp.experts.mlp2_weight"
776
+ )
777
+ weight_mapping[f"block.{layer_id}.mlp.down_proj_bias"] = (
778
+ f"model.layers.{layer_id}.mlp.experts.mlp2_bias"
779
+ )
780
+
781
+ return weight_mapping
782
+
783
+ # TODO beautify code
784
+ def load_weights(
785
+ self,
786
+ weights: Iterable[Tuple[str, torch.Tensor]],
787
+ is_nextn: bool = False,
788
+ weight_name_mapping: dict = None,
789
+ ):
790
+ quant_config_name = (
791
+ self.quant_config.get_name() if self.quant_config is not None else None
792
+ )
793
+ if quant_config_name != "mxfp4":
794
+ self._load_normal_weights(
795
+ weights, is_nextn=is_nextn, weight_name_mapping=weight_name_mapping
796
+ )
797
+ else:
798
+ self._load_weights_mxfp4(
799
+ weights, is_nextn=is_nextn, weight_name_mapping=weight_name_mapping
800
+ )
801
+
802
+ def _load_weights_mxfp4(self, weights, is_nextn, weight_name_mapping):
803
+ mxfp4_weights = []
804
+ normal_weights = []
805
+
806
+ for name, weight in weights:
807
+ if (
808
+ ".experts" in name
809
+ and self.quant_config is not None
810
+ and self.quant_config.get_name() == "mxfp4"
811
+ ):
812
+ mxfp4_weights.append((name, weight))
813
+ else:
814
+ normal_weights.append((name, weight))
815
+
816
+ mxfp4_loaded_params = self._load_mxfp4_experts_weights(mxfp4_weights)
817
+ self._load_normal_weights(
818
+ normal_weights,
819
+ is_nextn=is_nextn,
820
+ weight_name_mapping=weight_name_mapping,
821
+ other_loaded_param_names=mxfp4_loaded_params,
822
+ )
823
+
824
+ def _load_mxfp4_experts_weights(self, weights):
825
+
826
+ params_dict = dict(self.named_parameters())
827
+ loaded_params: set[str] = set()
828
+ mxfp4_block = 32
829
+
830
+ moe_tp_rank = get_moe_tensor_parallel_rank()
831
+ moe_tp_size = get_moe_tensor_parallel_world_size()
832
+ moe_ep_rank = get_moe_expert_parallel_rank()
833
+ moe_ep_size = get_moe_expert_parallel_world_size()
834
+
835
+ intermediate_size = self.config.intermediate_size
836
+ intermediate_size_block = intermediate_size // mxfp4_block
837
+ per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
838
+ per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
839
+
840
+ # Calculate common slicing bounds for current rank
841
+ assert self.config.num_local_experts % moe_ep_size == 0
842
+ moe_num_global_experts = self.config.num_local_experts
843
+ moe_num_local_experts = self.config.num_local_experts // moe_ep_size
844
+ moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size
845
+ moe_tp_rank_end = min(
846
+ (moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size
847
+ )
848
+ moe_ep_rank_start = moe_ep_rank * moe_num_local_experts
849
+ moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts
850
+
851
+ for name, weight in weights:
852
+ weight = weight.cuda()
853
+
854
+ if "gate_up_proj_blocks" in name:
855
+ # Handle MLP gate and up projection weights
856
+ new_name = name.replace("gate_up_proj_blocks", "w13_weight")
857
+
858
+ # flat weight from (E, 2 * N, block_size, entry_per_block)
859
+ # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
860
+ weight = weight.view(
861
+ moe_num_global_experts, 2 * intermediate_size, -1
862
+ ).contiguous()
863
+
864
+ narrow_weight = weight[
865
+ moe_ep_rank_start:moe_ep_rank_end,
866
+ 2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
867
+ ...,
868
+ ]
869
+
870
+ param = params_dict[new_name]
871
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
872
+ weight_loader(
873
+ param,
874
+ narrow_weight,
875
+ weight_name=new_name,
876
+ shard_id=None,
877
+ expert_id=None,
878
+ )
879
+ loaded_params.add(new_name)
880
+
881
+ elif "down_proj_blocks" in name:
882
+ # Handle MLP down projection weights
883
+ new_name = name.replace("down_proj_blocks", "w2_weight")
884
+ # same flatten here, but since 2 mx4 value are packed in 1
885
+ # uint8, divide by 2
886
+ weight = weight.view(
887
+ moe_num_global_experts, -1, intermediate_size // 2
888
+ ).contiguous()
889
+ narrow_weight = weight[
890
+ moe_ep_rank_start:moe_ep_rank_end,
891
+ ...,
892
+ moe_tp_rank_start // 2 : moe_tp_rank_end // 2,
893
+ ]
894
+
895
+ param = params_dict[new_name]
896
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
897
+ weight_loader(
898
+ param,
899
+ narrow_weight,
900
+ weight_name=new_name,
901
+ shard_id=None,
902
+ expert_id=None,
903
+ )
904
+ loaded_params.add(new_name)
905
+
906
+ elif "gate_up_proj_scales" in name:
907
+ # Handle MLP gate and up projection weights scale
908
+ new_name = name.replace("gate_up_proj_scales", "w13_weight_scale")
909
+ narrow_weight = weight[
910
+ moe_ep_rank_start:moe_ep_rank_end,
911
+ 2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
912
+ ...,
913
+ ]
914
+
915
+ param = params_dict[new_name]
916
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
917
+ weight_loader(
918
+ param,
919
+ narrow_weight,
920
+ weight_name=new_name,
921
+ shard_id=None,
922
+ expert_id=None,
923
+ )
924
+ loaded_params.add(new_name)
925
+
926
+ elif "down_proj_scales" in name:
927
+ # Handle MLP down projection weights
928
+ new_name = name.replace("down_proj_scales", "w2_weight_scale")
929
+ narrow_weight = weight[
930
+ moe_ep_rank_start:moe_ep_rank_end,
931
+ ...,
932
+ moe_tp_rank_start // mxfp4_block : moe_tp_rank_end // mxfp4_block,
933
+ ]
934
+
935
+ param = params_dict[new_name]
936
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
937
+ weight_loader(
938
+ param,
939
+ narrow_weight,
940
+ weight_name=new_name,
941
+ shard_id=None,
942
+ expert_id=None,
943
+ )
944
+ loaded_params.add(new_name)
945
+ elif "gate_up_proj_bias" in name:
946
+ # Handle MLP gate and up projection biases
947
+ new_name = name.replace("gate_up_proj_bias", "w13_weight_bias")
948
+
949
+ narrow_weight = weight[
950
+ moe_ep_rank_start:moe_ep_rank_end,
951
+ 2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
952
+ ]
953
+
954
+ param = params_dict[new_name]
955
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
956
+ weight_loader(
957
+ param,
958
+ narrow_weight,
959
+ weight_name=new_name,
960
+ shard_id=None,
961
+ expert_id=None,
962
+ )
963
+ loaded_params.add(new_name)
964
+
965
+ elif "down_proj_bias" in name:
966
+ narrow_weight = weight[moe_ep_rank_start:moe_ep_rank_end, ...]
967
+ if moe_tp_rank != 0:
968
+ narrow_weight = torch.zeros_like(narrow_weight)
969
+
970
+ # Handle MLP down projection bias
971
+ new_name = name.replace("down_proj_bias", "w2_weight_bias")
972
+ param = params_dict[new_name]
973
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
974
+ weight_loader(
975
+ param,
976
+ narrow_weight,
977
+ weight_name=new_name,
978
+ shard_id=None,
979
+ expert_id=None,
980
+ )
981
+ loaded_params.add(new_name)
982
+
983
+ return loaded_params
984
+
985
+ def _load_normal_weights(
986
+ self,
987
+ weights,
988
+ is_nextn: bool,
989
+ weight_name_mapping: dict,
990
+ other_loaded_param_names=[],
991
+ ):
992
+ tp_rank = get_tensor_model_parallel_rank()
993
+ if is_nextn:
994
+ logging.warning(
995
+ "Loading weights for nextn is currently not supported in GptOssForCausalLM. "
996
+ )
997
+ return
998
+ weights = _canonicalize_weights(self.config, weights)
999
+ weights = sorted(weights, key=lambda x: x[0]) # Sort by name for consistency
1000
+
1001
+ new_weights = []
1002
+ for name, p in weights:
1003
+ if "qkv.weight" in name:
1004
+ q_proj, k_proj, v_proj = p.split(
1005
+ [
1006
+ self.config.num_attention_heads * self.config.head_dim,
1007
+ self.config.num_key_value_heads * self.config.head_dim,
1008
+ self.config.num_key_value_heads * self.config.head_dim,
1009
+ ],
1010
+ dim=0,
1011
+ )
1012
+ new_weights.append(
1013
+ (f"{name.replace('qkv.weight', 'q_proj.weight')}", q_proj)
1014
+ )
1015
+ new_weights.append(
1016
+ (f"{name.replace('qkv.weight', 'k_proj.weight')}", k_proj)
1017
+ )
1018
+ new_weights.append(
1019
+ (f"{name.replace('qkv.weight', 'v_proj.weight')}", v_proj)
1020
+ )
1021
+ elif "qkv.bias" in name:
1022
+ q_bias, k_bias, v_bias = p.split(
1023
+ [
1024
+ self.config.num_attention_heads * self.config.head_dim,
1025
+ self.config.num_key_value_heads * self.config.head_dim,
1026
+ self.config.num_key_value_heads * self.config.head_dim,
1027
+ ],
1028
+ dim=0,
1029
+ )
1030
+ new_weights.append(
1031
+ (f"{name.replace('qkv.bias', 'q_proj.bias')}", q_bias)
1032
+ )
1033
+ new_weights.append(
1034
+ (f"{name.replace('qkv.bias', 'k_proj.bias')}", k_bias)
1035
+ )
1036
+ new_weights.append(
1037
+ (f"{name.replace('qkv.bias', 'v_proj.bias')}", v_bias)
1038
+ )
1039
+ else:
1040
+ new_weights.append((name, p))
1041
+ weights = new_weights
1042
+
1043
+ # Use provided weight name mapping if available, otherwise use default
1044
+ if weight_name_mapping is None:
1045
+ weight_name_mapping = self._get_default_weight_mapping()
1046
+ else:
1047
+ # Merge with default mapping
1048
+ default_mapping = self._get_default_weight_mapping()
1049
+ default_mapping.update(weight_name_mapping)
1050
+ weight_name_mapping = default_mapping
1051
+
1052
+ stacked_params_mapping = [
1053
+ # (param_name, shard_name, shard_id)
1054
+ ("qkv_proj", "q_proj", "q"),
1055
+ ("qkv_proj", "k_proj", "k"),
1056
+ ("qkv_proj", "v_proj", "v"),
1057
+ ]
1058
+ expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
1059
+ ckpt_gate_up_proj_name="gate_up_proj",
1060
+ ckpt_down_proj_name="down_proj",
1061
+ ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
1062
+ ckpt_down_proj_bias_name="down_proj_bias",
1063
+ )
1064
+
1065
+ params_dict = dict(self.named_parameters())
1066
+ params_checker = {k: False for k, v in params_dict.items()}
1067
+
1068
+ for other_loaded_param_name in other_loaded_param_names:
1069
+ params_checker[other_loaded_param_name] = True
1070
+
1071
+ for name, loaded_weight in weights:
1072
+ loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
1073
+
1074
+ # Apply weight name mapping if provided
1075
+ if weight_name_mapping and name in weight_name_mapping:
1076
+ name = weight_name_mapping[name]
1077
+
1078
+ layer_id = get_layer_id(name)
1079
+ if (
1080
+ layer_id is not None
1081
+ and hasattr(self.model, "start_layer")
1082
+ and (
1083
+ layer_id < self.model.start_layer
1084
+ or layer_id >= self.model.end_layer
1085
+ )
1086
+ ):
1087
+ continue
1088
+
1089
+ if "rotary_emb.inv_freq" in name:
1090
+ continue
1091
+ for param_name, weight_name, shard_id in stacked_params_mapping:
1092
+ if weight_name not in name:
1093
+ continue
1094
+ if "mlp.experts" in name:
1095
+ continue
1096
+
1097
+ name = name.replace(weight_name, param_name)
1098
+ if name.endswith(".bias") and name not in params_dict:
1099
+ continue
1100
+ if name not in params_dict:
1101
+ continue
1102
+
1103
+ param = params_dict[name]
1104
+ weight_loader = param.weight_loader
1105
+ weight_loader(param, loaded_weight, shard_id)
1106
+ params_checker[name] = True
1107
+ break
1108
+ else:
1109
+ for mapping in expert_params_mapping:
1110
+ param_name, weight_name, shard_id = mapping
1111
+ if weight_name not in name:
1112
+ continue
1113
+ name = name.replace(weight_name, param_name)
1114
+ if name not in params_dict:
1115
+ continue
1116
+ param = params_dict[name]
1117
+ weight_loader = param.weight_loader
1118
+ if "bias" not in name:
1119
+ loaded_weight = loaded_weight.transpose(-2, -1)
1120
+ if "w2_weight_bias" in name and get_moe_tensor_parallel_rank() != 0:
1121
+ loaded_weight = loaded_weight.zero_()
1122
+
1123
+ weight_loader(
1124
+ param,
1125
+ loaded_weight,
1126
+ name,
1127
+ shard_id=shard_id,
1128
+ )
1129
+ params_checker[name] = True
1130
+ break
1131
+ else:
1132
+ if name.endswith(".bias") and name not in params_dict:
1133
+ continue
1134
+ if name not in params_dict:
1135
+ continue
1136
+ if name in params_dict.keys():
1137
+ param = params_dict[name]
1138
+ if "sinks" in name:
1139
+ start = tp_rank * param.numel()
1140
+ param.data.copy_(
1141
+ loaded_weight[start : start + param.numel()]
1142
+ )
1143
+ else:
1144
+ weight_loader = getattr(
1145
+ param, "weight_loader", default_weight_loader
1146
+ )
1147
+ weight_loader(param, loaded_weight)
1148
+ params_checker[name] = True
1149
+ else:
1150
+ logger.warning(f"Parameter {name} not found in params_dict")
1151
+
1152
+ not_loaded_params = [k for k, v in params_checker.items() if not v]
1153
+ if tp_rank == 0:
1154
+ if len(not_loaded_params) > 0:
1155
+ raise Exception(f"Not all parameters loaded: {not_loaded_params}")
1156
+ else:
1157
+ logging.info("All parameters loaded successfully.")
1158
+
1159
+ def get_embed_and_head(self):
1160
+ return self.model.embed_tokens.weight, self.lm_head.weight
1161
+
1162
+ def set_embed_and_head(self, embed, head):
1163
+ del self.model.embed_tokens.weight
1164
+ del self.lm_head.weight
1165
+ self.model.embed_tokens.weight = embed
1166
+ self.lm_head.weight = head
1167
+ torch.cuda.empty_cache()
1168
+ torch.cuda.synchronize()
1169
+
1170
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
1171
+ if not self.pp_group.is_last_rank:
1172
+ return
1173
+
1174
+ if layer_ids is None:
1175
+ self.capture_aux_hidden_states = True
1176
+ num_layers = self.config.num_hidden_layers
1177
+ self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
1178
+ else:
1179
+ self.capture_aux_hidden_states = True
1180
+ # we plus 1 here because in sglang, for the ith layer, it takes the output
1181
+ # of the (i-1)th layer as aux hidden state
1182
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
1183
+
1184
+ @classmethod
1185
+ def get_model_config_for_expert_location(cls, config):
1186
+ return ModelConfigForExpertLocation(
1187
+ num_layers=config.num_hidden_layers,
1188
+ num_logical_experts=config.num_local_experts,
1189
+ num_groups=None,
1190
+ )
1191
+
1192
+ def get_attention_sliding_window_size(self):
1193
+ return get_attention_sliding_window_size(self.config)
1194
+
1195
+
1196
+ def _canonicalize_weights(config, weights_in: Iterable[Tuple[str, torch.Tensor]]):
1197
+ weights_out_dict = dict(weights_in)
1198
+
1199
+ for layer_id in range(config.num_hidden_layers):
1200
+ for name_chunk in ["mlp1_weight", "mlp2_weight"]:
1201
+ name_prefix = f"block.{layer_id}.mlp.{name_chunk}"
1202
+ w_blocks = weights_out_dict.pop(f"{name_prefix}.blocks", None)
1203
+ w_scales = weights_out_dict.pop(f"{name_prefix}.scales", None)
1204
+ if w_blocks is not None:
1205
+ weights_out_dict[name_prefix] = _WeightCreator(
1206
+ partial(
1207
+ _dequant_mlp_weight,
1208
+ debug_name=name_prefix,
1209
+ w_blocks=w_blocks,
1210
+ w_scales=w_scales,
1211
+ )
1212
+ )
1213
+
1214
+ return list(weights_out_dict.items())
1215
+
1216
+
1217
+ def _dequant_mlp_weight(debug_name, w_blocks, w_scales):
1218
+ if get_tensor_model_parallel_rank() == 0:
1219
+ logger.info(f"Dequantize {debug_name} start")
1220
+
1221
+ original_device = w_blocks.device
1222
+
1223
+ w_blocks = w_blocks.cuda()
1224
+ w_scales = w_scales.cuda()
1225
+
1226
+ w_bf16 = dequant_mxfp4(w_block=w_blocks, w_scale=w_scales, out_dtype=torch.bfloat16)
1227
+ w_bf16 = w_bf16.transpose(-2, -1).contiguous()
1228
+
1229
+ if get_tensor_model_parallel_rank() == 0:
1230
+ logger.info(
1231
+ f"Dequantize {debug_name} end {w_blocks.shape=} {w_scales.shape=} {w_bf16.shape=}"
1232
+ )
1233
+
1234
+ return w_bf16.to(original_device)
1235
+
1236
+
1237
+ class _WeightCreator:
1238
+ def __init__(self, fn):
1239
+ self._fn = fn
1240
+
1241
+ @staticmethod
1242
+ def maybe_materialize(obj):
1243
+ if isinstance(obj, _WeightCreator):
1244
+ output = obj._fn()
1245
+ obj._fn = None
1246
+ return output
1247
+
1248
+ return obj
1249
+
1250
+
1251
+ EntryClass = GptOssForCausalLM