sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,28 @@
|
|
1
1
|
import logging
|
2
2
|
from fractions import Fraction
|
3
|
-
from typing import Any, Dict, List, Optional, Union
|
3
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
7
|
-
from sglang.srt.layers.linear import LinearBase
|
8
|
-
from sglang.srt.layers.quantization.base_config import
|
7
|
+
from sglang.srt.layers.linear import LinearBase, set_weight_attrs
|
8
|
+
from sglang.srt.layers.quantization.base_config import (
|
9
|
+
QuantizationConfig,
|
10
|
+
QuantizeMethodBase,
|
11
|
+
)
|
12
|
+
from sglang.srt.layers.quantization.utils import replace_parameter
|
9
13
|
from sglang.srt.utils import is_cuda
|
10
14
|
|
11
15
|
_is_cuda = is_cuda()
|
12
16
|
|
13
17
|
try:
|
14
|
-
from vllm
|
18
|
+
from vllm import _custom_ops as ops
|
15
19
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
16
20
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
21
|
+
FusedMoE,
|
22
|
+
FusedMoEMethodBase,
|
23
|
+
FusedMoeWeightScaleSupported,
|
17
24
|
GPTQMarlinLinearMethod,
|
18
|
-
|
25
|
+
marlin_moe_permute_scales,
|
19
26
|
)
|
20
27
|
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
21
28
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
@@ -27,7 +34,9 @@ try:
|
|
27
34
|
except ImportError:
|
28
35
|
VLLM_AVAILABLE = False
|
29
36
|
|
30
|
-
GPTQLinearMethod = MarlinLinearMethod =
|
37
|
+
GPTQLinearMethod = MarlinLinearMethod = Any
|
38
|
+
|
39
|
+
FusedMoEMethodBase = QuantizeMethodBase
|
31
40
|
|
32
41
|
class scalar_types:
|
33
42
|
uint4b8 = "uint4b8"
|
@@ -437,3 +446,286 @@ class MarlinConfig(QuantizationConfig):
|
|
437
446
|
):
|
438
447
|
return MarlinLinearMethod(self)
|
439
448
|
return None
|
449
|
+
|
450
|
+
|
451
|
+
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
452
|
+
"""MoE Marlin method with quantization."""
|
453
|
+
|
454
|
+
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
455
|
+
self.quant_config = quant_config
|
456
|
+
|
457
|
+
def create_weights(
|
458
|
+
self,
|
459
|
+
layer: torch.nn.Module,
|
460
|
+
num_experts: int,
|
461
|
+
hidden_size: int,
|
462
|
+
intermediate_size_per_partition: int,
|
463
|
+
params_dtype: torch.dtype,
|
464
|
+
**extra_weight_attrs,
|
465
|
+
):
|
466
|
+
intermediate_size = extra_weight_attrs.pop("intermediate_size")
|
467
|
+
|
468
|
+
self.is_k_full = (not self.quant_config.desc_act) or (
|
469
|
+
intermediate_size_per_partition == intermediate_size
|
470
|
+
)
|
471
|
+
|
472
|
+
if self.quant_config.group_size != -1:
|
473
|
+
scales_size13 = hidden_size // self.quant_config.group_size
|
474
|
+
w2_scales_size = (
|
475
|
+
intermediate_size
|
476
|
+
if self.quant_config.desc_act
|
477
|
+
else intermediate_size_per_partition
|
478
|
+
)
|
479
|
+
scales_size2 = w2_scales_size // self.quant_config.group_size
|
480
|
+
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
481
|
+
else:
|
482
|
+
scales_size13 = 1
|
483
|
+
scales_size2 = 1
|
484
|
+
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
|
485
|
+
|
486
|
+
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
|
487
|
+
# Fused gate_up_proj (column parallel)
|
488
|
+
w13_qweight = torch.nn.Parameter(
|
489
|
+
torch.empty(
|
490
|
+
num_experts,
|
491
|
+
hidden_size // self.quant_config.pack_factor,
|
492
|
+
2 * intermediate_size_per_partition,
|
493
|
+
dtype=torch.int32,
|
494
|
+
),
|
495
|
+
requires_grad=False,
|
496
|
+
)
|
497
|
+
layer.register_parameter("w13_qweight", w13_qweight)
|
498
|
+
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
499
|
+
# down_proj (row parallel)
|
500
|
+
w2_qweight = torch.nn.Parameter(
|
501
|
+
torch.empty(
|
502
|
+
num_experts,
|
503
|
+
intermediate_size_per_partition // self.quant_config.pack_factor,
|
504
|
+
hidden_size,
|
505
|
+
dtype=torch.int32,
|
506
|
+
),
|
507
|
+
requires_grad=False,
|
508
|
+
)
|
509
|
+
layer.register_parameter("w2_qweight", w2_qweight)
|
510
|
+
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
511
|
+
# up_proj scales
|
512
|
+
w13_scales = torch.nn.Parameter(
|
513
|
+
torch.empty(
|
514
|
+
num_experts,
|
515
|
+
scales_size13,
|
516
|
+
2 * intermediate_size_per_partition,
|
517
|
+
dtype=torch.half,
|
518
|
+
),
|
519
|
+
requires_grad=False,
|
520
|
+
)
|
521
|
+
layer.register_parameter("w13_scales", w13_scales)
|
522
|
+
set_weight_attrs(w13_scales, extra_weight_attrs)
|
523
|
+
# down_proj scales
|
524
|
+
w2_scales = torch.nn.Parameter(
|
525
|
+
torch.empty(num_experts, scales_size2, hidden_size, dtype=torch.half),
|
526
|
+
requires_grad=False,
|
527
|
+
)
|
528
|
+
layer.register_parameter("w2_scales", w2_scales)
|
529
|
+
set_weight_attrs(w2_scales, extra_weight_attrs)
|
530
|
+
# dont shard the w2 scales when running act order
|
531
|
+
set_weight_attrs(w2_scales, {"load_full_w2": self.quant_config.desc_act})
|
532
|
+
# up_proj scales
|
533
|
+
w13_qzeros = torch.nn.Parameter(
|
534
|
+
torch.empty(
|
535
|
+
num_experts,
|
536
|
+
scales_size13,
|
537
|
+
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
|
538
|
+
dtype=params_dtype,
|
539
|
+
),
|
540
|
+
requires_grad=False,
|
541
|
+
)
|
542
|
+
layer.register_parameter("w13_qzeros", w13_qzeros)
|
543
|
+
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
544
|
+
# down_proj scales
|
545
|
+
w2_qzeros = torch.nn.Parameter(
|
546
|
+
torch.empty(
|
547
|
+
num_experts,
|
548
|
+
scales_size2,
|
549
|
+
hidden_size // self.quant_config.pack_factor,
|
550
|
+
dtype=params_dtype,
|
551
|
+
),
|
552
|
+
requires_grad=False,
|
553
|
+
)
|
554
|
+
layer.register_parameter("w2_qzeros", w2_qzeros)
|
555
|
+
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
556
|
+
# dont shard the w2 scales when running act order
|
557
|
+
set_weight_attrs(w2_qzeros, {"load_full_w2": self.quant_config.desc_act})
|
558
|
+
w13_g_idx = torch.nn.Parameter(
|
559
|
+
torch.empty(
|
560
|
+
num_experts,
|
561
|
+
hidden_size,
|
562
|
+
dtype=torch.int32,
|
563
|
+
),
|
564
|
+
requires_grad=False,
|
565
|
+
)
|
566
|
+
layer.register_parameter("w13_g_idx", w13_g_idx)
|
567
|
+
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
568
|
+
w2_g_idx = torch.nn.Parameter(
|
569
|
+
torch.empty(
|
570
|
+
num_experts,
|
571
|
+
intermediate_size_per_partition,
|
572
|
+
dtype=torch.int32,
|
573
|
+
),
|
574
|
+
requires_grad=False,
|
575
|
+
)
|
576
|
+
layer.register_parameter("w2_g_idx", w2_g_idx)
|
577
|
+
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
578
|
+
w13_g_idx_sort_indices = torch.nn.Parameter(
|
579
|
+
torch.empty(
|
580
|
+
num_experts,
|
581
|
+
hidden_size,
|
582
|
+
dtype=torch.int32,
|
583
|
+
),
|
584
|
+
requires_grad=False,
|
585
|
+
)
|
586
|
+
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
587
|
+
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
|
588
|
+
w2_g_idx_sort_indices = torch.nn.Parameter(
|
589
|
+
torch.empty(
|
590
|
+
num_experts,
|
591
|
+
intermediate_size_per_partition,
|
592
|
+
dtype=torch.int32,
|
593
|
+
),
|
594
|
+
requires_grad=False,
|
595
|
+
)
|
596
|
+
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
597
|
+
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
598
|
+
|
599
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
600
|
+
|
601
|
+
# Process act_order
|
602
|
+
if self.quant_config.desc_act:
|
603
|
+
# Get sorting based on g_idx
|
604
|
+
num_experts = layer.w13_g_idx.shape[0]
|
605
|
+
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
|
606
|
+
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
|
607
|
+
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
|
608
|
+
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
|
609
|
+
for e in range(num_experts):
|
610
|
+
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
|
611
|
+
torch.int32
|
612
|
+
)
|
613
|
+
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
|
614
|
+
torch.int32
|
615
|
+
)
|
616
|
+
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
|
617
|
+
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
|
618
|
+
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
|
619
|
+
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
|
620
|
+
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
621
|
+
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
622
|
+
else:
|
623
|
+
# Reset g_idx related tensors
|
624
|
+
num_experts = layer.w13_g_idx.shape[0]
|
625
|
+
device = layer.w13_g_idx.device
|
626
|
+
layer.w13_g_idx = torch.nn.Parameter(
|
627
|
+
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
628
|
+
requires_grad=False,
|
629
|
+
)
|
630
|
+
layer.w2_g_idx = torch.nn.Parameter(
|
631
|
+
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
632
|
+
requires_grad=False,
|
633
|
+
)
|
634
|
+
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
635
|
+
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
636
|
+
requires_grad=False,
|
637
|
+
)
|
638
|
+
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
639
|
+
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
640
|
+
requires_grad=False,
|
641
|
+
)
|
642
|
+
# Repack weights
|
643
|
+
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
644
|
+
layer.w13_qweight,
|
645
|
+
layer.w13_g_idx_sort_indices,
|
646
|
+
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
647
|
+
layer.w13_qweight.shape[2],
|
648
|
+
self.quant_config.quant_type.size_bits,
|
649
|
+
)
|
650
|
+
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
651
|
+
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
652
|
+
layer.w2_qweight,
|
653
|
+
layer.w2_g_idx_sort_indices,
|
654
|
+
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
655
|
+
layer.w2_qweight.shape[2],
|
656
|
+
self.quant_config.quant_type.size_bits,
|
657
|
+
)
|
658
|
+
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
659
|
+
# Repack scales
|
660
|
+
marlin_w13_scales = marlin_moe_permute_scales(
|
661
|
+
s=layer.w13_scales,
|
662
|
+
size_k=layer.intermediate_size_per_partition,
|
663
|
+
size_n=layer.w13_scales.shape[2],
|
664
|
+
group_size=self.quant_config.group_size,
|
665
|
+
)
|
666
|
+
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
667
|
+
marlin_w2_scales = marlin_moe_permute_scales(
|
668
|
+
s=layer.w2_scales,
|
669
|
+
size_k=layer.w2_scales.shape[1]
|
670
|
+
* (
|
671
|
+
self.quant_config.group_size
|
672
|
+
if self.quant_config.group_size != -1
|
673
|
+
else self.quant_config.pack_factor
|
674
|
+
),
|
675
|
+
size_n=layer.w2_scales.shape[2],
|
676
|
+
group_size=self.quant_config.group_size,
|
677
|
+
)
|
678
|
+
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
679
|
+
|
680
|
+
def apply(
|
681
|
+
self,
|
682
|
+
layer: torch.nn.Module,
|
683
|
+
x: torch.Tensor,
|
684
|
+
router_logits: torch.Tensor,
|
685
|
+
top_k: int,
|
686
|
+
renormalize: bool,
|
687
|
+
use_grouped_topk: bool = False,
|
688
|
+
topk_group: Optional[int] = None,
|
689
|
+
num_expert_group: Optional[int] = None,
|
690
|
+
global_num_experts: int = -1,
|
691
|
+
expert_map: Optional[torch.Tensor] = None,
|
692
|
+
custom_routing_function: Optional[Callable] = None,
|
693
|
+
scoring_func: str = "softmax",
|
694
|
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
695
|
+
activation: str = "silu",
|
696
|
+
) -> torch.Tensor:
|
697
|
+
assert activation == "silu", "Only SiLU activation is supported."
|
698
|
+
|
699
|
+
# The input must currently be float16
|
700
|
+
orig_dtype = x.dtype
|
701
|
+
x = x.half()
|
702
|
+
|
703
|
+
topk_weights, topk_ids = FusedMoE.select_experts(
|
704
|
+
hidden_states=x,
|
705
|
+
router_logits=router_logits,
|
706
|
+
use_grouped_topk=use_grouped_topk,
|
707
|
+
top_k=top_k,
|
708
|
+
renormalize=renormalize,
|
709
|
+
topk_group=topk_group,
|
710
|
+
num_expert_group=num_expert_group,
|
711
|
+
custom_routing_function=custom_routing_function,
|
712
|
+
scoring_func=scoring_func,
|
713
|
+
e_score_correction_bias=e_score_correction_bias,
|
714
|
+
)
|
715
|
+
|
716
|
+
return torch.ops.vllm.fused_marlin_moe(
|
717
|
+
x,
|
718
|
+
layer.w13_qweight,
|
719
|
+
layer.w2_qweight,
|
720
|
+
layer.w13_scales,
|
721
|
+
layer.w2_scales,
|
722
|
+
router_logits,
|
723
|
+
topk_weights,
|
724
|
+
topk_ids,
|
725
|
+
g_idx1=layer.w13_g_idx,
|
726
|
+
g_idx2=layer.w2_g_idx,
|
727
|
+
sort_indices1=layer.w13_g_idx_sort_indices,
|
728
|
+
sort_indices2=layer.w2_g_idx_sort_indices,
|
729
|
+
num_bits=self.quant_config.quant_type.size_bits,
|
730
|
+
is_k_full=self.is_k_full,
|
731
|
+
).to(orig_dtype)
|
@@ -22,9 +22,11 @@ def _per_token_quant_int8(
|
|
22
22
|
x_ptr,
|
23
23
|
xq_ptr,
|
24
24
|
scale_ptr,
|
25
|
+
x_sum_ptr,
|
25
26
|
stride_x,
|
26
27
|
stride_xq,
|
27
28
|
N,
|
29
|
+
CAL_SUM: tl.constexpr,
|
28
30
|
BLOCK: tl.constexpr,
|
29
31
|
):
|
30
32
|
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
|
@@ -38,16 +40,23 @@ def _per_token_quant_int8(
|
|
38
40
|
scale_x = absmax / 127
|
39
41
|
x_q = x * (127 / absmax)
|
40
42
|
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
|
43
|
+
if CAL_SUM:
|
44
|
+
x_sum = tl.sum(x, axis=0)
|
45
|
+
tl.store(x_sum_ptr + row_id, x_sum.to(x_sum_ptr.dtype.element_ty))
|
41
46
|
|
42
47
|
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
|
43
|
-
tl.store(scale_ptr + row_id, scale_x)
|
48
|
+
tl.store(scale_ptr + row_id, scale_x.to(scale_ptr.dtype.element_ty))
|
44
49
|
|
45
50
|
|
46
|
-
def per_token_quant_int8(x):
|
51
|
+
def per_token_quant_int8(x, scale_dtype=torch.float32, cal_sum=False):
|
47
52
|
M = x.numel() // x.shape[-1]
|
48
53
|
N = x.shape[-1]
|
49
54
|
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
|
50
|
-
scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=
|
55
|
+
scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=scale_dtype)
|
56
|
+
if cal_sum:
|
57
|
+
x_sum = torch.empty(x.shape[:-1], device=x.device, dtype=x.dtype)
|
58
|
+
else:
|
59
|
+
x_sum = None
|
51
60
|
BLOCK = triton.next_power_of_2(N)
|
52
61
|
# heuristics for number of warps
|
53
62
|
num_warps = min(max(BLOCK // 256, 1), 8)
|
@@ -57,15 +66,19 @@ def per_token_quant_int8(x):
|
|
57
66
|
x,
|
58
67
|
x_q,
|
59
68
|
scales,
|
69
|
+
x_sum,
|
60
70
|
stride_x=x.stride(-2),
|
61
71
|
stride_xq=x_q.stride(-2),
|
62
72
|
N=N,
|
73
|
+
CAL_SUM=cal_sum,
|
63
74
|
BLOCK=BLOCK,
|
64
75
|
num_warps=num_warps,
|
65
76
|
num_stages=1,
|
66
77
|
)
|
67
|
-
|
68
|
-
|
78
|
+
if cal_sum:
|
79
|
+
return x_q, scales, x_sum
|
80
|
+
else:
|
81
|
+
return x_q, scales
|
69
82
|
|
70
83
|
|
71
84
|
@triton.jit
|
@@ -0,0 +1,244 @@
|
|
1
|
+
from typing import Any, Callable, Dict, List, Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch.nn.parameter import Parameter
|
5
|
+
|
6
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
7
|
+
from sglang.srt.layers.linear import LinearMethodBase
|
8
|
+
from sglang.srt.layers.parameter import (
|
9
|
+
ChannelQuantScaleParameter,
|
10
|
+
GroupQuantScaleParameter,
|
11
|
+
ModelWeightParameter,
|
12
|
+
)
|
13
|
+
from sglang.srt.layers.quantization.base_config import (
|
14
|
+
QuantizationConfig,
|
15
|
+
QuantizeMethodBase,
|
16
|
+
)
|
17
|
+
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
18
|
+
from sglang.srt.utils import is_cuda
|
19
|
+
|
20
|
+
_is_cuda = is_cuda()
|
21
|
+
if _is_cuda:
|
22
|
+
from sgl_kernel import qserve_w4a8_per_chn_gemm, qserve_w4a8_per_group_gemm
|
23
|
+
|
24
|
+
|
25
|
+
QoQ_SUPPORTED_WEIGHT_BITS = [4]
|
26
|
+
QoQ_SUPPORTED_GROUP_SIZES = [-1, 128]
|
27
|
+
|
28
|
+
|
29
|
+
class QoQConfig(QuantizationConfig):
|
30
|
+
"""Config class for QoQ Quantization.
|
31
|
+
|
32
|
+
- Weight: static, per-channel/group, asymmetric
|
33
|
+
- Activation: dynamic, per-token, symmetric
|
34
|
+
|
35
|
+
Reference: https://arxiv.org/abs/2405.04532
|
36
|
+
https://github.com/mit-han-lab/omniserve
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(self, weight_bits: int, group_size: int) -> None:
|
40
|
+
self.weight_bits = weight_bits
|
41
|
+
self.group_size = group_size
|
42
|
+
|
43
|
+
# Verify
|
44
|
+
if self.weight_bits not in QoQ_SUPPORTED_WEIGHT_BITS:
|
45
|
+
raise ValueError(
|
46
|
+
f"QoQ does not support weight_bits = {self.weight_bits}. "
|
47
|
+
f"Only weight_bits = {QoQ_SUPPORTED_WEIGHT_BITS} "
|
48
|
+
"are supported."
|
49
|
+
)
|
50
|
+
if self.group_size not in QoQ_SUPPORTED_GROUP_SIZES:
|
51
|
+
raise ValueError(
|
52
|
+
f"QoQ does not support group_size = {self.group_size}. "
|
53
|
+
f"Only group_sizes = {QoQ_SUPPORTED_GROUP_SIZES} "
|
54
|
+
"are supported."
|
55
|
+
)
|
56
|
+
|
57
|
+
# 4 bits packed into 8 bit datatype.
|
58
|
+
self.pack_factor = 8 // self.weight_bits
|
59
|
+
|
60
|
+
def __repr__(self) -> str:
|
61
|
+
return "QoQConfig(weight_bits={}, group_size={})".format(
|
62
|
+
self.weight_bits, self.group_size
|
63
|
+
)
|
64
|
+
|
65
|
+
@classmethod
|
66
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
67
|
+
return [torch.float16]
|
68
|
+
|
69
|
+
@classmethod
|
70
|
+
def get_min_capability(cls) -> int:
|
71
|
+
return 80
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
def get_name(self) -> str:
|
75
|
+
return "qoq"
|
76
|
+
|
77
|
+
@classmethod
|
78
|
+
def get_config_filenames(cls) -> List[str]:
|
79
|
+
"""List of filenames to search for in the model directory."""
|
80
|
+
return [
|
81
|
+
"quant_config.json",
|
82
|
+
"quantize_config.json",
|
83
|
+
]
|
84
|
+
|
85
|
+
@classmethod
|
86
|
+
def from_config(cls, config: Dict[str, Any]) -> "QoQConfig":
|
87
|
+
weight_bits = cls.get_from_keys(config, ["wbits"])
|
88
|
+
group_size = cls.get_from_keys(config, ["group_size"])
|
89
|
+
return cls(weight_bits, group_size)
|
90
|
+
|
91
|
+
def get_quant_method(
|
92
|
+
self,
|
93
|
+
layer: torch.nn.Module,
|
94
|
+
prefix: str,
|
95
|
+
) -> Optional["QuantizeMethodBase"]:
|
96
|
+
from sglang.srt.layers.linear import LinearBase
|
97
|
+
|
98
|
+
if isinstance(layer, LinearBase):
|
99
|
+
return QoQLinearMethod(self)
|
100
|
+
return None
|
101
|
+
|
102
|
+
def get_scaled_act_names(self) -> List[str]:
|
103
|
+
return []
|
104
|
+
|
105
|
+
|
106
|
+
class QoQLinearMethod(LinearMethodBase):
|
107
|
+
"""Linear method for QoQ.
|
108
|
+
|
109
|
+
Args:
|
110
|
+
quant_config: The QoQ quantization config.
|
111
|
+
"""
|
112
|
+
|
113
|
+
def __init__(self, quant_config: QoQConfig):
|
114
|
+
self.quant_config = quant_config
|
115
|
+
|
116
|
+
def create_weights(
|
117
|
+
self,
|
118
|
+
layer: torch.nn.Module,
|
119
|
+
input_size_per_partition: int,
|
120
|
+
output_partition_sizes: List[int],
|
121
|
+
input_size: int,
|
122
|
+
output_size: int,
|
123
|
+
params_dtype: torch.dtype,
|
124
|
+
**extra_weight_attrs,
|
125
|
+
):
|
126
|
+
|
127
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
128
|
+
|
129
|
+
# Validate output_size_per_partition
|
130
|
+
output_size_per_partition = sum(output_partition_sizes)
|
131
|
+
if output_size_per_partition % 32 != 0:
|
132
|
+
raise ValueError(
|
133
|
+
f"Weight output_size_per_partition = "
|
134
|
+
f"{output_size_per_partition} is not divisible by 32."
|
135
|
+
)
|
136
|
+
|
137
|
+
# Validate input_size_per_partition
|
138
|
+
if input_size_per_partition % self.quant_config.pack_factor != 0:
|
139
|
+
raise ValueError(
|
140
|
+
f"Weight input_size_per_partition = "
|
141
|
+
f"{input_size_per_partition} is not divisible by "
|
142
|
+
f"pack_factor = {self.quant_config.pack_factor}."
|
143
|
+
)
|
144
|
+
if (
|
145
|
+
self.quant_config.group_size != -1
|
146
|
+
and input_size_per_partition % self.quant_config.group_size != 0
|
147
|
+
):
|
148
|
+
raise ValueError(
|
149
|
+
f"Weight input_size_per_partition = "
|
150
|
+
f"{input_size_per_partition} is not divisible by "
|
151
|
+
f"group_size = {self.quant_config.group_size}."
|
152
|
+
)
|
153
|
+
|
154
|
+
qweight = ModelWeightParameter(
|
155
|
+
data=torch.empty(
|
156
|
+
output_size_per_partition,
|
157
|
+
input_size_per_partition // self.quant_config.pack_factor,
|
158
|
+
dtype=torch.int8,
|
159
|
+
),
|
160
|
+
input_dim=1,
|
161
|
+
output_dim=0,
|
162
|
+
weight_loader=weight_loader,
|
163
|
+
)
|
164
|
+
layer.register_parameter("qweight", qweight)
|
165
|
+
|
166
|
+
s1_scales = ChannelQuantScaleParameter(
|
167
|
+
data=torch.empty(output_size_per_partition, dtype=torch.float16),
|
168
|
+
output_dim=0,
|
169
|
+
weight_loader=weight_loader,
|
170
|
+
)
|
171
|
+
layer.register_parameter("s1_scales", s1_scales)
|
172
|
+
|
173
|
+
if self.quant_config.group_size == -1:
|
174
|
+
s1_szeros = ChannelQuantScaleParameter(
|
175
|
+
data=torch.empty(output_size_per_partition, dtype=torch.float16),
|
176
|
+
output_dim=0,
|
177
|
+
weight_loader=weight_loader,
|
178
|
+
)
|
179
|
+
layer.register_parameter("s1_szeros", s1_szeros)
|
180
|
+
else:
|
181
|
+
s2_scales = GroupQuantScaleParameter(
|
182
|
+
data=torch.empty(
|
183
|
+
(
|
184
|
+
input_size_per_partition // self.quant_config.group_size,
|
185
|
+
output_size_per_partition,
|
186
|
+
),
|
187
|
+
dtype=torch.int8,
|
188
|
+
),
|
189
|
+
input_dim=0,
|
190
|
+
output_dim=1,
|
191
|
+
weight_loader=weight_loader,
|
192
|
+
)
|
193
|
+
layer.register_parameter("s2_scales", s2_scales)
|
194
|
+
|
195
|
+
s2_zeros = GroupQuantScaleParameter(
|
196
|
+
data=torch.empty(
|
197
|
+
(
|
198
|
+
input_size_per_partition // self.quant_config.group_size,
|
199
|
+
output_size_per_partition,
|
200
|
+
),
|
201
|
+
dtype=torch.int8,
|
202
|
+
),
|
203
|
+
input_dim=0,
|
204
|
+
output_dim=1,
|
205
|
+
weight_loader=weight_loader,
|
206
|
+
)
|
207
|
+
layer.register_parameter("s2_zeros", s2_zeros)
|
208
|
+
|
209
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
210
|
+
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
211
|
+
layer.s1_scales = Parameter(layer.s1_scales.data, requires_grad=False)
|
212
|
+
if self.quant_config.group_size == -1:
|
213
|
+
layer.s1_szeros = Parameter(layer.s1_szeros.data, requires_grad=False)
|
214
|
+
else:
|
215
|
+
layer.s2_scales = Parameter(layer.s2_scales.data, requires_grad=False)
|
216
|
+
layer.s2_zeros = Parameter(layer.s2_zeros.data, requires_grad=False)
|
217
|
+
|
218
|
+
def apply(
|
219
|
+
self,
|
220
|
+
layer: torch.nn.Module,
|
221
|
+
x: torch.Tensor,
|
222
|
+
bias: Optional[torch.Tensor] = None,
|
223
|
+
):
|
224
|
+
assert x.dtype == torch.float16, "QoQ only supports float16 input now"
|
225
|
+
if self.quant_config.group_size == -1:
|
226
|
+
x_q, x_scale, x_sum = per_token_quant_int8(
|
227
|
+
x, scale_dtype=x.dtype, cal_sum=True
|
228
|
+
)
|
229
|
+
out = qserve_w4a8_per_chn_gemm(
|
230
|
+
x_q, layer.qweight, layer.s1_scales, x_scale, layer.s1_szeros, x_sum
|
231
|
+
)
|
232
|
+
else:
|
233
|
+
x_q, x_scale = per_token_quant_int8(x, scale_dtype=x.dtype)
|
234
|
+
out = qserve_w4a8_per_group_gemm(
|
235
|
+
x_q,
|
236
|
+
layer.qweight,
|
237
|
+
layer.s2_zeros,
|
238
|
+
layer.s2_scales,
|
239
|
+
layer.s1_scales,
|
240
|
+
x_scale,
|
241
|
+
)
|
242
|
+
if bias is not None:
|
243
|
+
out = out + bias
|
244
|
+
return out
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -170,9 +170,7 @@ class LoRAManager:
|
|
170
170
|
dim=0,
|
171
171
|
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
|
172
172
|
)
|
173
|
-
self.cuda_graph_batch_info.max_len =
|
174
|
-
torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
|
175
|
-
)
|
173
|
+
self.cuda_graph_batch_info.max_len = 1
|
176
174
|
|
177
175
|
for i, lora_path in enumerate(forward_batch.lora_paths):
|
178
176
|
self.cuda_graph_batch_info.weight_indices[i] = (
|