sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -32,7 +32,11 @@ from sglang.srt.distributed import (
|
|
32
32
|
parallel_state,
|
33
33
|
tensor_model_parallel_all_reduce,
|
34
34
|
)
|
35
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
36
|
+
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
37
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
35
38
|
from sglang.srt.layers.activation import SiluAndMul
|
39
|
+
from sglang.srt.layers.amx_utils import PackWeightMethod
|
36
40
|
from sglang.srt.layers.communicator import (
|
37
41
|
LayerCommunicator,
|
38
42
|
LayerScatterModes,
|
@@ -77,11 +81,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
77
81
|
ParallelLMHead,
|
78
82
|
VocabParallelEmbedding,
|
79
83
|
)
|
80
|
-
from sglang.srt.managers.expert_distribution import (
|
81
|
-
get_global_expert_distribution_recorder,
|
82
|
-
)
|
83
|
-
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
84
|
-
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
85
84
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
86
85
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
87
86
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
@@ -97,12 +96,14 @@ from sglang.srt.utils import (
|
|
97
96
|
bind_or_assign,
|
98
97
|
cpu_has_amx_support,
|
99
98
|
get_bool_env_var,
|
99
|
+
get_device_sm,
|
100
100
|
get_int_env_var,
|
101
101
|
is_cpu,
|
102
102
|
is_cuda,
|
103
103
|
is_hip,
|
104
104
|
is_non_idle_and_non_empty,
|
105
105
|
log_info_on_rank0,
|
106
|
+
use_intel_amx_backend,
|
106
107
|
)
|
107
108
|
|
108
109
|
_is_hip = is_hip()
|
@@ -111,9 +112,16 @@ _is_fp8_fnuz = is_fp8_fnuz()
|
|
111
112
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
112
113
|
_is_cpu_amx_available = cpu_has_amx_support()
|
113
114
|
_is_cpu = is_cpu()
|
115
|
+
_device_sm = get_device_sm()
|
114
116
|
|
115
117
|
if _is_cuda:
|
116
|
-
from sgl_kernel import
|
118
|
+
from sgl_kernel import (
|
119
|
+
awq_dequantize,
|
120
|
+
bmm_fp8,
|
121
|
+
dsv3_fused_a_gemm,
|
122
|
+
dsv3_router_gemm,
|
123
|
+
merge_state_v2,
|
124
|
+
)
|
117
125
|
elif _is_cpu and _is_cpu_amx_available:
|
118
126
|
pass
|
119
127
|
else:
|
@@ -124,8 +132,6 @@ if _is_hip:
|
|
124
132
|
decode_attention_fwd_grouped_rope,
|
125
133
|
)
|
126
134
|
|
127
|
-
if _use_aiter:
|
128
|
-
from aiter.rotary_embedding import get_rope
|
129
135
|
|
130
136
|
logger = logging.getLogger(__name__)
|
131
137
|
|
@@ -144,6 +150,9 @@ class AttnForwardMethod(IntEnum):
|
|
144
150
|
# Use MLA but with fused RoPE
|
145
151
|
MLA_FUSED_ROPE = auto()
|
146
152
|
|
153
|
+
# Use MLA with fused RoPE kernel for CPU
|
154
|
+
MLA_FUSED_ROPE_CPU = auto()
|
155
|
+
|
147
156
|
|
148
157
|
class DeepseekV2MLP(nn.Module):
|
149
158
|
def __init__(
|
@@ -212,9 +221,31 @@ class MoEGate(nn.Module):
|
|
212
221
|
)
|
213
222
|
else:
|
214
223
|
self.e_score_correction_bias = None
|
224
|
+
if _is_cpu and _is_cpu_amx_available:
|
225
|
+
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
215
226
|
|
216
227
|
def forward(self, hidden_states):
|
217
|
-
|
228
|
+
if use_intel_amx_backend(self):
|
229
|
+
return torch.ops.sgl_kernel.weight_packed_linear(
|
230
|
+
hidden_states,
|
231
|
+
self.weight,
|
232
|
+
None, # bias
|
233
|
+
True, # is_vnni
|
234
|
+
)
|
235
|
+
|
236
|
+
if (
|
237
|
+
_is_cuda
|
238
|
+
and hidden_states.shape[0] < 4
|
239
|
+
and hidden_states.shape[1] == 7168
|
240
|
+
and self.weight.shape[0] == 256
|
241
|
+
and _device_sm >= 90
|
242
|
+
):
|
243
|
+
logits = dsv3_router_gemm(hidden_states, self.weight).to(
|
244
|
+
hidden_states.dtype
|
245
|
+
)
|
246
|
+
else:
|
247
|
+
logits = F.linear(hidden_states, self.weight, None)
|
248
|
+
|
218
249
|
return logits
|
219
250
|
|
220
251
|
|
@@ -288,6 +319,9 @@ class DeepseekV2MoE(nn.Module):
|
|
288
319
|
),
|
289
320
|
)
|
290
321
|
|
322
|
+
self.shared_experts_is_int8 = False
|
323
|
+
self.shared_experts_is_fp8 = False
|
324
|
+
self.shared_experts_weight_block_size = None
|
291
325
|
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
292
326
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
293
327
|
# disable tp for shared experts when enable deepep moe
|
@@ -304,6 +338,28 @@ class DeepseekV2MoE(nn.Module):
|
|
304
338
|
else {}
|
305
339
|
),
|
306
340
|
)
|
341
|
+
is_packed_weight = hasattr(
|
342
|
+
self.shared_experts.gate_up_proj.quant_method, "quant_config"
|
343
|
+
) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
|
344
|
+
"awq",
|
345
|
+
"moe_wna16",
|
346
|
+
}
|
347
|
+
self.shared_experts_is_int8 = (
|
348
|
+
not is_packed_weight
|
349
|
+
and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
|
350
|
+
)
|
351
|
+
self.shared_experts_is_fp8 = (
|
352
|
+
not is_packed_weight
|
353
|
+
and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
|
354
|
+
)
|
355
|
+
if self.shared_experts_is_fp8:
|
356
|
+
assert (
|
357
|
+
self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
|
358
|
+
== self.shared_experts.down_proj.quant_method.quant_config.weight_block_size
|
359
|
+
)
|
360
|
+
self.shared_experts_weight_block_size = (
|
361
|
+
self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
|
362
|
+
)
|
307
363
|
|
308
364
|
self.top_k = config.num_experts_per_tok
|
309
365
|
|
@@ -382,13 +438,19 @@ class DeepseekV2MoE(nn.Module):
|
|
382
438
|
return final_hidden_states
|
383
439
|
|
384
440
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
441
|
+
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
442
|
+
self.shared_experts.gate_up_proj
|
443
|
+
):
|
444
|
+
return self.forward_cpu(hidden_states)
|
445
|
+
|
385
446
|
shared_output = self._forward_shared_experts(hidden_states)
|
386
447
|
# router_logits: (num_tokens, n_experts)
|
387
448
|
router_logits = self.gate(hidden_states)
|
388
449
|
final_hidden_states = self.experts(
|
389
450
|
hidden_states=hidden_states, router_logits=router_logits
|
390
451
|
)
|
391
|
-
if not _is_cuda:
|
452
|
+
if not _is_cuda and not _use_aiter:
|
453
|
+
# fused in biased_grouped_topk so we can skip here
|
392
454
|
final_hidden_states *= self.routed_scaling_factor
|
393
455
|
if shared_output is not None:
|
394
456
|
final_hidden_states = final_hidden_states + shared_output
|
@@ -396,6 +458,59 @@ class DeepseekV2MoE(nn.Module):
|
|
396
458
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
397
459
|
return final_hidden_states
|
398
460
|
|
461
|
+
def forward_cpu(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
462
|
+
# router_logits: (num_tokens, n_experts)
|
463
|
+
router_logits = self.gate(hidden_states)
|
464
|
+
fused_experts_out = self.experts(
|
465
|
+
hidden_states=hidden_states, router_logits=router_logits
|
466
|
+
)
|
467
|
+
|
468
|
+
assert use_intel_amx_backend(
|
469
|
+
self.shared_experts.gate_up_proj
|
470
|
+
) == use_intel_amx_backend(self.shared_experts.down_proj)
|
471
|
+
# [Note] inplace should be False in fused_experts.
|
472
|
+
# If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
|
473
|
+
# While hidden_states is still needed in shared_expert.
|
474
|
+
final_hidden_states = torch.ops.sgl_kernel.shared_expert_cpu(
|
475
|
+
hidden_states,
|
476
|
+
self.shared_experts.gate_up_proj.weight,
|
477
|
+
self.shared_experts.down_proj.weight,
|
478
|
+
fused_experts_out,
|
479
|
+
self.routed_scaling_factor,
|
480
|
+
True, # inplace
|
481
|
+
self.shared_experts_is_int8, # use_int8_w8a8
|
482
|
+
self.shared_experts_is_fp8, # use_fp8_w8a16
|
483
|
+
(
|
484
|
+
self.shared_experts.gate_up_proj.weight_scale
|
485
|
+
if self.shared_experts_is_int8
|
486
|
+
else (
|
487
|
+
self.shared_experts.gate_up_proj.weight_scale_inv
|
488
|
+
if self.shared_experts_is_fp8
|
489
|
+
else None
|
490
|
+
)
|
491
|
+
), # w1_scale
|
492
|
+
(
|
493
|
+
self.shared_experts.down_proj.weight_scale
|
494
|
+
if self.shared_experts_is_int8
|
495
|
+
else (
|
496
|
+
self.shared_experts.down_proj.weight_scale_inv
|
497
|
+
if self.shared_experts_is_fp8
|
498
|
+
else None
|
499
|
+
)
|
500
|
+
), # w2_scale
|
501
|
+
(
|
502
|
+
self.shared_experts_weight_block_size
|
503
|
+
if self.shared_experts_is_fp8
|
504
|
+
else None
|
505
|
+
), # block_size
|
506
|
+
None, # a1_scale
|
507
|
+
None, # a2_scale
|
508
|
+
True, # is_vnni
|
509
|
+
)
|
510
|
+
if self.tp_size > 1:
|
511
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
512
|
+
return final_hidden_states
|
513
|
+
|
399
514
|
def forward_deepep(
|
400
515
|
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
401
516
|
) -> torch.Tensor:
|
@@ -443,7 +558,7 @@ class DeepseekV2MoE(nn.Module):
|
|
443
558
|
hidden_states=hidden_states,
|
444
559
|
topk_idx=topk_idx,
|
445
560
|
topk_weights=topk_weights,
|
446
|
-
|
561
|
+
forward_batch=forward_batch,
|
447
562
|
)
|
448
563
|
final_hidden_states = self.experts(
|
449
564
|
hidden_states=hidden_states,
|
@@ -454,14 +569,14 @@ class DeepseekV2MoE(nn.Module):
|
|
454
569
|
masked_m=masked_m,
|
455
570
|
expected_m=expected_m,
|
456
571
|
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
457
|
-
|
572
|
+
forward_batch=forward_batch,
|
458
573
|
)
|
459
574
|
if self.ep_size > 1:
|
460
575
|
final_hidden_states = self.deepep_dispatcher.combine(
|
461
576
|
hidden_states=final_hidden_states,
|
462
577
|
topk_idx=topk_idx,
|
463
578
|
topk_weights=topk_weights,
|
464
|
-
|
579
|
+
forward_batch=forward_batch,
|
465
580
|
)
|
466
581
|
|
467
582
|
if shared_output is not None:
|
@@ -536,7 +651,7 @@ class DeepseekV2MoE(nn.Module):
|
|
536
651
|
hidden_states=state.hidden_states_mlp_input,
|
537
652
|
topk_idx=state.pop("topk_idx_local"),
|
538
653
|
topk_weights=state.pop("topk_weights_local"),
|
539
|
-
|
654
|
+
forward_batch=state.forward_batch,
|
540
655
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
541
656
|
)
|
542
657
|
|
@@ -568,7 +683,7 @@ class DeepseekV2MoE(nn.Module):
|
|
568
683
|
masked_m=state.pop("masked_m"),
|
569
684
|
expected_m=state.pop("expected_m"),
|
570
685
|
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
571
|
-
|
686
|
+
forward_batch=state.forward_batch,
|
572
687
|
)
|
573
688
|
|
574
689
|
def op_combine_a(self, state):
|
@@ -577,7 +692,7 @@ class DeepseekV2MoE(nn.Module):
|
|
577
692
|
hidden_states=state.pop("hidden_states_experts_output"),
|
578
693
|
topk_idx=state.pop("topk_idx_dispatched"),
|
579
694
|
topk_weights=state.pop("topk_weights_dispatched"),
|
580
|
-
|
695
|
+
forward_batch=state.forward_batch,
|
581
696
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
582
697
|
)
|
583
698
|
|
@@ -777,6 +892,60 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
777
892
|
"SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
|
778
893
|
)
|
779
894
|
|
895
|
+
# If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
|
896
|
+
# which requires self.w_kc and self.w_vc to be packed.
|
897
|
+
# If not, we will use torch.bmm and weight shouldn't be packed in this case
|
898
|
+
has_fused_proj = hasattr(self, "fused_qkv_a_proj_with_mqa")
|
899
|
+
if has_fused_proj and _is_cpu and _is_cpu_amx_available:
|
900
|
+
self.quant_method = PackWeightMethod(
|
901
|
+
weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
|
902
|
+
)
|
903
|
+
|
904
|
+
is_packed_weight = (
|
905
|
+
has_fused_proj
|
906
|
+
and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
|
907
|
+
and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
|
908
|
+
in {"awq", "moe_wna16"}
|
909
|
+
)
|
910
|
+
self.use_min_latency_fused_a_gemm = (
|
911
|
+
has_fused_proj
|
912
|
+
and not is_packed_weight
|
913
|
+
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
|
914
|
+
and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
|
915
|
+
and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
|
916
|
+
and _is_cuda
|
917
|
+
and _device_sm >= 90
|
918
|
+
)
|
919
|
+
|
920
|
+
self.qkv_proj_with_rope_is_int8 = (
|
921
|
+
has_fused_proj
|
922
|
+
and not is_packed_weight
|
923
|
+
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
|
924
|
+
)
|
925
|
+
self.qkv_proj_with_rope_is_fp8 = (
|
926
|
+
has_fused_proj
|
927
|
+
and not is_packed_weight
|
928
|
+
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
|
929
|
+
)
|
930
|
+
|
931
|
+
self.weight_block_size = None
|
932
|
+
if self.qkv_proj_with_rope_is_fp8 and _is_cpu and _is_cpu_amx_available:
|
933
|
+
assert getattr(
|
934
|
+
self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
|
935
|
+
) == getattr(self.q_b_proj.quant_method, "block_quant", False)
|
936
|
+
use_block_quant = getattr(
|
937
|
+
self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
|
938
|
+
)
|
939
|
+
|
940
|
+
if use_block_quant:
|
941
|
+
assert (
|
942
|
+
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
943
|
+
== self.q_b_proj.quant_method.quant_config.weight_block_size
|
944
|
+
)
|
945
|
+
self.weight_block_size = (
|
946
|
+
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
947
|
+
)
|
948
|
+
|
780
949
|
def dispatch_attn_forward_method(
|
781
950
|
self, forward_batch: ForwardBatch
|
782
951
|
) -> AttnForwardMethod:
|
@@ -790,9 +959,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
790
959
|
else:
|
791
960
|
return AttnForwardMethod.MLA
|
792
961
|
else:
|
793
|
-
|
962
|
+
if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
|
963
|
+
self
|
964
|
+
):
|
965
|
+
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
966
|
+
else:
|
967
|
+
return AttnForwardMethod.MLA
|
794
968
|
|
795
|
-
if self.attention_backend == "
|
969
|
+
if self.attention_backend == "ascend":
|
970
|
+
return AttnForwardMethod.MLA
|
971
|
+
elif self.attention_backend == "flashinfer":
|
796
972
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
797
973
|
if (
|
798
974
|
not self.flashinfer_mla_disable_ragged
|
@@ -904,6 +1080,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
904
1080
|
inner_state = self.forward_absorb_fused_mla_rope_prepare(
|
905
1081
|
positions, hidden_states, forward_batch, zero_allocator
|
906
1082
|
)
|
1083
|
+
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
|
1084
|
+
inner_state = self.forward_absorb_fused_mla_rope_cpu_prepare(
|
1085
|
+
positions, hidden_states, forward_batch, zero_allocator
|
1086
|
+
)
|
907
1087
|
else:
|
908
1088
|
raise NotImplementedError
|
909
1089
|
return None, attn_forward_method, forward_batch, inner_state
|
@@ -923,6 +1103,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
923
1103
|
return self.forward_absorb_core(*inner_state)
|
924
1104
|
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
925
1105
|
return self.forward_absorb_fused_mla_rope_core(*inner_state)
|
1106
|
+
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
|
1107
|
+
return self.forward_absorb_fused_mla_rope_cpu_core(*inner_state)
|
926
1108
|
else:
|
927
1109
|
raise NotImplementedError
|
928
1110
|
|
@@ -986,7 +1168,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
986
1168
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
987
1169
|
|
988
1170
|
if self.q_lora_rank is not None:
|
989
|
-
|
1171
|
+
if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
|
1172
|
+
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
1173
|
+
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
1174
|
+
)
|
1175
|
+
else:
|
1176
|
+
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
|
1177
|
+
q, latent_cache = fused_qkv_a_proj_out.split(
|
990
1178
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
991
1179
|
)
|
992
1180
|
k_nope = latent_cache[..., : self.kv_lora_rank]
|
@@ -1240,6 +1428,57 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1240
1428
|
zero_allocator,
|
1241
1429
|
)
|
1242
1430
|
|
1431
|
+
def forward_absorb_fused_mla_rope_cpu_prepare(
|
1432
|
+
self,
|
1433
|
+
positions: torch.Tensor,
|
1434
|
+
hidden_states: torch.Tensor,
|
1435
|
+
forward_batch: ForwardBatch,
|
1436
|
+
zero_allocator: BumpAllocator,
|
1437
|
+
):
|
1438
|
+
assert self.q_lora_rank is not None and use_intel_amx_backend(
|
1439
|
+
self
|
1440
|
+
), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
|
1441
|
+
|
1442
|
+
q_input, k_input, v_input = (
|
1443
|
+
torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight(
|
1444
|
+
hidden_states,
|
1445
|
+
self.fused_qkv_a_proj_with_mqa.weight,
|
1446
|
+
self.q_b_proj.weight,
|
1447
|
+
self.w_kc,
|
1448
|
+
self.q_a_layernorm.weight,
|
1449
|
+
self.kv_a_layernorm.weight,
|
1450
|
+
positions,
|
1451
|
+
self.rotary_emb.cos_sin_cache,
|
1452
|
+
self.kv_a_layernorm.variance_epsilon,
|
1453
|
+
self.qkv_proj_with_rope_is_int8,
|
1454
|
+
self.qkv_proj_with_rope_is_fp8,
|
1455
|
+
(
|
1456
|
+
self.fused_qkv_a_proj_with_mqa.weight_scale
|
1457
|
+
if self.qkv_proj_with_rope_is_int8
|
1458
|
+
else (
|
1459
|
+
self.fused_qkv_a_proj_with_mqa.weight_scale_inv
|
1460
|
+
if self.qkv_proj_with_rope_is_fp8
|
1461
|
+
else None
|
1462
|
+
)
|
1463
|
+
),
|
1464
|
+
(
|
1465
|
+
self.q_b_proj.weight_scale
|
1466
|
+
if self.qkv_proj_with_rope_is_int8
|
1467
|
+
else (
|
1468
|
+
self.q_b_proj.weight_scale_inv
|
1469
|
+
if self.qkv_proj_with_rope_is_fp8
|
1470
|
+
else None
|
1471
|
+
)
|
1472
|
+
),
|
1473
|
+
True, # is_vnni
|
1474
|
+
self.weight_block_size,
|
1475
|
+
self.q_lora_rank,
|
1476
|
+
self.kv_lora_rank,
|
1477
|
+
self.qk_rope_head_dim,
|
1478
|
+
)
|
1479
|
+
)
|
1480
|
+
return (q_input, k_input, v_input, forward_batch, zero_allocator)
|
1481
|
+
|
1243
1482
|
def forward_absorb_fused_mla_rope_core(
|
1244
1483
|
self,
|
1245
1484
|
q_input,
|
@@ -1313,6 +1552,43 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1313
1552
|
|
1314
1553
|
return output
|
1315
1554
|
|
1555
|
+
def forward_absorb_fused_mla_rope_cpu_core(
|
1556
|
+
self, q_input, k_input, v_input, forward_batch, zero_allocator
|
1557
|
+
):
|
1558
|
+
assert self.q_lora_rank is not None and use_intel_amx_backend(
|
1559
|
+
self
|
1560
|
+
), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
|
1561
|
+
|
1562
|
+
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
1563
|
+
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
1564
|
+
|
1565
|
+
# [Note] Align shapes of bmm inputs.
|
1566
|
+
# Shapes of inputs:
|
1567
|
+
# q_nope: [M, B, K]
|
1568
|
+
# original self.w_kc: [B, K, N]
|
1569
|
+
# current self.w_kc (which has been converted in PackWeightMethod): [B, N, K]
|
1570
|
+
|
1571
|
+
# Shapes of inputs to sgl_kernel.cpu.bmm:
|
1572
|
+
# out: [B, M, N]
|
1573
|
+
# mat1: [B, M, K]
|
1574
|
+
# mat2: [B, N, K]
|
1575
|
+
B = self.w_vc.size(0)
|
1576
|
+
N = self.w_vc.size(1)
|
1577
|
+
M = attn_output.size(0)
|
1578
|
+
output = torch.empty([M, int(B * N)], dtype=attn_output.dtype)
|
1579
|
+
attn_bmm_output = output.view([M, B, N]).transpose_(0, 1)
|
1580
|
+
torch.ops.sgl_kernel.bmm_cpu(
|
1581
|
+
attn_bmm_output,
|
1582
|
+
attn_output.transpose(0, 1),
|
1583
|
+
self.w_vc,
|
1584
|
+
True, # is_vnni
|
1585
|
+
None, # scale
|
1586
|
+
)
|
1587
|
+
attn_output = output
|
1588
|
+
output, _ = self.o_proj(attn_output)
|
1589
|
+
|
1590
|
+
return output
|
1591
|
+
|
1316
1592
|
def _chunked_prefix_attn_mha(
|
1317
1593
|
self,
|
1318
1594
|
q: torch.Tensor,
|
@@ -1564,11 +1840,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1564
1840
|
hidden_states, residual, forward_batch
|
1565
1841
|
)
|
1566
1842
|
|
1567
|
-
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
|
1568
|
-
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
|
1569
|
-
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
|
1570
|
-
hidden_states = hidden_states.clone()
|
1571
|
-
|
1572
1843
|
return hidden_states, residual
|
1573
1844
|
|
1574
1845
|
def op_comm_prepare_attn(
|
@@ -1610,7 +1881,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1610
1881
|
and hidden_states.shape[0] == 0
|
1611
1882
|
):
|
1612
1883
|
state.hidden_states_mlp_output = self.mlp(
|
1613
|
-
hidden_states, state.forward_batch
|
1884
|
+
hidden_states, state.forward_batch
|
1614
1885
|
)
|
1615
1886
|
else:
|
1616
1887
|
state.hidden_states_mlp_output = hidden_states
|
@@ -1659,7 +1930,7 @@ class DeepseekV2Model(nn.Module):
|
|
1659
1930
|
self.embed_tokens = VocabParallelEmbedding(
|
1660
1931
|
config.vocab_size,
|
1661
1932
|
config.hidden_size,
|
1662
|
-
|
1933
|
+
use_attn_tp_group=True,
|
1663
1934
|
)
|
1664
1935
|
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
1665
1936
|
self.layers = nn.ModuleList(
|
@@ -1964,6 +2235,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1964
2235
|
)
|
1965
2236
|
if _is_hip:
|
1966
2237
|
self_attn.w_scale *= 2.0
|
2238
|
+
# TODO: remove this after adding FP8 support in bmm cpu kernel
|
2239
|
+
if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn:
|
2240
|
+
self_attn.w_kc = (
|
2241
|
+
self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale
|
2242
|
+
)
|
2243
|
+
self_attn.w_vc = (
|
2244
|
+
self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale
|
2245
|
+
)
|
1967
2246
|
else:
|
1968
2247
|
num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
|
1969
2248
|
num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
|
@@ -253,11 +253,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|
253
253
|
weights_loader = getattr(param, "weight_loader", default_weight_loader)
|
254
254
|
weights_loader(param, loaded_weight)
|
255
255
|
|
256
|
-
def pad_input_ids(self, input_ids: List[int],
|
257
|
-
|
258
|
-
|
259
|
-
)
|
260
|
-
return helper.pad_input_tokens(input_ids, image_inputs)
|
256
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
257
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
258
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
261
259
|
|
262
260
|
def get_image_feature(self, items: List[MultimodalDataItem]):
|
263
261
|
|
@@ -166,8 +166,7 @@ class Gemma3Attention(nn.Module):
|
|
166
166
|
prefix=add_prefix("o_proj", prefix),
|
167
167
|
)
|
168
168
|
|
169
|
-
|
170
|
-
self.is_sliding = bool((layer_id + 1) % config.sliding_window_pattern)
|
169
|
+
self.is_sliding = config.layer_types[layer_id] == "sliding_attention"
|
171
170
|
|
172
171
|
# Initialize the rotary embedding.
|
173
172
|
if self.is_sliding:
|