sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +5 -1
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +17 -2
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +65 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +5 -9
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +148 -72
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +105 -66
- sglang/srt/function_call/function_call_parser.py +6 -4
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +46 -25
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +88 -34
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +33 -14
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +188 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +62 -13
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +27 -11
- sglang/srt/managers/scheduler.py +48 -26
- sglang/srt/managers/tokenizer_manager.py +62 -28
- sglang/srt/managers/tp_worker.py +5 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +35 -18
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +109 -37
- sglang/srt/models/deepseek_v2.py +63 -30
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +2 -6
- sglang/srt/models/qwen3_moe.py +6 -8
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +48 -5
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +132 -60
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +113 -69
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_activation.py +50 -1
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,7 @@ from sglang.srt.utils import (
|
|
24
24
|
)
|
25
25
|
|
26
26
|
if TYPE_CHECKING:
|
27
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
27
28
|
from sglang.srt.layers.moe.topk import TopKOutput
|
28
29
|
|
29
30
|
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
@@ -129,6 +130,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
129
130
|
super().__init__()
|
130
131
|
self.use_triton_kernels = use_triton_kernels
|
131
132
|
|
133
|
+
self.triton_kernel_moe_forward = None
|
134
|
+
if torch.cuda.is_available() and has_triton_kernels:
|
135
|
+
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
136
|
+
triton_kernel_moe_forward as _tk_forward,
|
137
|
+
)
|
138
|
+
|
139
|
+
self.triton_kernel_moe_forward = _tk_forward
|
140
|
+
|
132
141
|
def create_weights(
|
133
142
|
self,
|
134
143
|
layer: torch.nn.Module,
|
@@ -194,6 +203,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
194
203
|
no_combine: bool = False,
|
195
204
|
routed_scaling_factor: Optional[float] = None,
|
196
205
|
) -> torch.Tensor:
|
206
|
+
|
207
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
208
|
+
|
209
|
+
if isinstance(layer, EPMoE):
|
210
|
+
return layer.run_moe(
|
211
|
+
hidden_states=x,
|
212
|
+
topk_output=topk_output,
|
213
|
+
)
|
214
|
+
|
197
215
|
return self.forward(
|
198
216
|
x=x,
|
199
217
|
layer=layer,
|
@@ -219,16 +237,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
219
237
|
) -> torch.Tensor:
|
220
238
|
|
221
239
|
if self.use_triton_kernels:
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
# gating_output=router_logits,
|
229
|
-
# topk=top_k,
|
230
|
-
# renormalize=renormalize,
|
231
|
-
# )
|
240
|
+
return self.triton_kernel_moe_forward(
|
241
|
+
hidden_states=x,
|
242
|
+
w1=layer.w13_weight,
|
243
|
+
w2=layer.w2_weight,
|
244
|
+
topk_output=topk_output,
|
245
|
+
)
|
232
246
|
else:
|
233
247
|
if _use_aiter:
|
234
248
|
assert not no_combine, "unsupported"
|
@@ -354,69 +368,3 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
354
368
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
355
369
|
|
356
370
|
forward_native = forward_cpu
|
357
|
-
|
358
|
-
|
359
|
-
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
360
|
-
|
361
|
-
def create_weights(
|
362
|
-
self,
|
363
|
-
layer: torch.nn.Module,
|
364
|
-
num_experts_per_partition: int,
|
365
|
-
hidden_size: int,
|
366
|
-
intermediate_size: int,
|
367
|
-
params_dtype: torch.dtype,
|
368
|
-
**extra_weight_attrs,
|
369
|
-
):
|
370
|
-
# Fused gate_up_proj (column parallel)
|
371
|
-
w13_weight = torch.nn.Parameter(
|
372
|
-
torch.empty(
|
373
|
-
num_experts_per_partition,
|
374
|
-
2 * intermediate_size,
|
375
|
-
hidden_size,
|
376
|
-
dtype=params_dtype,
|
377
|
-
),
|
378
|
-
requires_grad=False,
|
379
|
-
)
|
380
|
-
layer.register_parameter("w13_weight", w13_weight)
|
381
|
-
set_weight_attrs(w13_weight, extra_weight_attrs)
|
382
|
-
|
383
|
-
# down_proj (row parallel)
|
384
|
-
w2_weight = torch.nn.Parameter(
|
385
|
-
torch.empty(
|
386
|
-
num_experts_per_partition,
|
387
|
-
hidden_size,
|
388
|
-
intermediate_size,
|
389
|
-
dtype=params_dtype,
|
390
|
-
),
|
391
|
-
requires_grad=False,
|
392
|
-
)
|
393
|
-
layer.register_parameter("w2_weight", w2_weight)
|
394
|
-
set_weight_attrs(w2_weight, extra_weight_attrs)
|
395
|
-
|
396
|
-
# scale
|
397
|
-
layer.register_parameter("w13_input_scale", None)
|
398
|
-
layer.register_parameter("w13_weight_scale", None)
|
399
|
-
|
400
|
-
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
401
|
-
|
402
|
-
w2_input_scale = torch.nn.Parameter(
|
403
|
-
ones_tensor,
|
404
|
-
requires_grad=False,
|
405
|
-
)
|
406
|
-
layer.register_parameter("w2_input_scale", w2_input_scale)
|
407
|
-
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
408
|
-
|
409
|
-
w2_weight_scale = torch.nn.Parameter(
|
410
|
-
ones_tensor,
|
411
|
-
requires_grad=False,
|
412
|
-
)
|
413
|
-
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
414
|
-
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
415
|
-
|
416
|
-
def apply(
|
417
|
-
self,
|
418
|
-
layer: torch.nn.Module,
|
419
|
-
hidden_states: torch.Tensor,
|
420
|
-
topk_output: TopKOutput,
|
421
|
-
) -> torch.Tensor:
|
422
|
-
raise NotImplementedError
|
@@ -17,15 +17,6 @@ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_np
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
19
19
|
|
20
|
-
_is_cuda = is_cuda()
|
21
|
-
_is_npu = is_npu()
|
22
|
-
_is_cpu_amx_available = cpu_has_amx_support()
|
23
|
-
_is_cpu = is_cpu()
|
24
|
-
_is_hip = is_hip()
|
25
|
-
|
26
|
-
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
27
|
-
from vllm._custom_ops import scaled_fp8_quant
|
28
|
-
|
29
20
|
|
30
21
|
def is_layer_skipped(
|
31
22
|
prefix: str,
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import Any, Dict, List, Optional
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch.nn import Module
|
@@ -17,6 +17,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
|
17
17
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
18
18
|
from sglang.srt.utils import set_weight_attrs
|
19
19
|
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput
|
22
|
+
|
20
23
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
21
24
|
|
22
25
|
logger = logging.getLogger(__name__)
|
@@ -84,13 +87,14 @@ class W4AFp8Config(QuantizationConfig):
|
|
84
87
|
self, layer: torch.nn.Module, prefix: str
|
85
88
|
) -> Optional[QuantizeMethodBase]:
|
86
89
|
from sglang.srt.layers.linear import LinearBase
|
90
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
87
91
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
88
92
|
|
89
93
|
if isinstance(layer, LinearBase):
|
90
94
|
if is_layer_skipped(prefix, self.ignored_layers):
|
91
95
|
return UnquantizedLinearMethod()
|
92
96
|
return Fp8LinearMethod(self)
|
93
|
-
elif isinstance(layer,
|
97
|
+
elif isinstance(layer, EPMoE):
|
94
98
|
return W4AFp8MoEMethod(self)
|
95
99
|
return None
|
96
100
|
|
@@ -105,8 +109,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
105
109
|
|
106
110
|
def create_weights(
|
107
111
|
self,
|
108
|
-
layer:
|
109
|
-
|
112
|
+
layer: EPMoE,
|
113
|
+
num_experts: int,
|
110
114
|
hidden_size: int,
|
111
115
|
intermediate_size: int,
|
112
116
|
params_dtype: torch.dtype,
|
@@ -117,7 +121,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
117
121
|
# Fused gate_up_proj (column parallel)
|
118
122
|
w13_weight = torch.nn.Parameter(
|
119
123
|
torch.empty(
|
120
|
-
|
124
|
+
num_experts,
|
121
125
|
intermediate_size * 2,
|
122
126
|
hidden_size // 2,
|
123
127
|
dtype=torch.int8,
|
@@ -130,7 +134,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
130
134
|
# down_proj (row parallel)
|
131
135
|
w2_weight = torch.nn.Parameter(
|
132
136
|
torch.empty(
|
133
|
-
|
137
|
+
num_experts,
|
134
138
|
hidden_size,
|
135
139
|
intermediate_size // 2,
|
136
140
|
dtype=torch.int8,
|
@@ -142,7 +146,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
142
146
|
|
143
147
|
w13_weight_scale = torch.nn.Parameter(
|
144
148
|
torch.zeros(
|
145
|
-
|
149
|
+
num_experts,
|
146
150
|
2 * intermediate_size,
|
147
151
|
hidden_size // self.quant_config.group_size,
|
148
152
|
dtype=torch.float32,
|
@@ -154,7 +158,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
154
158
|
|
155
159
|
w2_weight_scale = torch.nn.Parameter(
|
156
160
|
torch.zeros(
|
157
|
-
|
161
|
+
num_experts,
|
158
162
|
hidden_size,
|
159
163
|
intermediate_size // self.quant_config.group_size,
|
160
164
|
dtype=torch.float32,
|
@@ -166,14 +170,14 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
166
170
|
|
167
171
|
# Input scales
|
168
172
|
w13_input_scale = torch.nn.Parameter(
|
169
|
-
torch.ones((
|
173
|
+
torch.ones((num_experts, 2), dtype=torch.bfloat16),
|
170
174
|
requires_grad=False,
|
171
175
|
)
|
172
176
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
173
177
|
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
174
178
|
|
175
179
|
w2_input_scale = torch.nn.Parameter(
|
176
|
-
torch.ones(
|
180
|
+
torch.ones(num_experts, dtype=torch.bfloat16),
|
177
181
|
requires_grad=False,
|
178
182
|
)
|
179
183
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
@@ -183,25 +187,25 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
183
187
|
device = layer.w13_weight.device
|
184
188
|
|
185
189
|
self.a_strides1 = torch.full(
|
186
|
-
(
|
190
|
+
(num_experts, 3),
|
187
191
|
hidden_size,
|
188
192
|
device=device,
|
189
193
|
dtype=torch.int64,
|
190
194
|
)
|
191
195
|
self.c_strides1 = torch.full(
|
192
|
-
(
|
196
|
+
(num_experts, 3),
|
193
197
|
2 * intermediate_size,
|
194
198
|
device=device,
|
195
199
|
dtype=torch.int64,
|
196
200
|
)
|
197
201
|
self.a_strides2 = torch.full(
|
198
|
-
(
|
202
|
+
(num_experts, 3),
|
199
203
|
intermediate_size,
|
200
204
|
device=device,
|
201
205
|
dtype=torch.int64,
|
202
206
|
)
|
203
207
|
self.c_strides2 = torch.full(
|
204
|
-
(
|
208
|
+
(num_experts, 3),
|
205
209
|
hidden_size,
|
206
210
|
device=device,
|
207
211
|
dtype=torch.int64,
|
@@ -212,13 +216,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
212
216
|
self.s_strides2 = self.c_strides2
|
213
217
|
|
214
218
|
self.expert_offsets = torch.empty(
|
215
|
-
(
|
219
|
+
(num_experts + 1), dtype=torch.int32, device=device
|
216
220
|
)
|
217
221
|
self.problem_sizes1 = torch.empty(
|
218
|
-
(
|
222
|
+
(num_experts, 3), dtype=torch.int32, device=device
|
219
223
|
)
|
220
224
|
self.problem_sizes2 = torch.empty(
|
221
|
-
(
|
225
|
+
(num_experts, 3), dtype=torch.int32, device=device
|
222
226
|
)
|
223
227
|
|
224
228
|
return
|
@@ -266,3 +270,50 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
266
270
|
[w2_input_scale_max], dtype=dtype, device=device
|
267
271
|
)
|
268
272
|
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
|
273
|
+
|
274
|
+
def apply(
|
275
|
+
self,
|
276
|
+
layer: EPMoE,
|
277
|
+
hidden_states: torch.Tensor,
|
278
|
+
topk_output: TopKOutput,
|
279
|
+
) -> torch.Tensor:
|
280
|
+
|
281
|
+
# TODO(ch-wan): move it out of this class
|
282
|
+
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
283
|
+
|
284
|
+
topk_ids, topk_weights, _ = topk_output
|
285
|
+
local_topk_ids = topk_ids
|
286
|
+
if layer.expert_map is not None:
|
287
|
+
"Translate info from expert_map to topk_ids"
|
288
|
+
local_topk_ids = torch.where(
|
289
|
+
layer.expert_map[topk_ids] != layer.num_experts,
|
290
|
+
layer.expert_map[topk_ids],
|
291
|
+
layer.num_experts,
|
292
|
+
)
|
293
|
+
|
294
|
+
return cutlass_w4a8_moe(
|
295
|
+
layer.start_expert_id,
|
296
|
+
layer.end_expert_id,
|
297
|
+
layer.num_experts,
|
298
|
+
hidden_states,
|
299
|
+
layer.w13_weight,
|
300
|
+
layer.w2_weight,
|
301
|
+
layer.w13_weight_scale_inv,
|
302
|
+
layer.w2_weight_scale_inv,
|
303
|
+
topk_weights,
|
304
|
+
topk_ids,
|
305
|
+
local_topk_ids,
|
306
|
+
self.a_strides1,
|
307
|
+
self.b_strides1,
|
308
|
+
self.c_strides1,
|
309
|
+
self.a_strides2,
|
310
|
+
self.b_strides2,
|
311
|
+
self.c_strides2,
|
312
|
+
self.s_strides13,
|
313
|
+
self.s_strides2,
|
314
|
+
self.expert_offsets,
|
315
|
+
self.problem_sizes1,
|
316
|
+
self.problem_sizes2,
|
317
|
+
layer.w13_input_scale,
|
318
|
+
layer.w2_input_scale,
|
319
|
+
)
|
@@ -12,14 +12,16 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
"""Radix attention."""
|
15
|
+
from __future__ import annotations
|
15
16
|
|
16
17
|
from enum import Enum
|
17
|
-
from typing import Optional
|
18
|
+
from typing import TYPE_CHECKING, Optional
|
18
19
|
|
19
20
|
from torch import nn
|
20
21
|
|
21
|
-
|
22
|
-
from sglang.srt.
|
22
|
+
if TYPE_CHECKING:
|
23
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
24
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
23
25
|
|
24
26
|
|
25
27
|
class AttentionType(Enum):
|