sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,202 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.distributed as dist
|
6
|
+
|
7
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
8
|
+
from sglang.srt.utils import is_flashinfer_available
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
_flashinfer_comm = None
|
13
|
+
_workspace_manager = None
|
14
|
+
|
15
|
+
if is_flashinfer_available():
|
16
|
+
try:
|
17
|
+
import flashinfer.comm as comm
|
18
|
+
|
19
|
+
_flashinfer_comm = comm
|
20
|
+
except ImportError:
|
21
|
+
logger.warning(
|
22
|
+
"flashinfer.comm is not available, falling back to standard "
|
23
|
+
"implementation"
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
class FlashInferWorkspaceManager:
|
28
|
+
def __init__(self):
|
29
|
+
self.workspace_tensor = None
|
30
|
+
self.ipc_handles = None
|
31
|
+
self.world_size = None
|
32
|
+
self.rank = None
|
33
|
+
self.initialized = False
|
34
|
+
|
35
|
+
def initialize(
|
36
|
+
self,
|
37
|
+
world_size: int,
|
38
|
+
rank: int,
|
39
|
+
max_token_num: int,
|
40
|
+
hidden_dim: int,
|
41
|
+
group=None,
|
42
|
+
use_fp32_lamport: bool = False,
|
43
|
+
):
|
44
|
+
"""Initialize workspace"""
|
45
|
+
if self.initialized and self.world_size == world_size:
|
46
|
+
return
|
47
|
+
|
48
|
+
if _flashinfer_comm is None:
|
49
|
+
logger.warning(
|
50
|
+
"FlashInfer comm not available, skipping workspace " "initialization"
|
51
|
+
)
|
52
|
+
return
|
53
|
+
|
54
|
+
self.cleanup()
|
55
|
+
|
56
|
+
self.ipc_handles, self.workspace_tensor = (
|
57
|
+
comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
58
|
+
rank,
|
59
|
+
world_size,
|
60
|
+
max_token_num,
|
61
|
+
hidden_dim,
|
62
|
+
group=group,
|
63
|
+
use_fp32_lamport=use_fp32_lamport,
|
64
|
+
)
|
65
|
+
)
|
66
|
+
|
67
|
+
self.world_size = world_size
|
68
|
+
self.rank = rank
|
69
|
+
self.initialized = True
|
70
|
+
|
71
|
+
logger.info(
|
72
|
+
f"FlashInfer workspace initialized for rank {rank}, "
|
73
|
+
f"world_size {world_size}"
|
74
|
+
)
|
75
|
+
|
76
|
+
def cleanup(self):
|
77
|
+
"""Clean up workspace"""
|
78
|
+
if self.initialized and self.ipc_handles is not None:
|
79
|
+
try:
|
80
|
+
_flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
|
81
|
+
self.ipc_handles, group=dist.group.WORLD
|
82
|
+
)
|
83
|
+
except Exception as e:
|
84
|
+
logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
|
85
|
+
finally:
|
86
|
+
self.workspace_tensor = None
|
87
|
+
self.ipc_handles = None
|
88
|
+
self.initialized = False
|
89
|
+
|
90
|
+
|
91
|
+
_workspace_manager = FlashInferWorkspaceManager()
|
92
|
+
|
93
|
+
|
94
|
+
def ensure_workspace_initialized(
|
95
|
+
max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
96
|
+
):
|
97
|
+
"""Ensure workspace is initialized"""
|
98
|
+
if not is_flashinfer_available() or _flashinfer_comm is None:
|
99
|
+
return False
|
100
|
+
|
101
|
+
world_size = get_tensor_model_parallel_world_size()
|
102
|
+
if world_size <= 1:
|
103
|
+
return False
|
104
|
+
|
105
|
+
rank = dist.get_rank()
|
106
|
+
|
107
|
+
if (
|
108
|
+
not _workspace_manager.initialized
|
109
|
+
or _workspace_manager.world_size != world_size
|
110
|
+
):
|
111
|
+
_workspace_manager.initialize(
|
112
|
+
world_size=world_size,
|
113
|
+
rank=rank,
|
114
|
+
max_token_num=max_token_num,
|
115
|
+
hidden_dim=hidden_dim,
|
116
|
+
use_fp32_lamport=use_fp32_lamport,
|
117
|
+
)
|
118
|
+
|
119
|
+
return _workspace_manager.initialized
|
120
|
+
|
121
|
+
|
122
|
+
def flashinfer_allreduce_residual_rmsnorm(
|
123
|
+
input_tensor: torch.Tensor,
|
124
|
+
residual: torch.Tensor,
|
125
|
+
weight: torch.Tensor,
|
126
|
+
eps: float = 1e-6,
|
127
|
+
max_token_num: int = 128,
|
128
|
+
use_oneshot: bool = True,
|
129
|
+
trigger_completion_at_end: bool = False,
|
130
|
+
fp32_acc: bool = False,
|
131
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
132
|
+
"""
|
133
|
+
Use FlashInfer's fused allreduce + residual + RMS norm operation
|
134
|
+
|
135
|
+
Args:
|
136
|
+
input_tensor: Input tensor that needs allreduce
|
137
|
+
residual: Residual tensor
|
138
|
+
weight: RMS norm weight
|
139
|
+
eps: RMS norm epsilon
|
140
|
+
max_token_num: Maximum token number
|
141
|
+
use_oneshot: Whether to use oneshot mode
|
142
|
+
trigger_completion_at_end: Whether to trigger completion at end
|
143
|
+
fp32_acc: Whether to use fp32 precision
|
144
|
+
|
145
|
+
Returns:
|
146
|
+
Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output)
|
147
|
+
"""
|
148
|
+
if not is_flashinfer_available() or _flashinfer_comm is None:
|
149
|
+
logger.debug(
|
150
|
+
"FlashInfer not available, falling back to standard " "implementation"
|
151
|
+
)
|
152
|
+
return None, None
|
153
|
+
|
154
|
+
world_size = get_tensor_model_parallel_world_size()
|
155
|
+
if world_size <= 1:
|
156
|
+
logger.debug("Single GPU, no need for allreduce fusion")
|
157
|
+
return None, None
|
158
|
+
|
159
|
+
if not ensure_workspace_initialized(
|
160
|
+
max_token_num=max_token_num,
|
161
|
+
hidden_dim=input_tensor.shape[-1],
|
162
|
+
use_fp32_lamport=(input_tensor.dtype == torch.float32),
|
163
|
+
):
|
164
|
+
logger.debug("FlashInfer workspace not available")
|
165
|
+
return None, None
|
166
|
+
|
167
|
+
token_num, hidden_dim = input_tensor.shape
|
168
|
+
|
169
|
+
residual_out = torch.empty_like(residual)
|
170
|
+
norm_out = torch.empty_like(input_tensor)
|
171
|
+
|
172
|
+
_flashinfer_comm.trtllm_allreduce_fusion(
|
173
|
+
allreduce_in=input_tensor,
|
174
|
+
world_size=world_size,
|
175
|
+
world_rank=dist.get_rank(),
|
176
|
+
token_num=token_num,
|
177
|
+
hidden_dim=hidden_dim,
|
178
|
+
workspace_ptrs=_workspace_manager.workspace_tensor,
|
179
|
+
launch_with_pdl=True,
|
180
|
+
use_oneshot=use_oneshot,
|
181
|
+
trigger_completion_at_end=trigger_completion_at_end,
|
182
|
+
fp32_acc=fp32_acc,
|
183
|
+
pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm),
|
184
|
+
allreduce_out=None,
|
185
|
+
residual_in=residual,
|
186
|
+
residual_out=residual_out,
|
187
|
+
norm_out=norm_out,
|
188
|
+
quant_out=None,
|
189
|
+
scale_out=None,
|
190
|
+
rms_gamma=weight,
|
191
|
+
rms_eps=eps,
|
192
|
+
scale_factor=None,
|
193
|
+
layout_code=None,
|
194
|
+
)
|
195
|
+
|
196
|
+
return norm_out, residual_out
|
197
|
+
|
198
|
+
|
199
|
+
def cleanup_flashinfer_workspace():
|
200
|
+
global _workspace_manager
|
201
|
+
if _workspace_manager is not None:
|
202
|
+
_workspace_manager.cleanup()
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -163,6 +163,32 @@ class RMSNorm(CustomOp):
|
|
163
163
|
else:
|
164
164
|
return self.forward_native(x, residual)
|
165
165
|
|
166
|
+
def forward_with_allreduce_fusion(
|
167
|
+
self,
|
168
|
+
x: torch.Tensor,
|
169
|
+
residual: Optional[torch.Tensor] = None,
|
170
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
171
|
+
"""
|
172
|
+
Forward method with allreduce fusion, prioritizing flashinfer fused operations
|
173
|
+
"""
|
174
|
+
if residual is not None:
|
175
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
176
|
+
from sglang.srt.layers.flashinfer_comm_fusion import (
|
177
|
+
flashinfer_allreduce_residual_rmsnorm,
|
178
|
+
)
|
179
|
+
|
180
|
+
if get_tensor_model_parallel_world_size() > 1:
|
181
|
+
fused_result = flashinfer_allreduce_residual_rmsnorm(
|
182
|
+
input_tensor=x,
|
183
|
+
residual=residual,
|
184
|
+
weight=self.weight,
|
185
|
+
eps=self.variance_epsilon,
|
186
|
+
)
|
187
|
+
if fused_result[0] is not None:
|
188
|
+
return fused_result
|
189
|
+
|
190
|
+
return self.forward(x, residual)
|
191
|
+
|
166
192
|
|
167
193
|
class GemmaRMSNorm(CustomOp):
|
168
194
|
def __init__(
|
sglang/srt/layers/linear.py
CHANGED
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
|
|
17
17
|
tensor_model_parallel_all_gather,
|
18
18
|
tensor_model_parallel_all_reduce,
|
19
19
|
)
|
20
|
+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
20
21
|
from sglang.srt.layers.parameter import (
|
21
22
|
BasevLLMParameter,
|
22
23
|
BlockQuantScaleParameter,
|
@@ -31,10 +32,10 @@ from sglang.srt.layers.quantization.base_config import (
|
|
31
32
|
QuantizeMethodBase,
|
32
33
|
)
|
33
34
|
from sglang.srt.utils import (
|
34
|
-
_process_weight_after_loading,
|
35
35
|
cpu_has_amx_support,
|
36
36
|
is_cpu,
|
37
37
|
set_weight_attrs,
|
38
|
+
use_intel_amx_backend,
|
38
39
|
)
|
39
40
|
|
40
41
|
logger = logging.getLogger(__name__)
|
@@ -175,7 +176,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|
175
176
|
|
176
177
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
177
178
|
if _is_cpu and _is_cpu_amx_available:
|
178
|
-
|
179
|
+
_amx_process_weight_after_loading(layer, ["weight"])
|
179
180
|
|
180
181
|
def apply(
|
181
182
|
self,
|
@@ -184,7 +185,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|
184
185
|
bias: Optional[torch.Tensor] = None,
|
185
186
|
) -> torch.Tensor:
|
186
187
|
|
187
|
-
if
|
188
|
+
if use_intel_amx_backend(layer):
|
188
189
|
return torch.ops.sgl_kernel.weight_packed_linear(
|
189
190
|
x, layer.weight, bias, True # is_vnni
|
190
191
|
)
|
@@ -425,8 +426,26 @@ class ColumnParallelLinear(LinearBase):
|
|
425
426
|
if output_dim is not None and not use_bitsandbytes_4bit:
|
426
427
|
shard_size = param_data.shape[output_dim]
|
427
428
|
start_idx = self.tp_rank * shard_size
|
428
|
-
|
429
|
-
|
429
|
+
|
430
|
+
if _is_cpu:
|
431
|
+
from sglang.srt.model_loader.weight_utils import (
|
432
|
+
narrow_padded_param_and_loaded_weight,
|
433
|
+
)
|
434
|
+
|
435
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
436
|
+
param_data,
|
437
|
+
loaded_weight,
|
438
|
+
0, # param_data_start
|
439
|
+
start_idx,
|
440
|
+
output_dim,
|
441
|
+
shard_size,
|
442
|
+
not self.use_presharded_weights,
|
443
|
+
)
|
444
|
+
else:
|
445
|
+
if not self.use_presharded_weights:
|
446
|
+
loaded_weight = loaded_weight.narrow(
|
447
|
+
output_dim, start_idx, shard_size
|
448
|
+
)
|
430
449
|
|
431
450
|
# Special case for loading scales off disk, which often do not
|
432
451
|
# have a shape (such as in the case of AutoFP8).
|
@@ -643,10 +662,29 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
643
662
|
|
644
663
|
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
645
664
|
start_idx = self.tp_rank * shard_size
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
665
|
+
|
666
|
+
if _is_cpu:
|
667
|
+
from sglang.srt.model_loader.weight_utils import (
|
668
|
+
narrow_padded_param_and_loaded_weight,
|
669
|
+
)
|
670
|
+
|
671
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
672
|
+
param_data,
|
673
|
+
loaded_weight,
|
674
|
+
0, # param_data_start
|
675
|
+
start_idx,
|
676
|
+
output_dim,
|
677
|
+
shard_size,
|
678
|
+
not use_bitsandbytes_4bit and not self.use_presharded_weights,
|
679
|
+
)
|
680
|
+
else:
|
681
|
+
# bitsandbytes loads the weights of the specific portion
|
682
|
+
# no need to narrow here
|
683
|
+
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
684
|
+
loaded_weight = loaded_weight.narrow(
|
685
|
+
output_dim, start_idx, shard_size
|
686
|
+
)
|
687
|
+
|
650
688
|
# Special case for AQLM codebooks.
|
651
689
|
elif is_metadata:
|
652
690
|
# metadata indicates fixed size concatenated along dim 0
|
@@ -1111,10 +1149,27 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
1111
1149
|
shard_id = self.tp_rank // self.num_kv_head_replicas
|
1112
1150
|
start_idx = shard_id * shard_size
|
1113
1151
|
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
1152
|
+
if _is_cpu:
|
1153
|
+
from sglang.srt.model_loader.weight_utils import (
|
1154
|
+
narrow_padded_param_and_loaded_weight,
|
1155
|
+
)
|
1156
|
+
|
1157
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
1158
|
+
param_data,
|
1159
|
+
loaded_weight,
|
1160
|
+
0, # param_data_start
|
1161
|
+
start_idx,
|
1162
|
+
output_dim,
|
1163
|
+
shard_size,
|
1164
|
+
not use_bitsandbytes_4bit and not self.use_presharded_weights,
|
1165
|
+
)
|
1166
|
+
else:
|
1167
|
+
# bitsandbytes loads the weights of the specific portion
|
1168
|
+
# no need to narrow here
|
1169
|
+
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
1170
|
+
loaded_weight = loaded_weight.narrow(
|
1171
|
+
output_dim, start_idx, shard_size
|
1172
|
+
)
|
1118
1173
|
|
1119
1174
|
# Special case for for AQLM codebooks.
|
1120
1175
|
elif is_metadata:
|
@@ -1256,7 +1311,22 @@ class RowParallelLinear(LinearBase):
|
|
1256
1311
|
):
|
1257
1312
|
shard_size = param_data.shape[input_dim]
|
1258
1313
|
start_idx = self.tp_rank * shard_size
|
1259
|
-
|
1314
|
+
|
1315
|
+
if _is_cpu:
|
1316
|
+
from sglang.srt.model_loader.weight_utils import (
|
1317
|
+
narrow_padded_param_and_loaded_weight,
|
1318
|
+
)
|
1319
|
+
|
1320
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
1321
|
+
param_data,
|
1322
|
+
loaded_weight,
|
1323
|
+
0, # param_data_start
|
1324
|
+
start_idx,
|
1325
|
+
input_dim,
|
1326
|
+
shard_size,
|
1327
|
+
)
|
1328
|
+
else:
|
1329
|
+
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
1260
1330
|
|
1261
1331
|
# Special case for loading scales off disk, which often do not
|
1262
1332
|
# have a shape (such as in the case of AutoFP8).
|
@@ -42,7 +42,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
42
42
|
ForwardBatch,
|
43
43
|
ForwardMode,
|
44
44
|
)
|
45
|
-
from sglang.srt.utils import dump_to_file
|
45
|
+
from sglang.srt.utils import dump_to_file, use_intel_amx_backend
|
46
46
|
|
47
47
|
logger = logging.getLogger(__name__)
|
48
48
|
|
@@ -436,13 +436,13 @@ class LogitsProcessor(nn.Module):
|
|
436
436
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
437
437
|
logits_metadata.compute_dp_attention_metadata(hidden_states)
|
438
438
|
hidden_states, local_hidden_states = (
|
439
|
-
logits_metadata.gathered_buffer,
|
440
|
-
hidden_states
|
439
|
+
torch.empty_like(logits_metadata.gathered_buffer),
|
440
|
+
hidden_states,
|
441
441
|
)
|
442
442
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
443
443
|
|
444
444
|
if hasattr(lm_head, "weight"):
|
445
|
-
if
|
445
|
+
if use_intel_amx_backend(lm_head):
|
446
446
|
logits = torch.ops.sgl_kernel.weight_packed_linear(
|
447
447
|
hidden_states.to(lm_head.weight.dtype),
|
448
448
|
lm_head.weight,
|
@@ -0,0 +1,215 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Cutlass W4A8 MoE kernel."""
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from sgl_kernel import (
|
7
|
+
cutlass_w4a8_moe_mm,
|
8
|
+
get_cutlass_w4a8_moe_mm_data,
|
9
|
+
sgl_per_tensor_quant_fp8,
|
10
|
+
silu_and_mul,
|
11
|
+
)
|
12
|
+
|
13
|
+
from sglang.srt.layers.moe.ep_moe.kernels import (
|
14
|
+
post_reorder_triton_kernel,
|
15
|
+
pre_reorder_triton_kernel_for_cutlass_moe,
|
16
|
+
run_cutlass_moe_ep_preproess,
|
17
|
+
)
|
18
|
+
|
19
|
+
|
20
|
+
def cutlass_w4a8_moe(
|
21
|
+
start_expert_id: int,
|
22
|
+
end_expert_id: int,
|
23
|
+
total_num_experts: int,
|
24
|
+
a: torch.Tensor,
|
25
|
+
w1_q: torch.Tensor,
|
26
|
+
w2_q: torch.Tensor,
|
27
|
+
w1_scale: torch.Tensor,
|
28
|
+
w2_scale: torch.Tensor,
|
29
|
+
topk_weights: torch.Tensor,
|
30
|
+
topk_ids_: torch.Tensor,
|
31
|
+
local_topk_ids: torch.Tensor,
|
32
|
+
a_strides1: torch.Tensor,
|
33
|
+
b_strides1: torch.Tensor,
|
34
|
+
c_strides1: torch.Tensor,
|
35
|
+
a_strides2: torch.Tensor,
|
36
|
+
b_strides2: torch.Tensor,
|
37
|
+
c_strides2: torch.Tensor,
|
38
|
+
s_strides13: torch.Tensor,
|
39
|
+
s_strides2: torch.Tensor,
|
40
|
+
expert_offsets: torch.Tensor,
|
41
|
+
problem_sizes1: torch.Tensor,
|
42
|
+
problem_sizes2: torch.Tensor,
|
43
|
+
a1_scale: Optional[torch.Tensor] = None,
|
44
|
+
a2_scale: Optional[torch.Tensor] = None,
|
45
|
+
apply_router_weight_on_input: bool = False,
|
46
|
+
) -> torch.Tensor:
|
47
|
+
"""
|
48
|
+
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
|
49
|
+
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
50
|
+
mechanism. The matrix multiplications are implemented with CUTLASS
|
51
|
+
grouped gemm.
|
52
|
+
|
53
|
+
Parameters:
|
54
|
+
- a (torch.Tensor): The input tensor to the MoE layer.
|
55
|
+
Shape: [M, K]
|
56
|
+
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
|
57
|
+
Shape: [num_experts, N * 2, K // 2]
|
58
|
+
(the weights are passed transposed and int4-packed)
|
59
|
+
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
|
60
|
+
Shape: [num_experts, K, N // 2]
|
61
|
+
(the weights are passed transposed and int4-packed)
|
62
|
+
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
63
|
+
Shape: [num_experts, K // 512, N * 8]
|
64
|
+
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
65
|
+
Shape: [num_experts, N // 512, K * 4]
|
66
|
+
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
67
|
+
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
|
68
|
+
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
|
69
|
+
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
70
|
+
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
|
71
|
+
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
|
72
|
+
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
|
73
|
+
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
|
74
|
+
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
|
75
|
+
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
76
|
+
Shape: scalar or [1, K]
|
77
|
+
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
78
|
+
quantize the intermediate result between the gemms.
|
79
|
+
Shape: scalar or [1, N]
|
80
|
+
- apply_router_weight_on_input (bool): When true, the topk weights are
|
81
|
+
applied directly on the inputs. This is only applicable when topk is 1.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
|
85
|
+
"""
|
86
|
+
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
|
87
|
+
assert w1_q.dtype == torch.int8
|
88
|
+
assert w2_q.dtype == torch.int8
|
89
|
+
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
|
90
|
+
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
|
91
|
+
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
92
|
+
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
93
|
+
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
94
|
+
assert (
|
95
|
+
w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
|
96
|
+
and w1_scale.shape[2] == w1_q.shape[1] * 4
|
97
|
+
), "W1 scale shape mismatch"
|
98
|
+
assert (
|
99
|
+
w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
|
100
|
+
and w2_scale.shape[2] == w2_q.shape[1] * 4
|
101
|
+
), "W2 scale shape mismatch"
|
102
|
+
|
103
|
+
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
|
104
|
+
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
105
|
+
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
106
|
+
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
107
|
+
num_experts = w1_q.size(0)
|
108
|
+
m = a.size(0)
|
109
|
+
k = w1_q.size(2) * 2 # w1_q is transposed and packed
|
110
|
+
n = w2_q.size(2) * 2 # w2_q is transposed and packed
|
111
|
+
topk = topk_ids_.size(1)
|
112
|
+
|
113
|
+
if apply_router_weight_on_input:
|
114
|
+
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
|
115
|
+
|
116
|
+
device = a.device
|
117
|
+
|
118
|
+
_, src2dst, _ = run_cutlass_moe_ep_preproess(
|
119
|
+
local_topk_ids,
|
120
|
+
num_experts,
|
121
|
+
)
|
122
|
+
|
123
|
+
gateup_input = torch.empty(
|
124
|
+
(m * topk, k),
|
125
|
+
device=device,
|
126
|
+
dtype=torch.float8_e4m3fn,
|
127
|
+
)
|
128
|
+
|
129
|
+
pre_reorder_triton_kernel_for_cutlass_moe[(m,)](
|
130
|
+
a,
|
131
|
+
gateup_input,
|
132
|
+
src2dst,
|
133
|
+
local_topk_ids,
|
134
|
+
a1_scale,
|
135
|
+
total_num_experts,
|
136
|
+
topk,
|
137
|
+
k,
|
138
|
+
BLOCK_SIZE=512,
|
139
|
+
)
|
140
|
+
|
141
|
+
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
|
142
|
+
# they are kept to allow for a quick switch of the permutation logic
|
143
|
+
# from the current triton kernel implementation to the cutlass-based one if needed.
|
144
|
+
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
145
|
+
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
146
|
+
get_cutlass_w4a8_moe_mm_data(
|
147
|
+
local_topk_ids,
|
148
|
+
expert_offsets,
|
149
|
+
problem_sizes1,
|
150
|
+
problem_sizes2,
|
151
|
+
a_map,
|
152
|
+
c_map,
|
153
|
+
num_experts,
|
154
|
+
n,
|
155
|
+
k,
|
156
|
+
)
|
157
|
+
|
158
|
+
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
|
159
|
+
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
|
160
|
+
|
161
|
+
cutlass_w4a8_moe_mm(
|
162
|
+
c1,
|
163
|
+
gateup_input,
|
164
|
+
w1_q,
|
165
|
+
a1_scale.float(),
|
166
|
+
w1_scale,
|
167
|
+
expert_offsets[:-1],
|
168
|
+
problem_sizes1,
|
169
|
+
a_strides1,
|
170
|
+
b_strides1,
|
171
|
+
c_strides1,
|
172
|
+
s_strides13,
|
173
|
+
128,
|
174
|
+
topk,
|
175
|
+
)
|
176
|
+
|
177
|
+
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
|
178
|
+
silu_and_mul(c1, intermediate)
|
179
|
+
|
180
|
+
intermediate_q = torch.empty(
|
181
|
+
intermediate.shape, dtype=torch.float8_e4m3fn, device=device
|
182
|
+
)
|
183
|
+
sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
|
184
|
+
|
185
|
+
cutlass_w4a8_moe_mm(
|
186
|
+
c2,
|
187
|
+
intermediate_q,
|
188
|
+
w2_q,
|
189
|
+
a2_scale.float(),
|
190
|
+
w2_scale,
|
191
|
+
expert_offsets[:-1],
|
192
|
+
problem_sizes2,
|
193
|
+
a_strides2,
|
194
|
+
b_strides2,
|
195
|
+
c_strides2,
|
196
|
+
s_strides2,
|
197
|
+
128,
|
198
|
+
topk,
|
199
|
+
)
|
200
|
+
|
201
|
+
output = torch.empty_like(a)
|
202
|
+
post_reorder_triton_kernel[(m,)](
|
203
|
+
c2,
|
204
|
+
output,
|
205
|
+
src2dst,
|
206
|
+
topk_ids_,
|
207
|
+
topk_weights,
|
208
|
+
start_expert_id,
|
209
|
+
end_expert_id,
|
210
|
+
topk,
|
211
|
+
k,
|
212
|
+
0,
|
213
|
+
BLOCK_SIZE=512,
|
214
|
+
)
|
215
|
+
return output
|