sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,28 @@
1
1
  import logging
2
2
  from fractions import Fraction
3
- from typing import Any, Dict, List, Optional, Union
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
4
 
5
5
  import torch
6
6
 
7
- from sglang.srt.layers.linear import LinearBase
8
- from sglang.srt.layers.quantization.base_config import QuantizationConfig
7
+ from sglang.srt.layers.linear import LinearBase, set_weight_attrs
8
+ from sglang.srt.layers.quantization.base_config import (
9
+ QuantizationConfig,
10
+ QuantizeMethodBase,
11
+ )
12
+ from sglang.srt.layers.quantization.utils import replace_parameter
9
13
  from sglang.srt.utils import is_cuda
10
14
 
11
15
  _is_cuda = is_cuda()
12
16
 
13
17
  try:
14
- from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
18
+ from vllm import _custom_ops as ops
15
19
  from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
16
20
  from vllm.model_executor.layers.quantization.gptq_marlin import (
21
+ FusedMoE,
22
+ FusedMoEMethodBase,
23
+ FusedMoeWeightScaleSupported,
17
24
  GPTQMarlinLinearMethod,
18
- GPTQMarlinMoEMethod,
25
+ marlin_moe_permute_scales,
19
26
  )
20
27
  from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
21
28
  from vllm.model_executor.layers.quantization.utils.marlin_utils import (
@@ -27,7 +34,9 @@ try:
27
34
  except ImportError:
28
35
  VLLM_AVAILABLE = False
29
36
 
30
- GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any
37
+ GPTQLinearMethod = MarlinLinearMethod = Any
38
+
39
+ FusedMoEMethodBase = QuantizeMethodBase
31
40
 
32
41
  class scalar_types:
33
42
  uint4b8 = "uint4b8"
@@ -437,3 +446,286 @@ class MarlinConfig(QuantizationConfig):
437
446
  ):
438
447
  return MarlinLinearMethod(self)
439
448
  return None
