sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_latency.py +17 -8
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +5 -17
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -4
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +33 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/activation.py +12 -0
  15. sglang/srt/layers/attention_backend.py +480 -0
  16. sglang/srt/layers/flashinfer_utils.py +235 -0
  17. sglang/srt/layers/fused_moe/layer.py +27 -7
  18. sglang/srt/layers/layernorm.py +12 -0
  19. sglang/srt/layers/logits_processor.py +64 -77
  20. sglang/srt/layers/radix_attention.py +11 -161
  21. sglang/srt/layers/sampler.py +38 -122
  22. sglang/srt/layers/torchao_utils.py +75 -0
  23. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  24. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  25. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  26. sglang/srt/lora/lora.py +403 -0
  27. sglang/srt/lora/lora_config.py +43 -0
  28. sglang/srt/lora/lora_manager.py +259 -0
  29. sglang/srt/managers/controller_multi.py +1 -5
  30. sglang/srt/managers/controller_single.py +0 -5
  31. sglang/srt/managers/io_struct.py +16 -1
  32. sglang/srt/managers/policy_scheduler.py +122 -5
  33. sglang/srt/managers/schedule_batch.py +105 -71
  34. sglang/srt/managers/tokenizer_manager.py +17 -8
  35. sglang/srt/managers/tp_worker.py +188 -121
  36. sglang/srt/model_executor/cuda_graph_runner.py +69 -133
  37. sglang/srt/model_executor/forward_batch_info.py +35 -312
  38. sglang/srt/model_executor/model_runner.py +123 -154
  39. sglang/srt/models/baichuan.py +416 -0
  40. sglang/srt/models/chatglm.py +1 -5
  41. sglang/srt/models/commandr.py +1 -5
  42. sglang/srt/models/dbrx.py +1 -5
  43. sglang/srt/models/deepseek.py +1 -5
  44. sglang/srt/models/deepseek_v2.py +7 -6
  45. sglang/srt/models/exaone.py +1 -5
  46. sglang/srt/models/gemma.py +1 -5
  47. sglang/srt/models/gemma2.py +1 -5
  48. sglang/srt/models/gpt_bigcode.py +1 -5
  49. sglang/srt/models/grok.py +1 -5
  50. sglang/srt/models/internlm2.py +1 -5
  51. sglang/srt/models/llama.py +51 -5
  52. sglang/srt/models/llama_classification.py +1 -20
  53. sglang/srt/models/llava.py +30 -5
  54. sglang/srt/models/llavavid.py +2 -2
  55. sglang/srt/models/minicpm.py +1 -5
  56. sglang/srt/models/minicpm3.py +669 -0
  57. sglang/srt/models/mixtral.py +6 -5
  58. sglang/srt/models/mixtral_quant.py +1 -5
  59. sglang/srt/models/olmoe.py +415 -0
  60. sglang/srt/models/qwen.py +1 -5
  61. sglang/srt/models/qwen2.py +1 -5
  62. sglang/srt/models/qwen2_moe.py +6 -5
  63. sglang/srt/models/stablelm.py +1 -5
  64. sglang/srt/models/xverse.py +375 -0
  65. sglang/srt/models/xverse_moe.py +445 -0
  66. sglang/srt/openai_api/adapter.py +65 -46
  67. sglang/srt/openai_api/protocol.py +11 -3
  68. sglang/srt/sampling/sampling_batch_info.py +46 -80
  69. sglang/srt/server.py +30 -15
  70. sglang/srt/server_args.py +163 -28
  71. sglang/srt/utils.py +19 -51
  72. sglang/test/few_shot_gsm8k.py +132 -0
  73. sglang/test/runners.py +114 -22
  74. sglang/test/test_programs.py +7 -5
  75. sglang/test/test_utils.py +85 -2
  76. sglang/utils.py +32 -37
  77. sglang/version.py +1 -1
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
  79. sglang-0.3.1.post1.dist-info/RECORD +130 -0
  80. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
  81. sglang-0.3.0.dist-info/RECORD +0 -118
  82. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
  83. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,669 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """Inference-only MiniCPM3 model compatible with HuggingFace weights."""
