sglang 0.4.3__py3-none-any.whl → 0.4.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,295 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ """Inference-only DeepSeek NextN Speculative Decoding."""
16
+ from typing import Iterable, Optional, Tuple
17
+
18
+ import torch
19
+ from torch import nn
20
+ from transformers import PretrainedConfig
21
+ from vllm import _custom_ops as ops
22
+
23
+ from sglang.srt.layers.layernorm import RMSNorm
24
+ from sglang.srt.layers.linear import ReplicatedLinear
25
+ from sglang.srt.layers.logits_processor import LogitsProcessor
26
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
27
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
28
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
29
+ from sglang.srt.layers.quantization.fp8_utils import (
30
+ block_quant_to_tensor_quant,
31
+ normalize_e4m3fn_to_e4m3fnuz,
32
+ )
33
+ from sglang.srt.layers.vocab_parallel_embedding import (
34
+ ParallelLMHead,
35
+ VocabParallelEmbedding,
36
+ )
37
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
38
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
40
+ from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
41
+ from sglang.srt.utils import is_hip
42
+
43
+ is_hip_ = is_hip()
44
+
45
+
46
+ class DeepseekModelNextN(nn.Module):
47
+ def __init__(
48
+ self,
49
+ config: PretrainedConfig,
50
+ quant_config: Optional[QuantizationConfig] = None,
51
+ ) -> None:
52
+ super().__init__()
53
+ self.vocab_size = config.vocab_size
54
+
55
+ self.embed_tokens = VocabParallelEmbedding(
56
+ config.vocab_size,
57
+ config.hidden_size,
58
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
59
+ )
60
+
61
+ self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
62
+ self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
63
+
64
+ self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
65
+
66
+ self.decoder = DeepseekV2DecoderLayer(
67
+ config, 0, quant_config=quant_config, is_nextn=True
68
+ )
69
+
70
+ self.shared_head = nn.Module()
71
+ self.shared_head.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
72
+
73
+ def forward(
74
+ self,
75
+ input_ids: torch.Tensor,
76
+ positions: torch.Tensor,
77
+ forward_batch: ForwardBatch,
78
+ input_embeds: torch.Tensor = None,
79
+ ) -> torch.Tensor:
80
+ if input_embeds is None:
81
+ hidden_states = self.embed_tokens(input_ids)
82
+ else:
83
+ hidden_states = input_embeds
84
+
85
+ hidden_states = self.eh_proj(
86
+ torch.cat(
87
+ (
88
+ self.enorm(hidden_states),
89
+ self.hnorm(forward_batch.spec_info.hidden_states),
90
+ ),
91
+ dim=-1,
92
+ )
93
+ )
94
+
95
+ residual = None
96
+ hidden_states, residual = self.decoder(
97
+ positions, hidden_states, forward_batch, residual
98
+ )
99
+
100
+ if not forward_batch.forward_mode.is_idle():
101
+ hidden_states, _ = self.shared_head.norm(hidden_states, residual)
102
+ return hidden_states
103
+
104
+
105
+ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
106
+
107
+ def __init__(
108
+ self,
109
+ config: PretrainedConfig,
110
+ quant_config: Optional[QuantizationConfig] = None,
111
+ ) -> None:
112
+ nn.Module.__init__(self)
113
+ self.config = config
114
+ self.quant_config = quant_config
115
+
116
+ self.model = DeepseekModelNextN(config, quant_config)
117
+
118
+ if global_server_args_dict["enable_dp_attention"]:
119
+ self.model.shared_head.head = ReplicatedLinear(
120
+ config.hidden_size,
121
+ config.vocab_size,
122
+ bias=False,
123
+ )
124
+ self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
125
+ else:
126
+ self.model.shared_head.head = ParallelLMHead(
127
+ config.vocab_size,
128
+ config.hidden_size,
129
+ quant_config=quant_config,
130
+ )
131
+ self.logits_processor = LogitsProcessor(config)
132
+
133
+ @torch.no_grad()
134
+ def forward(
135
+ self,
136
+ input_ids: torch.Tensor,
137
+ positions: torch.Tensor,
138
+ forward_batch: ForwardBatch,
139
+ ) -> torch.Tensor:
140
+ hidden_states = self.model(input_ids, positions, forward_batch)
141
+ return self.logits_processor(
142
+ input_ids, hidden_states, self.model.shared_head.head, forward_batch
143
+ )
144
+
145
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
146
+ if hasattr(self.config, "num_nextn_predict_layers"):
147
+ num_nextn_layers = self.config.num_nextn_predict_layers
148
+ assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
149
+ assert num_nextn_layers == self.config.num_hidden_layers
150
+ else:
151
+ raise ValueError("num_nextn_predict_layers is not in the config")
152
+
153
+ stacked_params_mapping = [
154
+ # (param_name, shard_name, shard_id)
155
+ ("gate_up_proj", "gate_proj", 0),
156
+ ("gate_up_proj", "up_proj", 1),
157
+ ]
158
+
159
+ # Params for weights, fp8 weight scales, fp8 activation scales
160
+ # (param_name, weight_name, expert_id, shard_id)
161
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
162
+ expert_params_mapping = MoEImpl.make_expert_params_mapping(
163
+ ckpt_gate_proj_name="gate_proj",
164
+ ckpt_down_proj_name="down_proj",
165
+ ckpt_up_proj_name="up_proj",
166
+ num_experts=self.config.n_routed_experts,
167
+ )
168
+
169
+ nextn_layer_prefix = "model.layers.0"
170
+ nextn_spec_weight_names = [
171
+ "shared_head.head",
172
+ "shared_head.norm",
173
+ "eh_proj",
174
+ "embed_tokens",
175
+ "enorm",
176
+ "hnorm",
177
+ ]
178
+
179
+ params_dict = dict(self.named_parameters())
180
+ for name, loaded_weight in weights:
181
+ if not name.startswith(nextn_layer_prefix):
182
+ continue
183
+ else:
184
+ is_decoder = True
185
+ # For nextn specific weights
186
+ for weight_name in nextn_spec_weight_names:
187
+ if weight_name in name:
188
+ name = name.replace(nextn_layer_prefix, "model")
189
+ is_decoder = False
190
+ break
191
+ # For decoder layer weights
192
+ if is_decoder:
193
+ name = name.replace(nextn_layer_prefix, "model.decoder")
194
+
195
+ if "rotary_emb.inv_freq" in name:
196
+ continue
197
+ for param_name, weight_name, shard_id in stacked_params_mapping:
198
+ # Skip non-stacked layers and experts (experts handled below).
199
+ if weight_name not in name:
200
+ continue
201
+ # We have mlp.experts[0].gate_proj in the checkpoint.
202
+ # Since we handle the experts below in expert_params_mapping,
203
+ # we need to skip here BEFORE we update the name, otherwise
204
+ # name will be updated to mlp.experts[0].gate_up_proj, which
205
+ # will then be updated below in expert_params_mapping
206
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
207
+ if ("mlp.experts." in name) and name not in params_dict:
208
+ continue
209
+ name = name.replace(weight_name, param_name)
210
+ # Skip loading extra bias for GPTQ models.
211
+ if name.endswith(".bias") and name not in params_dict:
212
+ continue
213
+ param = params_dict[name]
214
+ weight_loader = param.weight_loader
215
+ weight_loader(param, loaded_weight, shard_id)
216
+ break
217
+ else:
218
+ for mapping in expert_params_mapping:
219
+ param_name, weight_name, expert_id, shard_id = mapping
220
+ if weight_name not in name:
221
+ continue
222
+ name = name.replace(weight_name, param_name)
223
+ param = params_dict[name]
224
+ weight_loader = param.weight_loader
225
+ weight_loader(
226
+ param,
227
+ loaded_weight,
228
+ name,
229
+ shard_id=shard_id,
230
+ expert_id=expert_id,
231
+ )
232
+ break
233
+ else:
234
+ # Skip loading extra bias for GPTQ models.
235
+ if name.endswith(".bias") and name not in params_dict:
236
+ continue
237
+
238
+ param = params_dict[name]
239
+ weight_loader = getattr(
240
+ param, "weight_loader", default_weight_loader
241
+ )
242
+ weight_loader(param, loaded_weight)
243
+
244
+ if not global_server_args_dict["disable_mla"]:
245
+ self_attn = self.model.decoder.self_attn
246
+ if hasattr(self_attn.kv_b_proj, "qweight"):
247
+ # AWQ compatible
248
+ w = ops.awq_dequantize(
249
+ self_attn.kv_b_proj.qweight,
250
+ self_attn.kv_b_proj.scales,
251
+ self_attn.kv_b_proj.qzeros,
252
+ 0,
253
+ 0,
254
+ 0,
255
+ ).T
256
+ else:
257
+ w = self_attn.kv_b_proj.weight
258
+ # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
259
+ # This may affect the accuracy of fp8 model.
260
+ if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
261
+ torch.float8_e4m3fn,
262
+ torch.float8_e4m3fnuz,
263
+ ):
264
+ weight_block_size = self.quant_config.weight_block_size
265
+ if weight_block_size is not None:
266
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
267
+ if is_hip_:
268
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
269
+ weight=w,
270
+ weight_scale=self_attn.kv_b_proj.weight_scale_inv,
271
+ input_scale=None,
272
+ )
273
+ else:
274
+ weight = w
275
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
276
+
277
+ w, scale = block_quant_to_tensor_quant(
278
+ weight, weight_scale, weight_block_size
279
+ )
280
+ self_attn.w_scale = scale
281
+ w_kc, w_vc = w.unflatten(
282
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
283
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
284
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
285
+ self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
286
+ if (
287
+ hasattr(self_attn.kv_b_proj, "weight_scale")
288
+ and self_attn.w_scale is None
289
+ ):
290
+ self_attn.w_scale = self_attn.kv_b_proj.weight_scale
291
+ if is_hip_:
292
+ self_attn.w_scale *= 2.0
293
+
294
+
295
+ EntryClass = [DeepseekV3ForCausalLMNextN]
@@ -519,6 +519,8 @@ class DeepseekV2AttentionMLA(nn.Module):
519
519
  # Triton: Use normal computation for prefill and use weight absorption for extend/decode
520
520
  if (
521
521
  forward_batch.forward_mode.is_extend()
522
+ and not forward_batch.forward_mode.is_target_verify()
523
+ and not forward_batch.forward_mode.is_draft_extend()
522
524
  and forward_batch.extend_prefix_lens.sum() == 0
523
525
  ):
524
526
  return self.forward_normal(positions, hidden_states, forward_batch)
@@ -680,6 +682,7 @@ class DeepseekV2DecoderLayer(nn.Module):
680
682
  config: PretrainedConfig,
681
683
  layer_id: int,
682
684
  quant_config: Optional[QuantizationConfig] = None,
685
+ is_nextn: bool = False,
683
686
  ) -> None:
684
687
  super().__init__()
685
688
  self.hidden_size = config.hidden_size
@@ -731,7 +734,7 @@ class DeepseekV2DecoderLayer(nn.Module):
731
734
  quant_config=quant_config,
732
735
  layer_id=layer_id,
733
736
  )
734
- if (
737
+ if is_nextn or (
735
738
  config.n_routed_experts is not None
736
739
  and layer_id >= config.first_k_dense_replace
737
740
  and layer_id % config.moe_layer_freq == 0
@@ -449,7 +449,8 @@ class LlavaBaseForCausalLM(nn.Module):
449
449
  projector_weights = {
450
450
  "model.mm_projector.0": "multi_modal_projector.linear_1",
451
451
  "model.mm_projector.2": "multi_modal_projector.linear_2",
452
- "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
452
+ "model.vision_tower.vision_tower": "vision_tower",
453
+ # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
453
454
  "model.image_newline": "language_model.model.image_newline",
454
455
  }
455
456
  params_dict = dict(self.named_parameters())