449
+
450
+
451
+ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
452
+ """MoE Marlin method with quantization."""
453
+
454
+ def __init__(self, quant_config: GPTQMarlinConfig) -> None:
455
+ self.quant_config = quant_config
456
+
457
+ def create_weights(
458
+ self,
459
+ layer: torch.nn.Module,
460
+ num_experts: int,
461
+ hidden_size: int,
462
+ intermediate_size_per_partition: int,
463
+ params_dtype: torch.dtype,
464
+ **extra_weight_attrs,
465
+ ):
466
+ intermediate_size = extra_weight_attrs.pop("intermediate_size")
467
+
468
+ self.is_k_full = (not self.quant_config.desc_act) or (
469
+ intermediate_size_per_partition == intermediate_size
470
+ )
471
+
472
+ if self.quant_config.group_size != -1:
473
+ scales_size13 = hidden_size // self.quant_config.group_size
474
+ w2_scales_size = (
475
+ intermediate_size
476
+ if self.quant_config.desc_act
477
+ else intermediate_size_per_partition
478
+ )
479
+ scales_size2 = w2_scales_size // self.quant_config.group_size
480
+ strategy = FusedMoeWeightScaleSupported.GROUP.value
481
+ else:
482
+ scales_size13 = 1
483
+ scales_size2 = 1
484
+ strategy = FusedMoeWeightScaleSupported.CHANNEL.value
485
+
486
+ extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
487
+ # Fused gate_up_proj (column parallel)
488
+ w13_qweight = torch.nn.Parameter(
489
+ torch.empty(
490
+ num_experts,
491
+ hidden_size // self.quant_config.pack_factor,
492
+ 2 * intermediate_size_per_partition,
493
+ dtype=torch.int32,
494
+ ),
495
+ requires_grad=False,
496
+ )
497
+ layer.register_parameter("w13_qweight", w13_qweight)
498
+ set_weight_attrs(w13_qweight, extra_weight_attrs)
499
+ # down_proj (row parallel)
500
+ w2_qweight = torch.nn.Parameter(
501
+ torch.empty(
502
+ num_experts,
503
+ intermediate_size_per_partition // self.quant_config.pack_factor,
504
+ hidden_size,
505
+ dtype=torch.int32,
506
+ ),
507
+ requires_grad=False,
508
+ )
509
+ layer.register_parameter("w2_qweight", w2_qweight)
510
+ set_weight_attrs(w2_qweight, extra_weight_attrs)
511
+ # up_proj scales
512
+ w13_scales = torch.nn.Parameter(
513
+ torch.empty(
514
+ num_experts,
515
+ scales_size13,
516
+ 2 * intermediate_size_per_partition,
517
+ dtype=torch.half,
518
+ ),
519
+ requires_grad=False,
520
+ )
521
+ layer.register_parameter("w13_scales", w13_scales)
522
+ set_weight_attrs(w13_scales, extra_weight_attrs)
523
+ # down_proj scales
524
+ w2_scales = torch.nn.Parameter(
525
+ torch.empty(num_experts, scales_size2, hidden_size, dtype=torch.half),
526
+ requires_grad=False,
527
+ )
528
+ layer.register_parameter("w2_scales", w2_scales)
529
+ set_weight_attrs(w2_scales, extra_weight_attrs)
530
+ # dont shard the w2 scales when running act order
531
+ set_weight_attrs(w2_scales, {"load_full_w2": self.quant_config.desc_act})
532
+ # up_proj scales
533
+ w13_qzeros = torch.nn.Parameter(
534
+ torch.empty(
535
+ num_experts,
536
+ scales_size13,
537
+ 2 * intermediate_size_per_partition // self.quant_config.pack_factor,
538
+ dtype=params_dtype,
539
+ ),
540
+ requires_grad=False,
541
+ )
542
+ layer.register_parameter("w13_qzeros", w13_qzeros)
543
+ set_weight_attrs(w13_qzeros, extra_weight_attrs)
544
+ # down_proj scales
545
+ w2_qzeros = torch.nn.Parameter(
546
+ torch.empty(
547
+ num_experts,
548
+ scales_size2,
549
+ hidden_size // self.quant_config.pack_factor,
550
+ dtype=params_dtype,
551
+ ),
552
+ requires_grad=False,
553
+ )
554
+ layer.register_parameter("w2_qzeros", w2_qzeros)
555
+ set_weight_attrs(w2_qzeros, extra_weight_attrs)
556
+ # dont shard the w2 scales when running act order
557
+ set_weight_attrs(w2_qzeros, {"load_full_w2": self.quant_config.desc_act})
558
+ w13_g_idx = torch.nn.Parameter(
559
+ torch.empty(
560
+ num_experts,
561
+ hidden_size,
562
+ dtype=torch.int32,
563
+ ),
564
+ requires_grad=False,
565
+ )
566
+ layer.register_parameter("w13_g_idx", w13_g_idx)
567
+ set_weight_attrs(w13_g_idx, extra_weight_attrs)
568
+ w2_g_idx = torch.nn.Parameter(
569
+ torch.empty(
570
+ num_experts,
571
+ intermediate_size_per_partition,
572
+ dtype=torch.int32,
573
+ ),
574
+ requires_grad=False,
575
+ )
576
+ layer.register_parameter("w2_g_idx", w2_g_idx)
577
+ set_weight_attrs(w2_g_idx, extra_weight_attrs)
578
+ w13_g_idx_sort_indices = torch.nn.Parameter(
579
+ torch.empty(
580
+ num_experts,
581
+ hidden_size,
582
+ dtype=torch.int32,
583
+ ),
584
+ requires_grad=False,
585
+ )
586
+ layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
587
+ set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
588
+ w2_g_idx_sort_indices = torch.nn.Parameter(
589
+ torch.empty(
590
+ num_experts,
591
+ intermediate_size_per_partition,
592
+ dtype=torch.int32,
593
+ ),
594
+ requires_grad=False,
595
+ )
596
+ layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
597
+ set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
598
+
599
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
600
+
601
+ # Process act_order
602
+ if self.quant_config.desc_act:
603
+ # Get sorting based on g_idx
604
+ num_experts = layer.w13_g_idx.shape[0]
605
+ w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
606
+ w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
607
+ w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
608
+ w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
609
+ for e in range(num_experts):
610
+ w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
611
+ torch.int32
612
+ )
613
+ w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
614
+ torch.int32
615
+ )
616
+ w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
617
+ w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
618
+ replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
619
+ replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
620
+ replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
621
+ replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
622
+ else:
623
+ # Reset g_idx related tensors
624
+ num_experts = layer.w13_g_idx.shape[0]
625
+ device = layer.w13_g_idx.device
626
+ layer.w13_g_idx = torch.nn.Parameter(
627
+ torch.empty((num_experts, 0), dtype=torch.int32, device=device),
628
+ requires_grad=False,
629
+ )
630
+ layer.w2_g_idx = torch.nn.Parameter(
631
+ torch.empty((num_experts, 0), dtype=torch.int32, device=device),
632
+ requires_grad=False,
633
+ )
634
+ layer.w13_g_idx_sort_indices = torch.nn.Parameter(
635
+ torch.empty((num_experts, 0), dtype=torch.int32, device=device),
636
+ requires_grad=False,
637
+ )
638
+ layer.w2_g_idx_sort_indices = torch.nn.Parameter(
639
+ torch.empty((num_experts, 0), dtype=torch.int32, device=device),
640
+ requires_grad=False,
641
+ )
642
+ # Repack weights
643
+ marlin_w13_qweight = ops.gptq_marlin_moe_repack(
644
+ layer.w13_qweight,
645
+ layer.w13_g_idx_sort_indices,
646
+ layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
647
+ layer.w13_qweight.shape[2],
648
+ self.quant_config.quant_type.size_bits,
649
+ )
650
+ replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
651
+ marlin_w2_qweight = ops.gptq_marlin_moe_repack(
652
+ layer.w2_qweight,
653
+ layer.w2_g_idx_sort_indices,
654
+ layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
655
+ layer.w2_qweight.shape[2],
656
+ self.quant_config.quant_type.size_bits,
657
+ )
658
+ replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
659
+ # Repack scales
660
+ marlin_w13_scales = marlin_moe_permute_scales(
661
+ s=layer.w13_scales,
662
+ size_k=layer.intermediate_size_per_partition,
663
+ size_n=layer.w13_scales.shape[2],
664
+ group_size=self.quant_config.group_size,
665
+ )
666
+ replace_parameter(layer, "w13_scales", marlin_w13_scales)
667
+ marlin_w2_scales = marlin_moe_permute_scales(
668
+ s=layer.w2_scales,
669
+ size_k=layer.w2_scales.shape[1]
670
+ * (
671
+ self.quant_config.group_size
672
+ if self.quant_config.group_size != -1
673
+ else self.quant_config.pack_factor
674
+ ),
675
+ size_n=layer.w2_scales.shape[2],
676
+ group_size=self.quant_config.group_size,
677
+ )
678
+ replace_parameter(layer, "w2_scales", marlin_w2_scales)
679
+
680
+ def apply(
681
+ self,
682
+ layer: torch.nn.Module,
683
+ x: torch.Tensor,
684
+ router_logits: torch.Tensor,
685
+ top_k: int,
686
+ renormalize: bool,
687
+ use_grouped_topk: bool = False,
688
+ topk_group: Optional[int] = None,
689
+ num_expert_group: Optional[int] = None,
690
+ global_num_experts: int = -1,
691
+ expert_map: Optional[torch.Tensor] = None,
692
+ custom_routing_function: Optional[Callable] = None,
693
+ scoring_func: str = "softmax",
694
+ e_score_correction_bias: Optional[torch.Tensor] = None,
695
+ activation: str = "silu",
696
+ ) -> torch.Tensor:
697
+ assert activation == "silu", "Only SiLU activation is supported."
698
+
699
+ # The input must currently be float16
700
+ orig_dtype = x.dtype
701
+ x = x.half()
702
+
703
+ topk_weights, topk_ids = FusedMoE.select_experts(
704
+ hidden_states=x,
705
+ router_logits=router_logits,
706
+ use_grouped_topk=use_grouped_topk,
707
+ top_k=top_k,
708
+ renormalize=renormalize,
709
+ topk_group=topk_group,
710
+ num_expert_group=num_expert_group,
711
+ custom_routing_function=custom_routing_function,
712
+ scoring_func=scoring_func,
713
+ e_score_correction_bias=e_score_correction_bias,
714
+ )
715
+
716
+ return torch.ops.vllm.fused_marlin_moe(
717
+ x,
718
+ layer.w13_qweight,
719
+ layer.w2_qweight,
720
+ layer.w13_scales,
721
+ layer.w2_scales,
722
+ router_logits,
723
+ topk_weights,
724
+ topk_ids,
725
+ g_idx1=layer.w13_g_idx,
726
+ g_idx2=layer.w2_g_idx,
727
+ sort_indices1=layer.w13_g_idx_sort_indices,
728
+ sort_indices2=layer.w2_g_idx_sort_indices,
729
+ num_bits=self.quant_config.quant_type.size_bits,
730
+ is_k_full=self.is_k_full,
731
+ ).to(orig_dtype)
@@ -22,9 +22,11 @@ def _per_token_quant_int8(
22
22
  x_ptr,
23
23
  xq_ptr,
24
24
  scale_ptr,
25
+ x_sum_ptr,
25
26
  stride_x,
26
27
  stride_xq,
27
28
  N,
29
+ CAL_SUM: tl.constexpr,
28
30
  BLOCK: tl.constexpr,
29
31
  ):
