sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
|
26
26
|
from sglang.srt.layers.quantization.utils import (
|
27
27
|
convert_to_channelwise,
|
28
28
|
is_layer_skipped,
|
29
|
+
per_tensor_dequantize,
|
29
30
|
requantize_with_max_scale,
|
30
31
|
)
|
31
32
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
110
111
|
self, layer: torch.nn.Module, prefix: str
|
111
112
|
) -> Optional["QuantizeMethodBase"]:
|
112
113
|
if self.exclude_modules and any(
|
113
|
-
module in prefix
|
114
|
+
module in prefix
|
115
|
+
or (
|
116
|
+
prefix.startswith("language_model.")
|
117
|
+
and module in prefix.removeprefix("language_model.")
|
118
|
+
)
|
119
|
+
for module in self.exclude_modules
|
114
120
|
):
|
115
121
|
return None
|
116
122
|
|
@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
119
125
|
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
120
126
|
return ModelOptFp8KVCacheMethod(self)
|
121
127
|
|
128
|
+
# Add MoE support
|
129
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
130
|
+
|
131
|
+
if isinstance(layer, FusedMoE):
|
132
|
+
return ModelOptFp8MoEMethod(self)
|
133
|
+
|
122
134
|
return None
|
123
135
|
|
124
136
|
def get_scaled_act_names(self) -> List[str]:
|
@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
|
234
246
|
super().__init__(quant_config)
|
235
247
|
|
236
248
|
|
249
|
+
class ModelOptFp8MoEMethod:
|
250
|
+
"""MoE method for ModelOpt FP8.
|
251
|
+
Supports loading FP8 checkpoints with static weight scale and activation scale.
|
252
|
+
|
253
|
+
Args:
|
254
|
+
quant_config: The ModelOpt quantization config.
|
255
|
+
"""
|
256
|
+
|
257
|
+
def __new__(cls, *args, **kwargs):
|
258
|
+
"""
|
259
|
+
Dynamic class composition pattern.
|
260
|
+
|
261
|
+
This allows us to effectively "inject" FusedMoEMethodBase as a parent class
|
262
|
+
at runtime while avoiding circular import issues.
|
263
|
+
"""
|
264
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
265
|
+
|
266
|
+
if not hasattr(cls, "_initialized"):
|
267
|
+
original_init = cls.__init__
|
268
|
+
new_cls = type(
|
269
|
+
cls.__name__,
|
270
|
+
(FusedMoEMethodBase,),
|
271
|
+
{
|
272
|
+
"__init__": original_init,
|
273
|
+
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
274
|
+
},
|
275
|
+
)
|
276
|
+
obj = super(new_cls, new_cls).__new__(new_cls)
|
277
|
+
obj.__init__(*args, **kwargs)
|
278
|
+
return obj
|
279
|
+
return super().__new__(cls)
|
280
|
+
|
281
|
+
def __init__(self, quant_config: ModelOptFp8Config):
|
282
|
+
self.quant_config = quant_config
|
283
|
+
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
284
|
+
|
285
|
+
def create_weights(
|
286
|
+
self,
|
287
|
+
layer: torch.nn.Module,
|
288
|
+
num_experts: int,
|
289
|
+
hidden_size: int,
|
290
|
+
intermediate_size: int,
|
291
|
+
params_dtype: torch.dtype,
|
292
|
+
**extra_weight_attrs,
|
293
|
+
):
|
294
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
295
|
+
|
296
|
+
# Use FP8 dtype if checkpoint is serialized, otherwise use the default dtype
|
297
|
+
weight_dtype = (
|
298
|
+
torch.float8_e4m3fn
|
299
|
+
if self.quant_config.is_checkpoint_fp8_serialized
|
300
|
+
else params_dtype
|
301
|
+
)
|
302
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
303
|
+
|
304
|
+
w13_weight = ModelWeightParameter(
|
305
|
+
data=torch.empty(
|
306
|
+
num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
|
307
|
+
),
|
308
|
+
input_dim=2,
|
309
|
+
output_dim=1,
|
310
|
+
weight_loader=weight_loader,
|
311
|
+
)
|
312
|
+
layer.register_parameter("w13_weight", w13_weight)
|
313
|
+
|
314
|
+
w2_weight = ModelWeightParameter(
|
315
|
+
data=torch.empty(
|
316
|
+
num_experts, hidden_size, intermediate_size, dtype=weight_dtype
|
317
|
+
),
|
318
|
+
input_dim=2,
|
319
|
+
output_dim=1,
|
320
|
+
weight_loader=weight_loader,
|
321
|
+
)
|
322
|
+
layer.register_parameter("w2_weight", w2_weight)
|
323
|
+
|
324
|
+
if self.quant_config.is_checkpoint_fp8_serialized:
|
325
|
+
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
|
326
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
327
|
+
# They will be combined to a single scale after weight loading.
|
328
|
+
w13_weight_scale = PerTensorScaleParameter(
|
329
|
+
data=torch.full(
|
330
|
+
(num_experts, 2),
|
331
|
+
torch.finfo(torch.float32).min,
|
332
|
+
dtype=torch.float32,
|
333
|
+
),
|
334
|
+
weight_loader=weight_loader,
|
335
|
+
)
|
336
|
+
w2_weight_scale = PerTensorScaleParameter(
|
337
|
+
data=torch.full(
|
338
|
+
(num_experts,), torch.finfo(torch.float32).min, dtype=torch.float32
|
339
|
+
),
|
340
|
+
weight_loader=weight_loader,
|
341
|
+
)
|
342
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
343
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
344
|
+
|
345
|
+
# Set weight loader attributes for scales
|
346
|
+
extra_weight_attrs.update(
|
347
|
+
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
348
|
+
)
|
349
|
+
|
350
|
+
# INPUT SCALES - Per-tensor scaling for ModelOpt
|
351
|
+
w13_input_scale = PerTensorScaleParameter(
|
352
|
+
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
|
353
|
+
weight_loader=weight_loader,
|
354
|
+
)
|
355
|
+
w2_input_scale = PerTensorScaleParameter(
|
356
|
+
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
|
357
|
+
weight_loader=weight_loader,
|
358
|
+
)
|
359
|
+
layer.register_parameter("w13_input_scale", w13_input_scale)
|
360
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
361
|
+
|
362
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
363
|
+
"""Process FP8 MoE weights after loading from serialized checkpoint.
|
364
|
+
|
365
|
+
Only supports pre-quantized checkpoints with FP8 weights and scales.
|
366
|
+
"""
|
367
|
+
|
368
|
+
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
369
|
+
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
370
|
+
|
371
|
+
# Handle scale parameters
|
372
|
+
if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
|
373
|
+
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
374
|
+
# We take the max of the w1 and w3 scales then dequant and requant each expert.
|
375
|
+
if layer.w13_weight_scale.dim() == 2: # Shape: (num_experts, 2)
|
376
|
+
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
377
|
+
|
378
|
+
# Get the maximum scale across w1 and w3 for each expert
|
379
|
+
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
380
|
+
|
381
|
+
# Requantize each expert's weights using the combined scale
|
382
|
+
# w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
|
383
|
+
# where the first intermediate_size rows are w1, the next are w3
|
384
|
+
intermediate_size = layer.w13_weight.shape[1] // 2
|
385
|
+
for expert_id in range(layer.w13_weight.shape[0]):
|
386
|
+
start = 0
|
387
|
+
for shard_id in range(2): # w1 and w3
|
388
|
+
# Dequantize using the original scale for this shard
|
389
|
+
dq_weight = per_tensor_dequantize(
|
390
|
+
layer.w13_weight[expert_id][
|
391
|
+
start : start + intermediate_size, :
|
392
|
+
],
|
393
|
+
layer.w13_weight_scale[expert_id][shard_id],
|
394
|
+
)
|
395
|
+
# Requantize using the combined max scale
|
396
|
+
(
|
397
|
+
layer.w13_weight[expert_id][
|
398
|
+
start : start + intermediate_size, :
|
399
|
+
],
|
400
|
+
_,
|
401
|
+
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
402
|
+
|
403
|
+
start += intermediate_size
|
404
|
+
|
405
|
+
# Update the scale parameter to be per-expert instead of per-shard
|
406
|
+
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
|
407
|
+
else:
|
408
|
+
layer.w13_weight_scale = Parameter(
|
409
|
+
layer.w13_weight_scale.data, requires_grad=False
|
410
|
+
)
|
411
|
+
|
412
|
+
if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
|
413
|
+
layer.w2_weight_scale = Parameter(
|
414
|
+
layer.w2_weight_scale.data, requires_grad=False
|
415
|
+
)
|
416
|
+
if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
|
417
|
+
layer.w13_input_scale = Parameter(
|
418
|
+
layer.w13_input_scale.max(), requires_grad=False
|
419
|
+
)
|
420
|
+
if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
|
421
|
+
layer.w2_input_scale = Parameter(
|
422
|
+
layer.w2_input_scale.max(), requires_grad=False
|
423
|
+
)
|
424
|
+
|
425
|
+
def apply(
|
426
|
+
self,
|
427
|
+
layer: torch.nn.Module,
|
428
|
+
x: torch.Tensor,
|
429
|
+
router_logits: torch.Tensor,
|
430
|
+
top_k: int,
|
431
|
+
renormalize: bool,
|
432
|
+
use_grouped_topk: bool,
|
433
|
+
topk_group: Optional[int] = None,
|
434
|
+
num_expert_group: Optional[int] = None,
|
435
|
+
num_fused_shared_experts: Optional[int] = None,
|
436
|
+
custom_routing_function: Optional[Callable] = None,
|
437
|
+
correction_bias: Optional[torch.Tensor] = None,
|
438
|
+
activation: str = "silu",
|
439
|
+
apply_router_weight_on_input: bool = False,
|
440
|
+
inplace: bool = True,
|
441
|
+
no_combine: bool = False,
|
442
|
+
routed_scaling_factor: Optional[float] = None,
|
443
|
+
) -> torch.Tensor:
|
444
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
445
|
+
from sglang.srt.layers.moe.topk import select_experts
|
446
|
+
|
447
|
+
# Expert selection
|
448
|
+
topk_weights, topk_ids = select_experts(
|
449
|
+
hidden_states=x,
|
450
|
+
router_logits=router_logits,
|
451
|
+
use_grouped_topk=use_grouped_topk,
|
452
|
+
top_k=top_k,
|
453
|
+
renormalize=renormalize,
|
454
|
+
topk_group=topk_group,
|
455
|
+
num_expert_group=num_expert_group,
|
456
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
457
|
+
custom_routing_function=custom_routing_function,
|
458
|
+
correction_bias=correction_bias,
|
459
|
+
routed_scaling_factor=routed_scaling_factor,
|
460
|
+
)
|
461
|
+
|
462
|
+
return fused_experts(
|
463
|
+
x,
|
464
|
+
layer.w13_weight,
|
465
|
+
layer.w2_weight,
|
466
|
+
topk_weights=topk_weights,
|
467
|
+
topk_ids=topk_ids,
|
468
|
+
inplace=inplace,
|
469
|
+
activation=activation,
|
470
|
+
use_fp8_w8a8=True,
|
471
|
+
per_channel_quant=False, # ModelOpt uses per-tensor quantization
|
472
|
+
w1_scale=layer.w13_weight_scale,
|
473
|
+
w2_scale=layer.w2_weight_scale,
|
474
|
+
a1_scale=layer.w13_input_scale,
|
475
|
+
a2_scale=layer.w2_input_scale,
|
476
|
+
no_combine=no_combine,
|
477
|
+
)
|
478
|
+
|
479
|
+
|
237
480
|
class ModelOptFp4Config(QuantizationConfig):
|
238
481
|
"""Config class for FP4."""
|
239
482
|
|
@@ -131,7 +131,7 @@ class MoeWNA16Config(QuantizationConfig):
|
|
131
131
|
capability_tuple = get_device_capability()
|
132
132
|
device_capability = (
|
133
133
|
-1
|
134
|
-
if
|
134
|
+
if all(capability is None for capability in capability_tuple)
|
135
135
|
else capability_tuple[0] * 10 + capability_tuple[1]
|
136
136
|
)
|
137
137
|
# Avoid circular import
|
@@ -0,0 +1,166 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
3
|
+
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
import numpy
|
7
|
+
import torch
|
8
|
+
from sgl_kernel.scalar_type import ScalarType
|
9
|
+
|
10
|
+
|
11
|
+
def get_pack_factor(num_bits):
|
12
|
+
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
13
|
+
return 32 // num_bits
|
14
|
+
|
15
|
+
|
16
|
+
def pack_cols(
|
17
|
+
q_w: torch.Tensor,
|
18
|
+
num_bits: int,
|
19
|
+
size_k: int,
|
20
|
+
size_n: int,
|
21
|
+
):
|
22
|
+
assert q_w.shape == (size_k, size_n)
|
23
|
+
|
24
|
+
pack_factor = get_pack_factor(num_bits)
|
25
|
+
assert size_n % pack_factor == 0
|
26
|
+
|
27
|
+
orig_device = q_w.device
|
28
|
+
|
29
|
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
30
|
+
|
31
|
+
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
32
|
+
|
33
|
+
for i in range(pack_factor):
|
34
|
+
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
35
|
+
|
36
|
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
37
|
+
q_res = q_res.contiguous()
|
38
|
+
|
39
|
+
return q_res
|
40
|
+
|
41
|
+
|
42
|
+
def unpack_cols(
|
43
|
+
packed_q_w: torch.Tensor,
|
44
|
+
num_bits: int,
|
45
|
+
size_k: int,
|
46
|
+
size_n: int,
|
47
|
+
):
|
48
|
+
pack_factor = get_pack_factor(num_bits)
|
49
|
+
assert size_n % pack_factor == 0
|
50
|
+
assert packed_q_w.shape == (
|
51
|
+
size_k,
|
52
|
+
size_n // pack_factor,
|
53
|
+
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
54
|
+
packed_q_w.shape, size_k, size_n, pack_factor
|
55
|
+
)
|
56
|
+
|
57
|
+
orig_device = packed_q_w.device
|
58
|
+
|
59
|
+
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
60
|
+
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
61
|
+
|
62
|
+
mask = (1 << num_bits) - 1
|
63
|
+
for i in range(pack_factor):
|
64
|
+
vals = packed_q_w_cpu & mask
|
65
|
+
packed_q_w_cpu >>= num_bits
|
66
|
+
q_res[:, i::pack_factor] = vals
|
67
|
+
|
68
|
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
69
|
+
q_res = q_res.contiguous()
|
70
|
+
|
71
|
+
return q_res
|
72
|
+
|
73
|
+
|
74
|
+
def quantize_weights(
|
75
|
+
w: torch.Tensor,
|
76
|
+
quant_type: ScalarType,
|
77
|
+
group_size: Optional[int],
|
78
|
+
zero_points: bool = False,
|
79
|
+
ref_zero_points_after_scales: bool = False,
|
80
|
+
):
|
81
|
+
assert (
|
82
|
+
quant_type.is_integer()
|
83
|
+
), "Floating point quantization may work but has not been tested"
|
84
|
+
assert not zero_points or group_size is not None, (
|
85
|
+
"to have group zero points, group_size must be provided "
|
86
|
+
"(-1 group_size is channelwise)"
|
87
|
+
)
|
88
|
+
|
89
|
+
orig_device = w.device
|
90
|
+
orig_type = w.dtype
|
91
|
+
size_k, size_n = w.shape
|
92
|
+
|
93
|
+
assert w.is_floating_point(), "w must be float"
|
94
|
+
|
95
|
+
if group_size == -1:
|
96
|
+
group_size = size_k
|
97
|
+
|
98
|
+
# Reshape to [groupsize, -1]
|
99
|
+
if group_size is not None and group_size < size_k:
|
100
|
+
w = w.reshape((-1, group_size, size_n))
|
101
|
+
w = w.permute(1, 0, 2)
|
102
|
+
w = w.reshape((group_size, -1))
|
103
|
+
|
104
|
+
# Compute scale for each group
|
105
|
+
max_val = torch.max(w, 0, keepdim=True).values
|
106
|
+
min_val = torch.min(w, 0, keepdim=True).values
|
107
|
+
|
108
|
+
max_q_val = quant_type.max()
|
109
|
+
min_q_val = quant_type.min()
|
110
|
+
|
111
|
+
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
112
|
+
maybe_w_zp = None
|
113
|
+
if group_size is not None:
|
114
|
+
if zero_points:
|
115
|
+
assert not quant_type.is_signed() and quant_type.max() > 0
|
116
|
+
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
117
|
+
maybe_w_zp = (
|
118
|
+
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
|
119
|
+
)
|
120
|
+
else:
|
121
|
+
# If the bias is such that there are no possible negative/positive
|
122
|
+
# values, set the max value to inf to avoid divide by 0
|
123
|
+
w_s = torch.max(
|
124
|
+
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
125
|
+
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
|
126
|
+
)
|
127
|
+
|
128
|
+
# Quantize
|
129
|
+
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
130
|
+
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
131
|
+
|
132
|
+
# Compute ref (dequantized)
|
133
|
+
# For some kernels (namely Machete) the zero-points are applied after the
|
134
|
+
# scales are applied, for this case computing the reference in similar way
|
135
|
+
# allows us to use tighter error tolerances in our unit tests.
|
136
|
+
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
137
|
+
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
138
|
+
else:
|
139
|
+
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
140
|
+
|
141
|
+
if quant_type.has_bias():
|
142
|
+
w_q += quant_type.bias
|
143
|
+
|
144
|
+
# Restore original shapes
|
145
|
+
if group_size is not None and group_size < size_k:
|
146
|
+
|
147
|
+
def reshape_w(w):
|
148
|
+
w = w.reshape((group_size, -1, size_n))
|
149
|
+
w = w.permute(1, 0, 2)
|
150
|
+
w = w.reshape((size_k, size_n)).contiguous()
|
151
|
+
return w
|
152
|
+
|
153
|
+
w_q = reshape_w(w_q)
|
154
|
+
w_ref = reshape_w(w_ref)
|
155
|
+
w_s = w_s.reshape((-1, size_n)).contiguous()
|
156
|
+
|
157
|
+
if maybe_w_zp is not None:
|
158
|
+
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
159
|
+
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
160
|
+
|
161
|
+
return (
|
162
|
+
w_ref.to(device=orig_device),
|
163
|
+
w_q.to(device=orig_device),
|
164
|
+
w_s if group_size is not None else None,
|
165
|
+
maybe_w_zp,
|
166
|
+
)
|