sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -11
- sglang/bench_serving.py +149 -1
- sglang/check_env.py +3 -3
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +32 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +151 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +58 -24
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +22 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +129 -94
- sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +6 -1
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +81 -35
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +44 -16
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +291 -72
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +60 -28
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +159 -90
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +2 -277
- sglang/srt/models/deepseek_v2.py +132 -37
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +93 -31
- sglang/srt/models/llama4.py +54 -7
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +4 -16
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +58 -62
- sglang/srt/openai_api/protocol.py +38 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +93 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +123 -10
- sglang/test/runners.py +4 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +32 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -24,34 +24,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
|
24
24
|
from sglang.srt.layers.layernorm import RMSNorm
|
25
25
|
from sglang.srt.layers.linear import ReplicatedLinear
|
26
26
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
27
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
28
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
29
27
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
30
|
-
from sglang.srt.layers.quantization.fp8_utils import (
|
31
|
-
block_quant_to_tensor_quant,
|
32
|
-
normalize_e4m3fn_to_e4m3fnuz,
|
33
|
-
)
|
34
|
-
from sglang.srt.layers.quantization.int8_utils import (
|
35
|
-
block_dequant as int8_block_dequant,
|
36
|
-
)
|
37
28
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
38
29
|
ParallelLMHead,
|
39
30
|
VocabParallelEmbedding,
|
40
31
|
)
|
41
32
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
42
33
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
|
-
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
44
34
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
45
|
-
from sglang.srt.utils import BumpAllocator, add_prefix
|
46
|
-
|
47
|
-
_is_hip = is_hip()
|
48
|
-
_is_cuda = is_cuda()
|
49
|
-
|
50
|
-
if _is_cuda:
|
51
|
-
from sgl_kernel import awq_dequantize
|
52
|
-
else:
|
53
|
-
from vllm._custom_ops import awq_dequantize
|
54
|
-
|
35
|
+
from sglang.srt.utils import BumpAllocator, add_prefix
|
55
36
|
|
56
37
|
logger = logging.getLogger(__name__)
|
57
38
|
|
@@ -177,263 +158,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
177
158
|
)
|
178
159
|
|
179
160
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
180
|
-
|
181
|
-
num_nextn_layers = self.config.num_nextn_predict_layers
|
182
|
-
assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
|
183
|
-
assert num_nextn_layers == self.config.num_hidden_layers
|
184
|
-
else:
|
185
|
-
raise ValueError("num_nextn_predict_layers is not in the config")
|
186
|
-
|
187
|
-
stacked_params_mapping = [
|
188
|
-
# (param_name, shard_name, shard_id)
|
189
|
-
("gate_up_proj", "gate_proj", 0),
|
190
|
-
("gate_up_proj", "up_proj", 1),
|
191
|
-
]
|
192
|
-
if self.n_share_experts_fusion > 0:
|
193
|
-
logger.info(
|
194
|
-
f"Cloning {self.n_share_experts_fusion} "
|
195
|
-
"replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
|
196
|
-
)
|
197
|
-
weights_list = list(weights)
|
198
|
-
weights_dict = dict(weights_list)
|
199
|
-
if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
|
200
|
-
suffix_list = [
|
201
|
-
"down_proj.weight",
|
202
|
-
"down_proj.weight_scale",
|
203
|
-
"gate_proj.weight",
|
204
|
-
"gate_proj.weight_scale",
|
205
|
-
"up_proj.weight",
|
206
|
-
"up_proj.weight_scale",
|
207
|
-
]
|
208
|
-
else:
|
209
|
-
suffix_list = [
|
210
|
-
"down_proj.weight",
|
211
|
-
"down_proj.weight_scale_inv",
|
212
|
-
"gate_proj.weight",
|
213
|
-
"gate_proj.weight_scale_inv",
|
214
|
-
"up_proj.weight",
|
215
|
-
"up_proj.weight_scale_inv",
|
216
|
-
]
|
217
|
-
names_to_remove = []
|
218
|
-
for suffix in suffix_list:
|
219
|
-
shared_expert_weight_name = (
|
220
|
-
f"model.layers.0.mlp.shared_experts.{suffix}"
|
221
|
-
)
|
222
|
-
for num_repeat in range(self.n_share_experts_fusion):
|
223
|
-
weights_list.append(
|
224
|
-
(
|
225
|
-
f"model.layers.0."
|
226
|
-
f"mlp.experts."
|
227
|
-
f"{self.config.n_routed_experts + num_repeat}"
|
228
|
-
f".{suffix}",
|
229
|
-
weights_dict[shared_expert_weight_name],
|
230
|
-
)
|
231
|
-
)
|
232
|
-
names_to_remove += [shared_expert_weight_name]
|
233
|
-
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
234
|
-
|
235
|
-
# Params for weights, fp8 weight scales, fp8 activation scales
|
236
|
-
# (param_name, weight_name, expert_id, shard_id)
|
237
|
-
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
238
|
-
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
239
|
-
ckpt_gate_proj_name="gate_proj",
|
240
|
-
ckpt_down_proj_name="down_proj",
|
241
|
-
ckpt_up_proj_name="up_proj",
|
242
|
-
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
243
|
-
)
|
244
|
-
|
245
|
-
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
246
|
-
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
247
|
-
self.config.q_lora_rank is not None
|
248
|
-
)
|
249
|
-
cached_a_proj = {} if fuse_qkv_a_proj else None
|
250
|
-
|
251
|
-
nextn_layer_prefix = "model.layers.0"
|
252
|
-
nextn_spec_weight_names = [
|
253
|
-
"shared_head.norm",
|
254
|
-
"eh_proj",
|
255
|
-
"enorm",
|
256
|
-
"hnorm",
|
257
|
-
]
|
258
|
-
|
259
|
-
params_dict = dict(self.named_parameters())
|
260
|
-
for name, loaded_weight in weights:
|
261
|
-
if not name.startswith(nextn_layer_prefix):
|
262
|
-
continue
|
263
|
-
|
264
|
-
# Use shared head and embed weights from target model
|
265
|
-
if "shared_head.head" in name or "embed_tokens" in name:
|
266
|
-
continue
|
267
|
-
|
268
|
-
is_decoder = True
|
269
|
-
# For nextn specific weights
|
270
|
-
for weight_name in nextn_spec_weight_names:
|
271
|
-
if weight_name in name:
|
272
|
-
name = name.replace(nextn_layer_prefix, "model")
|
273
|
-
is_decoder = False
|
274
|
-
break
|
275
|
-
# For decoder layer weights
|
276
|
-
if is_decoder:
|
277
|
-
name = name.replace(nextn_layer_prefix, "model.decoder")
|
278
|
-
|
279
|
-
if "rotary_emb.inv_freq" in name:
|
280
|
-
continue
|
281
|
-
for param_name, weight_name, shard_id in stacked_params_mapping:
|
282
|
-
# Skip non-stacked layers and experts (experts handled below).
|
283
|
-
if weight_name not in name:
|
284
|
-
continue
|
285
|
-
# We have mlp.experts[0].gate_proj in the checkpoint.
|
286
|
-
# Since we handle the experts below in expert_params_mapping,
|
287
|
-
# we need to skip here BEFORE we update the name, otherwise
|
288
|
-
# name will be updated to mlp.experts[0].gate_up_proj, which
|
289
|
-
# will then be updated below in expert_params_mapping
|
290
|
-
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
291
|
-
if ("mlp.experts." in name) and name not in params_dict:
|
292
|
-
continue
|
293
|
-
name = name.replace(weight_name, param_name)
|
294
|
-
# Skip loading extra bias for GPTQ models.
|
295
|
-
if name.endswith(".bias") and name not in params_dict:
|
296
|
-
continue
|
297
|
-
param = params_dict[name]
|
298
|
-
weight_loader = param.weight_loader
|
299
|
-
weight_loader(param, loaded_weight, shard_id)
|
300
|
-
break
|
301
|
-
else:
|
302
|
-
for mapping in expert_params_mapping:
|
303
|
-
param_name, weight_name, expert_id, shard_id = mapping
|
304
|
-
if weight_name not in name:
|
305
|
-
continue
|
306
|
-
name = name.replace(weight_name, param_name)
|
307
|
-
param = params_dict[name]
|
308
|
-
weight_loader = param.weight_loader
|
309
|
-
weight_loader(
|
310
|
-
param,
|
311
|
-
loaded_weight,
|
312
|
-
name,
|
313
|
-
shard_id=shard_id,
|
314
|
-
expert_id=expert_id,
|
315
|
-
)
|
316
|
-
break
|
317
|
-
else:
|
318
|
-
# Skip loading extra bias for GPTQ models.
|
319
|
-
if name.endswith(".bias") and name not in params_dict:
|
320
|
-
continue
|
321
|
-
|
322
|
-
# Handle fused_qkv_a_proj
|
323
|
-
if fuse_qkv_a_proj and (
|
324
|
-
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
325
|
-
):
|
326
|
-
cached_a_proj[name] = loaded_weight
|
327
|
-
q_a_proj_name = (
|
328
|
-
name
|
329
|
-
if "q_a_proj" in name
|
330
|
-
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
331
|
-
)
|
332
|
-
kv_a_proj_name = (
|
333
|
-
name
|
334
|
-
if "kv_a_proj_with_mqa" in name
|
335
|
-
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
336
|
-
)
|
337
|
-
|
338
|
-
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
339
|
-
if (
|
340
|
-
q_a_proj_name in cached_a_proj
|
341
|
-
and kv_a_proj_name in cached_a_proj
|
342
|
-
):
|
343
|
-
|
344
|
-
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
345
|
-
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
346
|
-
fused_weight = torch.cat(
|
347
|
-
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
348
|
-
)
|
349
|
-
|
350
|
-
param_name = name.replace(
|
351
|
-
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
352
|
-
)
|
353
|
-
param = params_dict[param_name]
|
354
|
-
|
355
|
-
weight_loader = getattr(
|
356
|
-
param, "weight_loader", default_weight_loader
|
357
|
-
)
|
358
|
-
weight_loader(param, fused_weight)
|
359
|
-
cached_a_proj.pop(q_a_proj_name)
|
360
|
-
cached_a_proj.pop(kv_a_proj_name)
|
361
|
-
else:
|
362
|
-
param = params_dict[name]
|
363
|
-
weight_loader = getattr(
|
364
|
-
param, "weight_loader", default_weight_loader
|
365
|
-
)
|
366
|
-
weight_loader(param, loaded_weight)
|
367
|
-
|
368
|
-
self_attn = self.model.decoder.self_attn
|
369
|
-
if hasattr(self_attn.kv_b_proj, "qweight"):
|
370
|
-
# AWQ compatible
|
371
|
-
if _is_cuda:
|
372
|
-
w = awq_dequantize(
|
373
|
-
self_attn.kv_b_proj.qweight,
|
374
|
-
self_attn.kv_b_proj.scales,
|
375
|
-
self_attn.kv_b_proj.qzeros,
|
376
|
-
).T
|
377
|
-
else:
|
378
|
-
w = awq_dequantize(
|
379
|
-
self_attn.kv_b_proj.qweight,
|
380
|
-
self_attn.kv_b_proj.scales,
|
381
|
-
self_attn.kv_b_proj.qzeros,
|
382
|
-
0,
|
383
|
-
0,
|
384
|
-
0,
|
385
|
-
).T
|
386
|
-
else:
|
387
|
-
w = self_attn.kv_b_proj.weight
|
388
|
-
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
389
|
-
# This may affect the accuracy of fp8 model.
|
390
|
-
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
|
391
|
-
torch.float8_e4m3fn,
|
392
|
-
torch.float8_e4m3fnuz,
|
393
|
-
):
|
394
|
-
weight_block_size = self.quant_config.weight_block_size
|
395
|
-
if weight_block_size is not None:
|
396
|
-
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
397
|
-
if _is_hip:
|
398
|
-
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
399
|
-
weight=w,
|
400
|
-
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
401
|
-
input_scale=None,
|
402
|
-
)
|
403
|
-
else:
|
404
|
-
weight = w
|
405
|
-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
406
|
-
|
407
|
-
w, scale = block_quant_to_tensor_quant(
|
408
|
-
weight, weight_scale, weight_block_size
|
409
|
-
)
|
410
|
-
self_attn.w_scale = scale
|
411
|
-
if w.dtype == torch.int8:
|
412
|
-
if hasattr(self.quant_config, "weight_block_size"):
|
413
|
-
# block-wise int8 need it
|
414
|
-
weight_block_size = self.quant_config.weight_block_size
|
415
|
-
if weight_block_size is not None:
|
416
|
-
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
417
|
-
weight = w
|
418
|
-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
419
|
-
w = int8_block_dequant(weight, weight_scale, weight_block_size).to(
|
420
|
-
torch.bfloat16
|
421
|
-
)
|
422
|
-
else:
|
423
|
-
# channel-wise int8 need it
|
424
|
-
assert hasattr(self_attn.kv_b_proj, "weight_scale")
|
425
|
-
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
426
|
-
torch.bfloat16
|
427
|
-
)
|
428
|
-
w_kc, w_vc = w.unflatten(
|
429
|
-
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
430
|
-
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
431
|
-
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
432
|
-
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
433
|
-
if hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None:
|
434
|
-
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
435
|
-
if _is_hip:
|
436
|
-
self_attn.w_scale *= 2.0
|
161
|
+
super().load_weights(weights, is_nextn=True)
|
437
162
|
|
438
163
|
|
439
164
|
EntryClass = [DeepseekV3ForCausalLMNextN]
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -59,10 +59,11 @@ from sglang.srt.layers.moe.topk import select_experts
|
|
59
59
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
60
60
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
61
61
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
62
|
-
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
63
62
|
per_tensor_quant_mla_fp8,
|
63
|
+
per_token_group_quant_mla_deep_gemm_masked_fp8,
|
64
64
|
)
|
65
65
|
from sglang.srt.layers.quantization.fp8_utils import (
|
66
|
+
block_quant_dequant,
|
66
67
|
block_quant_to_tensor_quant,
|
67
68
|
channel_quant_to_tensor_quant,
|
68
69
|
normalize_e4m3fn_to_e4m3fnuz,
|
@@ -88,6 +89,7 @@ from sglang.srt.utils import (
|
|
88
89
|
get_int_env_var,
|
89
90
|
is_cuda,
|
90
91
|
is_hip,
|
92
|
+
log_info_on_rank0,
|
91
93
|
)
|
92
94
|
|
93
95
|
_is_hip = is_hip()
|
@@ -356,6 +358,7 @@ class DeepseekV2MoE(nn.Module):
|
|
356
358
|
topk_idx,
|
357
359
|
topk_weights,
|
358
360
|
reorder_topk_ids,
|
361
|
+
num_recv_tokens_per_expert,
|
359
362
|
seg_indptr,
|
360
363
|
masked_m,
|
361
364
|
expected_m,
|
@@ -367,10 +370,13 @@ class DeepseekV2MoE(nn.Module):
|
|
367
370
|
)
|
368
371
|
final_hidden_states = self.experts(
|
369
372
|
hidden_states=hidden_states,
|
373
|
+
topk_idx=topk_idx,
|
374
|
+
topk_weights=topk_weights,
|
370
375
|
reorder_topk_ids=reorder_topk_ids,
|
371
376
|
seg_indptr=seg_indptr,
|
372
377
|
masked_m=masked_m,
|
373
378
|
expected_m=expected_m,
|
379
|
+
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
374
380
|
forward_mode=forward_mode,
|
375
381
|
)
|
376
382
|
if self.ep_size > 1:
|
@@ -421,6 +427,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
421
427
|
reduce_results: bool = True,
|
422
428
|
layer_id: int = None,
|
423
429
|
prefix: str = "",
|
430
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
424
431
|
) -> None:
|
425
432
|
super().__init__()
|
426
433
|
self.layer_id = layer_id
|
@@ -543,6 +550,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
543
550
|
prefix=add_prefix("attn_mha", prefix),
|
544
551
|
)
|
545
552
|
|
553
|
+
self.alt_stream = alt_stream
|
554
|
+
|
546
555
|
self.w_kc = None
|
547
556
|
self.w_vc = None
|
548
557
|
self.w_scale = None
|
@@ -706,20 +715,36 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
706
715
|
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
707
716
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
708
717
|
)
|
709
|
-
|
718
|
+
k_nope = latent_cache[..., : self.kv_lora_rank]
|
719
|
+
|
720
|
+
# overlap qk norm
|
721
|
+
if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
|
722
|
+
current_stream = torch.cuda.current_stream()
|
723
|
+
self.alt_stream.wait_stream(current_stream)
|
724
|
+
q = self.q_a_layernorm(q)
|
725
|
+
with torch.cuda.stream(self.alt_stream):
|
726
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
727
|
+
current_stream.wait_stream(self.alt_stream)
|
728
|
+
else:
|
729
|
+
q = self.q_a_layernorm(q)
|
730
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
731
|
+
|
732
|
+
k_nope = k_nope.unsqueeze(1)
|
710
733
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
711
734
|
else:
|
712
735
|
q = self.q_proj(hidden_states)[0].view(
|
713
736
|
-1, self.num_local_heads, self.qk_head_dim
|
714
737
|
)
|
715
738
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
739
|
+
k_nope = latent_cache[..., : self.kv_lora_rank]
|
740
|
+
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
|
741
|
+
|
716
742
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
743
|
+
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
717
744
|
|
718
745
|
if self.use_deep_gemm_bmm:
|
719
746
|
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
720
|
-
|
721
|
-
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
722
|
-
)
|
747
|
+
per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
|
723
748
|
)
|
724
749
|
q_nope_out = q_nope.new_empty(
|
725
750
|
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
@@ -750,14 +775,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
750
775
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
751
776
|
|
752
777
|
q_nope_out = q_nope_out.transpose(0, 1)
|
753
|
-
|
754
|
-
k_nope = latent_cache[..., : self.kv_lora_rank]
|
755
|
-
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
|
756
|
-
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
757
|
-
|
758
778
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
759
779
|
|
760
|
-
if self.attention_backend == "fa3":
|
780
|
+
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
|
761
781
|
attn_output = self.attn_mqa(
|
762
782
|
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
763
783
|
)
|
@@ -769,8 +789,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
769
789
|
|
770
790
|
if self.use_deep_gemm_bmm:
|
771
791
|
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
|
772
|
-
|
773
|
-
attn_output.transpose(0, 1)
|
792
|
+
per_token_group_quant_mla_deep_gemm_masked_fp8(
|
793
|
+
attn_output.transpose(0, 1)
|
774
794
|
)
|
775
795
|
)
|
776
796
|
attn_bmm_output = attn_output.new_empty(
|
@@ -1104,6 +1124,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1104
1124
|
quant_config: Optional[QuantizationConfig] = None,
|
1105
1125
|
is_nextn: bool = False,
|
1106
1126
|
prefix: str = "",
|
1127
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
1107
1128
|
) -> None:
|
1108
1129
|
super().__init__()
|
1109
1130
|
self.hidden_size = config.hidden_size
|
@@ -1133,6 +1154,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1133
1154
|
layer_id=layer_id,
|
1134
1155
|
reduce_results=False,
|
1135
1156
|
prefix=add_prefix("self_attn", prefix),
|
1157
|
+
alt_stream=alt_stream,
|
1136
1158
|
)
|
1137
1159
|
|
1138
1160
|
self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
|
@@ -1376,6 +1398,7 @@ class DeepseekV2Model(nn.Module):
|
|
1376
1398
|
config.hidden_size,
|
1377
1399
|
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
1378
1400
|
)
|
1401
|
+
self.alt_stream = torch.cuda.Stream()
|
1379
1402
|
self.layers = nn.ModuleList(
|
1380
1403
|
[
|
1381
1404
|
DeepseekV2DecoderLayer(
|
@@ -1383,6 +1406,7 @@ class DeepseekV2Model(nn.Module):
|
|
1383
1406
|
layer_id,
|
1384
1407
|
quant_config=quant_config,
|
1385
1408
|
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
1409
|
+
alt_stream=self.alt_stream,
|
1386
1410
|
)
|
1387
1411
|
for layer_id in range(config.num_hidden_layers)
|
1388
1412
|
]
|
@@ -1391,6 +1415,9 @@ class DeepseekV2Model(nn.Module):
|
|
1391
1415
|
|
1392
1416
|
self.dp_size = get_attention_dp_size()
|
1393
1417
|
|
1418
|
+
def get_input_embeddings(self) -> torch.Tensor:
|
1419
|
+
return self.embed_tokens
|
1420
|
+
|
1394
1421
|
def forward(
|
1395
1422
|
self,
|
1396
1423
|
input_ids: torch.Tensor,
|
@@ -1464,8 +1491,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1464
1491
|
):
|
1465
1492
|
self.n_share_experts_fusion = 0
|
1466
1493
|
global_server_args_dict["n_share_experts_fusion"] = 0
|
1467
|
-
|
1468
|
-
|
1494
|
+
log_info_on_rank0(
|
1495
|
+
logger,
|
1496
|
+
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
|
1469
1497
|
)
|
1470
1498
|
else:
|
1471
1499
|
assert (
|
@@ -1480,8 +1508,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1480
1508
|
):
|
1481
1509
|
self.n_share_experts_fusion = self.tp_size
|
1482
1510
|
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
|
1483
|
-
|
1484
|
-
|
1511
|
+
log_info_on_rank0(
|
1512
|
+
logger,
|
1513
|
+
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
1485
1514
|
)
|
1486
1515
|
|
1487
1516
|
def get_input_embeddings(self) -> nn.Embedding:
|
@@ -1502,11 +1531,20 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1502
1531
|
input_ids, hidden_states, self.lm_head, forward_batch
|
1503
1532
|
)
|
1504
1533
|
|
1505
|
-
def post_load_weights(self):
|
1534
|
+
def post_load_weights(self, is_nextn=False):
|
1506
1535
|
|
1507
1536
|
# Perform post-processing after loading weights
|
1508
|
-
|
1509
|
-
|
1537
|
+
layer_ids = (
|
1538
|
+
range(self.config.num_hidden_layers)
|
1539
|
+
if not is_nextn
|
1540
|
+
else [self.config.num_hidden_layers]
|
1541
|
+
)
|
1542
|
+
for layer_id in layer_ids:
|
1543
|
+
self_attn = (
|
1544
|
+
self.model.layers[layer_id].self_attn
|
1545
|
+
if not is_nextn
|
1546
|
+
else self.model.decoder.self_attn
|
1547
|
+
)
|
1510
1548
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
1511
1549
|
# AWQ compatible
|
1512
1550
|
if _is_cuda:
|
@@ -1552,13 +1590,22 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1552
1590
|
|
1553
1591
|
if (
|
1554
1592
|
_is_cuda
|
1555
|
-
and _ENABLE_JIT_DEEPGEMM
|
1556
1593
|
and weight_block_size[0] == 128
|
1557
1594
|
and weight_block_size[1] == 128
|
1558
1595
|
and model_dtype == torch.bfloat16
|
1559
1596
|
):
|
1560
|
-
|
1561
|
-
|
1597
|
+
if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
|
1598
|
+
"SGL_USE_DEEPGEMM_BMM", "false"
|
1599
|
+
):
|
1600
|
+
block_scale = weight_scale
|
1601
|
+
use_deep_gemm_bmm = True
|
1602
|
+
else:
|
1603
|
+
w = block_quant_dequant(
|
1604
|
+
weight,
|
1605
|
+
weight_scale,
|
1606
|
+
weight_block_size,
|
1607
|
+
model_dtype,
|
1608
|
+
)
|
1562
1609
|
else:
|
1563
1610
|
w, scale = block_quant_to_tensor_quant(
|
1564
1611
|
weight, weight_scale, weight_block_size
|
@@ -1612,7 +1659,20 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1612
1659
|
self_attn.w_vc = w_vc.contiguous()
|
1613
1660
|
self_attn.use_deep_gemm_bmm = True
|
1614
1661
|
|
1615
|
-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
1662
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
1663
|
+
if is_nextn:
|
1664
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
1665
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
1666
|
+
assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
|
1667
|
+
# compatible with old design
|
1668
|
+
nextn_layer_id = (
|
1669
|
+
0
|
1670
|
+
if self.config.num_hidden_layers == 1
|
1671
|
+
else self.config.num_hidden_layers
|
1672
|
+
)
|
1673
|
+
else:
|
1674
|
+
raise ValueError("num_nextn_predict_layers is not in the config")
|
1675
|
+
|
1616
1676
|
stacked_params_mapping = [
|
1617
1677
|
# (param_name, shard_name, shard_id)
|
1618
1678
|
("gate_up_proj", "gate_proj", 0),
|
@@ -1640,12 +1700,19 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1640
1700
|
"up_proj.weight_scale_inv",
|
1641
1701
|
]
|
1642
1702
|
names_to_remove = []
|
1643
|
-
|
1703
|
+
|
1704
|
+
moe_layers = (
|
1644
1705
|
range(
|
1645
1706
|
self.config.first_k_dense_replace,
|
1646
1707
|
self.config.num_hidden_layers,
|
1647
1708
|
self.config.moe_layer_freq,
|
1648
|
-
)
|
1709
|
+
)
|
1710
|
+
if not is_nextn
|
1711
|
+
else [nextn_layer_id]
|
1712
|
+
)
|
1713
|
+
|
1714
|
+
for moe_layer in tqdm(
|
1715
|
+
moe_layers,
|
1649
1716
|
desc=f"Cloning {self.n_share_experts_fusion} "
|
1650
1717
|
"replicas of the shared expert into MoE",
|
1651
1718
|
):
|
@@ -1686,18 +1753,46 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1686
1753
|
)
|
1687
1754
|
cached_a_proj = {} if fuse_qkv_a_proj else None
|
1688
1755
|
|
1756
|
+
if is_nextn:
|
1757
|
+
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
|
1758
|
+
nextn_spec_weight_names = [
|
1759
|
+
"shared_head.norm",
|
1760
|
+
"eh_proj",
|
1761
|
+
"enorm",
|
1762
|
+
"hnorm",
|
1763
|
+
]
|
1764
|
+
|
1689
1765
|
params_dict = dict(self.named_parameters())
|
1690
1766
|
for name, loaded_weight in weights:
|
1691
|
-
|
1692
|
-
|
1693
|
-
|
1694
|
-
|
1695
|
-
|
1696
|
-
|
1697
|
-
|
1698
|
-
|
1699
|
-
|
1700
|
-
|
1767
|
+
if not is_nextn:
|
1768
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
1769
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
1770
|
+
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
1771
|
+
name_list = name.split(".")
|
1772
|
+
if (
|
1773
|
+
len(name_list) >= 3
|
1774
|
+
and int(name_list[2]) >= self.config.num_hidden_layers
|
1775
|
+
):
|
1776
|
+
continue
|
1777
|
+
else:
|
1778
|
+
if not name.startswith(nextn_layer_prefix):
|
1779
|
+
continue
|
1780
|
+
|
1781
|
+
# Use shared head and embed weights from target model
|
1782
|
+
if "shared_head.head" in name or "embed_tokens" in name:
|
1783
|
+
continue
|
1784
|
+
|
1785
|
+
is_decoder = True
|
1786
|
+
# For nextn specific weights
|
1787
|
+
for weight_name in nextn_spec_weight_names:
|
1788
|
+
if weight_name in name:
|
1789
|
+
name = name.replace(nextn_layer_prefix, "model")
|
1790
|
+
is_decoder = False
|
1791
|
+
break
|
1792
|
+
# For decoder layer weights
|
1793
|
+
if is_decoder:
|
1794
|
+
name = name.replace(nextn_layer_prefix, "model.decoder")
|
1795
|
+
|
1701
1796
|
if "rotary_emb.inv_freq" in name:
|
1702
1797
|
continue
|
1703
1798
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
@@ -1786,7 +1881,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1786
1881
|
)
|
1787
1882
|
weight_loader(param, loaded_weight)
|
1788
1883
|
|
1789
|
-
self.post_load_weights()
|
1884
|
+
self.post_load_weights(is_nextn=is_nextn)
|
1790
1885
|
|
1791
1886
|
def get_embed_and_head(self):
|
1792
1887
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
sglang/srt/models/gemma3_mm.py
CHANGED
@@ -281,7 +281,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
281
281
|
pixel_values = torch.stack(
|
282
282
|
flatten_nested_list([item.pixel_values for item in items]), dim=0
|
283
283
|
)
|
284
|
-
pixel_values = pixel_values.to(
|
284
|
+
pixel_values = pixel_values.to(device=self.vision_tower.device)
|
285
285
|
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
|
286
286
|
|
287
287
|
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
|
sglang/srt/models/internlm2.py
CHANGED