sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
17
17
|
"""Inference-only DeepseekV2 model."""
|
18
18
|
|
19
|
+
import concurrent.futures
|
19
20
|
import logging
|
20
21
|
import os
|
21
22
|
from enum import IntEnum, auto
|
@@ -57,7 +58,7 @@ from sglang.srt.layers.linear import (
|
|
57
58
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
58
59
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
59
60
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
60
|
-
from sglang.srt.layers.moe.topk import
|
61
|
+
from sglang.srt.layers.moe.topk import TopK
|
61
62
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
62
63
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
63
64
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
@@ -126,6 +127,10 @@ if _is_cuda:
|
|
126
127
|
)
|
127
128
|
elif _is_cpu and _is_cpu_amx_available:
|
128
129
|
pass
|
130
|
+
elif _is_hip:
|
131
|
+
from sglang.srt.layers.quantization.awq_triton import (
|
132
|
+
awq_dequantize_triton as awq_dequantize,
|
133
|
+
)
|
129
134
|
else:
|
130
135
|
from vllm._custom_ops import awq_dequantize
|
131
136
|
|
@@ -224,7 +229,7 @@ class MoEGate(nn.Module):
|
|
224
229
|
)
|
225
230
|
if config.topk_method == "noaux_tc":
|
226
231
|
self.e_score_correction_bias = nn.Parameter(
|
227
|
-
torch.empty((config.n_routed_experts))
|
232
|
+
torch.empty((config.n_routed_experts), dtype=torch.float32)
|
228
233
|
)
|
229
234
|
else:
|
230
235
|
self.e_score_correction_bias = None
|
@@ -249,9 +254,8 @@ class MoEGate(nn.Module):
|
|
249
254
|
and self.weight.shape[0] == 256
|
250
255
|
and _device_sm >= 90
|
251
256
|
):
|
252
|
-
|
253
|
-
|
254
|
-
)
|
257
|
+
# router gemm output float32
|
258
|
+
logits = dsv3_router_gemm(hidden_states, self.weight)
|
255
259
|
else:
|
256
260
|
logits = F.linear(hidden_states, self.weight, None)
|
257
261
|
|
@@ -298,6 +302,17 @@ class DeepseekV2MoE(nn.Module):
|
|
298
302
|
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
299
303
|
)
|
300
304
|
|
305
|
+
self.topk = TopK(
|
306
|
+
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
307
|
+
renormalize=config.norm_topk_prob,
|
308
|
+
use_grouped_topk=True,
|
309
|
+
num_expert_group=config.n_group,
|
310
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
311
|
+
topk_group=config.topk_group,
|
312
|
+
correction_bias=self.gate.e_score_correction_bias,
|
313
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
314
|
+
)
|
315
|
+
|
301
316
|
self.experts = get_moe_impl_class()(
|
302
317
|
num_experts=config.n_routed_experts
|
303
318
|
+ self.num_fused_shared_experts
|
@@ -306,13 +321,7 @@ class DeepseekV2MoE(nn.Module):
|
|
306
321
|
hidden_size=config.hidden_size,
|
307
322
|
intermediate_size=config.moe_intermediate_size,
|
308
323
|
layer_id=self.layer_id,
|
309
|
-
renormalize=config.norm_topk_prob,
|
310
324
|
quant_config=quant_config,
|
311
|
-
use_grouped_topk=True,
|
312
|
-
num_expert_group=config.n_group,
|
313
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
314
|
-
topk_group=config.topk_group,
|
315
|
-
correction_bias=self.gate.e_score_correction_bias,
|
316
325
|
routed_scaling_factor=self.routed_scaling_factor,
|
317
326
|
prefix=add_prefix("experts", prefix),
|
318
327
|
**(
|
@@ -354,6 +363,7 @@ class DeepseekV2MoE(nn.Module):
|
|
354
363
|
self.shared_experts.gate_up_proj.quant_method, "quant_config"
|
355
364
|
) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
|
356
365
|
"awq",
|
366
|
+
"awq_marlin",
|
357
367
|
"moe_wna16",
|
358
368
|
}
|
359
369
|
self.shared_experts_is_int8 = (
|
@@ -437,21 +447,22 @@ class DeepseekV2MoE(nn.Module):
|
|
437
447
|
def forward_normal_dual_stream(
|
438
448
|
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
439
449
|
) -> torch.Tensor:
|
440
|
-
# router_logits: (num_tokens, n_experts)
|
441
|
-
router_logits = self.gate(hidden_states)
|
442
450
|
|
443
451
|
current_stream = torch.cuda.current_stream()
|
444
452
|
self.alt_stream.wait_stream(current_stream)
|
445
453
|
shared_output = self._forward_shared_experts(hidden_states)
|
446
454
|
|
447
455
|
with torch.cuda.stream(self.alt_stream):
|
456
|
+
# router_logits: (num_tokens, n_experts)
|
457
|
+
router_logits = self.gate(hidden_states)
|
458
|
+
topk_output = self.topk(hidden_states, router_logits)
|
448
459
|
final_hidden_states = self.experts(
|
449
|
-
hidden_states=hidden_states,
|
460
|
+
hidden_states=hidden_states, topk_output=topk_output
|
450
461
|
)
|
451
462
|
if not _is_cuda:
|
452
463
|
final_hidden_states *= self.routed_scaling_factor
|
453
464
|
current_stream.wait_stream(self.alt_stream)
|
454
|
-
final_hidden_states
|
465
|
+
final_hidden_states += shared_output
|
455
466
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
456
467
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
457
468
|
return final_hidden_states
|
@@ -462,13 +473,14 @@ class DeepseekV2MoE(nn.Module):
|
|
462
473
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
463
474
|
self.shared_experts.gate_up_proj
|
464
475
|
):
|
465
|
-
return self.forward_cpu(hidden_states)
|
476
|
+
return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
|
466
477
|
|
467
478
|
shared_output = self._forward_shared_experts(hidden_states)
|
468
479
|
# router_logits: (num_tokens, n_experts)
|
469
480
|
router_logits = self.gate(hidden_states)
|
481
|
+
topk_output = self.topk(hidden_states, router_logits)
|
470
482
|
final_hidden_states = self.experts(
|
471
|
-
hidden_states=hidden_states,
|
483
|
+
hidden_states=hidden_states, topk_output=topk_output
|
472
484
|
)
|
473
485
|
if not _is_cuda and not _use_aiter:
|
474
486
|
# fused in biased_grouped_topk so we can skip here
|
@@ -479,11 +491,14 @@ class DeepseekV2MoE(nn.Module):
|
|
479
491
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
480
492
|
return final_hidden_states
|
481
493
|
|
482
|
-
def forward_cpu(
|
494
|
+
def forward_cpu(
|
495
|
+
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
496
|
+
) -> torch.Tensor:
|
483
497
|
# router_logits: (num_tokens, n_experts)
|
484
498
|
router_logits = self.gate(hidden_states)
|
499
|
+
topk_output = self.topk(hidden_states, router_logits)
|
485
500
|
fused_experts_out = self.experts(
|
486
|
-
hidden_states=hidden_states,
|
501
|
+
hidden_states=hidden_states, topk_output=topk_output
|
487
502
|
)
|
488
503
|
|
489
504
|
assert use_intel_amx_backend(
|
@@ -528,30 +543,21 @@ class DeepseekV2MoE(nn.Module):
|
|
528
543
|
None, # a2_scale
|
529
544
|
True, # is_vnni
|
530
545
|
)
|
531
|
-
if self.tp_size > 1 and not
|
546
|
+
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
532
547
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
533
548
|
return final_hidden_states
|
534
549
|
|
535
550
|
def forward_deepep(
|
536
551
|
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
537
552
|
) -> torch.Tensor:
|
538
|
-
forward_mode = forward_batch.forward_mode
|
539
553
|
shared_output = None
|
540
|
-
if
|
554
|
+
if hidden_states.shape[0] > 0:
|
541
555
|
# router_logits: (num_tokens, n_experts)
|
542
556
|
router_logits = self.gate(hidden_states)
|
543
557
|
shared_output = self._forward_shared_experts(hidden_states)
|
544
|
-
topk_weights, topk_idx =
|
545
|
-
hidden_states
|
546
|
-
router_logits
|
547
|
-
top_k=self.top_k,
|
548
|
-
use_grouped_topk=True,
|
549
|
-
renormalize=self.renormalize,
|
550
|
-
topk_group=self.topk_group,
|
551
|
-
num_expert_group=self.num_expert_group,
|
552
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
553
|
-
correction_bias=self.correction_bias,
|
554
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
558
|
+
topk_weights, topk_idx, _ = self.topk(
|
559
|
+
hidden_states,
|
560
|
+
router_logits,
|
555
561
|
num_token_non_padded=forward_batch.num_token_non_padded,
|
556
562
|
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
557
563
|
layer_id=self.layer_id,
|
@@ -641,17 +647,9 @@ class DeepseekV2MoE(nn.Module):
|
|
641
647
|
with get_global_expert_distribution_recorder().with_current_layer(
|
642
648
|
self.layer_id
|
643
649
|
):
|
644
|
-
state.topk_weights_local, state.topk_idx_local =
|
650
|
+
state.topk_weights_local, state.topk_idx_local, _ = self.topk(
|
645
651
|
hidden_states=hidden_states,
|
646
652
|
router_logits=router_logits,
|
647
|
-
top_k=self.top_k,
|
648
|
-
use_grouped_topk=True,
|
649
|
-
renormalize=self.renormalize,
|
650
|
-
topk_group=self.topk_group,
|
651
|
-
num_expert_group=self.num_expert_group,
|
652
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
653
|
-
correction_bias=self.correction_bias,
|
654
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
655
653
|
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
656
654
|
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
657
655
|
layer_id=self.layer_id,
|
@@ -926,7 +924,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
926
924
|
has_fused_proj
|
927
925
|
and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
|
928
926
|
and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
|
929
|
-
in {"awq", "moe_wna16"}
|
927
|
+
in {"awq", "awq_marlin", "moe_wna16"}
|
930
928
|
)
|
931
929
|
self.use_min_latency_fused_a_gemm = (
|
932
930
|
has_fused_proj
|
@@ -1151,7 +1149,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1151
1149
|
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
1152
1150
|
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
1153
1151
|
latent_cache = latent_cache.unsqueeze(1)
|
1154
|
-
kv_a = self.kv_a_layernorm(kv_a
|
1152
|
+
kv_a = self.kv_a_layernorm(kv_a)
|
1155
1153
|
kv = self.kv_b_proj(kv_a)[0]
|
1156
1154
|
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
1157
1155
|
k_nope = kv[..., : self.qk_nope_head_dim]
|
@@ -1690,7 +1688,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1690
1688
|
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
1691
1689
|
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
1692
1690
|
latent_cache = latent_cache.unsqueeze(1)
|
1693
|
-
kv_a = self.kv_a_layernorm(kv_a
|
1691
|
+
kv_a = self.kv_a_layernorm(kv_a)
|
1694
1692
|
kv = self.kv_b_proj(kv_a)[0]
|
1695
1693
|
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
1696
1694
|
k_nope = kv[..., : self.qk_nope_head_dim]
|
@@ -2172,7 +2170,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2172
2170
|
)
|
2173
2171
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
2174
2172
|
# AWQ compatible
|
2175
|
-
if _is_cuda:
|
2173
|
+
if _is_cuda or _is_hip:
|
2176
2174
|
w = awq_dequantize(
|
2177
2175
|
self_attn.kv_b_proj.qweight,
|
2178
2176
|
self_attn.kv_b_proj.scales,
|
@@ -2434,154 +2432,175 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2434
2432
|
assert self.num_fused_shared_experts == 1
|
2435
2433
|
log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
|
2436
2434
|
|
2437
|
-
|
2438
|
-
|
2439
|
-
|
2440
|
-
|
2441
|
-
|
2442
|
-
|
2443
|
-
|
2444
|
-
|
2435
|
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
2436
|
+
futures = []
|
2437
|
+
params_dict = dict(self.named_parameters())
|
2438
|
+
weight_names = []
|
2439
|
+
for name, loaded_weight in weights:
|
2440
|
+
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
|
2441
|
+
name = name.replace(
|
2442
|
+
"mlp.shared_experts",
|
2443
|
+
f"mlp.experts.{self.config.n_routed_experts}",
|
2444
|
+
)
|
2445
2445
|
|
2446
|
-
|
2446
|
+
weight_names.append(name)
|
2447
2447
|
|
2448
|
-
|
2449
|
-
|
2450
|
-
|
2451
|
-
|
2452
|
-
|
2453
|
-
|
2454
|
-
|
2455
|
-
|
2456
|
-
|
2457
|
-
|
2458
|
-
|
2459
|
-
|
2460
|
-
|
2448
|
+
if not is_nextn:
|
2449
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
2450
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
2451
|
+
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
2452
|
+
name_list = name.split(".")
|
2453
|
+
if (
|
2454
|
+
len(name_list) >= 3
|
2455
|
+
and int(name_list[2]) >= self.config.num_hidden_layers
|
2456
|
+
):
|
2457
|
+
continue
|
2458
|
+
else:
|
2459
|
+
if not name.startswith(nextn_layer_prefix):
|
2460
|
+
continue
|
2461
2461
|
|
2462
|
-
|
2463
|
-
|
2464
|
-
|
2462
|
+
# Use shared head and embed weights from target model
|
2463
|
+
if "shared_head.head" in name or "embed_tokens" in name:
|
2464
|
+
continue
|
2465
2465
|
|
2466
|
-
|
2467
|
-
|
2468
|
-
|
2469
|
-
|
2470
|
-
|
2471
|
-
|
2472
|
-
|
2473
|
-
|
2474
|
-
|
2475
|
-
|
2476
|
-
|
2477
|
-
|
2478
|
-
continue
|
2479
|
-
for param_name, weight_name, shard_id in stacked_params_mapping:
|
2480
|
-
# Skip non-stacked layers and experts (experts handled below).
|
2481
|
-
if weight_name not in name:
|
2482
|
-
continue
|
2483
|
-
# We have mlp.experts[0].gate_proj in the checkpoint.
|
2484
|
-
# Since we handle the experts below in expert_params_mapping,
|
2485
|
-
# we need to skip here BEFORE we update the name, otherwise
|
2486
|
-
# name will be updated to mlp.experts[0].gate_up_proj, which
|
2487
|
-
# will then be updated below in expert_params_mapping
|
2488
|
-
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
2489
|
-
if ("mlp.experts." in name) and name not in params_dict:
|
2490
|
-
continue
|
2491
|
-
name = name.replace(weight_name, param_name)
|
2492
|
-
# Skip loading extra bias for GPTQ models.
|
2493
|
-
if name.endswith(".bias") and name not in params_dict:
|
2466
|
+
is_decoder = True
|
2467
|
+
# For nextn specific weights
|
2468
|
+
for weight_name in nextn_spec_weight_names:
|
2469
|
+
if weight_name in name:
|
2470
|
+
name = name.replace(nextn_layer_prefix, "model")
|
2471
|
+
is_decoder = False
|
2472
|
+
break
|
2473
|
+
# For decoder layer weights
|
2474
|
+
if is_decoder:
|
2475
|
+
name = name.replace(nextn_layer_prefix, "model.decoder")
|
2476
|
+
|
2477
|
+
if "rotary_emb.inv_freq" in name:
|
2494
2478
|
continue
|
2495
|
-
|
2496
|
-
|
2497
|
-
weight_loader(param, loaded_weight, shard_id)
|
2498
|
-
break
|
2499
|
-
else:
|
2500
|
-
for mapping in expert_params_mapping:
|
2501
|
-
param_name, weight_name, expert_id, shard_id = mapping
|
2479
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
2480
|
+
# Skip non-stacked layers and experts (experts handled below).
|
2502
2481
|
if weight_name not in name:
|
2503
2482
|
continue
|
2483
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
2484
|
+
# Since we handle the experts below in expert_params_mapping,
|
2485
|
+
# we need to skip here BEFORE we update the name, otherwise
|
2486
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
2487
|
+
# will then be updated below in expert_params_mapping
|
2488
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
2489
|
+
if ("mlp.experts." in name) and name not in params_dict:
|
2490
|
+
continue
|
2504
2491
|
name = name.replace(weight_name, param_name)
|
2492
|
+
# Skip loading extra bias for GPTQ models.
|
2493
|
+
if name.endswith(".bias") and name not in params_dict:
|
2494
|
+
continue
|
2505
2495
|
param = params_dict[name]
|
2506
2496
|
weight_loader = param.weight_loader
|
2507
|
-
|
2508
|
-
param,
|
2509
|
-
loaded_weight,
|
2510
|
-
name,
|
2511
|
-
shard_id=shard_id,
|
2512
|
-
expert_id=expert_id,
|
2497
|
+
futures.append(
|
2498
|
+
executor.submit(weight_loader, param, loaded_weight, shard_id)
|
2513
2499
|
)
|
2514
2500
|
break
|
2515
2501
|
else:
|
2516
|
-
|
2517
|
-
|
2518
|
-
|
2519
|
-
|
2520
|
-
|
2521
|
-
|
2522
|
-
|
2523
|
-
|
2524
|
-
|
2525
|
-
|
2526
|
-
|
2527
|
-
|
2528
|
-
|
2529
|
-
|
2530
|
-
|
2531
|
-
|
2502
|
+
for mapping in expert_params_mapping:
|
2503
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
2504
|
+
if weight_name not in name:
|
2505
|
+
continue
|
2506
|
+
name = name.replace(weight_name, param_name)
|
2507
|
+
param = params_dict[name]
|
2508
|
+
weight_loader = param.weight_loader
|
2509
|
+
futures.append(
|
2510
|
+
executor.submit(
|
2511
|
+
weight_loader,
|
2512
|
+
param,
|
2513
|
+
loaded_weight,
|
2514
|
+
name,
|
2515
|
+
shard_id=shard_id,
|
2516
|
+
expert_id=expert_id,
|
2517
|
+
)
|
2532
2518
|
)
|
2533
|
-
|
2534
|
-
|
2535
|
-
|
2536
|
-
|
2537
|
-
|
2519
|
+
break
|
2520
|
+
else:
|
2521
|
+
# Skip loading extra bias for GPTQ models.
|
2522
|
+
if name.endswith(".bias") and name not in params_dict:
|
2523
|
+
continue
|
2524
|
+
if fuse_qkv_a_proj and (
|
2525
|
+
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
2538
2526
|
):
|
2539
|
-
|
2540
|
-
|
2541
|
-
|
2542
|
-
if self.quant_config is not None and (
|
2543
|
-
self.quant_config.get_name() == "awq"
|
2544
|
-
or self.quant_config.get_name() == "moe_wna16"
|
2545
|
-
):
|
2546
|
-
cat_dim = 1
|
2547
|
-
fused_weight = torch.cat(
|
2548
|
-
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
|
2549
|
-
)
|
2550
|
-
param_name = (
|
2551
|
-
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
2527
|
+
cached_a_proj[name] = loaded_weight
|
2528
|
+
q_a_proj_name = (
|
2529
|
+
name
|
2552
2530
|
if "q_a_proj" in name
|
2553
|
-
else name.replace(
|
2554
|
-
|
2555
|
-
|
2531
|
+
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
2532
|
+
)
|
2533
|
+
kv_a_proj_name = (
|
2534
|
+
name
|
2535
|
+
if "kv_a_proj_with_mqa" in name
|
2536
|
+
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
2556
2537
|
)
|
2557
|
-
param = params_dict[param_name]
|
2558
2538
|
|
2539
|
+
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
2540
|
+
if (
|
2541
|
+
q_a_proj_name in cached_a_proj
|
2542
|
+
and kv_a_proj_name in cached_a_proj
|
2543
|
+
):
|
2544
|
+
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
2545
|
+
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
2546
|
+
cat_dim = 0
|
2547
|
+
if self.quant_config is not None and (
|
2548
|
+
self.quant_config.get_name() == "awq"
|
2549
|
+
or self.quant_config.get_name() == "awq_marlin"
|
2550
|
+
or self.quant_config.get_name() == "moe_wna16"
|
2551
|
+
):
|
2552
|
+
cat_dim = 1
|
2553
|
+
fused_weight = torch.cat(
|
2554
|
+
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
|
2555
|
+
)
|
2556
|
+
param_name = (
|
2557
|
+
name.replace(
|
2558
|
+
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
2559
|
+
)
|
2560
|
+
if "q_a_proj" in name
|
2561
|
+
else name.replace(
|
2562
|
+
"kv_a_proj_with_mqa",
|
2563
|
+
"fused_qkv_a_proj_with_mqa",
|
2564
|
+
)
|
2565
|
+
)
|
2566
|
+
param = params_dict[param_name]
|
2567
|
+
|
2568
|
+
weight_loader = getattr(
|
2569
|
+
param, "weight_loader", default_weight_loader
|
2570
|
+
)
|
2571
|
+
futures.append(
|
2572
|
+
executor.submit(weight_loader, param, fused_weight)
|
2573
|
+
)
|
2574
|
+
cached_a_proj.pop(q_a_proj_name)
|
2575
|
+
cached_a_proj.pop(kv_a_proj_name)
|
2576
|
+
else:
|
2577
|
+
if (
|
2578
|
+
"k_scale" in name or "v_scale" in name
|
2579
|
+
) and name not in params_dict:
|
2580
|
+
# modelopt attn kv scale is named differently
|
2581
|
+
for scale in ["k_scale", "v_scale"]:
|
2582
|
+
if scale in name:
|
2583
|
+
name = name.replace(
|
2584
|
+
f"{scale[0]}_proj", "attn_mqa"
|
2585
|
+
)
|
2586
|
+
break
|
2587
|
+
if name not in params_dict:
|
2588
|
+
# modelopt ckpt contains not needed weights for MTP module:
|
2589
|
+
# model.decoder.self_attn.attn_mqa.v_scale and
|
2590
|
+
# model.decoder.self_attn.attn_mqa.k_scale
|
2591
|
+
logger.warning(f"{name} not found in params_dict.")
|
2592
|
+
continue
|
2593
|
+
param = params_dict[name]
|
2559
2594
|
weight_loader = getattr(
|
2560
2595
|
param, "weight_loader", default_weight_loader
|
2561
2596
|
)
|
2562
|
-
|
2563
|
-
|
2564
|
-
|
2565
|
-
|
2566
|
-
|
2567
|
-
|
2568
|
-
|
2569
|
-
# modelopt attn kv scale is named differently
|
2570
|
-
for scale in ["k_scale", "v_scale"]:
|
2571
|
-
if scale in name:
|
2572
|
-
name = name.replace(f"{scale[0]}_proj", "attn_mqa")
|
2573
|
-
break
|
2574
|
-
if name not in params_dict:
|
2575
|
-
# modelopt ckpt contains not needed weights for MTP module:
|
2576
|
-
# model.decoder.self_attn.attn_mqa.v_scale and
|
2577
|
-
# model.decoder.self_attn.attn_mqa.k_scale
|
2578
|
-
logger.warning(f"{name} not found in params_dict.")
|
2579
|
-
continue
|
2580
|
-
param = params_dict[name]
|
2581
|
-
weight_loader = getattr(
|
2582
|
-
param, "weight_loader", default_weight_loader
|
2583
|
-
)
|
2584
|
-
weight_loader(param, loaded_weight)
|
2597
|
+
futures.append(
|
2598
|
+
executor.submit(weight_loader, param, loaded_weight)
|
2599
|
+
)
|
2600
|
+
|
2601
|
+
# Wait for all tasks to complete and raise any exceptions.
|
2602
|
+
for future in concurrent.futures.as_completed(futures):
|
2603
|
+
future.result()
|
2585
2604
|
|
2586
2605
|
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
|
2587
2606
|
|
@@ -260,7 +260,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|
260
260
|
def get_image_feature(self, items: List[MultimodalDataItem]):
|
261
261
|
|
262
262
|
images_spatial_crop = torch.cat(
|
263
|
-
[item.
|
263
|
+
[item.images_spatial_crop for item in items], dim=0
|
264
264
|
)
|
265
265
|
|
266
266
|
assert images_spatial_crop.dim() == 3
|
@@ -268,9 +268,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|
268
268
|
# TODO: can it be batched ?
|
269
269
|
images_in_this_batch = []
|
270
270
|
for item in items:
|
271
|
-
assert item.
|
271
|
+
assert item.feature.dim() == 4
|
272
272
|
image_feature = self.vision.forward_features(
|
273
|
-
item.
|
273
|
+
item.feature.type(next(self.vision.parameters()).dtype).to(
|
274
274
|
device=next(self.vision.parameters()).device
|
275
275
|
)
|
276
276
|
)
|
@@ -278,8 +278,8 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|
278
278
|
_, hw, n_dim = images_embeds.shape
|
279
279
|
h = w = int(hw**0.5)
|
280
280
|
tile_index = 0
|
281
|
-
for jdx in range(item.
|
282
|
-
num_width_tiles, num_height_tiles = item.
|
281
|
+
for jdx in range(item.images_spatial_crop.shape[1]):
|
282
|
+
num_width_tiles, num_height_tiles = item.images_spatial_crop[0, jdx]
|
283
283
|
if num_width_tiles == 0 or num_height_tiles == 0:
|
284
284
|
break
|
285
285
|
num_tiles_in_image = num_width_tiles * num_height_tiles
|
sglang/srt/models/gemma.py
CHANGED
@@ -318,6 +318,54 @@ class GemmaForCausalLM(nn.Module):
|
|
318
318
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
319
319
|
)
|
320
320
|
|
321
|
+
@torch.no_grad()
|
322
|
+
def forward_split_prefill(
|
323
|
+
self,
|
324
|
+
input_ids: torch.Tensor,
|
325
|
+
positions: torch.Tensor,
|
326
|
+
forward_batch: ForwardBatch,
|
327
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
328
|
+
input_embeds: torch.Tensor = None,
|
329
|
+
):
|
330
|
+
start, end = split_interval
|
331
|
+
# embed
|
332
|
+
if start == 0:
|
333
|
+
if input_embeds is None:
|
334
|
+
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
335
|
+
else:
|
336
|
+
forward_batch.hidden_states = input_embeds
|
337
|
+
|
338
|
+
# Normalize the embedding by sqrt(hidden_size)
|
339
|
+
forward_batch.hidden_states *= self.model.config.hidden_size**0.5
|
340
|
+
|
341
|
+
# decoder layer
|
342
|
+
for i in range(start, end):
|
343
|
+
layer = self.model.layers[i]
|
344
|
+
forward_batch.hidden_states, forward_batch.residual = layer(
|
345
|
+
positions,
|
346
|
+
forward_batch.hidden_states,
|
347
|
+
forward_batch,
|
348
|
+
forward_batch.residual,
|
349
|
+
)
|
350
|
+
|
351
|
+
if end == self.model.config.num_hidden_layers:
|
352
|
+
# norm
|
353
|
+
forward_batch.hidden_states, _ = self.model.norm(
|
354
|
+
forward_batch.hidden_states, forward_batch.residual
|
355
|
+
)
|
356
|
+
|
357
|
+
# logits process
|
358
|
+
result = self.logits_processor(
|
359
|
+
input_ids,
|
360
|
+
forward_batch.hidden_states,
|
361
|
+
self.model.embed_tokens,
|
362
|
+
forward_batch,
|
363
|
+
)
|
364
|
+
else:
|
365
|
+
result = None
|
366
|
+
|
367
|
+
return result
|
368
|
+
|
321
369
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
322
370
|
stacked_params_mapping = [
|
323
371
|
# (param_name, shard_name, shard_id)
|