sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +0 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +62 -6
- sglang/srt/disaggregation/mini_lb.py +5 -1
- sglang/srt/disaggregation/mooncake/conn.py +32 -62
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/prefill.py +40 -4
- sglang/srt/disaggregation/utils.py +15 -0
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +114 -71
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -57
- sglang/srt/layers/quantization/fp8_utils.py +187 -262
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +3 -2
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +1 -0
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +2 -4
- sglang/srt/managers/scheduler.py +12 -71
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +7 -2
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/model_runner.py +20 -27
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +289 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +29 -201
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +34 -32
- sglang/srt/speculative/eagle_worker.py +4 -7
- sglang/srt/utils.py +16 -1
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,398 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
from typing import Any, Dict, Iterable, Optional, Set, Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch import nn
|
6
|
+
|
7
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
8
|
+
from sglang.srt.layers.activation import get_act_fn
|
9
|
+
from sglang.srt.layers.linear import (
|
10
|
+
ColumnParallelLinear,
|
11
|
+
QKVParallelLinear,
|
12
|
+
RowParallelLinear,
|
13
|
+
)
|
14
|
+
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
15
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
16
|
+
from sglang.srt.layers.radix_attention import AttentionType, RadixAttention
|
17
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
18
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
19
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
20
|
+
|
21
|
+
BertConfig = None
|
22
|
+
|
23
|
+
|
24
|
+
class BertEmbedding(nn.Module):
|
25
|
+
|
26
|
+
def __init__(self, config: BertConfig):
|
27
|
+
|
28
|
+
super().__init__()
|
29
|
+
self.size = config.hidden_size
|
30
|
+
self.word_embeddings = VocabParallelEmbedding(
|
31
|
+
config.vocab_size, config.hidden_size
|
32
|
+
)
|
33
|
+
self.position_embeddings = VocabParallelEmbedding(
|
34
|
+
config.max_position_embeddings, config.hidden_size
|
35
|
+
)
|
36
|
+
self.token_type_embeddings = VocabParallelEmbedding(
|
37
|
+
config.type_vocab_size, config.hidden_size
|
38
|
+
)
|
39
|
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
40
|
+
self.position_ids = nn.Parameter(
|
41
|
+
torch.empty((1, config.max_position_embeddings)),
|
42
|
+
)
|
43
|
+
|
44
|
+
self.position_embedding_type = config.position_embedding_type
|
45
|
+
if self.position_embedding_type != "absolute":
|
46
|
+
raise ValueError(
|
47
|
+
"Only 'absolute' position_embedding_type" + " is supported"
|
48
|
+
)
|
49
|
+
|
50
|
+
def forward(
|
51
|
+
self,
|
52
|
+
input_ids: torch.Tensor,
|
53
|
+
position_ids: torch.Tensor,
|
54
|
+
) -> torch.Tensor:
|
55
|
+
input_shape = input_ids.size()
|
56
|
+
|
57
|
+
# Input embeddings.
|
58
|
+
inputs_embeds = self.word_embeddings(input_ids)
|
59
|
+
|
60
|
+
# Position embeddings.
|
61
|
+
position_embeddings = self.position_embeddings(position_ids)
|
62
|
+
|
63
|
+
token_type_ids = torch.zeros(
|
64
|
+
input_shape, dtype=torch.long, device=inputs_embeds.device
|
65
|
+
)
|
66
|
+
|
67
|
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
68
|
+
|
69
|
+
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
70
|
+
embeddings = self.LayerNorm(embeddings)
|
71
|
+
return embeddings
|
72
|
+
|
73
|
+
|
74
|
+
class BertEncoder(nn.Module):
|
75
|
+
|
76
|
+
def __init__(
|
77
|
+
self,
|
78
|
+
config: BertConfig,
|
79
|
+
quant_config: Optional[QuantizationConfig] = None,
|
80
|
+
prefix: str = "",
|
81
|
+
):
|
82
|
+
super().__init__()
|
83
|
+
self.config = config
|
84
|
+
self.quant_config = quant_config
|
85
|
+
self.layer = nn.ModuleList(
|
86
|
+
[
|
87
|
+
BertLayer(
|
88
|
+
config=config,
|
89
|
+
layer_id=layer_idx,
|
90
|
+
quant_config=quant_config,
|
91
|
+
prefix=f"{prefix}.layer.{layer_idx}",
|
92
|
+
)
|
93
|
+
for layer_idx in range(config.num_hidden_layers)
|
94
|
+
]
|
95
|
+
)
|
96
|
+
|
97
|
+
def forward(
|
98
|
+
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
99
|
+
) -> torch.Tensor:
|
100
|
+
for layer in self.layer:
|
101
|
+
hidden_states = layer(hidden_states, forward_batch)
|
102
|
+
return hidden_states
|
103
|
+
|
104
|
+
|
105
|
+
class BertLayer(nn.Module):
|
106
|
+
|
107
|
+
def __init__(
|
108
|
+
self,
|
109
|
+
config: BertConfig,
|
110
|
+
layer_id: int = 0,
|
111
|
+
quant_config: Optional[QuantizationConfig] = None,
|
112
|
+
prefix: str = "",
|
113
|
+
):
|
114
|
+
super().__init__()
|
115
|
+
|
116
|
+
self.attention = BertAttention(
|
117
|
+
hidden_size=config.hidden_size,
|
118
|
+
num_attention_heads=config.num_attention_heads,
|
119
|
+
layer_id=layer_id,
|
120
|
+
layer_norm_eps=config.layer_norm_eps,
|
121
|
+
quant_config=quant_config,
|
122
|
+
prefix=f"{prefix}.attention",
|
123
|
+
)
|
124
|
+
|
125
|
+
self.intermediate = BertIntermediate(
|
126
|
+
hidden_size=config.hidden_size,
|
127
|
+
intermediate_size=config.intermediate_size,
|
128
|
+
hidden_act=config.hidden_act,
|
129
|
+
quant_config=quant_config,
|
130
|
+
prefix=f"{prefix}.intermediate",
|
131
|
+
)
|
132
|
+
|
133
|
+
self.output = BertOutput(
|
134
|
+
hidden_size=config.hidden_size,
|
135
|
+
intermediate_size=config.intermediate_size,
|
136
|
+
layer_norm_eps=config.layer_norm_eps,
|
137
|
+
quant_config=quant_config,
|
138
|
+
prefix=f"{prefix}.output",
|
139
|
+
)
|
140
|
+
|
141
|
+
def forward(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
|
142
|
+
attn_output = self.attention(hidden_states, forward_batch)
|
143
|
+
intermediate_output = self.intermediate(attn_output)
|
144
|
+
output = self.output(intermediate_output, attn_output)
|
145
|
+
return output
|
146
|
+
|
147
|
+
|
148
|
+
class BertAttention(nn.Module):
|
149
|
+
|
150
|
+
def __init__(
|
151
|
+
self,
|
152
|
+
hidden_size: int,
|
153
|
+
num_attention_heads: int,
|
154
|
+
layer_norm_eps: float,
|
155
|
+
layer_id: int = 0,
|
156
|
+
quant_config: Optional[QuantizationConfig] = None,
|
157
|
+
prefix: str = "",
|
158
|
+
):
|
159
|
+
super().__init__()
|
160
|
+
|
161
|
+
self.self_attn = BertSelfAttention(
|
162
|
+
hidden_size=hidden_size,
|
163
|
+
num_attention_heads=num_attention_heads,
|
164
|
+
layer_id=layer_id,
|
165
|
+
quant_config=quant_config,
|
166
|
+
prefix=f"{prefix}.output",
|
167
|
+
)
|
168
|
+
|
169
|
+
self.output = BertSelfOutput(
|
170
|
+
hidden_size=hidden_size,
|
171
|
+
layer_norm_eps=layer_norm_eps,
|
172
|
+
quant_config=quant_config,
|
173
|
+
prefix=f"{prefix}.output",
|
174
|
+
)
|
175
|
+
|
176
|
+
def forward(
|
177
|
+
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
178
|
+
) -> torch.Tensor:
|
179
|
+
self_output = self.self_attn(hidden_states, forward_batch)
|
180
|
+
return self.output(self_output, hidden_states)
|
181
|
+
|
182
|
+
|
183
|
+
class BertSelfAttention(nn.Module):
|
184
|
+
|
185
|
+
def __init__(
|
186
|
+
self,
|
187
|
+
hidden_size: int,
|
188
|
+
num_attention_heads: int,
|
189
|
+
layer_id: int = 0,
|
190
|
+
quant_config: Optional[QuantizationConfig] = None,
|
191
|
+
prefix: str = "",
|
192
|
+
):
|
193
|
+
super().__init__()
|
194
|
+
self.hidden_size = hidden_size
|
195
|
+
tp_size = get_tensor_model_parallel_world_size()
|
196
|
+
|
197
|
+
self.total_num_heads = num_attention_heads
|
198
|
+
assert self.total_num_heads % tp_size == 0
|
199
|
+
|
200
|
+
self.num_heads = self.total_num_heads // tp_size
|
201
|
+
self.total_num_kv_heads = self.total_num_heads
|
202
|
+
self.head_dim = self.hidden_size // self.total_num_heads
|
203
|
+
assert self.head_dim * self.total_num_heads == self.hidden_size
|
204
|
+
|
205
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
206
|
+
|
207
|
+
self.q_size = self.num_heads * self.head_dim
|
208
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
209
|
+
self.scaling = self.head_dim**-0.5
|
210
|
+
self.qkv_proj = QKVParallelLinear(
|
211
|
+
hidden_size=self.hidden_size,
|
212
|
+
head_size=self.head_dim,
|
213
|
+
total_num_heads=self.total_num_heads,
|
214
|
+
total_num_kv_heads=self.total_num_kv_heads,
|
215
|
+
bias=True,
|
216
|
+
quant_config=quant_config,
|
217
|
+
prefix=f"{prefix}.qkv_proj",
|
218
|
+
)
|
219
|
+
|
220
|
+
self.attn = RadixAttention(
|
221
|
+
num_heads=self.num_heads,
|
222
|
+
head_dim=self.head_dim,
|
223
|
+
scaling=self.scaling,
|
224
|
+
num_kv_heads=self.num_kv_heads,
|
225
|
+
layer_id=layer_id,
|
226
|
+
prefix=f"{prefix}.attn",
|
227
|
+
attn_type=AttentionType.ENCODER_ONLY,
|
228
|
+
)
|
229
|
+
|
230
|
+
def forward(
|
231
|
+
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
232
|
+
) -> torch.Tensor:
|
233
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
234
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
235
|
+
output = self.attn(q, k, v, forward_batch)
|
236
|
+
return output
|
237
|
+
|
238
|
+
|
239
|
+
class BertSelfOutput(nn.Module):
|
240
|
+
|
241
|
+
def __init__(
|
242
|
+
self,
|
243
|
+
hidden_size: int,
|
244
|
+
layer_norm_eps: float,
|
245
|
+
quant_config: Optional[QuantizationConfig] = None,
|
246
|
+
prefix: str = "",
|
247
|
+
):
|
248
|
+
super().__init__()
|
249
|
+
self.dense = RowParallelLinear(
|
250
|
+
input_size=hidden_size,
|
251
|
+
output_size=hidden_size,
|
252
|
+
bias=True,
|
253
|
+
quant_config=quant_config,
|
254
|
+
prefix=f"{prefix}.dense",
|
255
|
+
)
|
256
|
+
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
257
|
+
|
258
|
+
def forward(
|
259
|
+
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
|
260
|
+
) -> torch.Tensor:
|
261
|
+
hidden_states, _ = self.dense(hidden_states)
|
262
|
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
263
|
+
return hidden_states
|
264
|
+
|
265
|
+
|
266
|
+
class BertIntermediate(nn.Module):
|
267
|
+
|
268
|
+
def __init__(
|
269
|
+
self,
|
270
|
+
hidden_size: int,
|
271
|
+
intermediate_size: int,
|
272
|
+
hidden_act: str,
|
273
|
+
quant_config: Optional[QuantizationConfig] = None,
|
274
|
+
prefix: str = "",
|
275
|
+
):
|
276
|
+
super().__init__()
|
277
|
+
self.dense = ColumnParallelLinear(
|
278
|
+
input_size=hidden_size,
|
279
|
+
output_size=intermediate_size,
|
280
|
+
bias=True,
|
281
|
+
quant_config=quant_config,
|
282
|
+
prefix=f"{prefix}.dense",
|
283
|
+
)
|
284
|
+
self.intermediate_act_fn = get_act_fn(hidden_act)
|
285
|
+
|
286
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
287
|
+
hidden_states, _ = self.dense(hidden_states)
|
288
|
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
289
|
+
return hidden_states
|
290
|
+
|
291
|
+
|
292
|
+
class BertOutput(nn.Module):
|
293
|
+
|
294
|
+
def __init__(
|
295
|
+
self,
|
296
|
+
hidden_size: int,
|
297
|
+
intermediate_size: int,
|
298
|
+
layer_norm_eps: float,
|
299
|
+
quant_config: Optional[QuantizationConfig] = None,
|
300
|
+
prefix: str = "",
|
301
|
+
):
|
302
|
+
super().__init__()
|
303
|
+
|
304
|
+
self.dense = RowParallelLinear(
|
305
|
+
input_size=intermediate_size,
|
306
|
+
output_size=hidden_size,
|
307
|
+
bias=True,
|
308
|
+
quant_config=quant_config,
|
309
|
+
prefix=f"{prefix}.dense",
|
310
|
+
)
|
311
|
+
|
312
|
+
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
313
|
+
|
314
|
+
def forward(
|
315
|
+
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
|
316
|
+
) -> torch.Tensor:
|
317
|
+
hidden_states, _ = self.dense(hidden_states)
|
318
|
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
319
|
+
return hidden_states
|
320
|
+
|
321
|
+
|
322
|
+
class BertModel(nn.Module):
|
323
|
+
|
324
|
+
def __init__(
|
325
|
+
self,
|
326
|
+
*,
|
327
|
+
config: BertConfig,
|
328
|
+
quant_config: Optional[QuantizationConfig] = None,
|
329
|
+
prefix: str = "",
|
330
|
+
):
|
331
|
+
super().__init__()
|
332
|
+
self.config = config
|
333
|
+
self.embeddings = BertEmbedding(config)
|
334
|
+
self.encoder = BertEncoder(
|
335
|
+
config=config, quant_config=quant_config, prefix=f"encoder"
|
336
|
+
)
|
337
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
338
|
+
# self.pooler = BertPooler(config)
|
339
|
+
|
340
|
+
@torch.no_grad()
|
341
|
+
def forward(
|
342
|
+
self,
|
343
|
+
input_ids: torch.Tensor,
|
344
|
+
positions: torch.Tensor,
|
345
|
+
forward_batch: ForwardBatch,
|
346
|
+
input_embeds: torch.Tensor = None,
|
347
|
+
get_embedding: bool = False,
|
348
|
+
) -> torch.Tensor:
|
349
|
+
assert get_embedding == True
|
350
|
+
# Your tokenized IDs
|
351
|
+
|
352
|
+
hidden_states = self.embeddings(
|
353
|
+
input_ids=input_ids,
|
354
|
+
position_ids=positions,
|
355
|
+
)
|
356
|
+
|
357
|
+
hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
|
358
|
+
return self.pooler(hidden_states, forward_batch)
|
359
|
+
|
360
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
361
|
+
stacked_params_mapping = [
|
362
|
+
# (param_name, shard_name, shard_id)
|
363
|
+
("qkv_proj", "query", "q"),
|
364
|
+
("qkv_proj", "key", "k"),
|
365
|
+
("qkv_proj", "value", "v"),
|
366
|
+
]
|
367
|
+
|
368
|
+
params_dict = dict(self.named_parameters())
|
369
|
+
for name, loaded_weight in weights:
|
370
|
+
name = name.replace("self", "self_attn")
|
371
|
+
if "pooler" in name:
|
372
|
+
continue
|
373
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
374
|
+
|
375
|
+
if weight_name not in name:
|
376
|
+
continue
|
377
|
+
name = name.replace(weight_name, param_name)
|
378
|
+
# Skip loading extra bias for GPTQ models.
|
379
|
+
if name.endswith(".bias") and name not in params_dict:
|
380
|
+
continue
|
381
|
+
param = params_dict[name]
|
382
|
+
weight_loader = param.weight_loader
|
383
|
+
weight_loader(param, loaded_weight, shard_id)
|
384
|
+
break
|
385
|
+
else:
|
386
|
+
# Skip loading extra bias for GPTQ models.
|
387
|
+
if name.endswith(".bias") and name not in params_dict:
|
388
|
+
continue
|
389
|
+
param = params_dict[name]
|
390
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
391
|
+
weight_loader(param, loaded_weight)
|
392
|
+
|
393
|
+
|
394
|
+
class Contriever(BertModel):
|
395
|
+
pass
|
396
|
+
|
397
|
+
|
398
|
+
EntryClass = [BertModel, Contriever]
|
sglang/srt/models/deepseek.py
CHANGED
@@ -170,7 +170,7 @@ class DeepseekMoE(nn.Module):
|
|
170
170
|
shared_output = self.shared_experts(hidden_states)
|
171
171
|
# router_logits: (num_tokens, n_experts)
|
172
172
|
router_logits, _ = self.gate(hidden_states)
|
173
|
-
final_hidden_states = fused_moe(
|
173
|
+
final_hidden_states = fused_moe.fused_moe(
|
174
174
|
hidden_states,
|
175
175
|
self.w1,
|
176
176
|
self.w2,
|
@@ -40,7 +40,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
40
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
41
41
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
42
42
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
43
|
-
from sglang.srt.utils import add_prefix, is_cuda, is_hip
|
43
|
+
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip
|
44
44
|
|
45
45
|
_is_hip = is_hip()
|
46
46
|
_is_cuda = is_cuda()
|
@@ -48,7 +48,7 @@ _is_cuda = is_cuda()
|
|
48
48
|
if _is_cuda:
|
49
49
|
from sgl_kernel import awq_dequantize
|
50
50
|
else:
|
51
|
-
from vllm import
|
51
|
+
from vllm._custom_ops import awq_dequantize
|
52
52
|
|
53
53
|
|
54
54
|
class DeepseekModelNextN(nn.Module):
|
@@ -91,6 +91,14 @@ class DeepseekModelNextN(nn.Module):
|
|
91
91
|
forward_batch: ForwardBatch,
|
92
92
|
input_embeds: torch.Tensor = None,
|
93
93
|
) -> torch.Tensor:
|
94
|
+
zero_allocator = BumpAllocator(
|
95
|
+
buffer_size=2,
|
96
|
+
dtype=torch.float32,
|
97
|
+
device=(
|
98
|
+
input_embeds.device if input_embeds is not None else input_ids.device
|
99
|
+
),
|
100
|
+
)
|
101
|
+
|
94
102
|
if input_embeds is None:
|
95
103
|
hidden_states = self.embed_tokens(input_ids)
|
96
104
|
else:
|
@@ -108,7 +116,7 @@ class DeepseekModelNextN(nn.Module):
|
|
108
116
|
|
109
117
|
residual = None
|
110
118
|
hidden_states, residual = self.decoder(
|
111
|
-
positions, hidden_states, forward_batch, residual
|
119
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
112
120
|
)
|
113
121
|
|
114
122
|
if not forward_batch.forward_mode.is_idle():
|
@@ -262,79 +270,75 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
262
270
|
)
|
263
271
|
weight_loader(param, loaded_weight)
|
264
272
|
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
).T
|
275
|
-
else:
|
276
|
-
w = ops.awq_dequantize(
|
277
|
-
self_attn.kv_b_proj.qweight,
|
278
|
-
self_attn.kv_b_proj.scales,
|
279
|
-
self_attn.kv_b_proj.qzeros,
|
280
|
-
0,
|
281
|
-
0,
|
282
|
-
0,
|
283
|
-
).T
|
273
|
+
self_attn = self.model.decoder.self_attn
|
274
|
+
if hasattr(self_attn.kv_b_proj, "qweight"):
|
275
|
+
# AWQ compatible
|
276
|
+
if _is_cuda:
|
277
|
+
w = awq_dequantize(
|
278
|
+
self_attn.kv_b_proj.qweight,
|
279
|
+
self_attn.kv_b_proj.scales,
|
280
|
+
self_attn.kv_b_proj.qzeros,
|
281
|
+
).T
|
284
282
|
else:
|
285
|
-
w =
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
283
|
+
w = awq_dequantize(
|
284
|
+
self_attn.kv_b_proj.qweight,
|
285
|
+
self_attn.kv_b_proj.scales,
|
286
|
+
self_attn.kv_b_proj.qzeros,
|
287
|
+
0,
|
288
|
+
0,
|
289
|
+
0,
|
290
|
+
).T
|
291
|
+
else:
|
292
|
+
w = self_attn.kv_b_proj.weight
|
293
|
+
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
294
|
+
# This may affect the accuracy of fp8 model.
|
295
|
+
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
|
296
|
+
torch.float8_e4m3fn,
|
297
|
+
torch.float8_e4m3fnuz,
|
298
|
+
):
|
299
|
+
weight_block_size = self.quant_config.weight_block_size
|
300
|
+
if weight_block_size is not None:
|
301
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
302
|
+
if _is_hip:
|
303
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
304
|
+
weight=w,
|
305
|
+
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
306
|
+
input_scale=None,
|
307
|
+
)
|
308
|
+
else:
|
309
|
+
weight = w
|
310
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
311
|
+
|
312
|
+
w, scale = block_quant_to_tensor_quant(
|
313
|
+
weight, weight_scale, weight_block_size
|
314
|
+
)
|
315
|
+
self_attn.w_scale = scale
|
316
|
+
if w.dtype == torch.int8:
|
317
|
+
if hasattr(self.quant_config, "weight_block_size"):
|
318
|
+
# block-wise int8 need it
|
292
319
|
weight_block_size = self.quant_config.weight_block_size
|
293
320
|
if weight_block_size is not None:
|
294
321
|
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
299
|
-
input_scale=None,
|
300
|
-
)
|
301
|
-
else:
|
302
|
-
weight = w
|
303
|
-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
304
|
-
|
305
|
-
w, scale = block_quant_to_tensor_quant(
|
306
|
-
weight, weight_scale, weight_block_size
|
307
|
-
)
|
308
|
-
self_attn.w_scale = scale
|
309
|
-
if w.dtype == torch.int8:
|
310
|
-
if hasattr(self.quant_config, "weight_block_size"):
|
311
|
-
# block-wise int8 need it
|
312
|
-
weight_block_size = self.quant_config.weight_block_size
|
313
|
-
if weight_block_size is not None:
|
314
|
-
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
315
|
-
weight = w
|
316
|
-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
317
|
-
w = int8_block_dequant(
|
318
|
-
weight, weight_scale, weight_block_size
|
319
|
-
).to(torch.bfloat16)
|
320
|
-
else:
|
321
|
-
# channel-wise int8 need it
|
322
|
-
assert hasattr(self_attn.kv_b_proj, "weight_scale")
|
323
|
-
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
322
|
+
weight = w
|
323
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
324
|
+
w = int8_block_dequant(weight, weight_scale, weight_block_size).to(
|
324
325
|
torch.bfloat16
|
325
326
|
)
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
327
|
+
else:
|
328
|
+
# channel-wise int8 need it
|
329
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale")
|
330
|
+
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
331
|
+
torch.bfloat16
|
332
|
+
)
|
333
|
+
w_kc, w_vc = w.unflatten(
|
334
|
+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
335
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
336
|
+
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
337
|
+
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
338
|
+
if hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None:
|
339
|
+
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
340
|
+
if _is_hip:
|
341
|
+
self_attn.w_scale *= 2.0
|
338
342
|
|
339
343
|
|
340
344
|
EntryClass = [DeepseekV3ForCausalLMNextN]
|