sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +48 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +34 -0
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +36 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +11 -7
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +50 -13
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +77 -84
- sglang/srt/managers/scheduler.py +113 -59
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +181 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +69 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +200 -27
- sglang/srt/utils.py +306 -146
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -209,6 +209,17 @@ def get_quant_config(
|
|
209
209
|
config["adapter_name_or_path"] = model_name_or_path
|
210
210
|
elif model_config.quantization == "modelopt":
|
211
211
|
if config["producer"]["name"] == "modelopt":
|
212
|
+
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
213
|
+
if config["quantization"]["quant_algo"] is None:
|
214
|
+
if (
|
215
|
+
model_config.hf_config.architectures[0]
|
216
|
+
!= "LlamaForCausalLMEagle3"
|
217
|
+
):
|
218
|
+
raise ValueError(
|
219
|
+
f"Invalid quant_config, quantization method: {model_config.quantization},"
|
220
|
+
f"hf architectures: {model_config.hf_config.architectures[0]}. "
|
221
|
+
)
|
222
|
+
return None
|
212
223
|
if "FP4" in config["quantization"]["quant_algo"]:
|
213
224
|
return ModelOptFp4Config.from_config(config)
|
214
225
|
else:
|
@@ -449,10 +460,12 @@ def safetensors_weights_iterator(
|
|
449
460
|
if disable_mmap:
|
450
461
|
with open(st_file, "rb") as f:
|
451
462
|
result = safetensors.torch.load(f.read())
|
463
|
+
for name, param in result.items():
|
464
|
+
yield name, param
|
452
465
|
else:
|
453
|
-
|
454
|
-
|
455
|
-
|
466
|
+
with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
|
467
|
+
for name in f.keys():
|
468
|
+
yield name, f.get_tensor(name)
|
456
469
|
|
457
470
|
|
458
471
|
def multi_thread_safetensors_weights_iterator(
|
@@ -485,7 +498,8 @@ def multi_thread_safetensors_weights_iterator(
|
|
485
498
|
with open(st_file, "rb") as f:
|
486
499
|
result = safetensors.torch.load(f.read())
|
487
500
|
else:
|
488
|
-
|
501
|
+
with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
|
502
|
+
result = {k: f.get_tensor(k) for k in f.keys()}
|
489
503
|
|
490
504
|
return result
|
491
505
|
|
@@ -947,3 +961,57 @@ def kv_cache_scales_loader(
|
|
947
961
|
tp_rank,
|
948
962
|
)
|
949
963
|
return []
|
964
|
+
|
965
|
+
|
966
|
+
def get_actual_shard_size(shard_size, weight_start, weight_end):
|
967
|
+
if weight_end < weight_start:
|
968
|
+
return 0
|
969
|
+
|
970
|
+
return min(shard_size, weight_end - weight_start)
|
971
|
+
|
972
|
+
|
973
|
+
def reset_param_data_if_needed(param_data, dim, start, length):
|
974
|
+
if length == 0:
|
975
|
+
return
|
976
|
+
|
977
|
+
assert length > 0, f"Length should be positive, but got {length}"
|
978
|
+
|
979
|
+
param_data.narrow(dim, start, length).zero_()
|
980
|
+
return
|
981
|
+
|
982
|
+
|
983
|
+
def narrow_padded_param_and_loaded_weight(
|
984
|
+
param_data,
|
985
|
+
loaded_weight,
|
986
|
+
param_data_start,
|
987
|
+
weight_start,
|
988
|
+
dim,
|
989
|
+
shard_size,
|
990
|
+
narrow_weight=True,
|
991
|
+
):
|
992
|
+
actual_shard_size = get_actual_shard_size(
|
993
|
+
shard_size, weight_start, loaded_weight.size(dim)
|
994
|
+
)
|
995
|
+
|
996
|
+
if narrow_weight:
|
997
|
+
if actual_shard_size > 0:
|
998
|
+
loaded_weight = loaded_weight.narrow(dim, weight_start, actual_shard_size)
|
999
|
+
else:
|
1000
|
+
# No real data to load; create a dummy tensor filled with zeros
|
1001
|
+
loaded_weight = torch.zeros_like(
|
1002
|
+
param_data.narrow(dim, param_data_start, actual_shard_size)
|
1003
|
+
)
|
1004
|
+
|
1005
|
+
# [Note] Reset padded weights to zero.
|
1006
|
+
# If the actual shard size is less than the shard size, we need to reset
|
1007
|
+
# the padded param_data to zero and then copy the loaded_weight into it.
|
1008
|
+
reset_param_data_if_needed(
|
1009
|
+
param_data,
|
1010
|
+
dim,
|
1011
|
+
param_data_start + actual_shard_size,
|
1012
|
+
shard_size - actual_shard_size,
|
1013
|
+
)
|
1014
|
+
|
1015
|
+
param_data = param_data.narrow(dim, param_data_start, actual_shard_size)
|
1016
|
+
|
1017
|
+
return param_data, loaded_weight
|
@@ -21,6 +21,7 @@ from torch import nn
|
|
21
21
|
from transformers import PretrainedConfig
|
22
22
|
|
23
23
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
24
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
24
25
|
from sglang.srt.layers.layernorm import RMSNorm
|
25
26
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
26
27
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -28,9 +29,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
28
29
|
ParallelLMHead,
|
29
30
|
VocabParallelEmbedding,
|
30
31
|
)
|
31
|
-
from sglang.srt.managers.expert_distribution import (
|
32
|
-
get_global_expert_distribution_recorder,
|
33
|
-
)
|
34
32
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
35
33
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
36
34
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -32,7 +32,11 @@ from sglang.srt.distributed import (
|
|
32
32
|
parallel_state,
|
33
33
|
tensor_model_parallel_all_reduce,
|
34
34
|
)
|
35
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
36
|
+
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
37
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
35
38
|
from sglang.srt.layers.activation import SiluAndMul
|
39
|
+
from sglang.srt.layers.amx_utils import PackWeightMethod
|
36
40
|
from sglang.srt.layers.communicator import (
|
37
41
|
LayerCommunicator,
|
38
42
|
LayerScatterModes,
|
@@ -77,11 +81,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
77
81
|
ParallelLMHead,
|
78
82
|
VocabParallelEmbedding,
|
79
83
|
)
|
80
|
-
from sglang.srt.managers.expert_distribution import (
|
81
|
-
get_global_expert_distribution_recorder,
|
82
|
-
)
|
83
|
-
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
84
|
-
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
85
84
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
86
85
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
87
86
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
@@ -93,17 +92,18 @@ from sglang.srt.utils import (
|
|
93
92
|
BumpAllocator,
|
94
93
|
DeepEPMode,
|
95
94
|
LazyValue,
|
96
|
-
PackWeightMethod,
|
97
95
|
add_prefix,
|
98
96
|
bind_or_assign,
|
99
97
|
cpu_has_amx_support,
|
100
98
|
get_bool_env_var,
|
99
|
+
get_device_sm,
|
101
100
|
get_int_env_var,
|
102
101
|
is_cpu,
|
103
102
|
is_cuda,
|
104
103
|
is_hip,
|
105
104
|
is_non_idle_and_non_empty,
|
106
105
|
log_info_on_rank0,
|
106
|
+
use_intel_amx_backend,
|
107
107
|
)
|
108
108
|
|
109
109
|
_is_hip = is_hip()
|
@@ -112,9 +112,16 @@ _is_fp8_fnuz = is_fp8_fnuz()
|
|
112
112
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
113
113
|
_is_cpu_amx_available = cpu_has_amx_support()
|
114
114
|
_is_cpu = is_cpu()
|
115
|
+
_device_sm = get_device_sm()
|
115
116
|
|
116
117
|
if _is_cuda:
|
117
|
-
from sgl_kernel import
|
118
|
+
from sgl_kernel import (
|
119
|
+
awq_dequantize,
|
120
|
+
bmm_fp8,
|
121
|
+
dsv3_fused_a_gemm,
|
122
|
+
dsv3_router_gemm,
|
123
|
+
merge_state_v2,
|
124
|
+
)
|
118
125
|
elif _is_cpu and _is_cpu_amx_available:
|
119
126
|
pass
|
120
127
|
else:
|
@@ -218,7 +225,7 @@ class MoEGate(nn.Module):
|
|
218
225
|
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
219
226
|
|
220
227
|
def forward(self, hidden_states):
|
221
|
-
if
|
228
|
+
if use_intel_amx_backend(self):
|
222
229
|
return torch.ops.sgl_kernel.weight_packed_linear(
|
223
230
|
hidden_states,
|
224
231
|
self.weight,
|
@@ -226,7 +233,19 @@ class MoEGate(nn.Module):
|
|
226
233
|
True, # is_vnni
|
227
234
|
)
|
228
235
|
|
229
|
-
|
236
|
+
if (
|
237
|
+
_is_cuda
|
238
|
+
and hidden_states.shape[0] < 4
|
239
|
+
and hidden_states.shape[1] == 7168
|
240
|
+
and self.weight.shape[0] == 256
|
241
|
+
and _device_sm >= 90
|
242
|
+
):
|
243
|
+
logits = dsv3_router_gemm(hidden_states, self.weight).to(
|
244
|
+
hidden_states.dtype
|
245
|
+
)
|
246
|
+
else:
|
247
|
+
logits = F.linear(hidden_states, self.weight, None)
|
248
|
+
|
230
249
|
return logits
|
231
250
|
|
232
251
|
|
@@ -300,6 +319,9 @@ class DeepseekV2MoE(nn.Module):
|
|
300
319
|
),
|
301
320
|
)
|
302
321
|
|
322
|
+
self.shared_experts_is_int8 = False
|
323
|
+
self.shared_experts_is_fp8 = False
|
324
|
+
self.shared_experts_weight_block_size = None
|
303
325
|
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
304
326
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
305
327
|
# disable tp for shared experts when enable deepep moe
|
@@ -316,6 +338,28 @@ class DeepseekV2MoE(nn.Module):
|
|
316
338
|
else {}
|
317
339
|
),
|
318
340
|
)
|
341
|
+
is_packed_weight = hasattr(
|
342
|
+
self.shared_experts.gate_up_proj.quant_method, "quant_config"
|
343
|
+
) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
|
344
|
+
"awq",
|
345
|
+
"moe_wna16",
|
346
|
+
}
|
347
|
+
self.shared_experts_is_int8 = (
|
348
|
+
not is_packed_weight
|
349
|
+
and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
|
350
|
+
)
|
351
|
+
self.shared_experts_is_fp8 = (
|
352
|
+
not is_packed_weight
|
353
|
+
and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
|
354
|
+
)
|
355
|
+
if self.shared_experts_is_fp8:
|
356
|
+
assert (
|
357
|
+
self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
|
358
|
+
== self.shared_experts.down_proj.quant_method.quant_config.weight_block_size
|
359
|
+
)
|
360
|
+
self.shared_experts_weight_block_size = (
|
361
|
+
self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
|
362
|
+
)
|
319
363
|
|
320
364
|
self.top_k = config.num_experts_per_tok
|
321
365
|
|
@@ -394,6 +438,11 @@ class DeepseekV2MoE(nn.Module):
|
|
394
438
|
return final_hidden_states
|
395
439
|
|
396
440
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
441
|
+
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
442
|
+
self.shared_experts.gate_up_proj
|
443
|
+
):
|
444
|
+
return self.forward_cpu(hidden_states)
|
445
|
+
|
397
446
|
shared_output = self._forward_shared_experts(hidden_states)
|
398
447
|
# router_logits: (num_tokens, n_experts)
|
399
448
|
router_logits = self.gate(hidden_states)
|
@@ -409,6 +458,59 @@ class DeepseekV2MoE(nn.Module):
|
|
409
458
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
410
459
|
return final_hidden_states
|
411
460
|
|
461
|
+
def forward_cpu(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
462
|
+
# router_logits: (num_tokens, n_experts)
|
463
|
+
router_logits = self.gate(hidden_states)
|
464
|
+
fused_experts_out = self.experts(
|
465
|
+
hidden_states=hidden_states, router_logits=router_logits
|
466
|
+
)
|
467
|
+
|
468
|
+
assert use_intel_amx_backend(
|
469
|
+
self.shared_experts.gate_up_proj
|
470
|
+
) == use_intel_amx_backend(self.shared_experts.down_proj)
|
471
|
+
# [Note] inplace should be False in fused_experts.
|
472
|
+
# If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
|
473
|
+
# While hidden_states is still needed in shared_expert.
|
474
|
+
final_hidden_states = torch.ops.sgl_kernel.shared_expert_cpu(
|
475
|
+
hidden_states,
|
476
|
+
self.shared_experts.gate_up_proj.weight,
|
477
|
+
self.shared_experts.down_proj.weight,
|
478
|
+
fused_experts_out,
|
479
|
+
self.routed_scaling_factor,
|
480
|
+
True, # inplace
|
481
|
+
self.shared_experts_is_int8, # use_int8_w8a8
|
482
|
+
self.shared_experts_is_fp8, # use_fp8_w8a16
|
483
|
+
(
|
484
|
+
self.shared_experts.gate_up_proj.weight_scale
|
485
|
+
if self.shared_experts_is_int8
|
486
|
+
else (
|
487
|
+
self.shared_experts.gate_up_proj.weight_scale_inv
|
488
|
+
if self.shared_experts_is_fp8
|
489
|
+
else None
|
490
|
+
)
|
491
|
+
), # w1_scale
|
492
|
+
(
|
493
|
+
self.shared_experts.down_proj.weight_scale
|
494
|
+
if self.shared_experts_is_int8
|
495
|
+
else (
|
496
|
+
self.shared_experts.down_proj.weight_scale_inv
|
497
|
+
if self.shared_experts_is_fp8
|
498
|
+
else None
|
499
|
+
)
|
500
|
+
), # w2_scale
|
501
|
+
(
|
502
|
+
self.shared_experts_weight_block_size
|
503
|
+
if self.shared_experts_is_fp8
|
504
|
+
else None
|
505
|
+
), # block_size
|
506
|
+
None, # a1_scale
|
507
|
+
None, # a2_scale
|
508
|
+
True, # is_vnni
|
509
|
+
)
|
510
|
+
if self.tp_size > 1:
|
511
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
512
|
+
return final_hidden_states
|
513
|
+
|
412
514
|
def forward_deepep(
|
413
515
|
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
414
516
|
) -> torch.Tensor:
|
@@ -456,7 +558,7 @@ class DeepseekV2MoE(nn.Module):
|
|
456
558
|
hidden_states=hidden_states,
|
457
559
|
topk_idx=topk_idx,
|
458
560
|
topk_weights=topk_weights,
|
459
|
-
|
561
|
+
forward_batch=forward_batch,
|
460
562
|
)
|
461
563
|
final_hidden_states = self.experts(
|
462
564
|
hidden_states=hidden_states,
|
@@ -467,14 +569,14 @@ class DeepseekV2MoE(nn.Module):
|
|
467
569
|
masked_m=masked_m,
|
468
570
|
expected_m=expected_m,
|
469
571
|
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
470
|
-
|
572
|
+
forward_batch=forward_batch,
|
471
573
|
)
|
472
574
|
if self.ep_size > 1:
|
473
575
|
final_hidden_states = self.deepep_dispatcher.combine(
|
474
576
|
hidden_states=final_hidden_states,
|
475
577
|
topk_idx=topk_idx,
|
476
578
|
topk_weights=topk_weights,
|
477
|
-
|
579
|
+
forward_batch=forward_batch,
|
478
580
|
)
|
479
581
|
|
480
582
|
if shared_output is not None:
|
@@ -549,7 +651,7 @@ class DeepseekV2MoE(nn.Module):
|
|
549
651
|
hidden_states=state.hidden_states_mlp_input,
|
550
652
|
topk_idx=state.pop("topk_idx_local"),
|
551
653
|
topk_weights=state.pop("topk_weights_local"),
|
552
|
-
|
654
|
+
forward_batch=state.forward_batch,
|
553
655
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
554
656
|
)
|
555
657
|
|
@@ -581,7 +683,7 @@ class DeepseekV2MoE(nn.Module):
|
|
581
683
|
masked_m=state.pop("masked_m"),
|
582
684
|
expected_m=state.pop("expected_m"),
|
583
685
|
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
584
|
-
|
686
|
+
forward_batch=state.forward_batch,
|
585
687
|
)
|
586
688
|
|
587
689
|
def op_combine_a(self, state):
|
@@ -590,7 +692,7 @@ class DeepseekV2MoE(nn.Module):
|
|
590
692
|
hidden_states=state.pop("hidden_states_experts_output"),
|
591
693
|
topk_idx=state.pop("topk_idx_dispatched"),
|
592
694
|
topk_weights=state.pop("topk_weights_dispatched"),
|
593
|
-
|
695
|
+
forward_batch=state.forward_batch,
|
594
696
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
595
697
|
)
|
596
698
|
|
@@ -793,33 +895,56 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
793
895
|
# If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
|
794
896
|
# which requires self.w_kc and self.w_vc to be packed.
|
795
897
|
# If not, we will use torch.bmm and weight shouldn't be packed in this case
|
796
|
-
|
797
|
-
|
798
|
-
and _is_cpu
|
799
|
-
and _is_cpu_amx_available
|
800
|
-
):
|
898
|
+
has_fused_proj = hasattr(self, "fused_qkv_a_proj_with_mqa")
|
899
|
+
if has_fused_proj and _is_cpu and _is_cpu_amx_available:
|
801
900
|
self.quant_method = PackWeightMethod(
|
802
901
|
weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
|
803
902
|
)
|
804
903
|
|
904
|
+
is_packed_weight = (
|
905
|
+
has_fused_proj
|
906
|
+
and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
|
907
|
+
and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
|
908
|
+
in {"awq", "moe_wna16"}
|
909
|
+
)
|
910
|
+
self.use_min_latency_fused_a_gemm = (
|
911
|
+
has_fused_proj
|
912
|
+
and not is_packed_weight
|
913
|
+
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
|
914
|
+
and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
|
915
|
+
and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
|
916
|
+
and _is_cuda
|
917
|
+
and _device_sm >= 90
|
918
|
+
)
|
919
|
+
|
805
920
|
self.qkv_proj_with_rope_is_int8 = (
|
806
|
-
|
921
|
+
has_fused_proj
|
922
|
+
and not is_packed_weight
|
807
923
|
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
|
808
924
|
)
|
809
925
|
self.qkv_proj_with_rope_is_fp8 = (
|
810
|
-
|
926
|
+
has_fused_proj
|
927
|
+
and not is_packed_weight
|
811
928
|
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
|
812
929
|
)
|
813
930
|
|
814
931
|
self.weight_block_size = None
|
815
|
-
if self.qkv_proj_with_rope_is_fp8:
|
816
|
-
assert (
|
817
|
-
self.fused_qkv_a_proj_with_mqa.quant_method
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
932
|
+
if self.qkv_proj_with_rope_is_fp8 and _is_cpu and _is_cpu_amx_available:
|
933
|
+
assert getattr(
|
934
|
+
self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
|
935
|
+
) == getattr(self.q_b_proj.quant_method, "block_quant", False)
|
936
|
+
use_block_quant = getattr(
|
937
|
+
self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
|
938
|
+
)
|
939
|
+
|
940
|
+
if use_block_quant:
|
941
|
+
assert (
|
942
|
+
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
943
|
+
== self.q_b_proj.quant_method.quant_config.weight_block_size
|
944
|
+
)
|
945
|
+
self.weight_block_size = (
|
946
|
+
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
947
|
+
)
|
823
948
|
|
824
949
|
def dispatch_attn_forward_method(
|
825
950
|
self, forward_batch: ForwardBatch
|
@@ -834,14 +959,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
834
959
|
else:
|
835
960
|
return AttnForwardMethod.MLA
|
836
961
|
else:
|
837
|
-
if hasattr(self, "fused_qkv_a_proj_with_mqa") and
|
838
|
-
self
|
962
|
+
if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
|
963
|
+
self
|
839
964
|
):
|
840
965
|
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
841
966
|
else:
|
842
967
|
return AttnForwardMethod.MLA
|
843
968
|
|
844
|
-
if self.attention_backend == "
|
969
|
+
if self.attention_backend == "ascend":
|
970
|
+
return AttnForwardMethod.MLA
|
971
|
+
elif self.attention_backend == "flashinfer":
|
845
972
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
846
973
|
if (
|
847
974
|
not self.flashinfer_mla_disable_ragged
|
@@ -1041,7 +1168,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1041
1168
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
1042
1169
|
|
1043
1170
|
if self.q_lora_rank is not None:
|
1044
|
-
|
1171
|
+
if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
|
1172
|
+
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
1173
|
+
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
1174
|
+
)
|
1175
|
+
else:
|
1176
|
+
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
|
1177
|
+
q, latent_cache = fused_qkv_a_proj_out.split(
|
1045
1178
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
1046
1179
|
)
|
1047
1180
|
k_nope = latent_cache[..., : self.kv_lora_rank]
|
@@ -1302,8 +1435,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1302
1435
|
forward_batch: ForwardBatch,
|
1303
1436
|
zero_allocator: BumpAllocator,
|
1304
1437
|
):
|
1305
|
-
assert self.q_lora_rank is not None and
|
1306
|
-
self
|
1438
|
+
assert self.q_lora_rank is not None and use_intel_amx_backend(
|
1439
|
+
self
|
1307
1440
|
), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
|
1308
1441
|
|
1309
1442
|
q_input, k_input, v_input = (
|
@@ -1422,8 +1555,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1422
1555
|
def forward_absorb_fused_mla_rope_cpu_core(
|
1423
1556
|
self, q_input, k_input, v_input, forward_batch, zero_allocator
|
1424
1557
|
):
|
1425
|
-
assert self.q_lora_rank is not None and
|
1426
|
-
self
|
1558
|
+
assert self.q_lora_rank is not None and use_intel_amx_backend(
|
1559
|
+
self
|
1427
1560
|
), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
|
1428
1561
|
|
1429
1562
|
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
@@ -1707,11 +1840,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1707
1840
|
hidden_states, residual, forward_batch
|
1708
1841
|
)
|
1709
1842
|
|
1710
|
-
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
|
1711
|
-
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
|
1712
|
-
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
|
1713
|
-
hidden_states = hidden_states.clone()
|
1714
|
-
|
1715
1843
|
return hidden_states, residual
|
1716
1844
|
|
1717
1845
|
def op_comm_prepare_attn(
|
@@ -1753,7 +1881,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1753
1881
|
and hidden_states.shape[0] == 0
|
1754
1882
|
):
|
1755
1883
|
state.hidden_states_mlp_output = self.mlp(
|
1756
|
-
hidden_states, state.forward_batch
|
1884
|
+
hidden_states, state.forward_batch
|
1757
1885
|
)
|
1758
1886
|
else:
|
1759
1887
|
state.hidden_states_mlp_output = hidden_states
|
@@ -1802,7 +1930,7 @@ class DeepseekV2Model(nn.Module):
|
|
1802
1930
|
self.embed_tokens = VocabParallelEmbedding(
|
1803
1931
|
config.vocab_size,
|
1804
1932
|
config.hidden_size,
|
1805
|
-
|
1933
|
+
use_attn_tp_group=True,
|
1806
1934
|
)
|
1807
1935
|
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
1808
1936
|
self.layers = nn.ModuleList(
|
@@ -2107,6 +2235,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2107
2235
|
)
|
2108
2236
|
if _is_hip:
|
2109
2237
|
self_attn.w_scale *= 2.0
|
2238
|
+
# TODO: remove this after adding FP8 support in bmm cpu kernel
|
2239
|
+
if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn:
|
2240
|
+
self_attn.w_kc = (
|
2241
|
+
self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale
|
2242
|
+
)
|
2243
|
+
self_attn.w_vc = (
|
2244
|
+
self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale
|
2245
|
+
)
|
2110
2246
|
else:
|
2111
2247
|
num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
|
2112
2248
|
num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
|
@@ -253,11 +253,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|
253
253
|
weights_loader = getattr(param, "weight_loader", default_weight_loader)
|
254
254
|
weights_loader(param, loaded_weight)
|
255
255
|
|
256
|
-
def pad_input_ids(self, input_ids: List[int],
|
257
|
-
|
258
|
-
|
259
|
-
)
|
260
|
-
return helper.pad_input_tokens(input_ids, image_inputs)
|
256
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
257
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
258
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
261
259
|
|
262
260
|
def get_image_feature(self, items: List[MultimodalDataItem]):
|
263
261
|
|
@@ -166,8 +166,7 @@ class Gemma3Attention(nn.Module):
|
|
166
166
|
prefix=add_prefix("o_proj", prefix),
|
167
167
|
)
|
168
168
|
|
169
|
-
|
170
|
-
self.is_sliding = bool((layer_id + 1) % config.sliding_window_pattern)
|
169
|
+
self.is_sliding = config.layer_types[layer_id] == "sliding_attention"
|
171
170
|
|
172
171
|
# Initialize the rotary embedding.
|
173
172
|
if self.is_sliding:
|
@@ -62,7 +62,7 @@ class Gemma3nTextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
|
|
62
62
|
pass
|
63
63
|
|
64
64
|
|
65
|
-
class
|
65
|
+
class Gemma3nTextMLP(nn.Module):
|
66
66
|
def __init__(
|
67
67
|
self,
|
68
68
|
hidden_size: int,
|
@@ -514,10 +514,11 @@ class Gemma3nDecoderLayer(nn.Module):
|
|
514
514
|
prefix=add_prefix("self_attn", prefix),
|
515
515
|
)
|
516
516
|
|
517
|
+
intermediate_size = config.intermediate_size[layer_id]
|
517
518
|
activation_sparsity = config.activation_sparsity_pattern[layer_id]
|
518
|
-
self.mlp =
|
519
|
+
self.mlp = Gemma3nTextMLP(
|
519
520
|
hidden_size=self.hidden_size,
|
520
|
-
intermediate_size=
|
521
|
+
intermediate_size=intermediate_size,
|
521
522
|
hidden_activation=config.hidden_activation,
|
522
523
|
activation_sparsity=activation_sparsity,
|
523
524
|
quant_config=quant_config,
|
sglang/srt/models/gemma3n_mm.py
CHANGED
@@ -21,7 +21,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
21
21
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
22
22
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
23
23
|
from sglang.srt.managers.mm_utils import (
|
24
|
-
|
24
|
+
MultiModalityDataPaddingPatternMultimodalTokens,
|
25
25
|
general_mm_embed_routine,
|
26
26
|
)
|
27
27
|
from sglang.srt.managers.schedule_batch import (
|
@@ -244,26 +244,11 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|
244
244
|
def pad_input_ids(
|
245
245
|
self,
|
246
246
|
input_ids: List[int],
|
247
|
-
mm_inputs:
|
247
|
+
mm_inputs: MultimodalInputs,
|
248
248
|
) -> List[int]:
|
249
249
|
"""Pad input IDs with image and audio tokens."""
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
# Collect available media token pairs
|
254
|
-
media_token_pairs = []
|
255
|
-
for attr_name in ["im_start_id", "audio_start_id"]:
|
256
|
-
if hasattr(mm_inputs, attr_name):
|
257
|
-
start_id = getattr(mm_inputs, attr_name)
|
258
|
-
end_id = getattr(mm_inputs, attr_name.replace("start", "end"))
|
259
|
-
media_token_pairs.append((start_id, end_id))
|
260
|
-
|
261
|
-
# Apply padding pattern if we have media tokens
|
262
|
-
if media_token_pairs:
|
263
|
-
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
264
|
-
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
265
|
-
|
266
|
-
return input_ids
|
250
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
251
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
267
252
|
|
268
253
|
def get_input_embeddings(self) -> nn.Embedding:
|
269
254
|
return self.language_model.get_input_embeddings()
|
@@ -431,7 +416,6 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|
431
416
|
)
|
432
417
|
|
433
418
|
positions += 1
|
434
|
-
|
435
419
|
if input_ids is not None:
|
436
420
|
# Prepare per-layer inputs from inputs_ids
|
437
421
|
per_layer_inputs_mask = torch.logical_and(
|
sglang/srt/models/hunyuan.py
CHANGED
@@ -28,6 +28,7 @@ from sglang.srt.distributed import (
|
|
28
28
|
get_tensor_model_parallel_world_size,
|
29
29
|
tensor_model_parallel_all_reduce,
|
30
30
|
)
|
31
|
+
from sglang.srt.eplb.expert_distribution import ExpertDistributionRecorder
|
31
32
|
from sglang.srt.layers.activation import SiluAndMul
|
32
33
|
from sglang.srt.layers.layernorm import RMSNorm
|
33
34
|
from sglang.srt.layers.linear import (
|
@@ -48,7 +49,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
48
49
|
ParallelLMHead,
|
49
50
|
VocabParallelEmbedding,
|
50
51
|
)
|
51
|
-
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
52
52
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
53
53
|
from sglang.srt.model_loader.weight_utils import (
|
54
54
|
default_weight_loader,
|
sglang/srt/models/kimi_vl.py
CHANGED
@@ -154,8 +154,7 @@ class KimiVLForConditionalGeneration(nn.Module):
|
|
154
154
|
return res
|
155
155
|
|
156
156
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
157
|
-
|
158
|
-
pattern = MultiModalityDataPaddingPatternMultimodalTokens(mm_inputs.im_token_id)
|
157
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
159
158
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
160
159
|
|
161
160
|
def forward(
|