sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -209,6 +209,17 @@ def get_quant_config(
|
|
209
209
|
config["adapter_name_or_path"] = model_name_or_path
|
210
210
|
elif model_config.quantization == "modelopt":
|
211
211
|
if config["producer"]["name"] == "modelopt":
|
212
|
+
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
213
|
+
if config["quantization"]["quant_algo"] is None:
|
214
|
+
if (
|
215
|
+
model_config.hf_config.architectures[0]
|
216
|
+
!= "LlamaForCausalLMEagle3"
|
217
|
+
):
|
218
|
+
raise ValueError(
|
219
|
+
f"Invalid quant_config, quantization method: {model_config.quantization},"
|
220
|
+
f"hf architectures: {model_config.hf_config.architectures[0]}. "
|
221
|
+
)
|
222
|
+
return None
|
212
223
|
if "FP4" in config["quantization"]["quant_algo"]:
|
213
224
|
return ModelOptFp4Config.from_config(config)
|
214
225
|
else:
|
@@ -449,10 +460,12 @@ def safetensors_weights_iterator(
|
|
449
460
|
if disable_mmap:
|
450
461
|
with open(st_file, "rb") as f:
|
451
462
|
result = safetensors.torch.load(f.read())
|
463
|
+
for name, param in result.items():
|
464
|
+
yield name, param
|
452
465
|
else:
|
453
|
-
|
454
|
-
|
455
|
-
|
466
|
+
with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
|
467
|
+
for name in f.keys():
|
468
|
+
yield name, f.get_tensor(name)
|
456
469
|
|
457
470
|
|
458
471
|
def multi_thread_safetensors_weights_iterator(
|
@@ -485,7 +498,8 @@ def multi_thread_safetensors_weights_iterator(
|
|
485
498
|
with open(st_file, "rb") as f:
|
486
499
|
result = safetensors.torch.load(f.read())
|
487
500
|
else:
|
488
|
-
|
501
|
+
with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
|
502
|
+
result = {k: f.get_tensor(k) for k in f.keys()}
|
489
503
|
|
490
504
|
return result
|
491
505
|
|
@@ -947,3 +961,57 @@ def kv_cache_scales_loader(
|
|
947
961
|
tp_rank,
|
948
962
|
)
|
949
963
|
return []
|
964
|
+
|
965
|
+
|
966
|
+
def get_actual_shard_size(shard_size, weight_start, weight_end):
|
967
|
+
if weight_end < weight_start:
|
968
|
+
return 0
|
969
|
+
|
970
|
+
return min(shard_size, weight_end - weight_start)
|
971
|
+
|
972
|
+
|
973
|
+
def reset_param_data_if_needed(param_data, dim, start, length):
|
974
|
+
if length == 0:
|
975
|
+
return
|
976
|
+
|
977
|
+
assert length > 0, f"Length should be positive, but got {length}"
|
978
|
+
|
979
|
+
param_data.narrow(dim, start, length).zero_()
|
980
|
+
return
|
981
|
+
|
982
|
+
|
983
|
+
def narrow_padded_param_and_loaded_weight(
|
984
|
+
param_data,
|
985
|
+
loaded_weight,
|
986
|
+
param_data_start,
|
987
|
+
weight_start,
|
988
|
+
dim,
|
989
|
+
shard_size,
|
990
|
+
narrow_weight=True,
|
991
|
+
):
|
992
|
+
actual_shard_size = get_actual_shard_size(
|
993
|
+
shard_size, weight_start, loaded_weight.size(dim)
|
994
|
+
)
|
995
|
+
|
996
|
+
if narrow_weight:
|
997
|
+
if actual_shard_size > 0:
|
998
|
+
loaded_weight = loaded_weight.narrow(dim, weight_start, actual_shard_size)
|
999
|
+
else:
|
1000
|
+
# No real data to load; create a dummy tensor filled with zeros
|
1001
|
+
loaded_weight = torch.zeros_like(
|
1002
|
+
param_data.narrow(dim, param_data_start, actual_shard_size)
|
1003
|
+
)
|
1004
|
+
|
1005
|
+
# [Note] Reset padded weights to zero.
|
1006
|
+
# If the actual shard size is less than the shard size, we need to reset
|
1007
|
+
# the padded param_data to zero and then copy the loaded_weight into it.
|
1008
|
+
reset_param_data_if_needed(
|
1009
|
+
param_data,
|
1010
|
+
dim,
|
1011
|
+
param_data_start + actual_shard_size,
|
1012
|
+
shard_size - actual_shard_size,
|
1013
|
+
)
|
1014
|
+
|
1015
|
+
param_data = param_data.narrow(dim, param_data_start, actual_shard_size)
|
1016
|
+
|
1017
|
+
return param_data, loaded_weight
|
@@ -21,6 +21,7 @@ from torch import nn
|
|
21
21
|
from transformers import PretrainedConfig
|
22
22
|
|
23
23
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
24
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
24
25
|
from sglang.srt.layers.layernorm import RMSNorm
|
25
26
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
26
27
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -28,9 +29,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
28
29
|
ParallelLMHead,
|
29
30
|
VocabParallelEmbedding,
|
30
31
|
)
|
31
|
-
from sglang.srt.managers.expert_distribution import (
|
32
|
-
get_global_expert_distribution_recorder,
|
33
|
-
)
|
34
32
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
35
33
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
36
34
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -32,7 +32,11 @@ from sglang.srt.distributed import (
|
|
32
32
|
parallel_state,
|
33
33
|
tensor_model_parallel_all_reduce,
|
34
34
|
)
|
35
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
36
|
+
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
37
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
35
38
|
from sglang.srt.layers.activation import SiluAndMul
|
39
|
+
from sglang.srt.layers.amx_utils import PackWeightMethod
|
36
40
|
from sglang.srt.layers.communicator import (
|
37
41
|
LayerCommunicator,
|
38
42
|
LayerScatterModes,
|
@@ -77,11 +81,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
77
81
|
ParallelLMHead,
|
78
82
|
VocabParallelEmbedding,
|
79
83
|
)
|
80
|
-
from sglang.srt.managers.expert_distribution import (
|
81
|
-
get_global_expert_distribution_recorder,
|
82
|
-
)
|
83
|
-
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
84
|
-
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
85
84
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
86
85
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
87
86
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
@@ -93,17 +92,18 @@ from sglang.srt.utils import (
|
|
93
92
|
BumpAllocator,
|
94
93
|
DeepEPMode,
|
95
94
|
LazyValue,
|
96
|
-
PackWeightMethod,
|
97
95
|
add_prefix,
|
98
96
|
bind_or_assign,
|
99
97
|
cpu_has_amx_support,
|
100
98
|
get_bool_env_var,
|
99
|
+
get_device_sm,
|
101
100
|
get_int_env_var,
|
102
101
|
is_cpu,
|
103
102
|
is_cuda,
|
104
103
|
is_hip,
|
105
104
|
is_non_idle_and_non_empty,
|
106
105
|
log_info_on_rank0,
|
106
|
+
use_intel_amx_backend,
|
107
107
|
)
|
108
108
|
|
109
109
|
_is_hip = is_hip()
|
@@ -112,9 +112,16 @@ _is_fp8_fnuz = is_fp8_fnuz()
|
|
112
112
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
113
113
|
_is_cpu_amx_available = cpu_has_amx_support()
|
114
114
|
_is_cpu = is_cpu()
|
115
|
+
_device_sm = get_device_sm()
|
115
116
|
|
116
117
|
if _is_cuda:
|
117
|
-
from sgl_kernel import
|
118
|
+
from sgl_kernel import (
|
119
|
+
awq_dequantize,
|
120
|
+
bmm_fp8,
|
121
|
+
dsv3_fused_a_gemm,
|
122
|
+
dsv3_router_gemm,
|
123
|
+
merge_state_v2,
|
124
|
+
)
|
118
125
|
elif _is_cpu and _is_cpu_amx_available:
|
119
126
|
pass
|
120
127
|
else:
|
@@ -203,8 +210,10 @@ class MoEGate(nn.Module):
|
|
203
210
|
self,
|
204
211
|
config,
|
205
212
|
prefix: str = "",
|
213
|
+
is_nextn: bool = False,
|
206
214
|
):
|
207
215
|
super().__init__()
|
216
|
+
self.is_nextn = is_nextn
|
208
217
|
self.weight = nn.Parameter(
|
209
218
|
torch.empty((config.n_routed_experts, config.hidden_size))
|
210
219
|
)
|
@@ -218,7 +227,7 @@ class MoEGate(nn.Module):
|
|
218
227
|
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
219
228
|
|
220
229
|
def forward(self, hidden_states):
|
221
|
-
if
|
230
|
+
if use_intel_amx_backend(self):
|
222
231
|
return torch.ops.sgl_kernel.weight_packed_linear(
|
223
232
|
hidden_states,
|
224
233
|
self.weight,
|
@@ -226,7 +235,21 @@ class MoEGate(nn.Module):
|
|
226
235
|
True, # is_vnni
|
227
236
|
)
|
228
237
|
|
229
|
-
|
238
|
+
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
|
239
|
+
if (
|
240
|
+
_is_cuda
|
241
|
+
and not self.is_nextn
|
242
|
+
and hidden_states.shape[0] < 4
|
243
|
+
and hidden_states.shape[1] == 7168
|
244
|
+
and self.weight.shape[0] == 256
|
245
|
+
and _device_sm >= 90
|
246
|
+
):
|
247
|
+
logits = dsv3_router_gemm(hidden_states, self.weight).to(
|
248
|
+
hidden_states.dtype
|
249
|
+
)
|
250
|
+
else:
|
251
|
+
logits = F.linear(hidden_states, self.weight, None)
|
252
|
+
|
230
253
|
return logits
|
231
254
|
|
232
255
|
|
@@ -239,6 +262,7 @@ class DeepseekV2MoE(nn.Module):
|
|
239
262
|
quant_config: Optional[QuantizationConfig] = None,
|
240
263
|
prefix: str = "",
|
241
264
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
265
|
+
is_nextn: bool = False,
|
242
266
|
):
|
243
267
|
super().__init__()
|
244
268
|
self.tp_size = get_tensor_model_parallel_world_size()
|
@@ -265,7 +289,9 @@ class DeepseekV2MoE(nn.Module):
|
|
265
289
|
"Only silu is supported for now."
|
266
290
|
)
|
267
291
|
|
268
|
-
self.gate = MoEGate(
|
292
|
+
self.gate = MoEGate(
|
293
|
+
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
294
|
+
)
|
269
295
|
|
270
296
|
self.experts = get_moe_impl_class()(
|
271
297
|
num_experts=config.n_routed_experts
|
@@ -300,6 +326,9 @@ class DeepseekV2MoE(nn.Module):
|
|
300
326
|
),
|
301
327
|
)
|
302
328
|
|
329
|
+
self.shared_experts_is_int8 = False
|
330
|
+
self.shared_experts_is_fp8 = False
|
331
|
+
self.shared_experts_weight_block_size = None
|
303
332
|
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
304
333
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
305
334
|
# disable tp for shared experts when enable deepep moe
|
@@ -316,6 +345,28 @@ class DeepseekV2MoE(nn.Module):
|
|
316
345
|
else {}
|
317
346
|
),
|
318
347
|
)
|
348
|
+
is_packed_weight = hasattr(
|
349
|
+
self.shared_experts.gate_up_proj.quant_method, "quant_config"
|
350
|
+
) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
|
351
|
+
"awq",
|
352
|
+
"moe_wna16",
|
353
|
+
}
|
354
|
+
self.shared_experts_is_int8 = (
|
355
|
+
not is_packed_weight
|
356
|
+
and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
|
357
|
+
)
|
358
|
+
self.shared_experts_is_fp8 = (
|
359
|
+
not is_packed_weight
|
360
|
+
and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
|
361
|
+
)
|
362
|
+
if self.shared_experts_is_fp8:
|
363
|
+
assert (
|
364
|
+
self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
|
365
|
+
== self.shared_experts.down_proj.quant_method.quant_config.weight_block_size
|
366
|
+
)
|
367
|
+
self.shared_experts_weight_block_size = (
|
368
|
+
self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
|
369
|
+
)
|
319
370
|
|
320
371
|
self.top_k = config.num_experts_per_tok
|
321
372
|
|
@@ -394,6 +445,11 @@ class DeepseekV2MoE(nn.Module):
|
|
394
445
|
return final_hidden_states
|
395
446
|
|
396
447
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
448
|
+
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
449
|
+
self.shared_experts.gate_up_proj
|
450
|
+
):
|
451
|
+
return self.forward_cpu(hidden_states)
|
452
|
+
|
397
453
|
shared_output = self._forward_shared_experts(hidden_states)
|
398
454
|
# router_logits: (num_tokens, n_experts)
|
399
455
|
router_logits = self.gate(hidden_states)
|
@@ -409,6 +465,59 @@ class DeepseekV2MoE(nn.Module):
|
|
409
465
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
410
466
|
return final_hidden_states
|
411
467
|
|
468
|
+
def forward_cpu(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
469
|
+
# router_logits: (num_tokens, n_experts)
|
470
|
+
router_logits = self.gate(hidden_states)
|
471
|
+
fused_experts_out = self.experts(
|
472
|
+
hidden_states=hidden_states, router_logits=router_logits
|
473
|
+
)
|
474
|
+
|
475
|
+
assert use_intel_amx_backend(
|
476
|
+
self.shared_experts.gate_up_proj
|
477
|
+
) == use_intel_amx_backend(self.shared_experts.down_proj)
|
478
|
+
# [Note] inplace should be False in fused_experts.
|
479
|
+
# If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
|
480
|
+
# While hidden_states is still needed in shared_expert.
|
481
|
+
final_hidden_states = torch.ops.sgl_kernel.shared_expert_cpu(
|
482
|
+
hidden_states,
|
483
|
+
self.shared_experts.gate_up_proj.weight,
|
484
|
+
self.shared_experts.down_proj.weight,
|
485
|
+
fused_experts_out,
|
486
|
+
self.routed_scaling_factor,
|
487
|
+
True, # inplace
|
488
|
+
self.shared_experts_is_int8, # use_int8_w8a8
|
489
|
+
self.shared_experts_is_fp8, # use_fp8_w8a16
|
490
|
+
(
|
491
|
+
self.shared_experts.gate_up_proj.weight_scale
|
492
|
+
if self.shared_experts_is_int8
|
493
|
+
else (
|
494
|
+
self.shared_experts.gate_up_proj.weight_scale_inv
|
495
|
+
if self.shared_experts_is_fp8
|
496
|
+
else None
|
497
|
+
)
|
498
|
+
), # w1_scale
|
499
|
+
(
|
500
|
+
self.shared_experts.down_proj.weight_scale
|
501
|
+
if self.shared_experts_is_int8
|
502
|
+
else (
|
503
|
+
self.shared_experts.down_proj.weight_scale_inv
|
504
|
+
if self.shared_experts_is_fp8
|
505
|
+
else None
|
506
|
+
)
|
507
|
+
), # w2_scale
|
508
|
+
(
|
509
|
+
self.shared_experts_weight_block_size
|
510
|
+
if self.shared_experts_is_fp8
|
511
|
+
else None
|
512
|
+
), # block_size
|
513
|
+
None, # a1_scale
|
514
|
+
None, # a2_scale
|
515
|
+
True, # is_vnni
|
516
|
+
)
|
517
|
+
if self.tp_size > 1:
|
518
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
519
|
+
return final_hidden_states
|
520
|
+
|
412
521
|
def forward_deepep(
|
413
522
|
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
414
523
|
) -> torch.Tensor:
|
@@ -456,7 +565,7 @@ class DeepseekV2MoE(nn.Module):
|
|
456
565
|
hidden_states=hidden_states,
|
457
566
|
topk_idx=topk_idx,
|
458
567
|
topk_weights=topk_weights,
|
459
|
-
|
568
|
+
forward_batch=forward_batch,
|
460
569
|
)
|
461
570
|
final_hidden_states = self.experts(
|
462
571
|
hidden_states=hidden_states,
|
@@ -467,14 +576,14 @@ class DeepseekV2MoE(nn.Module):
|
|
467
576
|
masked_m=masked_m,
|
468
577
|
expected_m=expected_m,
|
469
578
|
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
470
|
-
|
579
|
+
forward_batch=forward_batch,
|
471
580
|
)
|
472
581
|
if self.ep_size > 1:
|
473
582
|
final_hidden_states = self.deepep_dispatcher.combine(
|
474
583
|
hidden_states=final_hidden_states,
|
475
584
|
topk_idx=topk_idx,
|
476
585
|
topk_weights=topk_weights,
|
477
|
-
|
586
|
+
forward_batch=forward_batch,
|
478
587
|
)
|
479
588
|
|
480
589
|
if shared_output is not None:
|
@@ -549,7 +658,7 @@ class DeepseekV2MoE(nn.Module):
|
|
549
658
|
hidden_states=state.hidden_states_mlp_input,
|
550
659
|
topk_idx=state.pop("topk_idx_local"),
|
551
660
|
topk_weights=state.pop("topk_weights_local"),
|
552
|
-
|
661
|
+
forward_batch=state.forward_batch,
|
553
662
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
554
663
|
)
|
555
664
|
|
@@ -581,7 +690,7 @@ class DeepseekV2MoE(nn.Module):
|
|
581
690
|
masked_m=state.pop("masked_m"),
|
582
691
|
expected_m=state.pop("expected_m"),
|
583
692
|
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
584
|
-
|
693
|
+
forward_batch=state.forward_batch,
|
585
694
|
)
|
586
695
|
|
587
696
|
def op_combine_a(self, state):
|
@@ -590,7 +699,7 @@ class DeepseekV2MoE(nn.Module):
|
|
590
699
|
hidden_states=state.pop("hidden_states_experts_output"),
|
591
700
|
topk_idx=state.pop("topk_idx_dispatched"),
|
592
701
|
topk_weights=state.pop("topk_weights_dispatched"),
|
593
|
-
|
702
|
+
forward_batch=state.forward_batch,
|
594
703
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
595
704
|
)
|
596
705
|
|
@@ -793,33 +902,56 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
793
902
|
# If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
|
794
903
|
# which requires self.w_kc and self.w_vc to be packed.
|
795
904
|
# If not, we will use torch.bmm and weight shouldn't be packed in this case
|
796
|
-
|
797
|
-
|
798
|
-
and _is_cpu
|
799
|
-
and _is_cpu_amx_available
|
800
|
-
):
|
905
|
+
has_fused_proj = hasattr(self, "fused_qkv_a_proj_with_mqa")
|
906
|
+
if has_fused_proj and _is_cpu and _is_cpu_amx_available:
|
801
907
|
self.quant_method = PackWeightMethod(
|
802
908
|
weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
|
803
909
|
)
|
804
910
|
|
911
|
+
is_packed_weight = (
|
912
|
+
has_fused_proj
|
913
|
+
and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
|
914
|
+
and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
|
915
|
+
in {"awq", "moe_wna16"}
|
916
|
+
)
|
917
|
+
self.use_min_latency_fused_a_gemm = (
|
918
|
+
has_fused_proj
|
919
|
+
and not is_packed_weight
|
920
|
+
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
|
921
|
+
and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
|
922
|
+
and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
|
923
|
+
and _is_cuda
|
924
|
+
and _device_sm >= 90
|
925
|
+
)
|
926
|
+
|
805
927
|
self.qkv_proj_with_rope_is_int8 = (
|
806
|
-
|
928
|
+
has_fused_proj
|
929
|
+
and not is_packed_weight
|
807
930
|
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
|
808
931
|
)
|
809
932
|
self.qkv_proj_with_rope_is_fp8 = (
|
810
|
-
|
933
|
+
has_fused_proj
|
934
|
+
and not is_packed_weight
|
811
935
|
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
|
812
936
|
)
|
813
937
|
|
814
938
|
self.weight_block_size = None
|
815
|
-
if self.qkv_proj_with_rope_is_fp8:
|
816
|
-
assert (
|
817
|
-
self.fused_qkv_a_proj_with_mqa.quant_method
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
939
|
+
if self.qkv_proj_with_rope_is_fp8 and _is_cpu and _is_cpu_amx_available:
|
940
|
+
assert getattr(
|
941
|
+
self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
|
942
|
+
) == getattr(self.q_b_proj.quant_method, "block_quant", False)
|
943
|
+
use_block_quant = getattr(
|
944
|
+
self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
|
945
|
+
)
|
946
|
+
|
947
|
+
if use_block_quant:
|
948
|
+
assert (
|
949
|
+
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
950
|
+
== self.q_b_proj.quant_method.quant_config.weight_block_size
|
951
|
+
)
|
952
|
+
self.weight_block_size = (
|
953
|
+
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
954
|
+
)
|
823
955
|
|
824
956
|
def dispatch_attn_forward_method(
|
825
957
|
self, forward_batch: ForwardBatch
|
@@ -834,14 +966,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
834
966
|
else:
|
835
967
|
return AttnForwardMethod.MLA
|
836
968
|
else:
|
837
|
-
if hasattr(self, "fused_qkv_a_proj_with_mqa") and
|
838
|
-
self
|
969
|
+
if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
|
970
|
+
self
|
839
971
|
):
|
840
972
|
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
841
973
|
else:
|
842
974
|
return AttnForwardMethod.MLA
|
843
975
|
|
844
|
-
if self.attention_backend == "
|
976
|
+
if self.attention_backend == "ascend":
|
977
|
+
return AttnForwardMethod.MLA
|
978
|
+
elif self.attention_backend == "flashinfer":
|
845
979
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
846
980
|
if (
|
847
981
|
not self.flashinfer_mla_disable_ragged
|
@@ -1041,7 +1175,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1041
1175
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
1042
1176
|
|
1043
1177
|
if self.q_lora_rank is not None:
|
1044
|
-
|
1178
|
+
if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
|
1179
|
+
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
1180
|
+
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
1181
|
+
)
|
1182
|
+
else:
|
1183
|
+
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
|
1184
|
+
q, latent_cache = fused_qkv_a_proj_out.split(
|
1045
1185
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
1046
1186
|
)
|
1047
1187
|
k_nope = latent_cache[..., : self.kv_lora_rank]
|
@@ -1302,8 +1442,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1302
1442
|
forward_batch: ForwardBatch,
|
1303
1443
|
zero_allocator: BumpAllocator,
|
1304
1444
|
):
|
1305
|
-
assert self.q_lora_rank is not None and
|
1306
|
-
self
|
1445
|
+
assert self.q_lora_rank is not None and use_intel_amx_backend(
|
1446
|
+
self
|
1307
1447
|
), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
|
1308
1448
|
|
1309
1449
|
q_input, k_input, v_input = (
|
@@ -1422,8 +1562,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1422
1562
|
def forward_absorb_fused_mla_rope_cpu_core(
|
1423
1563
|
self, q_input, k_input, v_input, forward_batch, zero_allocator
|
1424
1564
|
):
|
1425
|
-
assert self.q_lora_rank is not None and
|
1426
|
-
self
|
1565
|
+
assert self.q_lora_rank is not None and use_intel_amx_backend(
|
1566
|
+
self
|
1427
1567
|
), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
|
1428
1568
|
|
1429
1569
|
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
@@ -1643,6 +1783,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1643
1783
|
prefix=add_prefix("mlp", prefix),
|
1644
1784
|
layer_id=self.layer_id,
|
1645
1785
|
alt_stream=alt_stream,
|
1786
|
+
is_nextn=is_nextn,
|
1646
1787
|
)
|
1647
1788
|
else:
|
1648
1789
|
if enable_moe_dense_fully_dp():
|
@@ -1707,11 +1848,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1707
1848
|
hidden_states, residual, forward_batch
|
1708
1849
|
)
|
1709
1850
|
|
1710
|
-
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
|
1711
|
-
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
|
1712
|
-
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
|
1713
|
-
hidden_states = hidden_states.clone()
|
1714
|
-
|
1715
1851
|
return hidden_states, residual
|
1716
1852
|
|
1717
1853
|
def op_comm_prepare_attn(
|
@@ -1753,7 +1889,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1753
1889
|
and hidden_states.shape[0] == 0
|
1754
1890
|
):
|
1755
1891
|
state.hidden_states_mlp_output = self.mlp(
|
1756
|
-
hidden_states, state.forward_batch
|
1892
|
+
hidden_states, state.forward_batch
|
1757
1893
|
)
|
1758
1894
|
else:
|
1759
1895
|
state.hidden_states_mlp_output = hidden_states
|
@@ -2107,6 +2243,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2107
2243
|
)
|
2108
2244
|
if _is_hip:
|
2109
2245
|
self_attn.w_scale *= 2.0
|
2246
|
+
# TODO: remove this after adding FP8 support in bmm cpu kernel
|
2247
|
+
if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn:
|
2248
|
+
self_attn.w_kc = (
|
2249
|
+
self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale
|
2250
|
+
)
|
2251
|
+
self_attn.w_vc = (
|
2252
|
+
self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale
|
2253
|
+
)
|
2110
2254
|
else:
|
2111
2255
|
num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
|
2112
2256
|
num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
|
@@ -2219,6 +2363,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2219
2363
|
ckpt_up_proj_name="up_proj",
|
2220
2364
|
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
2221
2365
|
)
|
2366
|
+
if self.quant_config and self.quant_config.get_name() == "w4afp8":
|
2367
|
+
expert_params_mapping += (
|
2368
|
+
get_moe_impl_class().make_expert_input_scale_params_mapping(
|
2369
|
+
num_experts=self.config.n_routed_experts
|
2370
|
+
)
|
2371
|
+
)
|
2222
2372
|
|
2223
2373
|
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
2224
2374
|
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
@@ -253,11 +253,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|
253
253
|
weights_loader = getattr(param, "weight_loader", default_weight_loader)
|
254
254
|
weights_loader(param, loaded_weight)
|
255
255
|
|
256
|
-
def pad_input_ids(self, input_ids: List[int],
|
257
|
-
|
258
|
-
|
259
|
-
)
|
260
|
-
return helper.pad_input_tokens(input_ids, image_inputs)
|
256
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
257
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
258
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
261
259
|
|
262
260
|
def get_image_feature(self, items: List[MultimodalDataItem]):
|
263
261
|
|
@@ -166,8 +166,7 @@ class Gemma3Attention(nn.Module):
|
|
166
166
|
prefix=add_prefix("o_proj", prefix),
|
167
167
|
)
|
168
168
|
|
169
|
-
|
170
|
-
self.is_sliding = bool((layer_id + 1) % config.sliding_window_pattern)
|
169
|
+
self.is_sliding = config.layer_types[layer_id] == "sliding_attention"
|
171
170
|
|
172
171
|
# Initialize the rotary embedding.
|
173
172
|
if self.is_sliding:
|
@@ -62,7 +62,7 @@ class Gemma3nTextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
|
|
62
62
|
pass
|
63
63
|
|
64
64
|
|
65
|
-
class
|
65
|
+
class Gemma3nTextMLP(nn.Module):
|
66
66
|
def __init__(
|
67
67
|
self,
|
68
68
|
hidden_size: int,
|
@@ -514,10 +514,11 @@ class Gemma3nDecoderLayer(nn.Module):
|
|
514
514
|
prefix=add_prefix("self_attn", prefix),
|
515
515
|
)
|
516
516
|
|
517
|
+
intermediate_size = config.intermediate_size[layer_id]
|
517
518
|
activation_sparsity = config.activation_sparsity_pattern[layer_id]
|
518
|
-
self.mlp =
|
519
|
+
self.mlp = Gemma3nTextMLP(
|
519
520
|
hidden_size=self.hidden_size,
|
520
|
-
intermediate_size=
|
521
|
+
intermediate_size=intermediate_size,
|
521
522
|
hidden_activation=config.hidden_activation,
|
522
523
|
activation_sparsity=activation_sparsity,
|
523
524
|
quant_config=quant_config,
|
sglang/srt/models/gemma3n_mm.py
CHANGED
@@ -21,7 +21,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
21
21
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
22
22
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
23
23
|
from sglang.srt.managers.mm_utils import (
|
24
|
-
|
24
|
+
MultiModalityDataPaddingPatternMultimodalTokens,
|
25
25
|
general_mm_embed_routine,
|
26
26
|
)
|
27
27
|
from sglang.srt.managers.schedule_batch import (
|
@@ -244,26 +244,11 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|
244
244
|
def pad_input_ids(
|
245
245
|
self,
|
246
246
|
input_ids: List[int],
|
247
|
-
mm_inputs:
|
247
|
+
mm_inputs: MultimodalInputs,
|
248
248
|
) -> List[int]:
|
249
249
|
"""Pad input IDs with image and audio tokens."""
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
# Collect available media token pairs
|
254
|
-
media_token_pairs = []
|
255
|
-
for attr_name in ["im_start_id", "audio_start_id"]:
|
256
|
-
if hasattr(mm_inputs, attr_name):
|
257
|
-
start_id = getattr(mm_inputs, attr_name)
|
258
|
-
end_id = getattr(mm_inputs, attr_name.replace("start", "end"))
|
259
|
-
media_token_pairs.append((start_id, end_id))
|
260
|
-
|
261
|
-
# Apply padding pattern if we have media tokens
|
262
|
-
if media_token_pairs:
|
263
|
-
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
264
|
-
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
265
|
-
|
266
|
-
return input_ids
|
250
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
251
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
267
252
|
|
268
253
|
def get_input_embeddings(self) -> nn.Embedding:
|
269
254
|
return self.language_model.get_input_embeddings()
|
@@ -431,7 +416,6 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|
431
416
|
)
|
432
417
|
|
433
418
|
positions += 1
|
434
|
-
|
435
419
|
if input_ids is not None:
|
436
420
|
# Prepare per-layer inputs from inputs_ids
|
437
421
|
per_layer_inputs_mask = torch.logical_and(
|