sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +6 -25
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +104 -71
  31. sglang/srt/managers/tokenizer_manager.py +17 -8
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +58 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +117 -131
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +1 -5
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +1 -5
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/llama.py +51 -5
  49. sglang/srt/models/llama_classification.py +1 -20
  50. sglang/srt/models/llava.py +30 -5
  51. sglang/srt/models/llavavid.py +2 -2
  52. sglang/srt/models/minicpm.py +1 -5
  53. sglang/srt/models/minicpm3.py +665 -0
  54. sglang/srt/models/mixtral.py +6 -5
  55. sglang/srt/models/mixtral_quant.py +1 -5
  56. sglang/srt/models/qwen.py +1 -5
  57. sglang/srt/models/qwen2.py +1 -5
  58. sglang/srt/models/qwen2_moe.py +6 -5
  59. sglang/srt/models/stablelm.py +1 -5
  60. sglang/srt/models/xverse.py +375 -0
  61. sglang/srt/models/xverse_moe.py +445 -0
  62. sglang/srt/openai_api/adapter.py +65 -46
  63. sglang/srt/openai_api/protocol.py +11 -3
  64. sglang/srt/sampling/sampling_batch_info.py +57 -44
  65. sglang/srt/server.py +24 -14
  66. sglang/srt/server_args.py +130 -28
  67. sglang/srt/utils.py +12 -0
  68. sglang/test/few_shot_gsm8k.py +132 -0
  69. sglang/test/runners.py +114 -22
  70. sglang/test/test_programs.py +7 -5
  71. sglang/test/test_utils.py +85 -1
  72. sglang/utils.py +32 -37
  73. sglang/version.py +1 -1
  74. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
  75. sglang-0.3.1.dist-info/RECORD +129 -0
  76. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  77. sglang-0.3.0.dist-info/RECORD +0 -118
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  79. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,665 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """Inference-only MiniCPM3 model compatible with HuggingFace weights."""
17
+
18
+ import math
19
+ from typing import Any, Dict, Iterable, Optional, Tuple
20
+
21
+ import torch
22
+ from flashinfer import bmm_fp8
23
+ from torch import nn
24
+ from transformers import PretrainedConfig
25
+ from vllm.config import CacheConfig
26
+ from vllm.distributed import get_tensor_model_parallel_world_size
27
+ from vllm.model_executor.layers.linear import (
28
+ ColumnParallelLinear,
29
+ MergedColumnParallelLinear,
30
+ ReplicatedLinear,
31
+ RowParallelLinear,
32
+ )
33
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
34
+ from vllm.model_executor.layers.rotary_embedding import get_rope
35
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
36
+ ParallelLMHead,
37
+ VocabParallelEmbedding,
38
+ )
39
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
+
41
+ from sglang.srt.layers.activation import SiluAndMul
42
+ from sglang.srt.layers.layernorm import RMSNorm
43
+ from sglang.srt.layers.logits_processor import LogitsProcessor
44
+ from sglang.srt.layers.radix_attention import RadixAttention
45
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
46
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
47
+
48
+
49
+ class MiniCPM3MLP(nn.Module):
50
+ def __init__(
51
+ self,
52
+ hidden_size: int,
53
+ intermediate_size: int,
54
+ hidden_act: str,
55
+ quant_config: Optional[QuantizationConfig] = None,
56
+ ) -> None:
57
+ super().__init__()
58
+ self.gate_up_proj = MergedColumnParallelLinear(
59
+ hidden_size,
60
+ [intermediate_size] * 2,
61
+ bias=False,
62
+ quant_config=quant_config,
63
+ )
64
+ self.down_proj = RowParallelLinear(
65
+ intermediate_size,
66
+ hidden_size,
67
+ bias=False,
68
+ quant_config=quant_config,
69
+ )
70
+ if hidden_act != "silu":
71
+ raise ValueError(
72
+ f"Unsupported activation: {hidden_act}. "
73
+ "Only silu is supported for now."
74
+ )
75
+ self.act_fn = SiluAndMul()
76
+
77
+ def forward(self, x):
78
+ gate_up, _ = self.gate_up_proj(x)
79
+ x = self.act_fn(gate_up)
80
+ x, _ = self.down_proj(x)
81
+ return x
82
+
83
+
84
+ def input_to_float8(x, dtype=torch.float8_e4m3fn):
85
+ finfo = torch.finfo(dtype)
86
+ min_val, max_val = x.aminmax()
87
+ amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
88
+ scale = finfo.max / amax
89
+ x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
90
+ return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
91
+
92
+
93
+ class MiniCPM3Attention(nn.Module):
94
+
95
+ def __init__(
96
+ self,
97
+ config: PretrainedConfig,
98
+ hidden_size: int,
99
+ num_heads: int,
100
+ qk_nope_head_dim: int,
101
+ qk_rope_head_dim: int,
102
+ v_head_dim: int,
103
+ q_lora_rank: int,
104
+ kv_lora_rank: int,
105
+ rope_theta: float = 10000,
106
+ rope_scaling: Optional[Dict[str, Any]] = None,
107
+ max_position_embeddings: int = 8192,
108
+ cache_config: Optional[CacheConfig] = None,
109
+ quant_config: Optional[QuantizationConfig] = None,
110
+ layer_id=None,
111
+ ) -> None:
112
+ super().__init__()
113
+ self.layer_id = layer_id
114
+ self.hidden_size = hidden_size
115
+ self.qk_nope_head_dim = qk_nope_head_dim
116
+ self.qk_rope_head_dim = qk_rope_head_dim
117
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
118
+ self.v_head_dim = v_head_dim
119
+ self.q_lora_rank = q_lora_rank
120
+ self.kv_lora_rank = kv_lora_rank
121
+ self.num_heads = num_heads
122
+ tp_size = get_tensor_model_parallel_world_size()
123
+ assert num_heads % tp_size == 0
124
+ self.num_local_heads = num_heads // tp_size
125
+ self.scaling = self.qk_head_dim**-0.5
126
+ self.rope_theta = rope_theta
127
+ self.max_position_embeddings = max_position_embeddings
128
+
129
+ if self.q_lora_rank is not None:
130
+ self.q_a_proj = ReplicatedLinear(
131
+ self.hidden_size,
132
+ self.q_lora_rank,
133
+ bias=False,
134
+ quant_config=quant_config,
135
+ )
136
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
137
+ self.q_b_proj = ColumnParallelLinear(
138
+ q_lora_rank,
139
+ self.num_heads * self.qk_head_dim,
140
+ bias=False,
141
+ quant_config=quant_config,
142
+ )
143
+ else:
144
+ self.q_proj = ColumnParallelLinear(
145
+ self.hidden_size,
146
+ self.num_heads * self.qk_head_dim,
147
+ bias=False,
148
+ quant_config=quant_config,
149
+ )
150
+
151
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
152
+ self.hidden_size,
153
+ self.kv_lora_rank + self.qk_rope_head_dim,
154
+ bias=False,
155
+ quant_config=quant_config,
156
+ )
157
+ self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
158
+ self.kv_b_proj = ColumnParallelLinear(
159
+ self.kv_lora_rank,
160
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
161
+ bias=False,
162
+ quant_config=quant_config,
163
+ )
164
+ # O projection.
165
+ self.o_proj = RowParallelLinear(
166
+ self.num_heads * self.v_head_dim,
167
+ self.hidden_size,
168
+ bias=False,
169
+ quant_config=quant_config,
170
+ )
171
+ self.rotary_emb = get_rope(
172
+ qk_rope_head_dim,
173
+ rotary_dim=qk_rope_head_dim,
174
+ max_position=max_position_embeddings,
175
+ base=rope_theta,
176
+ rope_scaling=rope_scaling,
177
+ )
178
+
179
+ # TODO support head_size 96
180
+ self.attn = RadixAttention(
181
+ self.num_local_heads,
182
+ 128,
183
+ self.scaling,
184
+ num_kv_heads=self.num_local_heads,
185
+ layer_id=layer_id,
186
+ )
187
+
188
+ def forward(
189
+ self,
190
+ positions: torch.Tensor,
191
+ hidden_states: torch.Tensor,
192
+ input_metadata: InputMetadata,
193
+ ) -> torch.Tensor:
194
+ if self.q_lora_rank is not None:
195
+ q = self.q_a_proj(hidden_states)[0]
196
+ q = self.q_a_layernorm(q)
197
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
198
+ else:
199
+ q = self.q_proj(hidden_states)[0].view(
200
+ -1, self.num_local_heads, self.qk_head_dim
201
+ )
202
+ _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
203
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
204
+ kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
205
+ latent_cache = latent_cache.unsqueeze(1)
206
+ kv_a = self.kv_a_layernorm(kv_a.contiguous())
207
+ kv = self.kv_b_proj(kv_a)[0]
208
+ kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
209
+ k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
210
+ k_pe = latent_cache[:, :, self.kv_lora_rank :]
211
+ original_shapes = [q_pe.shape, k_pe.shape]
212
+ q_pe, k_pe = self.rotary_emb(
213
+ positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1)
214
+ )
215
+ q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1])
216
+ q[..., self.qk_nope_head_dim :] = q_pe
217
+ k = torch.empty_like(q)
218
+ k[..., : self.qk_nope_head_dim] = k_nope
219
+ k[..., self.qk_nope_head_dim :] = k_pe
220
+ q = torch.nn.functional.pad(q, [0, 128 - self.qk_head_dim], value=0).view(
221
+ -1, self.num_local_heads * 128
222
+ )
223
+ k = torch.nn.functional.pad(k, [0, 128 - self.qk_head_dim], value=0).view(
224
+ -1, self.num_local_heads * 128
225
+ )
226
+ v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
227
+ -1, self.num_local_heads * 128
228
+ )
229
+ attn_output = self.attn(q, k, v, input_metadata)
230
+ attn_output = attn_output.view(-1, self.num_local_heads, 128)[
231
+ ..., : self.v_head_dim
232
+ ].reshape(-1, self.num_local_heads * self.v_head_dim)
233
+ output, _ = self.o_proj(attn_output)
234
+ return output
235
+
236
+
237
+ class MiniCPM3AttentionMLA(nn.Module):
238
+
239
+ def __init__(
240
+ self,
241
+ config: PretrainedConfig,
242
+ hidden_size: int,
243
+ num_heads: int,
244
+ qk_nope_head_dim: int,
245
+ qk_rope_head_dim: int,
246
+ v_head_dim: int,
247
+ q_lora_rank: int,
248
+ kv_lora_rank: int,
249
+ rope_theta: float = 10000,
250
+ rope_scaling: Optional[Dict[str, Any]] = None,
251
+ max_position_embeddings: int = 8192,
252
+ cache_config: Optional[CacheConfig] = None,
253
+ quant_config: Optional[QuantizationConfig] = None,
254
+ layer_id=None,
255
+ ) -> None:
256
+ super().__init__()
257
+ self.layer_id = layer_id
258
+ self.hidden_size = hidden_size
259
+ self.qk_nope_head_dim = qk_nope_head_dim
260
+ self.qk_rope_head_dim = qk_rope_head_dim
261
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
262
+ self.v_head_dim = v_head_dim
263
+ self.q_lora_rank = q_lora_rank
264
+ self.kv_lora_rank = kv_lora_rank
265
+ self.num_heads = num_heads
266
+ tp_size = get_tensor_model_parallel_world_size()
267
+ assert num_heads % tp_size == 0
268
+ self.num_local_heads = num_heads // tp_size
269
+ self.scaling = self.qk_head_dim**-0.5
270
+ self.rope_theta = rope_theta
271
+ self.max_position_embeddings = max_position_embeddings
272
+
273
+ if self.q_lora_rank is not None:
274
+ self.q_a_proj = ReplicatedLinear(
275
+ self.hidden_size,
276
+ self.q_lora_rank,
277
+ bias=False,
278
+ quant_config=quant_config,
279
+ )
280
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
281
+ self.q_b_proj = ColumnParallelLinear(
282
+ q_lora_rank,
283
+ self.num_heads * self.qk_head_dim,
284
+ bias=False,
285
+ quant_config=quant_config,
286
+ )
287
+ else:
288
+ self.q_proj = ColumnParallelLinear(
289
+ self.hidden_size,
290
+ self.num_heads * self.qk_head_dim,
291
+ bias=False,
292
+ quant_config=quant_config,
293
+ )
294
+
295
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
296
+ self.hidden_size,
297
+ self.kv_lora_rank + self.qk_rope_head_dim,
298
+ bias=False,
299
+ quant_config=quant_config,
300
+ )
301
+ self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
302
+ self.kv_b_proj = ColumnParallelLinear(
303
+ self.kv_lora_rank,
304
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
305
+ bias=False,
306
+ quant_config=quant_config,
307
+ )
308
+ # O projection.
309
+ self.o_proj = RowParallelLinear(
310
+ self.num_heads * self.v_head_dim,
311
+ self.hidden_size,
312
+ bias=False,
313
+ quant_config=quant_config,
314
+ )
315
+ self.rotary_emb = get_rope(
316
+ qk_rope_head_dim,
317
+ rotary_dim=qk_rope_head_dim,
318
+ max_position=max_position_embeddings,
319
+ base=rope_theta,
320
+ rope_scaling=rope_scaling,
321
+ )
322
+
323
+ self.attn = RadixAttention(
324
+ self.num_local_heads,
325
+ self.kv_lora_rank + self.qk_rope_head_dim,
326
+ self.scaling,
327
+ num_kv_heads=1,
328
+ layer_id=layer_id,
329
+ v_head_dim=self.kv_lora_rank,
330
+ )
331
+
332
+ self.w_kc = None
333
+ self.w_vc = None
334
+ self.w_scale = None
335
+
336
+ def forward(
337
+ self,
338
+ positions: torch.Tensor,
339
+ hidden_states: torch.Tensor,
340
+ input_metadata: InputMetadata,
341
+ ) -> torch.Tensor:
342
+ q_len = hidden_states.shape[0]
343
+ q_input = hidden_states.new_empty(
344
+ q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
345
+ )
346
+ if self.q_lora_rank is not None:
347
+ q = self.q_a_proj(hidden_states)[0]
348
+ q = self.q_a_layernorm(q)
349
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
350
+ else:
351
+ q = self.q_proj(hidden_states)[0].view(
352
+ -1, self.num_local_heads, self.qk_head_dim
353
+ )
354
+ q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
355
+
356
+ if self.w_kc.dtype == torch.float8_e4m3fn:
357
+ q_nope_val, q_nope_scale = input_to_float8(
358
+ q_nope.transpose(0, 1), torch.float8_e4m3fn
359
+ )
360
+ q_nope_out = bmm_fp8(
361
+ q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
362
+ )
363
+ else:
364
+ q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
365
+ q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
366
+
367
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
368
+ v_input = latent_cache[..., : self.kv_lora_rank]
369
+ v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
370
+ k_input = latent_cache.unsqueeze(1)
371
+ k_input[..., : self.kv_lora_rank] = v_input
372
+ k_pe = k_input[..., self.kv_lora_rank :]
373
+
374
+ original_shapes = [q_pe.shape, k_pe.shape]
375
+ q_pe, k_pe = self.rotary_emb(
376
+ positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1)
377
+ )
378
+ q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1])
379
+ q_input[..., self.kv_lora_rank :] = q_pe
380
+ k_input[..., self.kv_lora_rank :] = k_pe
381
+
382
+ attn_output = self.attn(q_input, k_input, v_input, input_metadata)
383
+ attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
384
+
385
+ if self.w_vc.dtype == torch.float8_e4m3fn:
386
+ attn_output_val, attn_output_scale = input_to_float8(
387
+ attn_output.transpose(0, 1), torch.float8_e4m3fn
388
+ )
389
+ attn_bmm_output = bmm_fp8(
390
+ attn_output_val,
391
+ self.w_vc,
392
+ attn_output_scale,
393
+ self.w_scale,
394
+ torch.bfloat16,
395
+ )
396
+ else:
397
+ attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
398
+ attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
399
+ output, _ = self.o_proj(attn_output)
400
+
401
+ return output
402
+
403
+
404
+ class MiniCPM3DecoderLayer(nn.Module):
405
+ def __init__(
406
+ self,
407
+ config: PretrainedConfig,
408
+ layer_id: int,
409
+ cache_config: Optional[CacheConfig] = None,
410
+ quant_config: Optional[QuantizationConfig] = None,
411
+ ) -> None:
412
+ super().__init__()
413
+ self.config = config
414
+ self.hidden_size = config.hidden_size
415
+ rope_theta = getattr(config, "rope_theta", 10000)
416
+ rope_scaling = getattr(config, "rope_scaling", None)
417
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
418
+ if global_server_args_dict["enable_mla"]:
419
+ self.self_attn = MiniCPM3AttentionMLA(
420
+ config=config,
421
+ hidden_size=self.hidden_size,
422
+ num_heads=config.num_attention_heads,
423
+ qk_nope_head_dim=config.qk_nope_head_dim,
424
+ qk_rope_head_dim=config.qk_rope_head_dim,
425
+ v_head_dim=self.hidden_size // config.num_attention_heads,
426
+ q_lora_rank=(
427
+ config.q_lora_rank if hasattr(config, "q_lora_rank") else None
428
+ ),
429
+ kv_lora_rank=config.kv_lora_rank,
430
+ rope_theta=rope_theta,
431
+ rope_scaling=rope_scaling,
432
+ max_position_embeddings=max_position_embeddings,
433
+ cache_config=cache_config,
434
+ quant_config=quant_config,
435
+ layer_id=layer_id,
436
+ )
437
+ else:
438
+ self.self_attn = MiniCPM3Attention(
439
+ config=config,
440
+ hidden_size=self.hidden_size,
441
+ num_heads=config.num_attention_heads,
442
+ qk_nope_head_dim=config.qk_nope_head_dim,
443
+ qk_rope_head_dim=config.qk_rope_head_dim,
444
+ v_head_dim=self.hidden_size // config.num_attention_heads,
445
+ q_lora_rank=(
446
+ config.q_lora_rank if hasattr(config, "q_lora_rank") else None
447
+ ),
448
+ kv_lora_rank=config.kv_lora_rank,
449
+ rope_theta=rope_theta,
450
+ rope_scaling=rope_scaling,
451
+ max_position_embeddings=max_position_embeddings,
452
+ cache_config=cache_config,
453
+ quant_config=quant_config,
454
+ layer_id=layer_id,
455
+ )
456
+ self.mlp = MiniCPM3MLP(
457
+ hidden_size=self.hidden_size,
458
+ intermediate_size=config.intermediate_size,
459
+ hidden_act=config.hidden_act,
460
+ quant_config=quant_config,
461
+ )
462
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
463
+ self.post_attention_layernorm = RMSNorm(
464
+ config.hidden_size, eps=config.rms_norm_eps
465
+ )
466
+
467
+ def forward(
468
+ self,
469
+ positions: torch.Tensor,
470
+ hidden_states: torch.Tensor,
471
+ input_metadata: InputMetadata,
472
+ residual: Optional[torch.Tensor],
473
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
474
+ # Self Attention
475
+ residual = hidden_states
476
+ hidden_states = self.input_layernorm(hidden_states)
477
+ hidden_states = self.self_attn(
478
+ positions=positions,
479
+ hidden_states=hidden_states,
480
+ input_metadata=input_metadata,
481
+ )
482
+ hidden_states = residual + hidden_states * (
483
+ self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
484
+ )
485
+
486
+ # Fully Connected
487
+ residual = hidden_states
488
+ hidden_states = self.post_attention_layernorm(hidden_states)
489
+ hidden_states = self.mlp(hidden_states)
490
+ hidden_states = residual + hidden_states * (
491
+ self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
492
+ )
493
+
494
+ return hidden_states, None
495
+
496
+
497
+ class MiniCPM3Model(nn.Module):
498
+ def __init__(
499
+ self,
500
+ config: PretrainedConfig,
501
+ cache_config: Optional[CacheConfig] = None,
502
+ quant_config: Optional[QuantizationConfig] = None,
503
+ ) -> None:
504
+ super().__init__()
505
+ self.config = config
506
+ self.padding_idx = config.pad_token_id
507
+ self.vocab_size = config.vocab_size
508
+ self.embed_tokens = VocabParallelEmbedding(
509
+ self.vocab_size,
510
+ config.hidden_size,
511
+ org_num_embeddings=config.vocab_size,
512
+ )
513
+ self.layers = nn.ModuleList(
514
+ [
515
+ MiniCPM3DecoderLayer(
516
+ config, i, cache_config=cache_config, quant_config=quant_config
517
+ )
518
+ for i in range(config.num_hidden_layers)
519
+ ]
520
+ )
521
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
522
+
523
+ def forward(
524
+ self,
525
+ input_ids: torch.Tensor,
526
+ positions: torch.Tensor,
527
+ input_metadata: InputMetadata,
528
+ input_embeds: torch.Tensor = None,
529
+ ) -> torch.Tensor:
530
+ if input_embeds is None:
531
+ hidden_states = self.embed_tokens(input_ids) * self.config.scale_emb
532
+ else:
533
+ hidden_states = input_embeds
534
+ residual = None
535
+
536
+ for i in range(len(self.layers)):
537
+ layer = self.layers[i]
538
+ hidden_states, residual = layer(
539
+ positions,
540
+ hidden_states,
541
+ input_metadata,
542
+ residual,
543
+ )
544
+ hidden_states = self.norm(hidden_states)
545
+ return hidden_states
546
+
547
+
548
+ class MiniCPM3ForCausalLM(nn.Module):
549
+ def __init__(
550
+ self,
551
+ config: PretrainedConfig,
552
+ cache_config: Optional[CacheConfig] = None,
553
+ quant_config: Optional[QuantizationConfig] = None,
554
+ ) -> None:
555
+ super().__init__()
556
+ self.config = config
557
+
558
+ self.num_experts = getattr(self.config, "num_experts", 0)
559
+ self.quant_config = quant_config
560
+ self.model = MiniCPM3Model(
561
+ config, cache_config=cache_config, quant_config=quant_config
562
+ )
563
+ # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
564
+ if not self.config.tie_word_embeddings:
565
+ self.lm_head = ParallelLMHead(
566
+ config.vocab_size,
567
+ config.hidden_size,
568
+ org_num_embeddings=config.vocab_size,
569
+ )
570
+
571
+ self.scale_width = self.config.hidden_size / self.config.dim_model_base
572
+
573
+ self.logits_processor = LogitsProcessor(config)
574
+
575
+ @torch.no_grad()
576
+ def forward(
577
+ self,
578
+ input_ids: torch.Tensor,
579
+ positions: torch.Tensor,
580
+ input_metadata: InputMetadata,
581
+ input_embeds: torch.Tensor = None,
582
+ ) -> torch.Tensor:
583
+ if input_embeds is not None:
584
+ input_embeds = input_embeds * self.config.scale_emb
585
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
586
+ hidden_states = hidden_states / self.scale_width
587
+ if self.config.tie_word_embeddings:
588
+ lm_head_weight = self.model.embed_tokens.weight
589
+ else:
590
+ lm_head_weight = self.lm_head.weight
591
+ return self.logits_processor(
592
+ input_ids, hidden_states, lm_head_weight, input_metadata
593
+ )
594
+
595
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
596
+ stacked_params_mapping = [
597
+ # (param_name, shard_name, shard_id)
598
+ ("gate_up_proj", "gate_proj", 0),
599
+ ("gate_up_proj", "up_proj", 1),
600
+ ]
601
+ expert_params_mapping = [
602
+ # (param_name, weight_name, expert_id)
603
+ (
604
+ "ws" if weight_name in ["w1", "w3"] else "w2s",
605
+ f"experts.{expert_id}.{weight_name}.weight",
606
+ expert_id,
607
+ )
608
+ for expert_id in range(self.num_experts)
609
+ for weight_name in ["w1", "w2", "w3"]
610
+ ]
611
+ params_dict = dict(self.named_parameters())
612
+ for name, loaded_weight in weights:
613
+ if "rotary_emb.inv_freq" in name:
614
+ continue
615
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
616
+ # Models trained using ColossalAI may include these tensors in
617
+ # the checkpoint. Skip them.
618
+ continue
619
+
620
+ for param_name, weight_name, shard_id in stacked_params_mapping:
621
+ if weight_name not in name:
622
+ continue
623
+ name = name.replace(weight_name, param_name)
624
+ # Skip loading extra bias for GPTQ models.
625
+ if name.endswith(".bias") and name not in params_dict:
626
+ continue
627
+ param = params_dict[name]
628
+ weight_loader = param.weight_loader
629
+ weight_loader(param, loaded_weight, shard_id)
630
+ break
631
+ else:
632
+ for param_name, weight_name, expert_id in expert_params_mapping:
633
+ if weight_name not in name:
634
+ continue
635
+ name = name.replace(weight_name, param_name)
636
+ param = params_dict[name]
637
+ weight_loader = param.weight_loader
638
+ weight_loader(
639
+ param, loaded_weight, weight_name, expert_id=expert_id
640
+ )
641
+ break
642
+ else:
643
+ # Skip loading extra bias for GPTQ models.
644
+ if name.endswith(".bias") and name not in params_dict:
645
+ continue
646
+ param = params_dict[name]
647
+ weight_loader = getattr(
648
+ param, "weight_loader", default_weight_loader
649
+ )
650
+ weight_loader(param, loaded_weight)
651
+
652
+ if global_server_args_dict["enable_mla"]:
653
+ for layer_id in range(self.config.num_hidden_layers):
654
+ self_attn = self.model.layers[layer_id].self_attn
655
+ w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
656
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
657
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
658
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
659
+ self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
660
+ if hasattr(self_attn.kv_b_proj, "weight_scale"):
661
+ self_attn.w_scale = self_attn.kv_b_proj.weight_scale
662
+ del self_attn.kv_b_proj
663
+
664
+
665
+ EntryClass = MiniCPM3ForCausalLM
@@ -41,7 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
- from sglang.srt.layers.sampler import Sampler
44
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_
45
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
45
46
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
47
 
47
48
 
@@ -297,10 +298,10 @@ class MixtralForCausalLM(nn.Module):
297
298
  super().__init__()
298
299
  self.config = config
299
300
  self.quant_config = quant_config
301
+ self.torchao_config = global_server_args_dict["torchao_config"]
300
302
  self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
301
303
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
302
304
  self.logits_processor = LogitsProcessor(config)
303
- self.sampler = Sampler()
304
305
 
305
306
  def forward(
306
307
  self,
@@ -310,11 +311,9 @@ class MixtralForCausalLM(nn.Module):
310
311
  input_embeds: torch.Tensor = None,
311
312
  ) -> torch.Tensor:
312
313
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
313
- logits_output = self.logits_processor(
314
+ return self.logits_processor(
314
315
  input_ids, hidden_states, self.lm_head.weight, input_metadata
315
316
  )
316
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
317
- return sample_output, logits_output
318
317
 
319
318
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
320
319
  stacked_params_mapping = [
@@ -380,5 +379,7 @@ class MixtralForCausalLM(nn.Module):
380
379
  )
381
380
  weight_loader(param, loaded_weight)
382
381
 
382
+ apply_torchao_config_(self, params_dict, set(["proj.weight"]))
383
+
383
384
 
384
385
  EntryClass = MixtralForCausalLM