sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +35 -0
  5. sglang/srt/conversation.py +9 -117
  6. sglang/srt/disaggregation/base/conn.py +5 -2
  7. sglang/srt/disaggregation/decode.py +6 -1
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
  9. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  10. sglang/srt/disaggregation/prefill.py +3 -0
  11. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  12. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  13. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  14. sglang/srt/distributed/parallel_state.py +22 -9
  15. sglang/srt/entrypoints/context.py +244 -0
  16. sglang/srt/entrypoints/engine.py +8 -5
  17. sglang/srt/entrypoints/harmony_utils.py +370 -0
  18. sglang/srt/entrypoints/http_server.py +106 -15
  19. sglang/srt/entrypoints/openai/protocol.py +227 -1
  20. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  21. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  22. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  23. sglang/srt/entrypoints/tool.py +87 -0
  24. sglang/srt/eplb/expert_distribution.py +4 -2
  25. sglang/srt/eplb/expert_location.py +5 -1
  26. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  27. sglang/srt/hf_transformers_utils.py +55 -13
  28. sglang/srt/jinja_template_utils.py +8 -1
  29. sglang/srt/layers/attention/aiter_backend.py +5 -8
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  31. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  32. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  33. sglang/srt/layers/attention/triton_backend.py +85 -14
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  35. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  36. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  38. sglang/srt/layers/attention/vision.py +40 -15
  39. sglang/srt/layers/communicator.py +35 -8
  40. sglang/srt/layers/dp_attention.py +12 -0
  41. sglang/srt/layers/linear.py +9 -8
  42. sglang/srt/layers/logits_processor.py +9 -1
  43. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  44. sglang/srt/layers/moe/ep_moe/layer.py +87 -107
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
  48. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
  49. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  50. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  51. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  52. sglang/srt/layers/moe/topk.py +12 -3
  53. sglang/srt/layers/moe/utils.py +59 -0
  54. sglang/srt/layers/quantization/__init__.py +22 -0
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  56. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  57. sglang/srt/layers/quantization/fp4.py +557 -0
  58. sglang/srt/layers/quantization/fp8.py +8 -7
  59. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  60. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  61. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  62. sglang/srt/layers/quantization/mxfp4.py +651 -0
  63. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  64. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  65. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  66. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  67. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  68. sglang/srt/layers/quantization/quark/utils.py +107 -0
  69. sglang/srt/layers/quantization/unquant.py +60 -6
  70. sglang/srt/layers/quantization/w4afp8.py +1 -1
  71. sglang/srt/layers/rotary_embedding.py +225 -1
  72. sglang/srt/layers/utils.py +9 -0
  73. sglang/srt/layers/vocab_parallel_embedding.py +15 -4
  74. sglang/srt/lora/lora_manager.py +70 -14
  75. sglang/srt/lora/lora_registry.py +10 -2
  76. sglang/srt/lora/mem_pool.py +43 -5
  77. sglang/srt/managers/cache_controller.py +61 -32
  78. sglang/srt/managers/data_parallel_controller.py +52 -2
  79. sglang/srt/managers/detokenizer_manager.py +1 -1
  80. sglang/srt/managers/io_struct.py +21 -4
  81. sglang/srt/managers/mm_utils.py +5 -11
  82. sglang/srt/managers/schedule_batch.py +30 -8
  83. sglang/srt/managers/schedule_policy.py +3 -1
  84. sglang/srt/managers/scheduler.py +170 -18
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  86. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  87. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  88. sglang/srt/managers/template_manager.py +59 -22
  89. sglang/srt/managers/tokenizer_manager.py +137 -67
  90. sglang/srt/managers/tp_worker.py +3 -0
  91. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  92. sglang/srt/managers/utils.py +45 -1
  93. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  94. sglang/srt/mem_cache/hicache_storage.py +13 -21
  95. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  96. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  97. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  98. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  99. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  100. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  101. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  102. sglang/srt/model_executor/forward_batch_info.py +48 -17
  103. sglang/srt/model_executor/model_runner.py +24 -2
  104. sglang/srt/model_loader/weight_utils.py +10 -0
  105. sglang/srt/models/bailing_moe.py +425 -0
  106. sglang/srt/models/deepseek_v2.py +95 -50
  107. sglang/srt/models/ernie4.py +426 -0
  108. sglang/srt/models/ernie4_eagle.py +203 -0
  109. sglang/srt/models/gemma3n_mm.py +39 -0
  110. sglang/srt/models/glm4_moe.py +102 -27
  111. sglang/srt/models/gpt_oss.py +1134 -0
  112. sglang/srt/models/grok.py +3 -3
  113. sglang/srt/models/llama4.py +13 -2
  114. sglang/srt/models/mixtral.py +3 -3
  115. sglang/srt/models/mllama4.py +428 -19
  116. sglang/srt/models/qwen2.py +6 -0
  117. sglang/srt/models/qwen2_moe.py +7 -4
  118. sglang/srt/models/qwen3_moe.py +39 -14
  119. sglang/srt/models/step3_vl.py +10 -1
  120. sglang/srt/models/transformers.py +2 -5
  121. sglang/srt/multimodal/processors/base_processor.py +4 -3
  122. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  123. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  124. sglang/srt/operations_strategy.py +1 -1
  125. sglang/srt/reasoning_parser.py +18 -39
  126. sglang/srt/server_args.py +218 -23
  127. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  128. sglang/srt/two_batch_overlap.py +163 -9
  129. sglang/srt/utils.py +41 -26
  130. sglang/srt/weight_sync/utils.py +1 -1
  131. sglang/test/runners.py +4 -4
  132. sglang/test/test_utils.py +4 -4
  133. sglang/version.py +1 -1
  134. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
  135. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
  136. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  137. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  138. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  139. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  140. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  141. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  142. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  143. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1134 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+
