sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,425 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/bailing_moe.py
3
+
4
+ from collections.abc import Iterable
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from transformers.configuration_utils import PretrainedConfig
11
+
12
+ from sglang.srt.distributed import (
13
+ get_tensor_model_parallel_world_size,
14
+ tensor_model_parallel_all_reduce,
15
+ )
16
+ from sglang.srt.layers.activation import SiluAndMul
17
+ from sglang.srt.layers.layernorm import RMSNorm
18
+ from sglang.srt.layers.linear import (
19
+ MergedColumnParallelLinear,
20
+ QKVParallelLinear,
21
+ ReplicatedLinear,
22
+ RowParallelLinear,
23
+ )
24
+ from sglang.srt.layers.logits_processor import LogitsProcessor
25
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
26
+ from sglang.srt.layers.moe.topk import TopK
27
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
+ from sglang.srt.layers.radix_attention import RadixAttention
29
+ from sglang.srt.layers.rotary_embedding import get_rope
30
+ from sglang.srt.layers.vocab_parallel_embedding import (
31
+ ParallelLMHead,
32
+ VocabParallelEmbedding,
33
+ )
34
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
35
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
36
+ from sglang.srt.utils import add_prefix, make_layers
37
+
38
+
39
+ class BailingAttention(nn.Module):
40
+
41
+ def __init__(
42
+ self,
43
+ config: PretrainedConfig,
44
+ layer_id: int = 0,
45
+ quant_config: Optional[QuantizationConfig] = None,
46
+ prefix: str = "",
47
+ ):
48
+ super().__init__()
49
+ self.hidden_size = config.hidden_size
50
+ tp_size = get_tensor_model_parallel_world_size()
51
+
52
+ self.total_num_heads = config.num_attention_heads
53
+ self.total_num_kv_heads = config.num_key_value_heads
54
+
55
+ assert self.total_num_heads % tp_size == 0
56
+ assert self.total_num_kv_heads % tp_size == 0
57
+
58
+ self.num_heads = self.total_num_heads // tp_size
59
+ self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
60
+ self.q_size = self.num_heads * self.head_dim
61
+
62
+ self.num_kv_heads = self.total_num_kv_heads // tp_size
63
+ self.kv_size = self.num_kv_heads * self.head_dim
64
+ self.scale = self.head_dim**-0.5
65
+
66
+ self.query_key_value = QKVParallelLinear(
67
+ self.hidden_size,
68
+ self.head_dim,
69
+ self.total_num_heads,
70
+ self.total_num_kv_heads,
71
+ bias=(config.use_bias or config.use_qkv_bias),
72
+ quant_config=quant_config,
73
+ prefix=add_prefix("query_key_value", prefix),
74
+ )
75
+
76
+ self.dense = RowParallelLinear(
77
+ self.total_num_heads * self.head_dim,
78
+ self.hidden_size,
79
+ bias=config.use_bias,
80
+ quant_config=quant_config,
81
+ prefix=add_prefix("dense", prefix),
82
+ )
83
+
84
+ self.attn = RadixAttention(
85
+ self.num_heads,
86
+ self.head_dim,
87
+ self.scale,
88
+ num_kv_heads=self.num_kv_heads,
89
+ layer_id=layer_id,
90
+ quant_config=quant_config,
91
+ prefix=add_prefix("attn", prefix),
92
+ )
93
+
94
+ self.rotary_emb = get_rope(
95
+ self.head_dim,
96
+ rotary_dim=self.head_dim,
97
+ max_position=config.max_position_embeddings,
98
+ base=config.rope_theta,
99
+ is_neox_style=True,
100
+ rope_scaling=config.rope_scaling,
101
+ )
102
+
103
+ def forward(
104
+ self,
105
+ hidden_states: torch.Tensor,
106
+ position_ids: torch.Tensor,
107
+ forward_batch: ForwardBatch,
108
+ ) -> torch.Tensor:
109
+ qkv, _ = self.query_key_value(hidden_states)
110
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
111
+
112
+ q, k = self.rotary_emb(position_ids, q, k)
113
+ context_layer = self.attn(q, k, v, forward_batch)
114
+ attn_output, _ = self.dense(context_layer)
115
+ return attn_output
116
+
117
+
118
+ class BailingMLP(nn.Module):
119
+ def __init__(
120
+ self,
121
+ intermediate_size: int,
122
+ config: PretrainedConfig,
123
+ quant_config: Optional[QuantizationConfig] = None,
124
+ reduce_results: Optional[bool] = True,
125
+ prefix: str = "",
126
+ ) -> None:
127
+ super().__init__()
128
+ self.gate_up_proj = MergedColumnParallelLinear(
129
+ config.hidden_size,
130
+ [intermediate_size] * 2,
131
+ bias=config.use_bias,
132
+ quant_config=quant_config,
133
+ prefix=add_prefix("gate_up_proj", prefix),
134
+ )
135
+ self.down_proj = RowParallelLinear(
136
+ intermediate_size,
137
+ config.hidden_size,
138
+ bias=config.use_bias,
139
+ quant_config=quant_config,
140
+ reduce_results=reduce_results,
141
+ prefix=add_prefix("down_proj", prefix),
142
+ )
143
+ self.act_fn = SiluAndMul()
144
+
145
+ def forward(self, x):
146
+ x, _ = self.gate_up_proj(x)
147
+ x = self.act_fn(x)
148
+ x, _ = self.down_proj(x)
149
+ return x
150
+
151
+
152
+ class BailingMoE(nn.Module):
153
+
154
+ def __init__(
155
+ self,
156
+ config: PretrainedConfig,
157
+ layer_id: int,
158
+ quant_config: Optional[QuantizationConfig] = None,
159
+ prefix: str = "",
160
+ ):
161
+ super().__init__()
162
+ self.tp_size = get_tensor_model_parallel_world_size()
163
+ self.num_experts = config.num_experts
164
+ self.top_k = config.num_experts_per_tok
165
+ self.hidden_size = config.hidden_size
166
+ self.num_shared_experts = config.num_shared_experts
167
+ self.norm_expert_prob = config.norm_topk_prob
168
+ self.moe_intermediate_size = config.moe_intermediate_size
169
+
170
+ self.gate = ReplicatedLinear(
171
+ self.hidden_size, self.num_experts, bias=False, quant_config=None
172
+ )
173
+
174
+ self.topk = TopK(top_k=self.top_k, renormalize=self.norm_expert_prob)
175
+
176
+ self.experts = FusedMoE(
177
+ num_experts=self.num_experts,
178
+ top_k=self.top_k,
179
+ layer_id=layer_id,
180
+ hidden_size=self.hidden_size,
181
+ intermediate_size=self.moe_intermediate_size,
182
+ reduce_results=False,
183
+ quant_config=quant_config,
184
+ prefix=add_prefix("experts", prefix),
185
+ )
186
+
187
+ if self.num_shared_experts > 0:
188
+ shared_intermediate_size = (
189
+ self.moe_intermediate_size * self.num_shared_experts
190
+ )
191
+ self.shared_experts = BailingMLP(
192
+ intermediate_size=shared_intermediate_size,
193
+ config=config,
194
+ quant_config=quant_config,
195
+ reduce_results=False,
196
+ prefix=add_prefix("shared_experts", prefix),
197
+ )
198
+ else:
199
+ self.shared_experts = None
200
+
201
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
202
+ orig_shape = hidden_states.shape
203
+ hidden_states_flat = hidden_states.view(-1, self.hidden_size)
204
+
205
+ shared_output = None
206
+ if self.shared_experts is not None:
207
+ shared_output = self.shared_experts(hidden_states_flat)
208
+
209
+ router_logits, _ = self.gate(hidden_states_flat)
210
+ topk_output = self.topk(hidden_states_flat, router_logits)
211
+ final_hidden_states = self.experts(hidden_states_flat, topk_output)
212
+
213
+ if shared_output is not None:
214
+ final_hidden_states = final_hidden_states + shared_output
215
+
216
+ if self.tp_size > 1:
217
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
218
+
219
+ return final_hidden_states.view(orig_shape)
220
+
221
+
222
+ class BailingMoeBlock(nn.Module):
223
+
224
+ def __init__(
225
+ self,
226
+ config: PretrainedConfig,
227
+ layer_id: int,
228
+ quant_config: Optional[QuantizationConfig] = None,
229
+ prefix: str = "",
230
+ ):
231
+ super().__init__()
232
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
233
+ self.attention = BailingAttention(
234
+ config, layer_id, quant_config, prefix=add_prefix("attention", prefix)
235
+ )
236
+ self.post_attention_layernorm = RMSNorm(
237
+ config.hidden_size, eps=config.rms_norm_eps
238
+ )
239
+ self.mlp = BailingMoE(
240
+ config=config,
241
+ layer_id=layer_id,
242
+ quant_config=quant_config,
243
+ prefix=add_prefix("mlp", prefix),
244
+ )
245
+
246
+ def forward(
247
+ self,
248
+ hidden_states: torch.Tensor,
249
+ position_ids: torch.Tensor,
250
+ residual: Optional[torch.Tensor],
251
+ forward_batch: ForwardBatch,
252
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
253
+ # Pre-normalization and residual connection for the attention block
254
+ if residual is None:
255
+ residual = hidden_states
256
+ normed_hidden_states = self.input_layernorm(hidden_states)
257
+ else:
258
+ normed_hidden_states, residual = self.input_layernorm(
259
+ hidden_states, residual
260
+ )
261
+
262
+ attn_output = self.attention(
263
+ hidden_states=normed_hidden_states,
264
+ position_ids=position_ids,
265
+ forward_batch=forward_batch,
266
+ )
267
+
268
+ # Pre-normalization and residual connection for the MLP block
269
+ normed_hidden_states, residual = self.post_attention_layernorm(
270
+ attn_output, residual
271
+ )
272
+ mlp_output = self.mlp(normed_hidden_states)
273
+
274
+ return mlp_output, residual
275
+
276
+
277
+ class BailingMoeModel(nn.Module):
278
+
279
+ def __init__(
280
+ self,
281
+ config: PretrainedConfig,
282
+ quant_config: Optional[QuantizationConfig] = None,
283
+ prefix: str = "",
284
+ ):
285
+ super().__init__()
286
+ self.config = config
287
+ self.padding_idx = config.pad_token_id
288
+ self.vocab_size = config.vocab_size
289
+ self.embed_dim = config.hidden_size
290
+
291
+ self.embed_tokens = VocabParallelEmbedding(
292
+ config.vocab_size,
293
+ config.hidden_size,
294
+ prefix=add_prefix("embed_tokens", prefix),
295
+ )
296
+ self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout)
297
+
298
+ self.layers = make_layers(
299
+ config.num_hidden_layers,
300
+ lambda idx, prefix: BailingMoeBlock(
301
+ config=config,
302
+ layer_id=idx,
303
+ quant_config=quant_config,
304
+ prefix=prefix,
305
+ ),
306
+ prefix=add_prefix("layers", prefix),
307
+ )
308
+
309
+ self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
310
+
311
+ def forward(
312
+ self,
313
+ input_ids: torch.Tensor,
314
+ position_ids: torch.Tensor,
315
+ forward_batch: ForwardBatch,
316
+ input_embeds: Optional[torch.Tensor] = None,
317
+ ) -> torch.Tensor:
318
+ if input_embeds is None:
319
+ hidden_states = self.embed_tokens(input_ids)
320
+ else:
321
+ hidden_states = input_embeds
322
+
323
+ residual = None
324
+ for layer in self.layers:
325
+ hidden_states, residual = layer(
326
+ hidden_states,
327
+ position_ids,
328
+ residual,
329
+ forward_batch,
330
+ )
331
+
332
+ hidden_states, _ = self.norm(hidden_states, residual)
333
+ return hidden_states
334
+
335
+
336
+ class BailingMoeForCausalLM(nn.Module):
337
+
338
+ def __init__(
339
+ self,
340
+ config: PretrainedConfig,
341
+ quant_config: Optional[QuantizationConfig] = None,
342
+ ) -> None:
343
+ super().__init__()
344
+ self.config = config
345
+ self.model = BailingMoeModel(config=config, quant_config=quant_config)
346
+ self.lm_head = ParallelLMHead(
347
+ num_embeddings=config.vocab_size,
348
+ embedding_dim=config.hidden_size,
349
+ quant_config=quant_config,
350
+ )
351
+ if config.tie_word_embeddings:
352
+ self.lm_head.weight = self.model.embed_tokens.weight
353
+
354
+ self.logits_processor = LogitsProcessor(config)
355
+
356
+ def forward(
357
+ self,
358
+ input_ids: torch.Tensor,
359
+ positions: torch.Tensor,
360
+ forward_batch: ForwardBatch,
361
+ inputs_embeds: Optional[torch.Tensor] = None,
362
+ ) -> torch.Tensor:
363
+ hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
364
+ return self.logits_processor(
365
+ input_ids, hidden_states, self.lm_head, forward_batch
366
+ )
367
+
368
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
369
+
370
+ stacked_params_mapping = [
371
+ ("gate_up_proj", "gate_proj", 0),
372
+ ("gate_up_proj", "up_proj", 1),
373
+ ]
374
+
375
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
376
+ ckpt_gate_proj_name="gate_proj",
377
+ ckpt_down_proj_name="down_proj",
378
+ ckpt_up_proj_name="up_proj",
379
+ num_experts=self.config.num_experts,
380
+ )
381
+
382
+ params_dict = dict(self.named_parameters())
383
+ for name, loaded_weight in weights:
384
+
385
+ if (
386
+ hasattr(self.config, "norm_head")
387
+ and self.config.norm_head
388
+ and "lm_head.weight" in name
389
+ ):
390
+ loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7)
391
+
392
+ if "model.word_embeddings.weight" == name:
393
+ name = "model.embed_tokens.weight"
394
+
395
+ for param_name, weight_name, shard_id in stacked_params_mapping:
396
+ if weight_name in name and "mlp.experts" not in name:
397
+ full_param_name = name.replace(weight_name, param_name)
398
+ param = params_dict[full_param_name]
399
+ param.weight_loader(param, loaded_weight, shard_id)
400
+ break
401
+ else:
402
+ for p_name, w_name, e_id, s_id in expert_params_mapping:
403
+ if w_name in name and "mlp.experts" in name:
404
+ full_param_name = name.replace(w_name, p_name)
405
+ param = params_dict[full_param_name]
406
+ param.weight_loader(
407
+ param,
408
+ loaded_weight,
409
+ full_param_name,
410
+ shard_id=s_id,
411
+ expert_id=e_id,
412
+ )
413
+ break
414
+ else:
415
+ if name.endswith(".bias") and name not in params_dict:
416
+ continue
417
+
418
+ param = params_dict[name]
419
+ weight_loader = getattr(
420
+ param, "weight_loader", default_weight_loader
421
+ )
422
+ weight_loader(param, loaded_weight)
423
+
424
+
425
+ EntryClass = BailingMoeForCausalLM