17
+
18
+ import math
19
+ from typing import Any, Dict, Iterable, Optional, Tuple
20
+
21
+ import torch
22
+ from torch import nn
23
+ from transformers import PretrainedConfig
24
+ from vllm.config import CacheConfig
25
+ from vllm.distributed import get_tensor_model_parallel_world_size
26
+ from vllm.model_executor.layers.linear import (
27
+ ColumnParallelLinear,
28
+ MergedColumnParallelLinear,
29
+ ReplicatedLinear,
30
+ RowParallelLinear,
31
+ )
32
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
33
+ from vllm.model_executor.layers.rotary_embedding import get_rope
34
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
35
+ ParallelLMHead,
36
+ VocabParallelEmbedding,
37
+ )
38
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
+
40
+ from sglang.srt.layers.activation import SiluAndMul
41
+ from sglang.srt.layers.layernorm import RMSNorm
42
+ from sglang.srt.layers.logits_processor import LogitsProcessor
43
+ from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
45
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
+ from sglang.srt.utils import is_hip
47
+
48
+ # ROCm: flashinfer available later
49
+ if not is_hip():
50
+ from flashinfer import bmm_fp8
51
+
52
+
53
+ class MiniCPM3MLP(nn.Module):
54
+ def __init__(
55
+ self,
56
+ hidden_size: int,
57
+ intermediate_size: int,
58
+ hidden_act: str,
59
+ quant_config: Optional[QuantizationConfig] = None,
60
+ ) -> None:
61
+ super().__init__()
62
+ self.gate_up_proj = MergedColumnParallelLinear(
63
+ hidden_size,
64
+ [intermediate_size] * 2,
65
+ bias=False,
66
+ quant_config=quant_config,
67
+ )
68
+ self.down_proj = RowParallelLinear(
69
+ intermediate_size,
70
+ hidden_size,
71
+ bias=False,
72
+ quant_config=quant_config,
73
+ )
74
+ if hidden_act != "silu":
75
+ raise ValueError(
76
+ f"Unsupported activation: {hidden_act}. "
77
+ "Only silu is supported for now."
78
+ )
79
+ self.act_fn = SiluAndMul()
80
+
81
+ def forward(self, x):
82
+ gate_up, _ = self.gate_up_proj(x)
83
+ x = self.act_fn(gate_up)
84
+ x, _ = self.down_proj(x)
85
+ return x
86
+
87
+
88
+ def input_to_float8(x, dtype=torch.float8_e4m3fn):
89
+ finfo = torch.finfo(dtype)
90
+ min_val, max_val = x.aminmax()
91
+ amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
92
+ scale = finfo.max / amax
93
+ x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
94
+ return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
95
+
96
+
97
+ class MiniCPM3Attention(nn.Module):
98
+
99
+ def __init__(
100
+ self,
101
+ config: PretrainedConfig,
102
+ hidden_size: int,
103
+ num_heads: int,
104
+ qk_nope_head_dim: int,
105
+ qk_rope_head_dim: int,
106
+ v_head_dim: int,
107
+ q_lora_rank: int,
108
+ kv_lora_rank: int,
109
+ rope_theta: float = 10000,
110
+ rope_scaling: Optional[Dict[str, Any]] = None,
111
+ max_position_embeddings: int = 8192,
112
+ cache_config: Optional[CacheConfig] = None,
113
+ quant_config: Optional[QuantizationConfig] = None,
114
+ layer_id=None,
115
+ ) -> None:
116
+ super().__init__()
117
+ self.layer_id = layer_id
118
+ self.hidden_size = hidden_size
119
+ self.qk_nope_head_dim = qk_nope_head_dim
120
+ self.qk_rope_head_dim = qk_rope_head_dim
121
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
122
+ self.v_head_dim = v_head_dim
123
+ self.q_lora_rank = q_lora_rank
124
+ self.kv_lora_rank = kv_lora_rank
125
+ self.num_heads = num_heads
126
+ tp_size = get_tensor_model_parallel_world_size()
127
+ assert num_heads % tp_size == 0
128
+ self.num_local_heads = num_heads // tp_size
129
+ self.scaling = self.qk_head_dim**-0.5
130
+ self.rope_theta = rope_theta
131
+ self.max_position_embeddings = max_position_embeddings
132
+
133
+ if self.q_lora_rank is not None:
134
+ self.q_a_proj = ReplicatedLinear(
135
+ self.hidden_size,
136
+ self.q_lora_rank,
137
+ bias=False,
138
+ quant_config=quant_config,
139
+ )
140
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
141
+ self.q_b_proj = ColumnParallelLinear(
142
+ q_lora_rank,
143
+ self.num_heads * self.qk_head_dim,
144
+ bias=False,
145
+ quant_config=quant_config,
146
+ )
147
+ else:
148
+ self.q_proj = ColumnParallelLinear(
149
+ self.hidden_size,
150
+ self.num_heads * self.qk_head_dim,
151
+ bias=False,
152
+ quant_config=quant_config,
153
+ )
154
+
155
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
156
+ self.hidden_size,
157
+ self.kv_lora_rank + self.qk_rope_head_dim,
158
+ bias=False,
159
+ quant_config=quant_config,
160
+ )
161
+ self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
162
+ self.kv_b_proj = ColumnParallelLinear(
163
+ self.kv_lora_rank,
164
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
165
+ bias=False,
166
+ quant_config=quant_config,
167
+ )
168
+ # O projection.
169
+ self.o_proj = RowParallelLinear(
170
+ self.num_heads * self.v_head_dim,
171
+ self.hidden_size,
172
+ bias=False,
173
+ quant_config=quant_config,
174
+ )
175
+ self.rotary_emb = get_rope(
176
+ qk_rope_head_dim,
177
+ rotary_dim=qk_rope_head_dim,
178
+ max_position=max_position_embeddings,
179
+ base=rope_theta,
180
+ rope_scaling=rope_scaling,
181
+ )
182
+
183
+ # TODO support head_size 96
184
+ self.attn = RadixAttention(
185
+ self.num_local_heads,
186
+ 128,
187
+ self.scaling,
188
+ num_kv_heads=self.num_local_heads,
189
+ layer_id=layer_id,
190
+ )
191
+
192
+ def forward(
193
+ self,
194
+ positions: torch.Tensor,
195
+ hidden_states: torch.Tensor,
196
+ input_metadata: InputMetadata,
197
+ ) -> torch.Tensor:
198
+ if self.q_lora_rank is not None:
199
+ q = self.q_a_proj(hidden_states)[0]
200
+ q = self.q_a_layernorm(q)
201
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
202
+ else:
203
+ q = self.q_proj(hidden_states)[0].view(
204
+ -1, self.num_local_heads, self.qk_head_dim
205
+ )
206
+ _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
207
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
208
+ kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
209
+ latent_cache = latent_cache.unsqueeze(1)
210
+ kv_a = self.kv_a_layernorm(kv_a.contiguous())
211
+ kv = self.kv_b_proj(kv_a)[0]
212
+ kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
213
+ k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
214
+ k_pe = latent_cache[:, :, self.kv_lora_rank :]
215
+ original_shapes = [q_pe.shape, k_pe.shape]
216
+ q_pe, k_pe = self.rotary_emb(
217
+ positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1)
218
+ )
219
+ q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1])
220
+ q[..., self.qk_nope_head_dim :] = q_pe
221
+ k = torch.empty_like(q)
222
+ k[..., : self.qk_nope_head_dim] = k_nope
223
+ k[..., self.qk_nope_head_dim :] = k_pe
224
+ q = torch.nn.functional.pad(q, [0, 128 - self.qk_head_dim], value=0).view(
225
+ -1, self.num_local_heads * 128
226
+ )
227
+ k = torch.nn.functional.pad(k, [0, 128 - self.qk_head_dim], value=0).view(
228
+ -1, self.num_local_heads * 128
229
+ )
230
+ v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
231
+ -1, self.num_local_heads * 128
232
+ )
233
+ attn_output = self.attn(q, k, v, input_metadata)
234
+ attn_output = attn_output.view(-1, self.num_local_heads, 128)[
235
+ ..., : self.v_head_dim
236
+ ].reshape(-1, self.num_local_heads * self.v_head_dim)
237
+ output, _ = self.o_proj(attn_output)
238
+ return output
239
+
240
+
241
+ class MiniCPM3AttentionMLA(nn.Module):
242
+
243
+ def __init__(
244
+ self,
245
+ config: PretrainedConfig,
246
+ hidden_size: int,
247
+ num_heads: int,
248
+ qk_nope_head_dim: int,
249
+ qk_rope_head_dim: int,
250
+ v_head_dim: int,
251
+ q_lora_rank: int,
252
+ kv_lora_rank: int,
253
+ rope_theta: float = 10000,
254
+ rope_scaling: Optional[Dict[str, Any]] = None,
255
+ max_position_embeddings: int = 8192,
256
+ cache_config: Optional[CacheConfig] = None,
257
+ quant_config: Optional[QuantizationConfig] = None,
258
+ layer_id=None,
259
+ ) -> None:
260
+ super().__init__()
261
+ self.layer_id = layer_id
262
+ self.hidden_size = hidden_size
263
+ self.qk_nope_head_dim = qk_nope_head_dim
264
+ self.qk_rope_head_dim = qk_rope_head_dim
265
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
266
+ self.v_head_dim = v_head_dim
267
+ self.q_lora_rank = q_lora_rank
268
+ self.kv_lora_rank = kv_lora_rank
269
+ self.num_heads = num_heads
270
+ tp_size = get_tensor_model_parallel_world_size()
271
+ assert num_heads % tp_size == 0
272
+ self.num_local_heads = num_heads // tp_size
273
+ self.scaling = self.qk_head_dim**-0.5
274
+ self.rope_theta = rope_theta
275
+ self.max_position_embeddings = max_position_embeddings
276
+
277
+ if self.q_lora_rank is not None:
278
+ self.q_a_proj = ReplicatedLinear(
279
+ self.hidden_size,
280
+ self.q_lora_rank,
281
+ bias=False,
282
+ quant_config=quant_config,
283
+ )
284
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
285
+ self.q_b_proj = ColumnParallelLinear(
286
+ q_lora_rank,
287
+ self.num_heads * self.qk_head_dim,
288
+ bias=False,
289
+ quant_config=quant_config,
290
+ )
291
+ else:
292
+ self.q_proj = ColumnParallelLinear(
293
+ self.hidden_size,
294
+ self.num_heads * self.qk_head_dim,
295
+ bias=False,
296
+ quant_config=quant_config,
297
+ )
298
+
299
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
300
+ self.hidden_size,
301
+ self.kv_lora_rank + self.qk_rope_head_dim,
302
+ bias=False,
303
+ quant_config=quant_config,
304
+ )
305
+ self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
306
+ self.kv_b_proj = ColumnParallelLinear(
307
+ self.kv_lora_rank,
308
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
309
+ bias=False,
310
+ quant_config=quant_config,
311
+ )
312
+ # O projection.
313
+ self.o_proj = RowParallelLinear(
314
+ self.num_heads * self.v_head_dim,
315
+ self.hidden_size,
316
+ bias=False,
317
+ quant_config=quant_config,
318
+ )
319
+ self.rotary_emb = get_rope(
320
+ qk_rope_head_dim,
321
+ rotary_dim=qk_rope_head_dim,
322
+ max_position=max_position_embeddings,
323
+ base=rope_theta,
324
+ rope_scaling=rope_scaling,
325
+ )
326
+
327
+ self.attn = RadixAttention(
328
+ self.num_local_heads,
329
+ self.kv_lora_rank + self.qk_rope_head_dim,
330
+ self.scaling,
331
+ num_kv_heads=1,
332
+ layer_id=layer_id,
333
+ v_head_dim=self.kv_lora_rank,
334
+ )
335
+
336
+ self.w_kc = None
337
+ self.w_vc = None
338
+ self.w_scale = None
339
+
340
+ def forward(
341
+ self,
342
+ positions: torch.Tensor,
343
+ hidden_states: torch.Tensor,
344
+ input_metadata: InputMetadata,
345
+ ) -> torch.Tensor:
346
+ q_len = hidden_states.shape[0]
347
+ q_input = hidden_states.new_empty(
348
+ q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
349
+ )
350
+ if self.q_lora_rank is not None:
351
+ q = self.q_a_proj(hidden_states)[0]
352
+ q = self.q_a_layernorm(q)
353
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
354
+ else:
355
+ q = self.q_proj(hidden_states)[0].view(
356
+ -1, self.num_local_heads, self.qk_head_dim
357
+ )
358
+ q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
359
+
360
+ if self.w_kc.dtype == torch.float8_e4m3fn:
361
+ q_nope_val, q_nope_scale = input_to_float8(
362
+ q_nope.transpose(0, 1), torch.float8_e4m3fn
363
+ )
364
+ q_nope_out = bmm_fp8(
365
+ q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
366
+ )
367
+ else:
368
+ q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
369
+ q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
370
+
371
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
372
+ v_input = latent_cache[..., : self.kv_lora_rank]
373
+ v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
374
+ k_input = latent_cache.unsqueeze(1)
375
+ k_input[..., : self.kv_lora_rank] = v_input
376
+ k_pe = k_input[..., self.kv_lora_rank :]
377
+
378
+ original_shapes = [q_pe.shape, k_pe.shape]
379
+ q_pe, k_pe = self.rotary_emb(
380
+ positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1)
381
+ )
382
+ q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1])
383
+ q_input[..., self.kv_lora_rank :] = q_pe
384
+ k_input[..., self.kv_lora_rank :] = k_pe
385
+
386
+ attn_output = self.attn(q_input, k_input, v_input, input_metadata)
387
+ attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
388
+
389
+ if self.w_vc.dtype == torch.float8_e4m3fn:
390
+ attn_output_val, attn_output_scale = input_to_float8(
391
+ attn_output.transpose(0, 1), torch.float8_e4m3fn
392
+ )
393
+ attn_bmm_output = bmm_fp8(
394
+ attn_output_val,
395
+ self.w_vc,
396
+ attn_output_scale,
397
+ self.w_scale,
398
+ torch.bfloat16,
399
+ )
400
+ else:
401
+ attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
402
+ attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
403
+ output, _ = self.o_proj(attn_output)
404
+
405
+ return output
406
+
407
+
408
+ class MiniCPM3DecoderLayer(nn.Module):
409
+ def __init__(
410
+ self,
411
+ config: PretrainedConfig,
412
+ layer_id: int,
413
+ cache_config: Optional[CacheConfig] = None,
414
+ quant_config: Optional[QuantizationConfig] = None,
415
+ ) -> None:
416
+ super().__init__()
417
+ self.config = config
418
+ self.hidden_size = config.hidden_size
419
+ rope_theta = getattr(config, "rope_theta", 10000)
420
+ rope_scaling = getattr(config, "rope_scaling", None)
421
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
422
+ if global_server_args_dict["enable_mla"]:
423
+ self.self_attn = MiniCPM3AttentionMLA(
424
+ config=config,
425
+ hidden_size=self.hidden_size,
426
+ num_heads=config.num_attention_heads,
427
+ qk_nope_head_dim=config.qk_nope_head_dim,
428
+ qk_rope_head_dim=config.qk_rope_head_dim,
429
+ v_head_dim=self.hidden_size // config.num_attention_heads,
430
+ q_lora_rank=(
431
+ config.q_lora_rank if hasattr(config, "q_lora_rank") else None
432
+ ),
433
+ kv_lora_rank=config.kv_lora_rank,
434
+ rope_theta=rope_theta,
435
+ rope_scaling=rope_scaling,
436
+ max_position_embeddings=max_position_embeddings,
437
+ cache_config=cache_config,
438
+ quant_config=quant_config,
439
+ layer_id=layer_id,
440
+ )
441
+ else:
442
+ self.self_attn = MiniCPM3Attention(
443
+ config=config,
444
+ hidden_size=self.hidden_size,
445
+ num_heads=config.num_attention_heads,
446
+ qk_nope_head_dim=config.qk_nope_head_dim,
447
+ qk_rope_head_dim=config.qk_rope_head_dim,
448
+ v_head_dim=self.hidden_size // config.num_attention_heads,
449
+ q_lora_rank=(
450
+ config.q_lora_rank if hasattr(config, "q_lora_rank") else None
451
+ ),
452
+ kv_lora_rank=config.kv_lora_rank,
453
+ rope_theta=rope_theta,
454
+ rope_scaling=rope_scaling,
455
+ max_position_embeddings=max_position_embeddings,
456
+ cache_config=cache_config,
457
+ quant_config=quant_config,
458
+ layer_id=layer_id,
459
+ )
460
+ self.mlp = MiniCPM3MLP(
461
+ hidden_size=self.hidden_size,
462
+ intermediate_size=config.intermediate_size,
463
+ hidden_act=config.hidden_act,
464
+ quant_config=quant_config,
465
+ )
466
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
467
+ self.post_attention_layernorm = RMSNorm(
468
+ config.hidden_size, eps=config.rms_norm_eps
469
+ )
470
+
471
+ def forward(
472
+ self,
473
+ positions: torch.Tensor,
474
+ hidden_states: torch.Tensor,
475
+ input_metadata: InputMetadata,
476
+ residual: Optional[torch.Tensor],
477
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
478
+ # Self Attention
479
+ residual = hidden_states
480
+ hidden_states = self.input_layernorm(hidden_states)
481
+ hidden_states = self.self_attn(
482
+ positions=positions,
483
+ hidden_states=hidden_states,
484
+ input_metadata=input_metadata,
485
+ )
486
+ hidden_states = residual + hidden_states * (
487
+ self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
488
+ )
489
+
490
+ # Fully Connected
491
+ residual = hidden_states
492
+ hidden_states = self.post_attention_layernorm(hidden_states)
493
+ hidden_states = self.mlp(hidden_states)
494
+ hidden_states = residual + hidden_states * (
495
+ self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
496
+ )
497
+
498
+ return hidden_states, None
499
+
500
+
501
+ class MiniCPM3Model(nn.Module):
502
+ def __init__(
503
+ self,
504
+ config: PretrainedConfig,
505
+ cache_config: Optional[CacheConfig] = None,
506
+ quant_config: Optional[QuantizationConfig] = None,
507
+ ) -> None:
508
+ super().__init__()
509
+ self.config = config
510
+ self.padding_idx = config.pad_token_id
511
+ self.vocab_size = config.vocab_size
512
+ self.embed_tokens = VocabParallelEmbedding(
513
+ self.vocab_size,
514
+ config.hidden_size,
515
+ org_num_embeddings=config.vocab_size,
516
+ )
517
+ self.layers = nn.ModuleList(
518
+ [
519
+ MiniCPM3DecoderLayer(
520
+ config, i, cache_config=cache_config, quant_config=quant_config
521
+ )
522
+ for i in range(config.num_hidden_layers)
523
+ ]
524
+ )
525
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
526
+
527
+ def forward(
528
+ self,
529
+ input_ids: torch.Tensor,
530
+ positions: torch.Tensor,
531
+ input_metadata: InputMetadata,
532
+ input_embeds: torch.Tensor = None,
533
+ ) -> torch.Tensor:
534
+ if input_embeds is None:
535
+ hidden_states = self.embed_tokens(input_ids) * self.config.scale_emb
536
+ else:
537
+ hidden_states = input_embeds
538
+ residual = None
539
+
540
+ for i in range(len(self.layers)):
541
+ layer = self.layers[i]
542
+ hidden_states, residual = layer(
543
+ positions,
544
+ hidden_states,
545
+ input_metadata,
546
+ residual,
547
+ )
548
+ hidden_states = self.norm(hidden_states)
549
+ return hidden_states
550
+
551
+
552
+ class MiniCPM3ForCausalLM(nn.Module):
553
+ def __init__(
554
+ self,
555
+ config: PretrainedConfig,
556
+ cache_config: Optional[CacheConfig] = None,
557
+ quant_config: Optional[QuantizationConfig] = None,
558
+ ) -> None:
559
+ super().__init__()
560
+ self.config = config
561
+
562
+ self.num_experts = getattr(self.config, "num_experts", 0)
563
+ self.quant_config = quant_config
564
+ self.model = MiniCPM3Model(
565
+ config, cache_config=cache_config, quant_config=quant_config
566
+ )
567
+ # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
568
+ if not self.config.tie_word_embeddings:
569
+ self.lm_head = ParallelLMHead(
570
+ config.vocab_size,
571
+ config.hidden_size,
572
+ org_num_embeddings=config.vocab_size,
573
+ )
574
+
575
+ self.scale_width = self.config.hidden_size / self.config.dim_model_base
576
+
577
+ self.logits_processor = LogitsProcessor(config)
578
+
579
+ @torch.no_grad()
580
+ def forward(
581
+ self,
582
+ input_ids: torch.Tensor,
583
+ positions: torch.Tensor,
584
+ input_metadata: InputMetadata,
585
+ input_embeds: torch.Tensor = None,
586
+ ) -> torch.Tensor:
587
+ if input_embeds is not None:
588
+ input_embeds = input_embeds * self.config.scale_emb
589
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
590
+ hidden_states = hidden_states / self.scale_width
591
+ if self.config.tie_word_embeddings:
592
+ lm_head_weight = self.model.embed_tokens.weight
593
+ else:
594
+ lm_head_weight = self.lm_head.weight
595
+ return self.logits_processor(
596
+ input_ids, hidden_states, lm_head_weight, input_metadata
597
+ )
598
+
599
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
600
+ stacked_params_mapping = [
601
+ # (param_name, shard_name, shard_id)
602
+ ("gate_up_proj", "gate_proj", 0),
603
+ ("gate_up_proj", "up_proj", 1),
604
+ ]
605
+ expert_params_mapping = [
606
+ # (param_name, weight_name, expert_id)
607
+ (
608
+ "ws" if weight_name in ["w1", "w3"] else "w2s",
609
+ f"experts.{expert_id}.{weight_name}.weight",
610
+ expert_id,
611
+ )
612
+ for expert_id in range(self.num_experts)
613
+ for weight_name in ["w1", "w2", "w3"]
614
+ ]
615
+ params_dict = dict(self.named_parameters())
616
+ for name, loaded_weight in weights:
617
+ if "rotary_emb.inv_freq" in name:
618
+ continue
619
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
620
+ # Models trained using ColossalAI may include these tensors in
621
+ # the checkpoint. Skip them.
622
+ continue
623
+
624
+ for param_name, weight_name, shard_id in stacked_params_mapping:
625
+ if weight_name not in name:
626
+ continue
627
+ name = name.replace(weight_name, param_name)
628
+ # Skip loading extra bias for GPTQ models.
629
+ if name.endswith(".bias") and name not in params_dict:
630
+ continue
631
+ param = params_dict[name]
632
+ weight_loader = param.weight_loader
633
+ weight_loader(param, loaded_weight, shard_id)
634
+ break
635
+ else:
636
+ for param_name, weight_name, expert_id in expert_params_mapping:
637
+ if weight_name not in name:
638
+ continue
639
+ name = name.replace(weight_name, param_name)
640
+ param = params_dict[name]
641
+ weight_loader = param.weight_loader
642
+ weight_loader(
643
+ param, loaded_weight, weight_name, expert_id=expert_id
644
+ )
645
+ break
646
+ else:
647
+ # Skip loading extra bias for GPTQ models.
648
+ if name.endswith(".bias") and name not in params_dict:
649
+ continue
650
+ param = params_dict[name]
651
+ weight_loader = getattr(
652
+ param, "weight_loader", default_weight_loader
653
+ )
654
+ weight_loader(param, loaded_weight)
655
+
656
+ if global_server_args_dict["enable_mla"]:
657
+ for layer_id in range(self.config.num_hidden_layers):
658
+ self_attn = self.model.layers[layer_id].self_attn
659
+ w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
660
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
661
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
662
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
663
+ self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
664
+ if hasattr(self_attn.kv_b_proj, "weight_scale"):
665
+ self_attn.w_scale = self_attn.kv_b_proj.weight_scale
666
+ del self_attn.kv_b_proj
667
+
668
+
669
+ EntryClass = MiniCPM3ForCausalLM