sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
17
17
|
"""Inference-only DeepseekV2 model."""
|
18
18
|
|
19
|
+
import concurrent.futures
|
19
20
|
import logging
|
20
21
|
import os
|
21
22
|
from enum import IntEnum, auto
|
@@ -57,7 +58,7 @@ from sglang.srt.layers.linear import (
|
|
57
58
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
58
59
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
59
60
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
60
|
-
from sglang.srt.layers.moe.topk import
|
61
|
+
from sglang.srt.layers.moe.topk import TopK
|
61
62
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
62
63
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
63
64
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
@@ -126,6 +127,10 @@ if _is_cuda:
|
|
126
127
|
)
|
127
128
|
elif _is_cpu and _is_cpu_amx_available:
|
128
129
|
pass
|
130
|
+
elif _is_hip:
|
131
|
+
from sglang.srt.layers.quantization.awq_triton import (
|
132
|
+
awq_dequantize_triton as awq_dequantize,
|
133
|
+
)
|
129
134
|
else:
|
130
135
|
from vllm._custom_ops import awq_dequantize
|
131
136
|
|
@@ -224,7 +229,7 @@ class MoEGate(nn.Module):
|
|
224
229
|
)
|
225
230
|
if config.topk_method == "noaux_tc":
|
226
231
|
self.e_score_correction_bias = nn.Parameter(
|
227
|
-
torch.empty((config.n_routed_experts))
|
232
|
+
torch.empty((config.n_routed_experts), dtype=torch.float32)
|
228
233
|
)
|
229
234
|
else:
|
230
235
|
self.e_score_correction_bias = None
|
@@ -249,9 +254,8 @@ class MoEGate(nn.Module):
|
|
249
254
|
and self.weight.shape[0] == 256
|
250
255
|
and _device_sm >= 90
|
251
256
|
):
|
252
|
-
|
253
|
-
|
254
|
-
)
|
257
|
+
# router gemm output float32
|
258
|
+
logits = dsv3_router_gemm(hidden_states, self.weight)
|
255
259
|
else:
|
256
260
|
logits = F.linear(hidden_states, self.weight, None)
|
257
261
|
|
@@ -298,6 +302,17 @@ class DeepseekV2MoE(nn.Module):
|
|
298
302
|
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
299
303
|
)
|
300
304
|
|
305
|
+
self.topk = TopK(
|
306
|
+
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
307
|
+
renormalize=config.norm_topk_prob,
|
308
|
+
use_grouped_topk=True,
|
309
|
+
num_expert_group=config.n_group,
|
310
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
311
|
+
topk_group=config.topk_group,
|
312
|
+
correction_bias=self.gate.e_score_correction_bias,
|
313
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
314
|
+
)
|
315
|
+
|
301
316
|
self.experts = get_moe_impl_class()(
|
302
317
|
num_experts=config.n_routed_experts
|
303
318
|
+ self.num_fused_shared_experts
|
@@ -306,13 +321,7 @@ class DeepseekV2MoE(nn.Module):
|
|
306
321
|
hidden_size=config.hidden_size,
|
307
322
|
intermediate_size=config.moe_intermediate_size,
|
308
323
|
layer_id=self.layer_id,
|
309
|
-
renormalize=config.norm_topk_prob,
|
310
324
|
quant_config=quant_config,
|
311
|
-
use_grouped_topk=True,
|
312
|
-
num_expert_group=config.n_group,
|
313
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
314
|
-
topk_group=config.topk_group,
|
315
|
-
correction_bias=self.gate.e_score_correction_bias,
|
316
325
|
routed_scaling_factor=self.routed_scaling_factor,
|
317
326
|
prefix=add_prefix("experts", prefix),
|
318
327
|
**(
|
@@ -354,6 +363,7 @@ class DeepseekV2MoE(nn.Module):
|
|
354
363
|
self.shared_experts.gate_up_proj.quant_method, "quant_config"
|
355
364
|
) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
|
356
365
|
"awq",
|
366
|
+
"awq_marlin",
|
357
367
|
"moe_wna16",
|
358
368
|
}
|
359
369
|
self.shared_experts_is_int8 = (
|
@@ -437,21 +447,22 @@ class DeepseekV2MoE(nn.Module):
|
|
437
447
|
def forward_normal_dual_stream(
|
438
448
|
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
439
449
|
) -> torch.Tensor:
|
440
|
-
# router_logits: (num_tokens, n_experts)
|
441
|
-
router_logits = self.gate(hidden_states)
|
442
450
|
|
443
451
|
current_stream = torch.cuda.current_stream()
|
444
452
|
self.alt_stream.wait_stream(current_stream)
|
445
453
|
shared_output = self._forward_shared_experts(hidden_states)
|
446
454
|
|
447
455
|
with torch.cuda.stream(self.alt_stream):
|
456
|
+
# router_logits: (num_tokens, n_experts)
|
457
|
+
router_logits = self.gate(hidden_states)
|
458
|
+
topk_output = self.topk(hidden_states, router_logits)
|
448
459
|
final_hidden_states = self.experts(
|
449
|
-
hidden_states=hidden_states,
|
460
|
+
hidden_states=hidden_states, topk_output=topk_output
|
450
461
|
)
|
451
462
|
if not _is_cuda:
|
452
463
|
final_hidden_states *= self.routed_scaling_factor
|
453
464
|
current_stream.wait_stream(self.alt_stream)
|
454
|
-
final_hidden_states
|
465
|
+
final_hidden_states += shared_output
|
455
466
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
456
467
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
457
468
|
return final_hidden_states
|
@@ -462,13 +473,14 @@ class DeepseekV2MoE(nn.Module):
|
|
462
473
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
463
474
|
self.shared_experts.gate_up_proj
|
464
475
|
):
|
465
|
-
return self.forward_cpu(hidden_states)
|
476
|
+
return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
|
466
477
|
|
467
478
|
shared_output = self._forward_shared_experts(hidden_states)
|
468
479
|
# router_logits: (num_tokens, n_experts)
|
469
480
|
router_logits = self.gate(hidden_states)
|
481
|
+
topk_output = self.topk(hidden_states, router_logits)
|
470
482
|
final_hidden_states = self.experts(
|
471
|
-
hidden_states=hidden_states,
|
483
|
+
hidden_states=hidden_states, topk_output=topk_output
|
472
484
|
)
|
473
485
|
if not _is_cuda and not _use_aiter:
|
474
486
|
# fused in biased_grouped_topk so we can skip here
|
@@ -479,11 +491,14 @@ class DeepseekV2MoE(nn.Module):
|
|
479
491
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
480
492
|
return final_hidden_states
|
481
493
|
|
482
|
-
def forward_cpu(
|
494
|
+
def forward_cpu(
|
495
|
+
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
496
|
+
) -> torch.Tensor:
|
483
497
|
# router_logits: (num_tokens, n_experts)
|
484
498
|
router_logits = self.gate(hidden_states)
|
499
|
+
topk_output = self.topk(hidden_states, router_logits)
|
485
500
|
fused_experts_out = self.experts(
|
486
|
-
hidden_states=hidden_states,
|
501
|
+
hidden_states=hidden_states, topk_output=topk_output
|
487
502
|
)
|
488
503
|
|
489
504
|
assert use_intel_amx_backend(
|
@@ -528,7 +543,7 @@ class DeepseekV2MoE(nn.Module):
|
|
528
543
|
None, # a2_scale
|
529
544
|
True, # is_vnni
|
530
545
|
)
|
531
|
-
if self.tp_size > 1 and not
|
546
|
+
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
532
547
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
533
548
|
return final_hidden_states
|
534
549
|
|
@@ -541,17 +556,9 @@ class DeepseekV2MoE(nn.Module):
|
|
541
556
|
# router_logits: (num_tokens, n_experts)
|
542
557
|
router_logits = self.gate(hidden_states)
|
543
558
|
shared_output = self._forward_shared_experts(hidden_states)
|
544
|
-
topk_weights, topk_idx =
|
545
|
-
hidden_states
|
546
|
-
router_logits
|
547
|
-
top_k=self.top_k,
|
548
|
-
use_grouped_topk=True,
|
549
|
-
renormalize=self.renormalize,
|
550
|
-
topk_group=self.topk_group,
|
551
|
-
num_expert_group=self.num_expert_group,
|
552
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
553
|
-
correction_bias=self.correction_bias,
|
554
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
559
|
+
topk_weights, topk_idx, _ = self.topk(
|
560
|
+
hidden_states,
|
561
|
+
router_logits,
|
555
562
|
num_token_non_padded=forward_batch.num_token_non_padded,
|
556
563
|
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
557
564
|
layer_id=self.layer_id,
|
@@ -641,17 +648,9 @@ class DeepseekV2MoE(nn.Module):
|
|
641
648
|
with get_global_expert_distribution_recorder().with_current_layer(
|
642
649
|
self.layer_id
|
643
650
|
):
|
644
|
-
state.topk_weights_local, state.topk_idx_local =
|
651
|
+
state.topk_weights_local, state.topk_idx_local, _ = self.topk(
|
645
652
|
hidden_states=hidden_states,
|
646
653
|
router_logits=router_logits,
|
647
|
-
top_k=self.top_k,
|
648
|
-
use_grouped_topk=True,
|
649
|
-
renormalize=self.renormalize,
|
650
|
-
topk_group=self.topk_group,
|
651
|
-
num_expert_group=self.num_expert_group,
|
652
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
653
|
-
correction_bias=self.correction_bias,
|
654
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
655
654
|
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
656
655
|
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
657
656
|
layer_id=self.layer_id,
|
@@ -926,7 +925,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
926
925
|
has_fused_proj
|
927
926
|
and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
|
928
927
|
and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
|
929
|
-
in {"awq", "moe_wna16"}
|
928
|
+
in {"awq", "awq_marlin", "moe_wna16"}
|
930
929
|
)
|
931
930
|
self.use_min_latency_fused_a_gemm = (
|
932
931
|
has_fused_proj
|
@@ -1151,7 +1150,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1151
1150
|
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
1152
1151
|
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
1153
1152
|
latent_cache = latent_cache.unsqueeze(1)
|
1154
|
-
kv_a = self.kv_a_layernorm(kv_a
|
1153
|
+
kv_a = self.kv_a_layernorm(kv_a)
|
1155
1154
|
kv = self.kv_b_proj(kv_a)[0]
|
1156
1155
|
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
1157
1156
|
k_nope = kv[..., : self.qk_nope_head_dim]
|
@@ -1690,7 +1689,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1690
1689
|
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
1691
1690
|
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
1692
1691
|
latent_cache = latent_cache.unsqueeze(1)
|
1693
|
-
kv_a = self.kv_a_layernorm(kv_a
|
1692
|
+
kv_a = self.kv_a_layernorm(kv_a)
|
1694
1693
|
kv = self.kv_b_proj(kv_a)[0]
|
1695
1694
|
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
1696
1695
|
k_nope = kv[..., : self.qk_nope_head_dim]
|
@@ -2172,7 +2171,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2172
2171
|
)
|
2173
2172
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
2174
2173
|
# AWQ compatible
|
2175
|
-
if _is_cuda:
|
2174
|
+
if _is_cuda or _is_hip:
|
2176
2175
|
w = awq_dequantize(
|
2177
2176
|
self_attn.kv_b_proj.qweight,
|
2178
2177
|
self_attn.kv_b_proj.scales,
|
@@ -2434,154 +2433,175 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2434
2433
|
assert self.num_fused_shared_experts == 1
|
2435
2434
|
log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
|
2436
2435
|
|
2437
|
-
|
2438
|
-
|
2439
|
-
|
2440
|
-
|
2441
|
-
|
2442
|
-
|
2443
|
-
|
2444
|
-
|
2436
|
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
2437
|
+
futures = []
|
2438
|
+
params_dict = dict(self.named_parameters())
|
2439
|
+
weight_names = []
|
2440
|
+
for name, loaded_weight in weights:
|
2441
|
+
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
|
2442
|
+
name = name.replace(
|
2443
|
+
"mlp.shared_experts",
|
2444
|
+
f"mlp.experts.{self.config.n_routed_experts}",
|
2445
|
+
)
|
2445
2446
|
|
2446
|
-
|
2447
|
+
weight_names.append(name)
|
2447
2448
|
|
2448
|
-
|
2449
|
-
|
2450
|
-
|
2451
|
-
|
2452
|
-
|
2453
|
-
|
2454
|
-
|
2455
|
-
|
2456
|
-
|
2457
|
-
|
2458
|
-
|
2459
|
-
|
2460
|
-
|
2449
|
+
if not is_nextn:
|
2450
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
2451
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
2452
|
+
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
2453
|
+
name_list = name.split(".")
|
2454
|
+
if (
|
2455
|
+
len(name_list) >= 3
|
2456
|
+
and int(name_list[2]) >= self.config.num_hidden_layers
|
2457
|
+
):
|
2458
|
+
continue
|
2459
|
+
else:
|
2460
|
+
if not name.startswith(nextn_layer_prefix):
|
2461
|
+
continue
|
2461
2462
|
|
2462
|
-
|
2463
|
-
|
2464
|
-
|
2463
|
+
# Use shared head and embed weights from target model
|
2464
|
+
if "shared_head.head" in name or "embed_tokens" in name:
|
2465
|
+
continue
|
2465
2466
|
|
2466
|
-
|
2467
|
-
|
2468
|
-
|
2469
|
-
|
2470
|
-
|
2471
|
-
|
2472
|
-
|
2473
|
-
|
2474
|
-
|
2475
|
-
|
2476
|
-
|
2477
|
-
|
2478
|
-
continue
|
2479
|
-
for param_name, weight_name, shard_id in stacked_params_mapping:
|
2480
|
-
# Skip non-stacked layers and experts (experts handled below).
|
2481
|
-
if weight_name not in name:
|
2482
|
-
continue
|
2483
|
-
# We have mlp.experts[0].gate_proj in the checkpoint.
|
2484
|
-
# Since we handle the experts below in expert_params_mapping,
|
2485
|
-
# we need to skip here BEFORE we update the name, otherwise
|
2486
|
-
# name will be updated to mlp.experts[0].gate_up_proj, which
|
2487
|
-
# will then be updated below in expert_params_mapping
|
2488
|
-
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
2489
|
-
if ("mlp.experts." in name) and name not in params_dict:
|
2490
|
-
continue
|
2491
|
-
name = name.replace(weight_name, param_name)
|
2492
|
-
# Skip loading extra bias for GPTQ models.
|
2493
|
-
if name.endswith(".bias") and name not in params_dict:
|
2467
|
+
is_decoder = True
|
2468
|
+
# For nextn specific weights
|
2469
|
+
for weight_name in nextn_spec_weight_names:
|
2470
|
+
if weight_name in name:
|
2471
|
+
name = name.replace(nextn_layer_prefix, "model")
|
2472
|
+
is_decoder = False
|
2473
|
+
break
|
2474
|
+
# For decoder layer weights
|
2475
|
+
if is_decoder:
|
2476
|
+
name = name.replace(nextn_layer_prefix, "model.decoder")
|
2477
|
+
|
2478
|
+
if "rotary_emb.inv_freq" in name:
|
2494
2479
|
continue
|
2495
|
-
|
2496
|
-
|
2497
|
-
weight_loader(param, loaded_weight, shard_id)
|
2498
|
-
break
|
2499
|
-
else:
|
2500
|
-
for mapping in expert_params_mapping:
|
2501
|
-
param_name, weight_name, expert_id, shard_id = mapping
|
2480
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
2481
|
+
# Skip non-stacked layers and experts (experts handled below).
|
2502
2482
|
if weight_name not in name:
|
2503
2483
|
continue
|
2484
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
2485
|
+
# Since we handle the experts below in expert_params_mapping,
|
2486
|
+
# we need to skip here BEFORE we update the name, otherwise
|
2487
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
2488
|
+
# will then be updated below in expert_params_mapping
|
2489
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
2490
|
+
if ("mlp.experts." in name) and name not in params_dict:
|
2491
|
+
continue
|
2504
2492
|
name = name.replace(weight_name, param_name)
|
2493
|
+
# Skip loading extra bias for GPTQ models.
|
2494
|
+
if name.endswith(".bias") and name not in params_dict:
|
2495
|
+
continue
|
2505
2496
|
param = params_dict[name]
|
2506
2497
|
weight_loader = param.weight_loader
|
2507
|
-
|
2508
|
-
param,
|
2509
|
-
loaded_weight,
|
2510
|
-
name,
|
2511
|
-
shard_id=shard_id,
|
2512
|
-
expert_id=expert_id,
|
2498
|
+
futures.append(
|
2499
|
+
executor.submit(weight_loader, param, loaded_weight, shard_id)
|
2513
2500
|
)
|
2514
2501
|
break
|
2515
2502
|
else:
|
2516
|
-
|
2517
|
-
|
2518
|
-
|
2519
|
-
|
2520
|
-
|
2521
|
-
|
2522
|
-
|
2523
|
-
|
2524
|
-
|
2525
|
-
|
2526
|
-
|
2527
|
-
|
2528
|
-
|
2529
|
-
|
2530
|
-
|
2531
|
-
|
2503
|
+
for mapping in expert_params_mapping:
|
2504
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
2505
|
+
if weight_name not in name:
|
2506
|
+
continue
|
2507
|
+
name = name.replace(weight_name, param_name)
|
2508
|
+
param = params_dict[name]
|
2509
|
+
weight_loader = param.weight_loader
|
2510
|
+
futures.append(
|
2511
|
+
executor.submit(
|
2512
|
+
weight_loader,
|
2513
|
+
param,
|
2514
|
+
loaded_weight,
|
2515
|
+
name,
|
2516
|
+
shard_id=shard_id,
|
2517
|
+
expert_id=expert_id,
|
2518
|
+
)
|
2532
2519
|
)
|
2533
|
-
|
2534
|
-
|
2535
|
-
|
2536
|
-
|
2537
|
-
|
2520
|
+
break
|
2521
|
+
else:
|
2522
|
+
# Skip loading extra bias for GPTQ models.
|
2523
|
+
if name.endswith(".bias") and name not in params_dict:
|
2524
|
+
continue
|
2525
|
+
if fuse_qkv_a_proj and (
|
2526
|
+
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
2538
2527
|
):
|
2539
|
-
|
2540
|
-
|
2541
|
-
|
2542
|
-
if self.quant_config is not None and (
|
2543
|
-
self.quant_config.get_name() == "awq"
|
2544
|
-
or self.quant_config.get_name() == "moe_wna16"
|
2545
|
-
):
|
2546
|
-
cat_dim = 1
|
2547
|
-
fused_weight = torch.cat(
|
2548
|
-
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
|
2549
|
-
)
|
2550
|
-
param_name = (
|
2551
|
-
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
2528
|
+
cached_a_proj[name] = loaded_weight
|
2529
|
+
q_a_proj_name = (
|
2530
|
+
name
|
2552
2531
|
if "q_a_proj" in name
|
2553
|
-
else name.replace(
|
2554
|
-
|
2555
|
-
|
2532
|
+
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
2533
|
+
)
|
2534
|
+
kv_a_proj_name = (
|
2535
|
+
name
|
2536
|
+
if "kv_a_proj_with_mqa" in name
|
2537
|
+
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
2556
2538
|
)
|
2557
|
-
param = params_dict[param_name]
|
2558
2539
|
|
2540
|
+
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
2541
|
+
if (
|
2542
|
+
q_a_proj_name in cached_a_proj
|
2543
|
+
and kv_a_proj_name in cached_a_proj
|
2544
|
+
):
|
2545
|
+
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
2546
|
+
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
2547
|
+
cat_dim = 0
|
2548
|
+
if self.quant_config is not None and (
|
2549
|
+
self.quant_config.get_name() == "awq"
|
2550
|
+
or self.quant_config.get_name() == "awq_marlin"
|
2551
|
+
or self.quant_config.get_name() == "moe_wna16"
|
2552
|
+
):
|
2553
|
+
cat_dim = 1
|
2554
|
+
fused_weight = torch.cat(
|
2555
|
+
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
|
2556
|
+
)
|
2557
|
+
param_name = (
|
2558
|
+
name.replace(
|
2559
|
+
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
2560
|
+
)
|
2561
|
+
if "q_a_proj" in name
|
2562
|
+
else name.replace(
|
2563
|
+
"kv_a_proj_with_mqa",
|
2564
|
+
"fused_qkv_a_proj_with_mqa",
|
2565
|
+
)
|
2566
|
+
)
|
2567
|
+
param = params_dict[param_name]
|
2568
|
+
|
2569
|
+
weight_loader = getattr(
|
2570
|
+
param, "weight_loader", default_weight_loader
|
2571
|
+
)
|
2572
|
+
futures.append(
|
2573
|
+
executor.submit(weight_loader, param, fused_weight)
|
2574
|
+
)
|
2575
|
+
cached_a_proj.pop(q_a_proj_name)
|
2576
|
+
cached_a_proj.pop(kv_a_proj_name)
|
2577
|
+
else:
|
2578
|
+
if (
|
2579
|
+
"k_scale" in name or "v_scale" in name
|
2580
|
+
) and name not in params_dict:
|
2581
|
+
# modelopt attn kv scale is named differently
|
2582
|
+
for scale in ["k_scale", "v_scale"]:
|
2583
|
+
if scale in name:
|
2584
|
+
name = name.replace(
|
2585
|
+
f"{scale[0]}_proj", "attn_mqa"
|
2586
|
+
)
|
2587
|
+
break
|
2588
|
+
if name not in params_dict:
|
2589
|
+
# modelopt ckpt contains not needed weights for MTP module:
|
2590
|
+
# model.decoder.self_attn.attn_mqa.v_scale and
|
2591
|
+
# model.decoder.self_attn.attn_mqa.k_scale
|
2592
|
+
logger.warning(f"{name} not found in params_dict.")
|
2593
|
+
continue
|
2594
|
+
param = params_dict[name]
|
2559
2595
|
weight_loader = getattr(
|
2560
2596
|
param, "weight_loader", default_weight_loader
|
2561
2597
|
)
|
2562
|
-
|
2563
|
-
|
2564
|
-
|
2565
|
-
|
2566
|
-
|
2567
|
-
|
2568
|
-
|
2569
|
-
# modelopt attn kv scale is named differently
|
2570
|
-
for scale in ["k_scale", "v_scale"]:
|
2571
|
-
if scale in name:
|
2572
|
-
name = name.replace(f"{scale[0]}_proj", "attn_mqa")
|
2573
|
-
break
|
2574
|
-
if name not in params_dict:
|
2575
|
-
# modelopt ckpt contains not needed weights for MTP module:
|
2576
|
-
# model.decoder.self_attn.attn_mqa.v_scale and
|
2577
|
-
# model.decoder.self_attn.attn_mqa.k_scale
|
2578
|
-
logger.warning(f"{name} not found in params_dict.")
|
2579
|
-
continue
|
2580
|
-
param = params_dict[name]
|
2581
|
-
weight_loader = getattr(
|
2582
|
-
param, "weight_loader", default_weight_loader
|
2583
|
-
)
|
2584
|
-
weight_loader(param, loaded_weight)
|
2598
|
+
futures.append(
|
2599
|
+
executor.submit(weight_loader, param, loaded_weight)
|
2600
|
+
)
|
2601
|
+
|
2602
|
+
# Wait for all tasks to complete and raise any exceptions.
|
2603
|
+
for future in concurrent.futures.as_completed(futures):
|
2604
|
+
future.result()
|
2585
2605
|
|
2586
2606
|
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
|
2587
2607
|
|
@@ -260,7 +260,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|
260
260
|
def get_image_feature(self, items: List[MultimodalDataItem]):
|
261
261
|
|
262
262
|
images_spatial_crop = torch.cat(
|
263
|
-
[item.
|
263
|
+
[item.images_spatial_crop for item in items], dim=0
|
264
264
|
)
|
265
265
|
|
266
266
|
assert images_spatial_crop.dim() == 3
|
@@ -268,9 +268,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|
268
268
|
# TODO: can it be batched ?
|
269
269
|
images_in_this_batch = []
|
270
270
|
for item in items:
|
271
|
-
assert item.
|
271
|
+
assert item.feature.dim() == 4
|
272
272
|
image_feature = self.vision.forward_features(
|
273
|
-
item.
|
273
|
+
item.feature.type(next(self.vision.parameters()).dtype).to(
|
274
274
|
device=next(self.vision.parameters()).device
|
275
275
|
)
|
276
276
|
)
|
@@ -278,8 +278,8 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|
278
278
|
_, hw, n_dim = images_embeds.shape
|
279
279
|
h = w = int(hw**0.5)
|
280
280
|
tile_index = 0
|
281
|
-
for jdx in range(item.
|
282
|
-
num_width_tiles, num_height_tiles = item.
|
281
|
+
for jdx in range(item.images_spatial_crop.shape[1]):
|
282
|
+
num_width_tiles, num_height_tiles = item.images_spatial_crop[0, jdx]
|
283
283
|
if num_width_tiles == 0 or num_height_tiles == 0:
|
284
284
|
break
|
285
285
|
num_tiles_in_image = num_width_tiles * num_height_tiles
|
sglang/srt/models/gemma.py
CHANGED
@@ -318,6 +318,54 @@ class GemmaForCausalLM(nn.Module):
|
|
318
318
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
319
319
|
)
|
320
320
|
|
321
|
+
@torch.no_grad()
|
322
|
+
def forward_split_prefill(
|
323
|
+
self,
|
324
|
+
input_ids: torch.Tensor,
|
325
|
+
positions: torch.Tensor,
|
326
|
+
forward_batch: ForwardBatch,
|
327
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
328
|
+
input_embeds: torch.Tensor = None,
|
329
|
+
):
|
330
|
+
start, end = split_interval
|
331
|
+
# embed
|
332
|
+
if start == 0:
|
333
|
+
if input_embeds is None:
|
334
|
+
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
335
|
+
else:
|
336
|
+
forward_batch.hidden_states = input_embeds
|
337
|
+
|
338
|
+
# Normalize the embedding by sqrt(hidden_size)
|
339
|
+
forward_batch.hidden_states *= self.model.config.hidden_size**0.5
|
340
|
+
|
341
|
+
# decoder layer
|
342
|
+
for i in range(start, end):
|
343
|
+
layer = self.model.layers[i]
|
344
|
+
forward_batch.hidden_states, forward_batch.residual = layer(
|
345
|
+
positions,
|
346
|
+
forward_batch.hidden_states,
|
347
|
+
forward_batch,
|
348
|
+
forward_batch.residual,
|
349
|
+
)
|
350
|
+
|
351
|
+
if end == self.model.config.num_hidden_layers:
|
352
|
+
# norm
|
353
|
+
forward_batch.hidden_states, _ = self.model.norm(
|
354
|
+
forward_batch.hidden_states, forward_batch.residual
|
355
|
+
)
|
356
|
+
|
357
|
+
# logits process
|
358
|
+
result = self.logits_processor(
|
359
|
+
input_ids,
|
360
|
+
forward_batch.hidden_states,
|
361
|
+
self.model.embed_tokens,
|
362
|
+
forward_batch,
|
363
|
+
)
|
364
|
+
else:
|
365
|
+
result = None
|
366
|
+
|
367
|
+
return result
|
368
|
+
|
321
369
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
322
370
|
stacked_params_mapping = [
|
323
371
|
# (param_name, shard_name, shard_id)
|