16
+ """Inference-only GptOss model compatible with HuggingFace weights."""
17
+
18
+ import logging
19
+ from collections.abc import Iterable
20
+ from functools import partial
21
+ from typing import Any, Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ from torch import nn
25
+ from transformers import PretrainedConfig
26
+
27
+ from sglang.srt.distributed import (
28
+ get_moe_expert_parallel_rank,
29
+ get_moe_expert_parallel_world_size,
30
+ get_moe_tensor_parallel_rank,
31
+ get_moe_tensor_parallel_world_size,
32
+ get_pp_group,
33
+ get_tensor_model_parallel_rank,
34
+ get_tensor_model_parallel_world_size,
35
+ tensor_model_parallel_all_reduce,
36
+ )
37
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
38
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
39
+ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
40
+ from sglang.srt.layers.dp_attention import (
41
+ get_attention_tp_rank,
42
+ get_attention_tp_size,
43
+ get_local_attention_dp_size,
44
+ )
45
+ from sglang.srt.layers.layernorm import RMSNorm
46
+ from sglang.srt.layers.linear import (
47
+ QKVParallelLinear,
48
+ ReplicatedLinear,
49
+ RowParallelLinear,
50
+ )
51
+ from sglang.srt.layers.logits_processor import LogitsProcessor
52
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
53
+ from sglang.srt.layers.moe.topk import TopK
54
+ from sglang.srt.layers.moe.utils import DeepEPMode
55
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
56
+ from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
57
+ from sglang.srt.layers.radix_attention import RadixAttention
58
+ from sglang.srt.layers.rotary_embedding import get_rope
59
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
60
+ from sglang.srt.layers.vocab_parallel_embedding import (
61
+ ParallelLMHead,
62
+ VocabParallelEmbedding,
63
+ )
64
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
65
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
66
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
67
+ from sglang.srt.utils import add_prefix, make_layers
68
+
69
+
70
+ class GptOssConfig(PretrainedConfig):
71
+ model_type = "gpt_oss"
72
+
73
+ def __init__(self, **kwargs):
74
+ super().__init__(**kwargs)
75
+
76
+
77
+ logger = logging.getLogger(__name__)
78
+
79
+
80
+ # Aligned with HF's implementation, using sliding window inclusive with the last token
81
+ # SGLang assumes exclusive
82
+ def get_attention_sliding_window_size(config):
83
+ return config.sliding_window - 1
84
+
85
+
86
+ class GptOssSparseMoeBlock(nn.Module):
87
+ def __init__(
88
+ self,
89
+ layer_id: int,
90
+ config: GptOssConfig,
91
+ quant_config: Optional[QuantizationConfig] = None,
92
+ prefix: str = "",
93
+ ):
94
+ super().__init__()
95
+ self.tp_size = get_tensor_model_parallel_world_size()
96
+ self.layer_id = layer_id
97
+ self.activation = config.hidden_act
98
+ self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
99
+ self.swiglu_limit = config.swiglu_limit
100
+
101
+ if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
102
+ self.topk = None
103
+ else:
104
+ self.topk = TopK(
105
+ top_k=config.num_experts_per_tok,
106
+ renormalize=True,
107
+ )
108
+
109
+ self.top_k = config.num_experts_per_tok
110
+ experts_type = get_moe_impl_class()
111
+ extra_kwargs = {}
112
+ if experts_type.__name__ == "FusedMoE":
113
+ quant_config_name = (
114
+ quant_config.get_name() if quant_config is not None else None
115
+ )
116
+ extra_kwargs = {
117
+ "enable_flashinfer_cutlass_moe": global_server_args_dict[
118
+ "enable_flashinfer_cutlass_moe"
119
+ ],
120
+ # for moe gate_up_proj and down_proj and their bias loading
121
+ "use_weight_loader_fused": quant_config_name != "mxfp4",
122
+ }
123
+ self.experts = experts_type(
124
+ num_experts=config.num_local_experts
125
+ + global_server_args_dict["ep_num_redundant_experts"],
126
+ top_k=config.num_experts_per_tok,
127
+ layer_id=layer_id,
128
+ hidden_size=config.hidden_size,
129
+ intermediate_size=config.intermediate_size,
130
+ quant_config=quant_config,
131
+ activation=self.activation,
132
+ activation_alpha=self.activation_alpha,
133
+ swiglu_limit=self.swiglu_limit,
134
+ with_bias=True,
135
+ prefix=add_prefix("experts", prefix),
136
+ **(
137
+ dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
138
+ if global_server_args_dict["moe_a2a_backend"].is_deepep()
139
+ else {}
140
+ ),
141
+ **extra_kwargs,
142
+ )
143
+
144
+ self.router = ReplicatedLinear(
145
+ config.hidden_size,
146
+ config.num_local_experts,
147
+ bias=True,
148
+ quant_config=None,
149
+ prefix=add_prefix("gate", prefix),
150
+ params_dtype=config.torch_dtype,
151
+ )
152
+
153
+ def forward(
154
+ self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
155
+ ) -> torch.Tensor:
156
+ if not global_server_args_dict["moe_a2a_backend"].is_deepep():
157
+ return self.forward_normal(hidden_states)
158
+ else:
159
+ raise Exception("forward_deepep branch not implemented yet")
160
+
161
+ def get_moe_weights(self):
162
+ return [
163
+ x.data
164
+ for name, x in self.experts.named_parameters()
165
+ if name not in ["correction_bias"]
166
+ ]
167
+
168
+ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
169
+ num_tokens, hidden_dim = hidden_states.shape
170
+ hidden_states = hidden_states.view(-1, hidden_dim)
171
+
172
+ # router_logits: (num_tokens, n_experts)
173
+ router_logits, _ = self.router(hidden_states)
174
+
175
+ kwargs = {"hidden_states": hidden_states}
176
+ if self.topk is not None:
177
+ kwargs["topk_output"] = self.topk(hidden_states, router_logits)
178
+ else:
179
+ kwargs["topk_output"] = (self.top_k, router_logits)
180
+ final_hidden_states = self.experts(**kwargs)
181
+
182
+ if self.tp_size > 1:
183
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
184
+
185
+ ans = final_hidden_states.view(num_tokens, hidden_dim)
186
+ return ans
187
+
188
+
189
+ class GptOssAttention(nn.Module):
190
+ def __init__(
191
+ self,
192
+ hidden_size: int,
193
+ num_heads: int,
194
+ num_kv_heads: int,
195
+ layer_id: int = 0,
196
+ rope_theta: float = 10000,
197
+ rope_scaling: Optional[Dict[str, Any]] = None,
198
+ max_position_embeddings: int = 8192,
199
+ head_dim: Optional[int] = None,
200
+ rms_norm_eps: float = 1e-06,
201
+ attention_bias: bool = False,
202
+ quant_config: Optional[QuantizationConfig] = None,
203
+ prefix: str = "",
204
+ sliding_window_size: int = -1, # if -1, normal attention, else, window attention.
205
+ layer_type: str = "",
206
+ params_dtype: torch.dtype = torch.bfloat16,
207
+ ) -> None:
208
+ super().__init__()
209
+ self.hidden_size = hidden_size
210
+ self.sliding_window_size = sliding_window_size
211
+
212
+ attn_tp_rank = get_attention_tp_rank()
213
+ attn_tp_size = get_attention_tp_size()
214
+
215
+ self.total_num_heads = num_heads
216
+ assert self.total_num_heads % attn_tp_size == 0
217
+ self.num_heads = self.total_num_heads // attn_tp_size
218
+ self.total_num_kv_heads = num_kv_heads
219
+ if self.total_num_kv_heads >= attn_tp_size:
220
+ # Number of KV heads is greater than TP size, so we partition
221
+ # the KV heads across multiple tensor parallel GPUs.
222
+ assert self.total_num_kv_heads % attn_tp_size == 0
223
+ else:
224
+ # Number of KV heads is less than TP size, so we replicate
225
+ # the KV heads across multiple tensor parallel GPUs.
226
+ assert attn_tp_size % self.total_num_kv_heads == 0
227
+ self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
228
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
229
+ self.q_size = self.num_heads * self.head_dim
230
+ self.kv_size = self.num_kv_heads * self.head_dim
231
+ self.scaling = self.head_dim**-0.5
232
+ self.rope_theta = rope_theta
233
+ self.max_position_embeddings = max_position_embeddings
234
+ self.tp_rank = get_tensor_model_parallel_rank()
235
+
236
+ self.qkv_proj = QKVParallelLinear(
237
+ hidden_size,
238
+ self.head_dim,
239
+ self.total_num_heads,
240
+ self.total_num_kv_heads,
241
+ bias=attention_bias,
242
+ params_dtype=params_dtype,
243
+ quant_config=quant_config,
244
+ tp_rank=attn_tp_rank,
245
+ tp_size=attn_tp_size,
246
+ prefix=add_prefix("qkv_proj", prefix),
247
+ )
248
+
249
+ self.sinks = nn.Parameter(
250
+ torch.empty(self.num_heads, dtype=params_dtype), requires_grad=False
251
+ )
252
+
253
+ self.o_proj = RowParallelLinear(
254
+ self.total_num_heads * self.head_dim,
255
+ hidden_size,
256
+ bias=attention_bias,
257
+ quant_config=quant_config,
258
+ tp_rank=attn_tp_rank,
259
+ tp_size=attn_tp_size,
260
+ reduce_results=False,
261
+ params_dtype=params_dtype,
262
+ prefix=add_prefix("o_proj", prefix),
263
+ )
264
+
265
+ self.rotary_emb = get_rope(
266
+ self.head_dim,
267
+ rotary_dim=self.head_dim,
268
+ max_position=max_position_embeddings,
269
+ base=rope_theta,
270
+ rope_scaling=rope_scaling,
271
+ )
272
+
273
+ assert layer_type in {"sliding_attention", "full_attention"}
274
+ use_sliding_window = layer_type == "sliding_attention"
275
+ self.attn = RadixAttention(
276
+ self.num_heads,
277
+ self.head_dim,
278
+ self.scaling,
279
+ num_kv_heads=self.num_kv_heads,
280
+ layer_id=layer_id,
281
+ prefix=add_prefix("attn", prefix),
282
+ sliding_window_size=(sliding_window_size if use_sliding_window else -1),
283
+ )
284
+ self.layer_id = layer_id
285
+
286
+ def forward_prepare(
287
+ self,
288
+ positions: torch.Tensor,
289
+ hidden_states: torch.Tensor,
290
+ forward_batch: ForwardBatch,
291
+ ):
292
+ if hidden_states.shape[0] == 0:
293
+ return hidden_states, forward_batch, None
294
+ qkv, _ = self.qkv_proj(hidden_states)
295
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
296
+ q, k = self.rotary_emb(positions, q, k)
297
+ inner_state = q, k, v, forward_batch
298
+ return None, forward_batch, inner_state
299
+
300
+ def forward_core(self, intermediate_state):
301
+ hidden_states, forward_batch, inner_state = intermediate_state
302
+ if inner_state is None:
303
+ return hidden_states
304
+ attn_output = self.attn(*inner_state, sinks=self.sinks.to(torch.float32))
305
+ output, _ = self.o_proj(attn_output)
306
+ return output
307
+
308
+ def forward(
309
+ self,
310
+ positions: torch.Tensor,
311
+ hidden_states: torch.Tensor,
312
+ forward_batch: ForwardBatch,
313
+ ) -> torch.Tensor:
314
+ s = self.forward_prepare(
315
+ positions=positions,
316
+ hidden_states=hidden_states,
317
+ forward_batch=forward_batch,
318
+ )
319
+ return self.forward_core(s)
320
+
321
+
322
+ class GptOssDecoderLayer(nn.Module):
323
+ def __init__(
324
+ self,
325
+ config: GptOssConfig,
326
+ layer_id: int,
327
+ quant_config: Optional[QuantizationConfig] = None,
328
+ prefix: str = "",
329
+ sliding_window_size: int | None = None,
330
+ ) -> None:
331
+ super().__init__()
332
+ self.config = config
333
+ self.hidden_size = config.hidden_size
334
+ rope_theta = getattr(config, "rope_theta", 10000)
335
+ rope_scaling = getattr(config, "rope_scaling", None)
336
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
337
+ head_dim = getattr(
338
+ config, "head_dim", config.hidden_size // config.num_attention_heads
339
+ )
340
+ rms_norm_eps = config.rms_norm_eps
341
+ attention_bias = config.attention_bias
342
+
343
+ if sliding_window_size is None:
344
+ self.sliding_window_size = get_attention_sliding_window_size(self.config)
345
+ else:
346
+ self.sliding_window_size = sliding_window_size
347
+
348
+ self.self_attn = GptOssAttention(
349
+ hidden_size=self.hidden_size,
350
+ num_heads=config.num_attention_heads,
351
+ num_kv_heads=config.num_key_value_heads,
352
+ layer_id=layer_id,
353
+ rope_theta=rope_theta,
354
+ rope_scaling=rope_scaling,
355
+ max_position_embeddings=max_position_embeddings,
356
+ head_dim=head_dim,
357
+ rms_norm_eps=rms_norm_eps,
358
+ attention_bias=attention_bias,
359
+ prefix=add_prefix("self_attn", prefix),
360
+ sliding_window_size=self.sliding_window_size,
361
+ layer_type=config.layer_types[layer_id],
362
+ params_dtype=config.torch_dtype,
363
+ )
364
+
365
+ self.layer_id = layer_id
366
+
367
+ self.attn_tp_size = get_attention_tp_size()
368
+ self.attn_tp_rank = get_attention_tp_rank()
369
+ self.local_dp_size = get_local_attention_dp_size()
370
+
371
+ # GptOss all layers are sparse and have no nextn now
372
+ self.is_layer_sparse = True
373
+ is_previous_layer_sparse = True
374
+
375
+ self.layer_scatter_modes = LayerScatterModes.init_new(
376
+ layer_id=layer_id,
377
+ num_layers=config.num_hidden_layers,
378
+ is_layer_sparse=self.is_layer_sparse,
379
+ is_previous_layer_sparse=is_previous_layer_sparse,
380
+ )
381
+
382
+ if self.is_layer_sparse:
383
+ self.mlp = GptOssSparseMoeBlock(
384
+ layer_id=self.layer_id,
385
+ config=config,
386
+ quant_config=quant_config,
387
+ prefix=add_prefix("mlp", prefix),
388
+ )
389
+ else:
390
+ raise NotImplementedError(
391
+ "Dense MLP is not implemented for GptOssDecoderLayer. "
392
+ "Please use GptOssSparseMoeBlock instead."
393
+ )
394
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
395
+ self.post_attention_layernorm = RMSNorm(
396
+ config.hidden_size, eps=config.rms_norm_eps
397
+ )
398
+
399
+ self.layer_communicator = LayerCommunicator(
400
+ layer_scatter_modes=self.layer_scatter_modes,
401
+ input_layernorm=self.input_layernorm,
402
+ post_attention_layernorm=self.post_attention_layernorm,
403
+ )
404
+
405
+ def forward(
406
+ self,
407
+ positions: torch.Tensor,
408
+ hidden_states: torch.Tensor,
409
+ forward_batch: ForwardBatch,
410
+ residual: Optional[torch.Tensor],
411
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
412
+ hidden_states, residual = self.layer_communicator.prepare_attn(
413
+ hidden_states, residual, forward_batch
414
+ )
415
+
416
+ if hidden_states.shape[0] != 0:
417
+ hidden_states = self.self_attn(
418
+ positions=positions,
419
+ hidden_states=hidden_states,
420
+ forward_batch=forward_batch,
421
+ )
422
+
423
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
424
+ hidden_states, residual, forward_batch
425
+ )
426
+
427
+ hidden_states = self.mlp(hidden_states, forward_batch)
428
+
429
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
430
+ hidden_states, residual, forward_batch
431
+ )
432
+
433
+ return hidden_states, residual
434
+
435
+
436
+ class GptOssModel(nn.Module):
437
+ def __init__(
438
+ self,
439
+ config: PretrainedConfig,
440
+ quant_config: Optional[QuantizationConfig] = None,
441
+ prefix: str = "",
442
+ decoder_layer_type: type[nn.Module] = GptOssDecoderLayer,
443
+ ) -> None:
444
+ super().__init__()
445
+ self.padding_idx = config.pad_token_id
446
+ self.vocab_size = config.vocab_size
447
+ self.pp_group = get_pp_group()
448
+
449
+ if self.pp_group.is_first_rank:
450
+ self.embed_tokens = VocabParallelEmbedding(
451
+ config.vocab_size,
452
+ config.hidden_size,
453
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
454
+ prefix=add_prefix("embed_tokens", prefix),
455
+ )
456
+ else:
457
+ self.embed_tokens = PPMissingLayer()
458
+
459
+ # Use the provided decoder layer type or default to GptOssDecoderLayer
460
+ decoder_layer_type = decoder_layer_type or GptOssDecoderLayer
461
+ self.layers, self.start_layer, self.end_layer = make_layers(
462
+ config.num_hidden_layers,
463
+ lambda idx, prefix: decoder_layer_type(
464
+ layer_id=idx,
465
+ config=config,
466
+ quant_config=quant_config,
467
+ prefix=prefix,
468
+ ),
469
+ pp_rank=self.pp_group.rank_in_group,
470
+ pp_size=self.pp_group.world_size,
471
+ prefix=add_prefix("layers", prefix),
472
+ )
473
+ if self.pp_group.is_last_rank:
474
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
475
+ else:
476
+ self.norm = PPMissingLayer(return_tuple=True)
477
+
478
+ self.layers_to_capture = []
479
+
480
+ def forward(
481
+ self,
482
+ input_ids: torch.Tensor,
483
+ positions: torch.Tensor,
484
+ forward_batch: ForwardBatch,
485
+ input_embeds: torch.Tensor = None,
486
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
487
+ ) -> Union[torch.Tensor, PPProxyTensors]:
488
+ if self.pp_group.is_first_rank:
489
+ if input_embeds is None:
490
+ hidden_states = self.embed_tokens(input_ids)
491
+ else:
492
+ hidden_states = input_embeds
493
+ residual = None
494
+ else:
495
+ assert pp_proxy_tensors is not None
496
+ hidden_states = pp_proxy_tensors["hidden_states"]
497
+ residual = pp_proxy_tensors["residual"]
498
+
499
+ aux_hidden_states = []
500
+ for i in range(self.start_layer, self.end_layer):
501
+ with get_global_expert_distribution_recorder().with_current_layer(i):
502
+ if i in self.layers_to_capture:
503
+ aux_hidden_states.append(hidden_states + residual)
504
+ layer = self.layers[i]
505
+ hidden_states, residual = layer(
506
+ positions, hidden_states, forward_batch, residual
507
+ )
508
+ if not self.pp_group.is_last_rank:
509
+ return PPProxyTensors(
510
+ {
511
+ "hidden_states": hidden_states,
512
+ "residual": residual,
513
+ }
514
+ )
515
+ else:
516
+ if hidden_states.shape[0] != 0:
517
+ if residual is None:
518
+ hidden_states = self.norm(hidden_states)
519
+ else:
520
+ hidden_states, _ = self.norm(hidden_states, residual)
521
+ if len(aux_hidden_states) == 0:
522
+ return hidden_states
523
+
524
+ return hidden_states, aux_hidden_states
525
+
526
+
527
+ class GptOssForCausalLM(nn.Module):
528
+ fall_back_to_pt_during_load = False
529
+
530
+ def __init__(
531
+ self,
532
+ config: GptOssConfig,
533
+ quant_config: Optional[QuantizationConfig] = None,
534
+ prefix: str = "",
535
+ ) -> None:
536
+ super().__init__()
537
+ self.pp_group = get_pp_group()
538
+ self.config = config
539
+ self.quant_config = quant_config
540
+ self.model = GptOssModel(
541
+ config, quant_config, prefix=add_prefix("model", prefix)
542
+ )
543
+ self.lm_head = ParallelLMHead(
544
+ config.vocab_size,
545
+ config.hidden_size,
546
+ # quant_config=quant_config,
547
+ prefix=add_prefix("lm_head", prefix),
548
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
549
+ )
550
+ self.logits_processor = LogitsProcessor(config)
551
+ self.capture_aux_hidden_states = False
552
+
553
+ @torch.no_grad()
554
+ def forward(
555
+ self,
556
+ input_ids: torch.Tensor,
557
+ positions: torch.Tensor,
558
+ forward_batch: ForwardBatch,
559
+ input_embeds: torch.Tensor = None,
560
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
561
+ ) -> torch.Tensor:
562
+ hidden_states = self.model(
563
+ input_ids,
564
+ positions,
565
+ forward_batch,
566
+ input_embeds,
567
+ pp_proxy_tensors=pp_proxy_tensors,
568
+ )
569
+
570
+ aux_hidden_states = None
571
+ if self.capture_aux_hidden_states:
572
+ hidden_states, aux_hidden_states = hidden_states
573
+
574
+ if self.pp_group.is_last_rank:
575
+ return self.logits_processor(
576
+ input_ids,
577
+ hidden_states,
578
+ self.lm_head,
579
+ forward_batch,
580
+ aux_hidden_states,
581
+ )
582
+ else:
583
+ return hidden_states
584
+
585
+ @property
586
+ def start_layer(self):
587
+ return self.model.start_layer
588
+
589
+ @property
590
+ def end_layer(self):
591
+ return self.model.end_layer
592
+
593
+ def _get_default_weight_mapping(self):
594
+ """Generate default weight name mapping for GptOss safetensors."""
595
+ weight_mapping = {}
596
+
597
+ # Map router weights to gate
598
+ weight_mapping["embedding.weight"] = "model.embed_tokens.weight"
599
+ weight_mapping["unembedding.weight"] = "lm_head.weight"
600
+ weight_mapping["norm.scale"] = "model.norm.weight"
601
+ for layer_id in range(self.config.num_hidden_layers):
602
+ weight_mapping[f"block.{layer_id}.attn.q_proj.weight"] = (
603
+ f"model.layers.{layer_id}.self_attn.q_proj.weight"
604
+ )
605
+ weight_mapping[f"block.{layer_id}.attn.q_proj.bias"] = (
606
+ f"model.layers.{layer_id}.self_attn.q_proj.bias"
607
+ )
608
+
609
+ weight_mapping[f"block.{layer_id}.attn.k_proj.weight"] = (
610
+ f"model.layers.{layer_id}.self_attn.k_proj.weight"
611
+ )
612
+ weight_mapping[f"block.{layer_id}.attn.k_proj.bias"] = (
613
+ f"model.layers.{layer_id}.self_attn.k_proj.bias"
614
+ )
615
+
616
+ weight_mapping[f"block.{layer_id}.attn.v_proj.weight"] = (
617
+ f"model.layers.{layer_id}.self_attn.v_proj.weight"
618
+ )
619
+ weight_mapping[f"block.{layer_id}.attn.v_proj.bias"] = (
620
+ f"model.layers.{layer_id}.self_attn.v_proj.bias"
621
+ )
622
+
623
+ weight_mapping[f"block.{layer_id}.attn.out.weight"] = (
624
+ f"model.layers.{layer_id}.self_attn.o_proj.weight"
625
+ )
626
+ weight_mapping[f"block.{layer_id}.attn.out.bias"] = (
627
+ f"model.layers.{layer_id}.self_attn.o_proj.bias"
628
+ )
629
+ weight_mapping[f"block.{layer_id}.attn.sinks"] = (
630
+ f"model.layers.{layer_id}.self_attn.sinks"
631
+ )
632
+ weight_mapping[f"block.{layer_id}.attn.norm.scale"] = (
633
+ f"model.layers.{layer_id}.input_layernorm.weight"
634
+ )
635
+
636
+ weight_mapping[f"block.{layer_id}.mlp.gate.weight"] = (
637
+ f"model.layers.{layer_id}.mlp.router.weight"
638
+ )
639
+ weight_mapping[f"block.{layer_id}.mlp.gate.bias"] = (
640
+ f"model.layers.{layer_id}.mlp.router.bias"
641
+ )
642
+ weight_mapping[f"block.{layer_id}.mlp.norm.scale"] = (
643
+ f"model.layers.{layer_id}.post_attention_layernorm.weight"
644
+ )
645
+ weight_mapping[f"block.{layer_id}.mlp.experts.gate_up_proj"] = (
646
+ f"model.layers.{layer_id}.mlp.experts.gate_up_proj"
647
+ )
648
+ weight_mapping[f"block.{layer_id}.mlp.gate_up_proj_bias"] = (
649
+ f"model.layers.{layer_id}.mlp.experts.gate_up_proj_bias"
650
+ )
651
+ weight_mapping[f"block.{layer_id}.mlp.down_proj"] = (
652
+ f"model.layers.{layer_id}.mlp.experts.mlp2_weight"
653
+ )
654
+ weight_mapping[f"block.{layer_id}.mlp.down_proj_bias"] = (
655
+ f"model.layers.{layer_id}.mlp.experts.mlp2_bias"
656
+ )
657
+
658
+ return weight_mapping
659
+
660
+ # TODO beautify code
661
+ def load_weights(
662
+ self,
663
+ weights: Iterable[Tuple[str, torch.Tensor]],
664
+ is_nextn: bool = False,
665
+ weight_name_mapping: dict = None,
666
+ ):
667
+ quant_config_name = (
668
+ self.quant_config.get_name() if self.quant_config is not None else None
669
+ )
670
+ if quant_config_name != "mxfp4":
671
+ self._load_normal_weights(
672
+ weights, is_nextn=is_nextn, weight_name_mapping=weight_name_mapping
673
+ )
674
+ else:
675
+ self._load_weights_mxfp4(
676
+ weights, is_nextn=is_nextn, weight_name_mapping=weight_name_mapping
677
+ )
678
+
679
+ def _load_weights_mxfp4(self, weights, is_nextn, weight_name_mapping):
680
+ mxfp4_weights = []
681
+ normal_weights = []
682
+
683
+ for name, weight in weights:
684
+ if (
685
+ ".experts" in name
686
+ and self.quant_config is not None
687
+ and self.quant_config.get_name() == "mxfp4"
688
+ ):
689
+ mxfp4_weights.append((name, weight))
690
+ else:
691
+ normal_weights.append((name, weight))
692
+
693
+ mxfp4_loaded_params = self._load_mxfp4_experts_weights(mxfp4_weights)
694
+ self._load_normal_weights(
695
+ normal_weights,
696
+ is_nextn=is_nextn,
697
+ weight_name_mapping=weight_name_mapping,
698
+ other_loaded_param_names=mxfp4_loaded_params,
699
+ )
700
+
701
+ def _load_mxfp4_experts_weights(self, weights):
702
+
703
+ params_dict = dict(self.named_parameters())
704
+ loaded_params: set[str] = set()
705
+ mxfp4_block = 32
706
+
707
+ moe_tp_rank = get_moe_tensor_parallel_rank()
708
+ moe_tp_size = get_moe_tensor_parallel_world_size()
709
+ moe_ep_rank = get_moe_expert_parallel_rank()
710
+ moe_ep_size = get_moe_expert_parallel_world_size()
711
+
712
+ intermediate_size = self.config.intermediate_size
713
+ intermediate_size_block = intermediate_size // mxfp4_block
714
+ per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
715
+ per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
716
+
717
+ # Calculate common slicing bounds for current rank
718
+ assert self.config.num_local_experts % moe_ep_size == 0
719
+ moe_num_global_experts = self.config.num_local_experts
720
+ moe_num_local_experts = self.config.num_local_experts // moe_ep_size
721
+ moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size
722
+ moe_tp_rank_end = min(
723
+ (moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size
724
+ )
725
+ moe_ep_rank_start = moe_ep_rank * moe_num_local_experts
726
+ moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts
727
+
728
+ for name, weight in weights:
729
+ weight = weight.cuda()
730
+
731
+ if "gate_up_proj_blocks" in name:
732
+ # Handle MLP gate and up projection weights
733
+ new_name = name.replace("gate_up_proj_blocks", "w13_weight")
734
+
735
+ # flat weight from (E, 2 * N, block_size, entry_per_block)
736
+ # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
737
+ weight = weight.view(
738
+ moe_num_global_experts, 2 * intermediate_size, -1
739
+ ).contiguous()
740
+
741
+ narrow_weight = weight[
742
+ moe_ep_rank_start:moe_ep_rank_end,
743
+ 2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
744
+ ...,
745
+ ]
746
+
747
+ param = params_dict[new_name]
748
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
749
+ weight_loader(
750
+ param,
751
+ narrow_weight,
752
+ weight_name=new_name,
753
+ shard_id=None,
754
+ expert_id=None,
755
+ )
756
+ loaded_params.add(new_name)
757
+
758
+ elif "down_proj_blocks" in name:
759
+ # Handle MLP down projection weights
760
+ new_name = name.replace("down_proj_blocks", "w2_weight")
761
+ # same flatten here, but since 2 mx4 value are packed in 1
762
+ # uint8, divide by 2
763
+ weight = weight.view(
764
+ moe_num_global_experts, -1, intermediate_size // 2
765
+ ).contiguous()
766
+ narrow_weight = weight[
767
+ moe_ep_rank_start:moe_ep_rank_end,
768
+ ...,
769
+ moe_tp_rank_start // 2 : moe_tp_rank_end // 2,
770
+ ]
771
+
772
+ param = params_dict[new_name]
773
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
774
+ weight_loader(
775
+ param,
776
+ narrow_weight,
777
+ weight_name=new_name,
778
+ shard_id=None,
779
+ expert_id=None,
780
+ )
781
+ loaded_params.add(new_name)
782
+
783
+ elif "gate_up_proj_scales" in name:
784
+ # Handle MLP gate and up projection weights scale
785
+ new_name = name.replace("gate_up_proj_scales", "w13_weight_scale")
786
+ narrow_weight = weight[
787
+ moe_ep_rank_start:moe_ep_rank_end,
788
+ 2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
789
+ ...,
790
+ ]
791
+
792
+ param = params_dict[new_name]
793
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
794
+ weight_loader(
795
+ param,
796
+ narrow_weight,
797
+ weight_name=new_name,
798
+ shard_id=None,
799
+ expert_id=None,
800
+ )
801
+ loaded_params.add(new_name)
802
+
803
+ elif "down_proj_scales" in name:
804
+ # Handle MLP down projection weights
805
+ new_name = name.replace("down_proj_scales", "w2_weight_scale")
806
+ narrow_weight = weight[
807
+ moe_ep_rank_start:moe_ep_rank_end,
808
+ ...,
809
+ moe_tp_rank_start // mxfp4_block : moe_tp_rank_end // mxfp4_block,
810
+ ]
811
+
812
+ param = params_dict[new_name]
813
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
814
+ weight_loader(
815
+ param,
816
+ narrow_weight,
817
+ weight_name=new_name,
818
+ shard_id=None,
819
+ expert_id=None,
820
+ )
821
+ loaded_params.add(new_name)
822
+ elif "gate_up_proj_bias" in name:
823
+ # Handle MLP gate and up projection biases
824
+ new_name = name.replace("gate_up_proj_bias", "w13_weight_bias")
825
+
826
+ narrow_weight = weight[
827
+ moe_ep_rank_start:moe_ep_rank_end,
828
+ 2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
829
+ ]
830
+
831
+ param = params_dict[new_name]
832
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
833
+ weight_loader(
834
+ param,
835
+ narrow_weight,
836
+ weight_name=new_name,
837
+ shard_id=None,
838
+ expert_id=None,
839
+ )
840
+ loaded_params.add(new_name)
841
+
842
+ elif "down_proj_bias" in name:
843
+ narrow_weight = weight[moe_ep_rank_start:moe_ep_rank_end, ...]
844
+ if moe_tp_rank != 0:
845
+ narrow_weight = torch.zeros_like(narrow_weight)
846
+
847
+ # Handle MLP down projection bias
848
+ new_name = name.replace("down_proj_bias", "w2_weight_bias")
849
+ param = params_dict[new_name]
850
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
851
+ weight_loader(
852
+ param,
853
+ narrow_weight,
854
+ weight_name=new_name,
855
+ shard_id=None,
856
+ expert_id=None,
857
+ )
858
+ loaded_params.add(new_name)
859
+
860
+ return loaded_params
861
+
862
+ def _load_normal_weights(
863
+ self,
864
+ weights,
865
+ is_nextn: bool,
866
+ weight_name_mapping: dict,
867
+ other_loaded_param_names=[],
868
+ ):
869
+ tp_rank = get_tensor_model_parallel_rank()
870
+ if is_nextn:
871
+ logging.warning(
872
+ "Loading weights for nextn is currently not supported in GptOssForCausalLM. "
873
+ )
874
+ return
875
+ weights = _canonicalize_weights(self.config, weights)
876
+ weights = sorted(weights, key=lambda x: x[0]) # Sort by name for consistency
877
+
878
+ new_weights = []
879
+ for name, p in weights:
880
+ if "qkv.weight" in name:
881
+ q_proj, k_proj, v_proj = p.split(
882
+ [
883
+ self.config.num_attention_heads * self.config.head_dim,
884
+ self.config.num_key_value_heads * self.config.head_dim,
885
+ self.config.num_key_value_heads * self.config.head_dim,
886
+ ],
887
+ dim=0,
888
+ )
889
+ new_weights.append(
890
+ (f"{name.replace('qkv.weight', 'q_proj.weight')}", q_proj)
891
+ )
892
+ new_weights.append(
893
+ (f"{name.replace('qkv.weight', 'k_proj.weight')}", k_proj)
894
+ )
895
+ new_weights.append(
896
+ (f"{name.replace('qkv.weight', 'v_proj.weight')}", v_proj)
897
+ )
898
+ elif "qkv.bias" in name:
899
+ q_bias, k_bias, v_bias = p.split(
900
+ [
901
+ self.config.num_attention_heads * self.config.head_dim,
902
+ self.config.num_key_value_heads * self.config.head_dim,
903
+ self.config.num_key_value_heads * self.config.head_dim,
904
+ ],
905
+ dim=0,
906
+ )
907
+ new_weights.append(
908
+ (f"{name.replace('qkv.bias', 'q_proj.bias')}", q_bias)
909
+ )
910
+ new_weights.append(
911
+ (f"{name.replace('qkv.bias', 'k_proj.bias')}", k_bias)
912
+ )
913
+ new_weights.append(
914
+ (f"{name.replace('qkv.bias', 'v_proj.bias')}", v_bias)
915
+ )
916
+ else:
917
+ new_weights.append((name, p))
918
+ weights = new_weights
919
+
920
+ # Use provided weight name mapping if available, otherwise use default
921
+ if weight_name_mapping is None:
922
+ weight_name_mapping = self._get_default_weight_mapping()
923
+ else:
924
+ # Merge with default mapping
925
+ default_mapping = self._get_default_weight_mapping()
926
+ default_mapping.update(weight_name_mapping)
927
+ weight_name_mapping = default_mapping
928
+
929
+ stacked_params_mapping = [
930
+ # (param_name, shard_name, shard_id)
931
+ ("qkv_proj", "q_proj", "q"),
932
+ ("qkv_proj", "k_proj", "k"),
933
+ ("qkv_proj", "v_proj", "v"),
934
+ ]
935
+ expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
936
+ ckpt_gate_up_proj_name="gate_up_proj",
937
+ ckpt_down_proj_name="down_proj",
938
+ ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
939
+ ckpt_down_proj_bias_name="down_proj_bias",
940
+ )
941
+
942
+ params_dict = dict(self.named_parameters())
943
+ params_checker = {k: False for k, v in params_dict.items()}
944
+
945
+ for other_loaded_param_name in other_loaded_param_names:
946
+ params_checker[other_loaded_param_name] = True
947
+
948
+ for name, loaded_weight in weights:
949
+ loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
950
+
951
+ # Apply weight name mapping if provided
952
+ if weight_name_mapping and name in weight_name_mapping:
953
+ name = weight_name_mapping[name]
954
+
955
+ layer_id = get_layer_id(name)
956
+ if (
957
+ layer_id is not None
958
+ and hasattr(self.model, "start_layer")
959
+ and (
960
+ layer_id < self.model.start_layer
961
+ or layer_id >= self.model.end_layer
962
+ )
963
+ ):
964
+ continue
965
+
966
+ if "rotary_emb.inv_freq" in name:
967
+ continue
968
+ for param_name, weight_name, shard_id in stacked_params_mapping:
969
+ if weight_name not in name:
970
+ continue
971
+ if "mlp.experts" in name:
972
+ continue
973
+
974
+ name = name.replace(weight_name, param_name)
975
+ if name.endswith(".bias") and name not in params_dict:
976
+ continue
977
+ if name not in params_dict:
978
+ continue
979
+
980
+ param = params_dict[name]
981
+ weight_loader = param.weight_loader
982
+ weight_loader(param, loaded_weight, shard_id)
983
+ params_checker[name] = True
984
+ break
985
+ else:
986
+ for mapping in expert_params_mapping:
987
+ param_name, weight_name, shard_id = mapping
988
+ if weight_name not in name:
989
+ continue
990
+ name = name.replace(weight_name, param_name)
991
+ if name not in params_dict:
992
+ continue
993
+ param = params_dict[name]
994
+ weight_loader = param.weight_loader
995
+ if "bias" not in name:
996
+ loaded_weight = loaded_weight.transpose(-2, -1)
997
+ if "w2_weight_bias" in name and get_moe_tensor_parallel_rank() != 0:
998
+ loaded_weight = loaded_weight.zero_()
999
+
1000
+ weight_loader(
1001
+ param,
1002
+ loaded_weight,
1003
+ name,
1004
+ shard_id=shard_id,
1005
+ )
1006
+ params_checker[name] = True
1007
+ break
1008
+ else:
1009
+ if name.endswith(".bias") and name not in params_dict:
1010
+ continue
1011
+ if name not in params_dict:
1012
+ continue
1013
+ if name in params_dict.keys():
1014
+ param = params_dict[name]
1015
+ if "sinks" in name:
1016
+ start = tp_rank * param.numel()
1017
+ param.data.copy_(
1018
+ loaded_weight[start : start + param.numel()]
1019
+ )
1020
+ else:
1021
+ weight_loader = getattr(
1022
+ param, "weight_loader", default_weight_loader
1023
+ )
1024
+ weight_loader(param, loaded_weight)
1025
+ params_checker[name] = True
1026
+ else:
1027
+ logger.warning(f"Parameter {name} not found in params_dict")
1028
+
1029
+ not_loaded_params = [k for k, v in params_checker.items() if not v]
1030
+ if tp_rank == 0:
1031
+ if len(not_loaded_params) > 0:
1032
+ raise Exception(f"Not all parameters loaded: {not_loaded_params}")
1033
+ else:
1034
+ logging.info("All parameters loaded successfully.")
1035
+
1036
+ self.routed_experts_weights_of_layer = {
1037
+ layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
1038
+ for layer_id in range(self.start_layer, self.end_layer)
1039
+ if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock)
1040
+ }
1041
+
1042
+ def get_embed_and_head(self):
1043
+ return self.model.embed_tokens.weight, self.lm_head.weight
1044
+
1045
+ def set_embed_and_head(self, embed, head):
1046
+ del self.model.embed_tokens.weight
1047
+ del self.lm_head.weight
1048
+ self.model.embed_tokens.weight = embed
1049
+ self.lm_head.weight = head
1050
+ torch.cuda.empty_cache()
1051
+ torch.cuda.synchronize()
1052
+
1053
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
1054
+ if not self.pp_group.is_last_rank:
1055
+ return
1056
+
1057
+ if layer_ids is None:
1058
+ self.capture_aux_hidden_states = True
1059
+ num_layers = self.config.num_hidden_layers
1060
+ self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
1061
+ else:
1062
+ self.capture_aux_hidden_states = True
1063
+ # we plus 1 here because in sglang, for the ith layer, it takes the output
1064
+ # of the (i-1)th layer as aux hidden state
1065
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
1066
+
1067
+ @classmethod
1068
+ def get_model_config_for_expert_location(cls, config):
1069
+ return ModelConfigForExpertLocation(
1070
+ num_layers=config.num_hidden_layers,
1071
+ num_logical_experts=config.num_local_experts,
1072
+ num_groups=None,
1073
+ )
1074
+
1075
+ def get_attention_sliding_window_size(self):
1076
+ return get_attention_sliding_window_size(self.config)
1077
+
1078
+
1079
+ def _canonicalize_weights(config, weights_in: Iterable[Tuple[str, torch.Tensor]]):
1080
+ weights_out_dict = dict(weights_in)
1081
+
1082
+ for layer_id in range(config.num_hidden_layers):
1083
+ for name_chunk in ["mlp1_weight", "mlp2_weight"]:
1084
+ name_prefix = f"block.{layer_id}.mlp.{name_chunk}"
1085
+ w_blocks = weights_out_dict.pop(f"{name_prefix}.blocks", None)
1086
+ w_scales = weights_out_dict.pop(f"{name_prefix}.scales", None)
1087
+ if w_blocks is not None:
1088
+ weights_out_dict[name_prefix] = _WeightCreator(
1089
+ partial(
1090
+ _dequant_mlp_weight,
1091
+ debug_name=name_prefix,
1092
+ w_blocks=w_blocks,
1093
+ w_scales=w_scales,
1094
+ )
1095
+ )
1096
+
1097
+ return list(weights_out_dict.items())
1098
+
1099
+
1100
+ def _dequant_mlp_weight(debug_name, w_blocks, w_scales):
1101
+ if get_tensor_model_parallel_rank() == 0:
1102
+ logger.info(f"Dequantize {debug_name} start")
1103
+
1104
+ original_device = w_blocks.device
1105
+
1106
+ w_blocks = w_blocks.cuda()
1107
+ w_scales = w_scales.cuda()
1108
+
1109
+ w_bf16 = dequant_mxfp4(w_block=w_blocks, w_scale=w_scales, out_dtype=torch.bfloat16)
1110
+ w_bf16 = w_bf16.transpose(-2, -1).contiguous()
1111
+
1112
+ if get_tensor_model_parallel_rank() == 0:
1113
+ logger.info(
1114
+ f"Dequantize {debug_name} end {w_blocks.shape=} {w_scales.shape=} {w_bf16.shape=}"
1115
+ )
1116
+
1117
+ return w_bf16.to(original_device)
1118
+
1119
+
1120
+ class _WeightCreator:
1121
+ def __init__(self, fn):
1122
+ self._fn = fn
1123
+
1124
+ @staticmethod
1125
+ def maybe_materialize(obj):
1126
+ if isinstance(obj, _WeightCreator):
1127
+ output = obj._fn()
1128
+ obj._fn = None
1129
+ return output
1130
+
1131
+ return obj
1132
+
1133
+
1134
+ EntryClass = GptOssForCausalLM