sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +36 -2
- sglang/srt/conversation.py +56 -3
- sglang/srt/disaggregation/ascend/__init__.py +6 -0
- sglang/srt/disaggregation/ascend/conn.py +44 -0
- sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
- sglang/srt/disaggregation/mooncake/conn.py +50 -18
- sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
- sglang/srt/disaggregation/utils.py +25 -3
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +1 -0
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +11 -0
- sglang/srt/entrypoints/openai/serving_chat.py +7 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/kimik2_detector.py +220 -0
- sglang/srt/hf_transformers_utils.py +18 -0
- sglang/srt/jinja_template_utils.py +8 -0
- sglang/srt/layers/communicator.py +20 -5
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/linear.py +12 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
- sglang/srt/layers/moe/ep_moe/layer.py +141 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/topk.py +8 -2
- sglang/srt/layers/parameter.py +19 -3
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -2
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +738 -14
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +35 -3
- sglang/srt/managers/mm_utils.py +59 -96
- sglang/srt/managers/schedule_batch.py +17 -6
- sglang/srt/managers/scheduler.py +38 -6
- sglang/srt/managers/tokenizer_manager.py +16 -0
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +176 -101
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -1
- sglang/srt/model_loader/loader.py +23 -12
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +78 -19
- sglang/srt/models/deepseek_vl2.py +1 -1
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +6 -3
- sglang/srt/models/internvl.py +8 -2
- sglang/srt/models/kimi_vl.py +8 -2
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llava.py +3 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpmo.py +1 -2
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral_quant.py +4 -0
- sglang/srt/models/mllama4.py +372 -82
- sglang/srt/models/phi4mm.py +8 -2
- sglang/srt/models/phimoe.py +553 -0
- sglang/srt/models/qwen2.py +2 -0
- sglang/srt/models/qwen2_5_vl.py +10 -7
- sglang/srt/models/qwen2_vl.py +12 -1
- sglang/srt/models/vila.py +8 -2
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/base_processor.py +197 -137
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
- sglang/srt/multimodal/processors/gemma3.py +4 -2
- sglang/srt/multimodal/processors/gemma3n.py +1 -1
- sglang/srt/multimodal/processors/internvl.py +1 -1
- sglang/srt/multimodal/processors/janus_pro.py +1 -1
- sglang/srt/multimodal/processors/kimi_vl.py +1 -1
- sglang/srt/multimodal/processors/minicpm.py +4 -3
- sglang/srt/multimodal/processors/mllama4.py +63 -61
- sglang/srt/multimodal/processors/phi4mm.py +1 -1
- sglang/srt/multimodal/processors/pixtral.py +1 -1
- sglang/srt/multimodal/processors/qwen_vl.py +203 -80
- sglang/srt/multimodal/processors/vila.py +1 -1
- sglang/srt/server_args.py +26 -4
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +191 -48
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
|
26
26
|
from sglang.srt.layers.quantization.utils import (
|
27
27
|
convert_to_channelwise,
|
28
28
|
is_layer_skipped,
|
29
|
+
per_tensor_dequantize,
|
29
30
|
requantize_with_max_scale,
|
30
31
|
)
|
31
32
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
110
111
|
self, layer: torch.nn.Module, prefix: str
|
111
112
|
) -> Optional["QuantizeMethodBase"]:
|
112
113
|
if self.exclude_modules and any(
|
113
|
-
module in prefix
|
114
|
+
module in prefix
|
115
|
+
or (
|
116
|
+
prefix.startswith("language_model.")
|
117
|
+
and module in prefix.removeprefix("language_model.")
|
118
|
+
)
|
119
|
+
for module in self.exclude_modules
|
114
120
|
):
|
115
121
|
return None
|
116
122
|
|
@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
119
125
|
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
120
126
|
return ModelOptFp8KVCacheMethod(self)
|
121
127
|
|
128
|
+
# Add MoE support
|
129
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
130
|
+
|
131
|
+
if isinstance(layer, FusedMoE):
|
132
|
+
return ModelOptFp8MoEMethod(self)
|
133
|
+
|
122
134
|
return None
|
123
135
|
|
124
136
|
def get_scaled_act_names(self) -> List[str]:
|
@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
|
234
246
|
super().__init__(quant_config)
|
235
247
|
|
236
248
|
|
249
|
+
class ModelOptFp8MoEMethod:
|
250
|
+
"""MoE method for ModelOpt FP8.
|
251
|
+
Supports loading FP8 checkpoints with static weight scale and activation scale.
|
252
|
+
|
253
|
+
Args:
|
254
|
+
quant_config: The ModelOpt quantization config.
|
255
|
+
"""
|
256
|
+
|
257
|
+
def __new__(cls, *args, **kwargs):
|
258
|
+
"""
|
259
|
+
Dynamic class composition pattern.
|
260
|
+
|
261
|
+
This allows us to effectively "inject" FusedMoEMethodBase as a parent class
|
262
|
+
at runtime while avoiding circular import issues.
|
263
|
+
"""
|
264
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
265
|
+
|
266
|
+
if not hasattr(cls, "_initialized"):
|
267
|
+
original_init = cls.__init__
|
268
|
+
new_cls = type(
|
269
|
+
cls.__name__,
|
270
|
+
(FusedMoEMethodBase,),
|
271
|
+
{
|
272
|
+
"__init__": original_init,
|
273
|
+
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
274
|
+
},
|
275
|
+
)
|
276
|
+
obj = super(new_cls, new_cls).__new__(new_cls)
|
277
|
+
obj.__init__(*args, **kwargs)
|
278
|
+
return obj
|
279
|
+
return super().__new__(cls)
|
280
|
+
|
281
|
+
def __init__(self, quant_config: ModelOptFp8Config):
|
282
|
+
self.quant_config = quant_config
|
283
|
+
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
284
|
+
|
285
|
+
def create_weights(
|
286
|
+
self,
|
287
|
+
layer: torch.nn.Module,
|
288
|
+
num_experts: int,
|
289
|
+
hidden_size: int,
|
290
|
+
intermediate_size: int,
|
291
|
+
params_dtype: torch.dtype,
|
292
|
+
**extra_weight_attrs,
|
293
|
+
):
|
294
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
295
|
+
|
296
|
+
# Use FP8 dtype if checkpoint is serialized, otherwise use the default dtype
|
297
|
+
weight_dtype = (
|
298
|
+
torch.float8_e4m3fn
|
299
|
+
if self.quant_config.is_checkpoint_fp8_serialized
|
300
|
+
else params_dtype
|
301
|
+
)
|
302
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
303
|
+
|
304
|
+
w13_weight = ModelWeightParameter(
|
305
|
+
data=torch.empty(
|
306
|
+
num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
|
307
|
+
),
|
308
|
+
input_dim=2,
|
309
|
+
output_dim=1,
|
310
|
+
weight_loader=weight_loader,
|
311
|
+
)
|
312
|
+
layer.register_parameter("w13_weight", w13_weight)
|
313
|
+
|
314
|
+
w2_weight = ModelWeightParameter(
|
315
|
+
data=torch.empty(
|
316
|
+
num_experts, hidden_size, intermediate_size, dtype=weight_dtype
|
317
|
+
),
|
318
|
+
input_dim=2,
|
319
|
+
output_dim=1,
|
320
|
+
weight_loader=weight_loader,
|
321
|
+
)
|
322
|
+
layer.register_parameter("w2_weight", w2_weight)
|
323
|
+
|
324
|
+
if self.quant_config.is_checkpoint_fp8_serialized:
|
325
|
+
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
|
326
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
327
|
+
# They will be combined to a single scale after weight loading.
|
328
|
+
w13_weight_scale = PerTensorScaleParameter(
|
329
|
+
data=torch.full(
|
330
|
+
(num_experts, 2),
|
331
|
+
torch.finfo(torch.float32).min,
|
332
|
+
dtype=torch.float32,
|
333
|
+
),
|
334
|
+
weight_loader=weight_loader,
|
335
|
+
)
|
336
|
+
w2_weight_scale = PerTensorScaleParameter(
|
337
|
+
data=torch.full(
|
338
|
+
(num_experts,), torch.finfo(torch.float32).min, dtype=torch.float32
|
339
|
+
),
|
340
|
+
weight_loader=weight_loader,
|
341
|
+
)
|
342
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
343
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
344
|
+
|
345
|
+
# Set weight loader attributes for scales
|
346
|
+
extra_weight_attrs.update(
|
347
|
+
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
348
|
+
)
|
349
|
+
|
350
|
+
# INPUT SCALES - Per-tensor scaling for ModelOpt
|
351
|
+
w13_input_scale = PerTensorScaleParameter(
|
352
|
+
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
|
353
|
+
weight_loader=weight_loader,
|
354
|
+
)
|
355
|
+
w2_input_scale = PerTensorScaleParameter(
|
356
|
+
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
|
357
|
+
weight_loader=weight_loader,
|
358
|
+
)
|
359
|
+
layer.register_parameter("w13_input_scale", w13_input_scale)
|
360
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
361
|
+
|
362
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
363
|
+
"""Process FP8 MoE weights after loading from serialized checkpoint.
|
364
|
+
|
365
|
+
Only supports pre-quantized checkpoints with FP8 weights and scales.
|
366
|
+
"""
|
367
|
+
|
368
|
+
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
369
|
+
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
370
|
+
|
371
|
+
# Handle scale parameters
|
372
|
+
if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
|
373
|
+
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
374
|
+
# We take the max of the w1 and w3 scales then dequant and requant each expert.
|
375
|
+
if layer.w13_weight_scale.dim() == 2: # Shape: (num_experts, 2)
|
376
|
+
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
377
|
+
|
378
|
+
# Get the maximum scale across w1 and w3 for each expert
|
379
|
+
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
380
|
+
|
381
|
+
# Requantize each expert's weights using the combined scale
|
382
|
+
# w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
|
383
|
+
# where the first intermediate_size rows are w1, the next are w3
|
384
|
+
intermediate_size = layer.w13_weight.shape[1] // 2
|
385
|
+
for expert_id in range(layer.w13_weight.shape[0]):
|
386
|
+
start = 0
|
387
|
+
for shard_id in range(2): # w1 and w3
|
388
|
+
# Dequantize using the original scale for this shard
|
389
|
+
dq_weight = per_tensor_dequantize(
|
390
|
+
layer.w13_weight[expert_id][
|
391
|
+
start : start + intermediate_size, :
|
392
|
+
],
|
393
|
+
layer.w13_weight_scale[expert_id][shard_id],
|
394
|
+
)
|
395
|
+
# Requantize using the combined max scale
|
396
|
+
(
|
397
|
+
layer.w13_weight[expert_id][
|
398
|
+
start : start + intermediate_size, :
|
399
|
+
],
|
400
|
+
_,
|
401
|
+
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
402
|
+
|
403
|
+
start += intermediate_size
|
404
|
+
|
405
|
+
# Update the scale parameter to be per-expert instead of per-shard
|
406
|
+
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
|
407
|
+
else:
|
408
|
+
layer.w13_weight_scale = Parameter(
|
409
|
+
layer.w13_weight_scale.data, requires_grad=False
|
410
|
+
)
|
411
|
+
|
412
|
+
if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
|
413
|
+
layer.w2_weight_scale = Parameter(
|
414
|
+
layer.w2_weight_scale.data, requires_grad=False
|
415
|
+
)
|
416
|
+
if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
|
417
|
+
layer.w13_input_scale = Parameter(
|
418
|
+
layer.w13_input_scale.max(), requires_grad=False
|
419
|
+
)
|
420
|
+
if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
|
421
|
+
layer.w2_input_scale = Parameter(
|
422
|
+
layer.w2_input_scale.max(), requires_grad=False
|
423
|
+
)
|
424
|
+
|
425
|
+
def apply(
|
426
|
+
self,
|
427
|
+
layer: torch.nn.Module,
|
428
|
+
x: torch.Tensor,
|
429
|
+
router_logits: torch.Tensor,
|
430
|
+
top_k: int,
|
431
|
+
renormalize: bool,
|
432
|
+
use_grouped_topk: bool,
|
433
|
+
topk_group: Optional[int] = None,
|
434
|
+
num_expert_group: Optional[int] = None,
|
435
|
+
num_fused_shared_experts: Optional[int] = None,
|
436
|
+
custom_routing_function: Optional[Callable] = None,
|
437
|
+
correction_bias: Optional[torch.Tensor] = None,
|
438
|
+
activation: str = "silu",
|
439
|
+
apply_router_weight_on_input: bool = False,
|
440
|
+
inplace: bool = True,
|
441
|
+
no_combine: bool = False,
|
442
|
+
routed_scaling_factor: Optional[float] = None,
|
443
|
+
) -> torch.Tensor:
|
444
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
445
|
+
from sglang.srt.layers.moe.topk import select_experts
|
446
|
+
|
447
|
+
# Expert selection
|
448
|
+
topk_weights, topk_ids = select_experts(
|
449
|
+
hidden_states=x,
|
450
|
+
router_logits=router_logits,
|
451
|
+
use_grouped_topk=use_grouped_topk,
|
452
|
+
top_k=top_k,
|
453
|
+
renormalize=renormalize,
|
454
|
+
topk_group=topk_group,
|
455
|
+
num_expert_group=num_expert_group,
|
456
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
457
|
+
custom_routing_function=custom_routing_function,
|
458
|
+
correction_bias=correction_bias,
|
459
|
+
routed_scaling_factor=routed_scaling_factor,
|
460
|
+
)
|
461
|
+
|
462
|
+
return fused_experts(
|
463
|
+
x,
|
464
|
+
layer.w13_weight,
|
465
|
+
layer.w2_weight,
|
466
|
+
topk_weights=topk_weights,
|
467
|
+
topk_ids=topk_ids,
|
468
|
+
inplace=inplace,
|
469
|
+
activation=activation,
|
470
|
+
use_fp8_w8a8=True,
|
471
|
+
per_channel_quant=False, # ModelOpt uses per-tensor quantization
|
472
|
+
w1_scale=layer.w13_weight_scale,
|
473
|
+
w2_scale=layer.w2_weight_scale,
|
474
|
+
a1_scale=layer.w13_input_scale,
|
475
|
+
a2_scale=layer.w2_input_scale,
|
476
|
+
no_combine=no_combine,
|
477
|
+
)
|
478
|
+
|
479
|
+
|
237
480
|
class ModelOptFp4Config(QuantizationConfig):
|
238
481
|
"""Config class for FP4."""
|
239
482
|
|
@@ -116,8 +116,7 @@ class MoeWNA16Config(QuantizationConfig):
|
|
116
116
|
|
117
117
|
@classmethod
|
118
118
|
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
119
|
-
|
120
|
-
if can_convert and user_quant == "moe_wna16":
|
119
|
+
if user_quant == "moe_wna16" and cls.is_moe_wna16_compatible(hf_quant_cfg):
|
121
120
|
return cls.get_name()
|
122
121
|
return None
|
123
122
|
|
@@ -0,0 +1,264 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, Dict, List, Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch.nn import Module
|
6
|
+
from torch.nn.parameter import Parameter
|
7
|
+
|
8
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
9
|
+
from sglang.srt.layers.quantization.base_config import (
|
10
|
+
QuantizationConfig,
|
11
|
+
QuantizeMethodBase,
|
12
|
+
)
|
13
|
+
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
14
|
+
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
15
|
+
from sglang.srt.utils import set_weight_attrs
|
16
|
+
|
17
|
+
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
class W4AFp8Config(QuantizationConfig):
|
23
|
+
"""Config class for MIXED_PRECISION W4AFp8."""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
is_checkpoint_fp8_serialized: bool = True,
|
28
|
+
is_checkpoint_w4afp8_serialized: bool = True,
|
29
|
+
linear_activation_scheme: str = "dynamic",
|
30
|
+
moe_activation_scheme: str = "static",
|
31
|
+
ignored_layers: Optional[List[str]] = None,
|
32
|
+
weight_block_size: Optional[List[int]] = None,
|
33
|
+
group_size: int = 128,
|
34
|
+
) -> None:
|
35
|
+
super().__init__()
|
36
|
+
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
37
|
+
self.is_checkpoint_w4afp8_serialized = is_checkpoint_w4afp8_serialized
|
38
|
+
if is_checkpoint_w4afp8_serialized:
|
39
|
+
logger.warning("Detected w4afp8 checkpoint. Please note that")
|
40
|
+
if moe_activation_scheme not in ACTIVATION_SCHEMES:
|
41
|
+
raise ValueError(f"Unsupported activation scheme {moe_activation_scheme}")
|
42
|
+
self.linear_activation_scheme = linear_activation_scheme
|
43
|
+
self.moe_activation_scheme = moe_activation_scheme
|
44
|
+
self.ignored_layers = ignored_layers or []
|
45
|
+
self.weight_block_size = [128, 128]
|
46
|
+
self.group_size = group_size
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def get_name(cls) -> str:
|
50
|
+
return "w4afp8"
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
54
|
+
return [torch.bfloat16, torch.float8_e4m3fn]
|
55
|
+
|
56
|
+
@classmethod
|
57
|
+
def get_min_capability(cls) -> int:
|
58
|
+
return 90
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
def get_config_filenames(cls) -> List[str]:
|
62
|
+
return []
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config":
|
66
|
+
quant_method = cls.get_from_keys(config, ["quant_method"])
|
67
|
+
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
68
|
+
is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
|
69
|
+
linear_activation_scheme = "dynamic"
|
70
|
+
moe_activation_scheme = "static"
|
71
|
+
weight_block_size = [128, 128]
|
72
|
+
return cls(
|
73
|
+
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
74
|
+
is_checkpoint_w4afp8_serialized=is_checkpoint_w4afp8_serialized,
|
75
|
+
linear_activation_scheme=linear_activation_scheme,
|
76
|
+
moe_activation_scheme=moe_activation_scheme,
|
77
|
+
weight_block_size=weight_block_size,
|
78
|
+
)
|
79
|
+
|
80
|
+
def get_quant_method(
|
81
|
+
self, layer: torch.nn.Module, prefix: str
|
82
|
+
) -> Optional["QuantizeMethodBase"]:
|
83
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
84
|
+
|
85
|
+
if isinstance(layer, LinearBase):
|
86
|
+
if is_layer_skipped(prefix, self.ignored_layers):
|
87
|
+
return UnquantizedLinearMethod()
|
88
|
+
return Fp8LinearMethod(self)
|
89
|
+
elif isinstance(layer, FusedMoE):
|
90
|
+
return W4AFp8MoEMethod(self)
|
91
|
+
return None
|
92
|
+
|
93
|
+
def get_scaled_act_names(self) -> List[str]:
|
94
|
+
return []
|
95
|
+
|
96
|
+
|
97
|
+
class W4AFp8MoEMethod:
|
98
|
+
|
99
|
+
def __init__(self, quant_config: W4AFp8Config):
|
100
|
+
self.quant_config = quant_config
|
101
|
+
|
102
|
+
def create_weights(
|
103
|
+
self,
|
104
|
+
layer: Module,
|
105
|
+
num_experts_per_partition: int,
|
106
|
+
hidden_size: int,
|
107
|
+
intermediate_size: int,
|
108
|
+
params_dtype: torch.dtype,
|
109
|
+
**extra_weight_attrs,
|
110
|
+
):
|
111
|
+
assert "weight_loader" in extra_weight_attrs
|
112
|
+
|
113
|
+
# Fused gate_up_proj (column parallel)
|
114
|
+
w13_weight = torch.nn.Parameter(
|
115
|
+
torch.empty(
|
116
|
+
num_experts_per_partition,
|
117
|
+
intermediate_size * 2,
|
118
|
+
hidden_size // 2,
|
119
|
+
dtype=torch.int8,
|
120
|
+
),
|
121
|
+
requires_grad=False,
|
122
|
+
)
|
123
|
+
layer.register_parameter("w13_weight", w13_weight)
|
124
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
125
|
+
|
126
|
+
# down_proj (row parallel)
|
127
|
+
w2_weight = torch.nn.Parameter(
|
128
|
+
torch.empty(
|
129
|
+
num_experts_per_partition,
|
130
|
+
hidden_size,
|
131
|
+
intermediate_size // 2,
|
132
|
+
dtype=torch.int8,
|
133
|
+
),
|
134
|
+
requires_grad=False,
|
135
|
+
)
|
136
|
+
layer.register_parameter("w2_weight", w2_weight)
|
137
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
138
|
+
|
139
|
+
w13_weight_scale = torch.nn.Parameter(
|
140
|
+
torch.zeros(
|
141
|
+
num_experts_per_partition,
|
142
|
+
2 * intermediate_size,
|
143
|
+
hidden_size // self.quant_config.group_size,
|
144
|
+
dtype=torch.float32,
|
145
|
+
),
|
146
|
+
requires_grad=False,
|
147
|
+
)
|
148
|
+
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
149
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
150
|
+
|
151
|
+
w2_weight_scale = torch.nn.Parameter(
|
152
|
+
torch.zeros(
|
153
|
+
num_experts_per_partition,
|
154
|
+
hidden_size,
|
155
|
+
intermediate_size // self.quant_config.group_size,
|
156
|
+
dtype=torch.float32,
|
157
|
+
),
|
158
|
+
requires_grad=False,
|
159
|
+
)
|
160
|
+
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
161
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
162
|
+
|
163
|
+
# Input scales
|
164
|
+
w13_input_scale = torch.nn.Parameter(
|
165
|
+
torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16),
|
166
|
+
requires_grad=False,
|
167
|
+
)
|
168
|
+
layer.register_parameter("w13_input_scale", w13_input_scale)
|
169
|
+
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
170
|
+
|
171
|
+
w2_input_scale = torch.nn.Parameter(
|
172
|
+
torch.ones(num_experts_per_partition, dtype=torch.bfloat16),
|
173
|
+
requires_grad=False,
|
174
|
+
)
|
175
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
176
|
+
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
177
|
+
|
178
|
+
# Pre-populate the strides
|
179
|
+
device = layer.w13_weight.device
|
180
|
+
|
181
|
+
self.a_strides1 = torch.full(
|
182
|
+
(num_experts_per_partition, 3),
|
183
|
+
hidden_size,
|
184
|
+
device=device,
|
185
|
+
dtype=torch.int64,
|
186
|
+
)
|
187
|
+
self.c_strides1 = torch.full(
|
188
|
+
(num_experts_per_partition, 3),
|
189
|
+
2 * intermediate_size,
|
190
|
+
device=device,
|
191
|
+
dtype=torch.int64,
|
192
|
+
)
|
193
|
+
self.a_strides2 = torch.full(
|
194
|
+
(num_experts_per_partition, 3),
|
195
|
+
intermediate_size,
|
196
|
+
device=device,
|
197
|
+
dtype=torch.int64,
|
198
|
+
)
|
199
|
+
self.c_strides2 = torch.full(
|
200
|
+
(num_experts_per_partition, 3),
|
201
|
+
hidden_size,
|
202
|
+
device=device,
|
203
|
+
dtype=torch.int64,
|
204
|
+
)
|
205
|
+
self.b_strides1 = self.a_strides1
|
206
|
+
self.s_strides13 = self.c_strides1
|
207
|
+
self.b_strides2 = self.a_strides2
|
208
|
+
self.s_strides2 = self.c_strides2
|
209
|
+
|
210
|
+
self.expert_offsets = torch.empty(
|
211
|
+
(num_experts_per_partition + 1), dtype=torch.int32, device=device
|
212
|
+
)
|
213
|
+
self.problem_sizes1 = torch.empty(
|
214
|
+
(num_experts_per_partition, 3), dtype=torch.int32, device=device
|
215
|
+
)
|
216
|
+
self.problem_sizes2 = torch.empty(
|
217
|
+
(num_experts_per_partition, 3), dtype=torch.int32, device=device
|
218
|
+
)
|
219
|
+
|
220
|
+
return
|
221
|
+
|
222
|
+
def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
|
223
|
+
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
|
224
|
+
s_shape = scales.shape
|
225
|
+
# Reshape to separate groups of 4
|
226
|
+
scales_interleaved = scales.reshape(
|
227
|
+
s_shape[0], s_shape[1], (s_shape[2] // 4), 4
|
228
|
+
)
|
229
|
+
# Permute dimensions to interleave
|
230
|
+
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
|
231
|
+
# Reshape back to original dimensions but with interleaved values
|
232
|
+
scales_interleaved = scales_interleaved.reshape(
|
233
|
+
s_shape[0], s_shape[2] // 4, s_shape[1] * 4
|
234
|
+
)
|
235
|
+
return scales_interleaved.contiguous()
|
236
|
+
|
237
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
238
|
+
dtype = torch.bfloat16
|
239
|
+
device = layer.w2_weight.device
|
240
|
+
|
241
|
+
# Interleave w13_weight_scale (gate_up_proj)
|
242
|
+
w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
|
243
|
+
w13_weight_scale = self._interleave_scales(w13_weight_scale)
|
244
|
+
layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
|
245
|
+
|
246
|
+
# Interleave w2_weight_scale (down_proj)
|
247
|
+
w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
|
248
|
+
w2_weight_scale = self._interleave_scales(w2_weight_scale)
|
249
|
+
layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
|
250
|
+
|
251
|
+
# Process input scales
|
252
|
+
w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item()
|
253
|
+
new_w13_input_scale = torch.tensor(
|
254
|
+
[w13_input_scale_max],
|
255
|
+
dtype=dtype,
|
256
|
+
device=device,
|
257
|
+
)
|
258
|
+
layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False)
|
259
|
+
|
260
|
+
w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item()
|
261
|
+
new_w2_input_scale = torch.tensor(
|
262
|
+
[w2_input_scale_max], dtype=dtype, device=device
|
263
|
+
)
|
264
|
+
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
|