30
32
  # Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
@@ -38,16 +40,23 @@ def _per_token_quant_int8(
38
40
  scale_x = absmax / 127
39
41
  x_q = x * (127 / absmax)
40
42
  x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
43
+ if CAL_SUM:
44
+ x_sum = tl.sum(x, axis=0)
45
+ tl.store(x_sum_ptr + row_id, x_sum.to(x_sum_ptr.dtype.element_ty))
41
46
 
42
47
  tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
43
- tl.store(scale_ptr + row_id, scale_x)
48
+ tl.store(scale_ptr + row_id, scale_x.to(scale_ptr.dtype.element_ty))
44
49
 
45
50
 
46
- def per_token_quant_int8(x):
51
+ def per_token_quant_int8(x, scale_dtype=torch.float32, cal_sum=False):
47
52
  M = x.numel() // x.shape[-1]
48
53
  N = x.shape[-1]
49
54
  x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
50
- scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
55
+ scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=scale_dtype)
56
+ if cal_sum:
57
+ x_sum = torch.empty(x.shape[:-1], device=x.device, dtype=x.dtype)
58
+ else:
59
+ x_sum = None
51
60
  BLOCK = triton.next_power_of_2(N)
52
61
  # heuristics for number of warps
53
62
  num_warps = min(max(BLOCK // 256, 1), 8)
@@ -57,15 +66,19 @@ def per_token_quant_int8(x):
57
66
  x,
58
67
  x_q,
59
68
  scales,
69
+ x_sum,
60
70
  stride_x=x.stride(-2),
61
71
  stride_xq=x_q.stride(-2),
62
72
  N=N,
73
+ CAL_SUM=cal_sum,
63
74
  BLOCK=BLOCK,
64
75
  num_warps=num_warps,
65
76
  num_stages=1,
66
77
  )
67
-
68
- return x_q, scales
78
+ if cal_sum:
79
+ return x_q, scales, x_sum
80
+ else:
81
+ return x_q, scales
69
82
 
70
83
 
71
84
  @triton.jit
@@ -0,0 +1,244 @@
1
+ from typing import Any, Callable, Dict, List, Optional
2
+
3
+ import torch
4
+ from torch.nn.parameter import Parameter
5
+
6
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
7
+ from sglang.srt.layers.linear import LinearMethodBase
8
+ from sglang.srt.layers.parameter import (
9
+ ChannelQuantScaleParameter,
10
+ GroupQuantScaleParameter,
11
+ ModelWeightParameter,
12
+ )
13
+ from sglang.srt.layers.quantization.base_config import (
14
+ QuantizationConfig,
15
+ QuantizeMethodBase,
16
+ )
17
+ from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
18
+ from sglang.srt.utils import is_cuda
19
+
20
+ _is_cuda = is_cuda()
21
+ if _is_cuda:
22
+ from sgl_kernel import qserve_w4a8_per_chn_gemm, qserve_w4a8_per_group_gemm
23
+
24
+
25
+ QoQ_SUPPORTED_WEIGHT_BITS = [4]
26
+ QoQ_SUPPORTED_GROUP_SIZES = [-1, 128]
27
+
28
+
29
+ class QoQConfig(QuantizationConfig):
30
+ """Config class for QoQ Quantization.
31
+
32
+ - Weight: static, per-channel/group, asymmetric
33
+ - Activation: dynamic, per-token, symmetric
34
+
35
+ Reference: https://arxiv.org/abs/2405.04532
36
+ https://github.com/mit-han-lab/omniserve
37
+ """
38
+
39
+ def __init__(self, weight_bits: int, group_size: int) -> None:
40
+ self.weight_bits = weight_bits
41
+ self.group_size = group_size
42
+
43
+ # Verify
44
+ if self.weight_bits not in QoQ_SUPPORTED_WEIGHT_BITS:
45
+ raise ValueError(
46
+ f"QoQ does not support weight_bits = {self.weight_bits}. "
47
+ f"Only weight_bits = {QoQ_SUPPORTED_WEIGHT_BITS} "
48
+ "are supported."
49
+ )
50
+ if self.group_size not in QoQ_SUPPORTED_GROUP_SIZES:
51
+ raise ValueError(
52
+ f"QoQ does not support group_size = {self.group_size}. "
53
+ f"Only group_sizes = {QoQ_SUPPORTED_GROUP_SIZES} "
54
+ "are supported."
55
+ )
56
+
57
+ # 4 bits packed into 8 bit datatype.
58
+ self.pack_factor = 8 // self.weight_bits
59
+
60
+ def __repr__(self) -> str:
61
+ return "QoQConfig(weight_bits={}, group_size={})".format(
62
+ self.weight_bits, self.group_size
63
+ )
64
+
65
+ @classmethod
66
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
67
+ return [torch.float16]
68
+
69
+ @classmethod
70
+ def get_min_capability(cls) -> int:
71
+ return 80
72
+
73
+ @classmethod
74
+ def get_name(self) -> str:
75
+ return "qoq"
76
+
77
+ @classmethod
78
+ def get_config_filenames(cls) -> List[str]:
79
+ """List of filenames to search for in the model directory."""
80
+ return [
81
+ "quant_config.json",
82
+ "quantize_config.json",
83
+ ]
84
+
85
+ @classmethod
86
+ def from_config(cls, config: Dict[str, Any]) -> "QoQConfig":
87
+ weight_bits = cls.get_from_keys(config, ["wbits"])
88
+ group_size = cls.get_from_keys(config, ["group_size"])
89
+ return cls(weight_bits, group_size)
90
+
91
+ def get_quant_method(
92
+ self,
93
+ layer: torch.nn.Module,
94
+ prefix: str,
95
+ ) -> Optional["QuantizeMethodBase"]:
96
+ from sglang.srt.layers.linear import LinearBase
97
+
98
+ if isinstance(layer, LinearBase):
99
+ return QoQLinearMethod(self)
100
+ return None
101
+
102
+ def get_scaled_act_names(self) -> List[str]:
103
+ return []
104
+
105
+
106
+ class QoQLinearMethod(LinearMethodBase):
107
+ """Linear method for QoQ.
108
+
109
+ Args:
110
+ quant_config: The QoQ quantization config.
111
+ """
112
+
113
+ def __init__(self, quant_config: QoQConfig):
114
+ self.quant_config = quant_config
115
+
116
+ def create_weights(
117
+ self,
118
+ layer: torch.nn.Module,
119
+ input_size_per_partition: int,
120
+ output_partition_sizes: List[int],
121
+ input_size: int,
122
+ output_size: int,
123
+ params_dtype: torch.dtype,
124
+ **extra_weight_attrs,
125
+ ):
126
+
127
+ weight_loader = extra_weight_attrs.get("weight_loader")
128
+
129
+ # Validate output_size_per_partition
130
+ output_size_per_partition = sum(output_partition_sizes)
131
+ if output_size_per_partition % 32 != 0:
132
+ raise ValueError(
133
+ f"Weight output_size_per_partition = "
134
+ f"{output_size_per_partition} is not divisible by 32."
135
+ )
136
+
137
+ # Validate input_size_per_partition
138
+ if input_size_per_partition % self.quant_config.pack_factor != 0:
139
+ raise ValueError(
140
+ f"Weight input_size_per_partition = "
141
+ f"{input_size_per_partition} is not divisible by "
142
+ f"pack_factor = {self.quant_config.pack_factor}."
143
+ )
144
+ if (
145
+ self.quant_config.group_size != -1
146
+ and input_size_per_partition % self.quant_config.group_size != 0
147
+ ):
148
+ raise ValueError(
149
+ f"Weight input_size_per_partition = "
150
+ f"{input_size_per_partition} is not divisible by "
151
+ f"group_size = {self.quant_config.group_size}."
152
+ )
153
+
154
+ qweight = ModelWeightParameter(
155
+ data=torch.empty(
156
+ output_size_per_partition,
157
+ input_size_per_partition // self.quant_config.pack_factor,
158
+ dtype=torch.int8,
159
+ ),
160
+ input_dim=1,
161
+ output_dim=0,
162
+ weight_loader=weight_loader,
163
+ )
164
+ layer.register_parameter("qweight", qweight)
165
+
166
+ s1_scales = ChannelQuantScaleParameter(
167
+ data=torch.empty(output_size_per_partition, dtype=torch.float16),
168
+ output_dim=0,
169
+ weight_loader=weight_loader,
170
+ )
171
+ layer.register_parameter("s1_scales", s1_scales)
172
+
173
+ if self.quant_config.group_size == -1:
174
+ s1_szeros = ChannelQuantScaleParameter(
175
+ data=torch.empty(output_size_per_partition, dtype=torch.float16),
176
+ output_dim=0,
177
+ weight_loader=weight_loader,
178
+ )
179
+ layer.register_parameter("s1_szeros", s1_szeros)
180
+ else:
181
+ s2_scales = GroupQuantScaleParameter(
182
+ data=torch.empty(
183
+ (
184
+ input_size_per_partition // self.quant_config.group_size,
185
+ output_size_per_partition,
186
+ ),
187
+ dtype=torch.int8,
188
+ ),
189
+ input_dim=0,
190
+ output_dim=1,
191
+ weight_loader=weight_loader,
192
+ )
193
+ layer.register_parameter("s2_scales", s2_scales)
194
+
195
+ s2_zeros = GroupQuantScaleParameter(
196
+ data=torch.empty(
197
+ (
198
+ input_size_per_partition // self.quant_config.group_size,
199
+ output_size_per_partition,
200
+ ),
201
+ dtype=torch.int8,
202
+ ),
203
+ input_dim=0,
204
+ output_dim=1,
205
+ weight_loader=weight_loader,
206
+ )
207
+ layer.register_parameter("s2_zeros", s2_zeros)
208
+
209
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
210
+ layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
211
+ layer.s1_scales = Parameter(layer.s1_scales.data, requires_grad=False)
212
+ if self.quant_config.group_size == -1:
213
+ layer.s1_szeros = Parameter(layer.s1_szeros.data, requires_grad=False)
214
+ else:
215
+ layer.s2_scales = Parameter(layer.s2_scales.data, requires_grad=False)
216
+ layer.s2_zeros = Parameter(layer.s2_zeros.data, requires_grad=False)
217
+
218
+ def apply(
219
+ self,
220
+ layer: torch.nn.Module,
221
+ x: torch.Tensor,
222
+ bias: Optional[torch.Tensor] = None,
223
+ ):
224
+ assert x.dtype == torch.float16, "QoQ only supports float16 input now"
225
+ if self.quant_config.group_size == -1:
226
+ x_q, x_scale, x_sum = per_token_quant_int8(
227
+ x, scale_dtype=x.dtype, cal_sum=True
228
+ )
229
+ out = qserve_w4a8_per_chn_gemm(
230
+ x_q, layer.qweight, layer.s1_scales, x_scale, layer.s1_szeros, x_sum
231
+ )
232
+ else:
233
+ x_q, x_scale = per_token_quant_int8(x, scale_dtype=x.dtype)
234
+ out = qserve_w4a8_per_group_gemm(
235
+ x_q,
236
+ layer.qweight,
237
+ layer.s2_zeros,
238
+ layer.s2_scales,
239
+ layer.s1_scales,
240
+ x_scale,
241
+ )
242
+ if bias is not None:
243
+ out = out + bias
244
+ return out
@@ -170,9 +170,7 @@ class LoRAManager:
170
170
  dim=0,
171
171
  out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
172
172
  )
173
- self.cuda_graph_batch_info.max_len = int(
174
- torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
175
- )
173
+ self.cuda_graph_batch_info.max_len = 1
176
174
 
177
175
  for i, lora_path in enumerate(forward_batch.lora_paths):
178
176
  self.cuda_graph_batch_info.weight_indices[i] = (