sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -52,6 +52,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
52
52
  apply_w8a8_block_fp8_linear,
53
53
  cutlass_fp8_supported,
54
54
  input_to_float8,
55
+ is_sm100_supported,
55
56
  normalize_e4m3fn_to_e4m3fnuz,
56
57
  )
57
58
  from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
@@ -235,7 +236,7 @@ class Fp8LinearMethod(LinearMethodBase):
235
236
  f"{input_size_per_partition} is not divisible by "
236
237
  f"weight quantization block_k = {block_k}."
237
238
  )
238
- # Required by collum parallel or enabling merged weights
239
+ # Required by column parallel or enabling merged weights
239
240
  if (
240
241
  tp_size > 1 and output_size // output_size_per_partition == tp_size
241
242
  ) or len(output_partition_sizes) > 1:
@@ -470,6 +471,7 @@ class Fp8MoEMethod:
470
471
  def __init__(self, quant_config):
471
472
  self.quant_config = quant_config
472
473
  self.block_quant = self.quant_config.weight_block_size is not None
474
+ self.cutlass_fp8_supported = cutlass_fp8_supported()
473
475
 
474
476
  def create_weights(
475
477
  self,
@@ -491,7 +493,7 @@ class Fp8MoEMethod:
491
493
  self.quant_config.weight_block_size[1],
492
494
  )
493
495
  # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
494
- # Required by collum parallel or enabling merged weights
496
+ # Required by column parallel or enabling merged weights
495
497
  if intermediate_size % block_n != 0:
496
498
  raise ValueError(
497
499
  f"The output_size of gate's and up's weight = "
@@ -568,6 +570,63 @@ class Fp8MoEMethod:
568
570
  layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
569
571
  layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
570
572
  assert self.quant_config.activation_scheme == "dynamic"
573
+ if (
574
+ get_bool_env_var("CUTLASS_MOE")
575
+ and self.cutlass_fp8_supported
576
+ and is_sm100_supported()
577
+ ):
578
+ self.ab_strides1 = torch.full(
579
+ (num_experts,),
580
+ hidden_size,
581
+ device=w13_weight.device,
582
+ dtype=torch.int64,
583
+ )
584
+ self.c_strides1 = torch.full(
585
+ (num_experts,),
586
+ 2 * intermediate_size,
587
+ device=w13_weight.device,
588
+ dtype=torch.int64,
589
+ )
590
+ self.ab_strides2 = torch.full(
591
+ (num_experts,),
592
+ intermediate_size,
593
+ device=w2_weight.device,
594
+ dtype=torch.int64,
595
+ )
596
+ self.c_strides2 = torch.full(
597
+ (num_experts,),
598
+ hidden_size,
599
+ device=w2_weight.device,
600
+ dtype=torch.int64,
601
+ )
602
+ self.workspace = torch.empty(
603
+ 90000, device=w13_weight.device, dtype=torch.uint8
604
+ )
605
+ self.a_ptr = torch.empty(
606
+ num_experts, device=w13_weight.device, dtype=torch.int64
607
+ )
608
+ self.b_ptr = torch.empty(
609
+ num_experts, device=w13_weight.device, dtype=torch.int64
610
+ )
611
+ self.out_ptr = torch.empty(
612
+ num_experts, device=w13_weight.device, dtype=torch.int64
613
+ )
614
+ self.a_scales_ptr = torch.empty(
615
+ num_experts, device=w13_weight.device, dtype=torch.int64
616
+ )
617
+ self.b_scales_ptr = torch.empty(
618
+ num_experts, device=w13_weight.device, dtype=torch.int64
619
+ )
620
+ self.expert_offsets = torch.empty(
621
+ num_experts + 1, device=w13_weight.device, dtype=torch.int32
622
+ )
623
+ self.problem_sizes1 = torch.empty(
624
+ num_experts, 3, device=w13_weight.device, dtype=torch.int32
625
+ )
626
+ self.problem_sizes2 = torch.empty(
627
+ num_experts, 3, device=w13_weight.device, dtype=torch.int32
628
+ )
629
+
571
630
  else:
572
631
  # Allocate 2 scales for w1 and w3 respectively.
573
632
  # They will be combined to a single scale after weight loading.
@@ -913,6 +972,37 @@ class Fp8MoEMethod:
913
972
  if ret is not None:
914
973
  return ret
915
974
 
975
+ if (
976
+ get_bool_env_var("CUTLASS_MOE")
977
+ and self.cutlass_fp8_supported
978
+ and self.block_quant
979
+ and is_sm100_supported()
980
+ ):
981
+ from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts
982
+
983
+ return cutlass_fused_experts(
984
+ x,
985
+ layer.w13_weight.transpose(1, 2),
986
+ layer.w2_weight.transpose(1, 2),
987
+ layer.w13_weight_scale_inv.transpose(1, 2),
988
+ layer.w2_weight_scale_inv.transpose(1, 2),
989
+ topk_weights,
990
+ topk_ids,
991
+ self.ab_strides1,
992
+ self.c_strides1,
993
+ self.ab_strides2,
994
+ self.c_strides2,
995
+ self.workspace,
996
+ self.a_ptr,
997
+ self.b_ptr,
998
+ self.out_ptr,
999
+ self.a_scales_ptr,
1000
+ self.b_scales_ptr,
1001
+ self.expert_offsets,
1002
+ self.problem_sizes1,
1003
+ self.problem_sizes2,
1004
+ use_fp8_blockscale=True,
1005
+ )
916
1006
  # Expert fusion with FP8 quantization
917
1007
  return fused_experts(
918
1008
  x,
@@ -104,7 +104,7 @@ def _per_token_group_quant_fp8(
104
104
  y_s_ptr,
105
105
  # Stride of input
106
106
  y_stride,
107
- # Collums of input
107
+ # Columns of input
108
108
  N,
109
109
  # Avoid to divide zero
110
110
  eps,
@@ -342,7 +342,7 @@ def _static_quant_fp8(
342
342
  y_s_repeat_ptr,
343
343
  # Stride of input
344
344
  y_stride,
345
- # Collums of input
345
+ # Columns of input
346
346
  N,
347
347
  # Information for float8
348
348
  fp8_min,
@@ -794,7 +794,7 @@ def w8a8_block_fp8_matmul(
794
794
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
795
795
  else:
796
796
  # Default config
797
- # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
797
+ # Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
798
798
  config = {
799
799
  "BLOCK_SIZE_M": 64,
800
800
  "BLOCK_SIZE_N": block_size[0],
@@ -80,6 +80,12 @@ def cutlass_fp8_supported():
80
80
  return False
81
81
 
82
82
 
83
+ def is_sm100_supported(device=None) -> bool:
84
+ return (torch.cuda.get_device_capability(device)[0] == 10) and (
85
+ torch.version.cuda >= "12.8"
86
+ )
87
+
88
+
83
89
  def normalize_e4m3fn_to_e4m3fnuz(
84
90
  weight: torch.Tensor,
85
91
  weight_scale: torch.Tensor,
@@ -1,21 +1,28 @@
1
1
  import logging
2
2
  from fractions import Fraction
3
- from typing import Any, Dict, List, Optional, Union
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
4
 
5
5
  import torch
6
6
 
7
- from sglang.srt.layers.linear import LinearBase
8
- from sglang.srt.layers.quantization.base_config import QuantizationConfig
7
+ from sglang.srt.layers.linear import LinearBase, set_weight_attrs
8
+ from sglang.srt.layers.quantization.base_config import (
9
+ QuantizationConfig,
10
+ QuantizeMethodBase,
11
+ )
12
+ from sglang.srt.layers.quantization.utils import replace_parameter
9
13
  from sglang.srt.utils import is_cuda
10
14
 
11
15
  _is_cuda = is_cuda()
12
16
 
13
17
  try:
14
- from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
18
+ from vllm import _custom_ops as ops
15
19
  from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
16
20
  from vllm.model_executor.layers.quantization.gptq_marlin import (
21
+ FusedMoE,
22
+ FusedMoEMethodBase,
23
+ FusedMoeWeightScaleSupported,
17
24
  GPTQMarlinLinearMethod,
18
- GPTQMarlinMoEMethod,
25
+ marlin_moe_permute_scales,
19
26
  )
20
27
  from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
21
28
  from vllm.model_executor.layers.quantization.utils.marlin_utils import (
@@ -27,7 +34,9 @@ try:
27
34
  except ImportError:
28
35
  VLLM_AVAILABLE = False
29
36
 
30
- GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any
37
+ GPTQLinearMethod = MarlinLinearMethod = Any
38
+
39
+ FusedMoEMethodBase = QuantizeMethodBase
31
40
 
32
41
  class scalar_types:
33
42
  uint4b8 = "uint4b8"
@@ -437,3 +446,286 @@ class MarlinConfig(QuantizationConfig):
437
446
  ):
438
447
  return MarlinLinearMethod(self)
439
448
  return None
449
+
450
+
451
+ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
452
+ """MoE Marlin method with quantization."""
453
+
454
+ def __init__(self, quant_config: GPTQMarlinConfig) -> None:
455
+ self.quant_config = quant_config
456
+
457
+ def create_weights(
458
+ self,
459
+ layer: torch.nn.Module,
460
+ num_experts: int,
461
+ hidden_size: int,
462
+ intermediate_size_per_partition: int,
463
+ params_dtype: torch.dtype,
464
+ **extra_weight_attrs,
465
+ ):
466
+ intermediate_size = extra_weight_attrs.pop("intermediate_size")
467
+
468
+ self.is_k_full = (not self.quant_config.desc_act) or (
469
+ intermediate_size_per_partition == intermediate_size
470
+ )
471
+
472
+ if self.quant_config.group_size != -1:
473
+ scales_size13 = hidden_size // self.quant_config.group_size
474
+ w2_scales_size = (
475
+ intermediate_size
476
+ if self.quant_config.desc_act
477
+ else intermediate_size_per_partition
478
+ )
479
+ scales_size2 = w2_scales_size // self.quant_config.group_size
480
+ strategy = FusedMoeWeightScaleSupported.GROUP.value
481
+ else:
482
+ scales_size13 = 1
483
+ scales_size2 = 1
484
+ strategy = FusedMoeWeightScaleSupported.CHANNEL.value
485
+
486
+ extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
487
+ # Fused gate_up_proj (column parallel)
488
+ w13_qweight = torch.nn.Parameter(
489
+ torch.empty(
490
+ num_experts,
491
+ hidden_size // self.quant_config.pack_factor,
492
+ 2 * intermediate_size_per_partition,
493
+ dtype=torch.int32,
494
+ ),
495
+ requires_grad=False,
496
+ )
497
+ layer.register_parameter("w13_qweight", w13_qweight)
498
+ set_weight_attrs(w13_qweight, extra_weight_attrs)
499
+ # down_proj (row parallel)
500
+ w2_qweight = torch.nn.Parameter(
501
+ torch.empty(
502
+ num_experts,
503
+ intermediate_size_per_partition // self.quant_config.pack_factor,
504
+ hidden_size,
505
+ dtype=torch.int32,
506
+ ),
507
+ requires_grad=False,
508
+ )
509
+ layer.register_parameter("w2_qweight", w2_qweight)
510
+ set_weight_attrs(w2_qweight, extra_weight_attrs)
511
+ # up_proj scales
512
+ w13_scales = torch.nn.Parameter(
513
+ torch.empty(
514
+ num_experts,
515
+ scales_size13,
516
+ 2 * intermediate_size_per_partition,
517
+ dtype=torch.half,
518
+ ),
519
+ requires_grad=False,
520
+ )
521
+ layer.register_parameter("w13_scales", w13_scales)
522
+ set_weight_attrs(w13_scales, extra_weight_attrs)
523
+ # down_proj scales
524
+ w2_scales = torch.nn.Parameter(
525
+ torch.empty(num_experts, scales_size2, hidden_size, dtype=torch.half),
526
+ requires_grad=False,
527
+ )
528
+ layer.register_parameter("w2_scales", w2_scales)
529
+ set_weight_attrs(w2_scales, extra_weight_attrs)
530
+ # dont shard the w2 scales when running act order
531
+ set_weight_attrs(w2_scales, {"load_full_w2": self.quant_config.desc_act})
532
+ # up_proj scales
533
+ w13_qzeros = torch.nn.Parameter(
534
+ torch.empty(
535
+ num_experts,
536
+ scales_size13,
537
+ 2 * intermediate_size_per_partition // self.quant_config.pack_factor,
538
+ dtype=params_dtype,
539
+ ),
540
+ requires_grad=False,
541
+ )
542
+ layer.register_parameter("w13_qzeros", w13_qzeros)
543
+ set_weight_attrs(w13_qzeros, extra_weight_attrs)
544
+ # down_proj scales
545
+ w2_qzeros = torch.nn.Parameter(
546
+ torch.empty(
547
+ num_experts,
548
+ scales_size2,
549
+ hidden_size // self.quant_config.pack_factor,
550
+ dtype=params_dtype,
551
+ ),
552
+ requires_grad=False,
553
+ )
554
+ layer.register_parameter("w2_qzeros", w2_qzeros)
555
+ set_weight_attrs(w2_qzeros, extra_weight_attrs)
556
+ # dont shard the w2 scales when running act order
557
+ set_weight_attrs(w2_qzeros, {"load_full_w2": self.quant_config.desc_act})
558
+ w13_g_idx = torch.nn.Parameter(
559
+ torch.empty(
560
+ num_experts,
561
+ hidden_size,
562
+ dtype=torch.int32,
563
+ ),
564
+ requires_grad=False,
565
+ )
566
+ layer.register_parameter("w13_g_idx", w13_g_idx)
567
+ set_weight_attrs(w13_g_idx, extra_weight_attrs)
568
+ w2_g_idx = torch.nn.Parameter(
569
+ torch.empty(
570
+ num_experts,
571
+ intermediate_size_per_partition,
572
+ dtype=torch.int32,
573
+ ),
574
+ requires_grad=False,
575
+ )
576
+ layer.register_parameter("w2_g_idx", w2_g_idx)
577
+ set_weight_attrs(w2_g_idx, extra_weight_attrs)
578
+ w13_g_idx_sort_indices = torch.nn.Parameter(
579
+ torch.empty(
580
+ num_experts,
581
+ hidden_size,
582
+ dtype=torch.int32,
583
+ ),
584
+ requires_grad=False,
585
+ )
586
+ layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
587
+ set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
588
+ w2_g_idx_sort_indices = torch.nn.Parameter(
589
+ torch.empty(
590
+ num_experts,
591
+ intermediate_size_per_partition,
592
+ dtype=torch.int32,
593
+ ),
594
+ requires_grad=False,
595
+ )
596
+ layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
597
+ set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
598
+
599
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
600
+
601
+ # Process act_order
602
+ if self.quant_config.desc_act:
603
+ # Get sorting based on g_idx
604
+ num_experts = layer.w13_g_idx.shape[0]
605
+ w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
606
+ w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
607
+ w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
608
+ w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
609
+ for e in range(num_experts):
610
+ w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
611
+ torch.int32
612
+ )
613
+ w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
614
+ torch.int32
615
+ )
616
+ w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
617
+ w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
618
+ replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
619
+ replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
620
+ replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
621
+ replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
622
+ else:
623
+ # Reset g_idx related tensors
624
+ num_experts = layer.w13_g_idx.shape[0]
625
+ device = layer.w13_g_idx.device
626
+ layer.w13_g_idx = torch.nn.Parameter(
627
+ torch.empty((num_experts, 0), dtype=torch.int32, device=device),
628
+ requires_grad=False,
629
+ )
630
+ layer.w2_g_idx = torch.nn.Parameter(
631
+ torch.empty((num_experts, 0), dtype=torch.int32, device=device),
632
+ requires_grad=False,
633
+ )
634
+ layer.w13_g_idx_sort_indices = torch.nn.Parameter(
635
+ torch.empty((num_experts, 0), dtype=torch.int32, device=device),
636
+ requires_grad=False,
637
+ )
638
+ layer.w2_g_idx_sort_indices = torch.nn.Parameter(
639
+ torch.empty((num_experts, 0), dtype=torch.int32, device=device),
640
+ requires_grad=False,
641
+ )
642
+ # Repack weights
643
+ marlin_w13_qweight = ops.gptq_marlin_moe_repack(
644
+ layer.w13_qweight,
645
+ layer.w13_g_idx_sort_indices,
646
+ layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
647
+ layer.w13_qweight.shape[2],
648
+ self.quant_config.quant_type.size_bits,
649
+ )
650
+ replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
651
+ marlin_w2_qweight = ops.gptq_marlin_moe_repack(
652
+ layer.w2_qweight,
653
+ layer.w2_g_idx_sort_indices,
654
+ layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
655
+ layer.w2_qweight.shape[2],
656
+ self.quant_config.quant_type.size_bits,
657
+ )
658
+ replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
659
+ # Repack scales
660
+ marlin_w13_scales = marlin_moe_permute_scales(
661
+ s=layer.w13_scales,
662
+ size_k=layer.intermediate_size_per_partition,
663
+ size_n=layer.w13_scales.shape[2],
664
+ group_size=self.quant_config.group_size,
665
+ )
666
+ replace_parameter(layer, "w13_scales", marlin_w13_scales)
667
+ marlin_w2_scales = marlin_moe_permute_scales(
668
+ s=layer.w2_scales,
669
+ size_k=layer.w2_scales.shape[1]
670
+ * (
671
+ self.quant_config.group_size
672
+ if self.quant_config.group_size != -1
673
+ else self.quant_config.pack_factor
674
+ ),
675
+ size_n=layer.w2_scales.shape[2],
676
+ group_size=self.quant_config.group_size,
677
+ )
678
+ replace_parameter(layer, "w2_scales", marlin_w2_scales)
679
+
680
+ def apply(
681
+ self,
682
+ layer: torch.nn.Module,
683
+ x: torch.Tensor,
684
+ router_logits: torch.Tensor,
685
+ top_k: int,
686
+ renormalize: bool,
687
+ use_grouped_topk: bool = False,
688
+ topk_group: Optional[int] = None,
689
+ num_expert_group: Optional[int] = None,
690
+ global_num_experts: int = -1,
691
+ expert_map: Optional[torch.Tensor] = None,
692
+ custom_routing_function: Optional[Callable] = None,
693
+ scoring_func: str = "softmax",
694
+ e_score_correction_bias: Optional[torch.Tensor] = None,
695
+ activation: str = "silu",
696
+ ) -> torch.Tensor:
697
+ assert activation == "silu", "Only SiLU activation is supported."
698
+
699
+ # The input must currently be float16
700
+ orig_dtype = x.dtype
701
+ x = x.half()
702
+
703
+ topk_weights, topk_ids = FusedMoE.select_experts(
704
+ hidden_states=x,
705
+ router_logits=router_logits,
706
+ use_grouped_topk=use_grouped_topk,
707
+ top_k=top_k,
708
+ renormalize=renormalize,
709
+ topk_group=topk_group,
710
+ num_expert_group=num_expert_group,
711
+ custom_routing_function=custom_routing_function,
712
+ scoring_func=scoring_func,
713
+ e_score_correction_bias=e_score_correction_bias,
714
+ )
715
+
716
+ return torch.ops.vllm.fused_marlin_moe(
717
+ x,
718
+ layer.w13_qweight,
719
+ layer.w2_qweight,
720
+ layer.w13_scales,
721
+ layer.w2_scales,
722
+ router_logits,
723
+ topk_weights,
724
+ topk_ids,
725
+ g_idx1=layer.w13_g_idx,
726
+ g_idx2=layer.w2_g_idx,
727
+ sort_indices1=layer.w13_g_idx_sort_indices,
728
+ sort_indices2=layer.w2_g_idx_sort_indices,
729
+ num_bits=self.quant_config.quant_type.size_bits,
730
+ is_k_full=self.is_k_full,
731
+ ).to(orig_dtype)
@@ -22,9 +22,11 @@ def _per_token_quant_int8(
22
22
  x_ptr,
23
23
  xq_ptr,
24
24
  scale_ptr,
25
+ x_sum_ptr,
25
26
  stride_x,
26
27
  stride_xq,
27
28
  N,
29
+ CAL_SUM: tl.constexpr,
28
30
  BLOCK: tl.constexpr,
29
31
  ):
30
32
  # Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
@@ -38,16 +40,23 @@ def _per_token_quant_int8(
38
40
  scale_x = absmax / 127
39
41
  x_q = x * (127 / absmax)
40
42
  x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
43
+ if CAL_SUM:
44
+ x_sum = tl.sum(x, axis=0)
45
+ tl.store(x_sum_ptr + row_id, x_sum.to(x_sum_ptr.dtype.element_ty))
41
46
 
42
47
  tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
43
- tl.store(scale_ptr + row_id, scale_x)
48
+ tl.store(scale_ptr + row_id, scale_x.to(scale_ptr.dtype.element_ty))
44
49
 
45
50
 
46
- def per_token_quant_int8(x):
51
+ def per_token_quant_int8(x, scale_dtype=torch.float32, cal_sum=False):
47
52
  M = x.numel() // x.shape[-1]
48
53
  N = x.shape[-1]
49
54
  x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
50
- scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
55
+ scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=scale_dtype)
56
+ if cal_sum:
57
+ x_sum = torch.empty(x.shape[:-1], device=x.device, dtype=x.dtype)
58
+ else:
59
+ x_sum = None
51
60
  BLOCK = triton.next_power_of_2(N)
52
61
  # heuristics for number of warps
53
62
  num_warps = min(max(BLOCK // 256, 1), 8)
@@ -57,15 +66,19 @@ def per_token_quant_int8(x):
57
66
  x,
58
67
  x_q,
59
68
  scales,
69
+ x_sum,
60
70
  stride_x=x.stride(-2),
61
71
  stride_xq=x_q.stride(-2),
62
72
  N=N,
73
+ CAL_SUM=cal_sum,
63
74
  BLOCK=BLOCK,
64
75
  num_warps=num_warps,
65
76
  num_stages=1,
66
77
  )
67
-
68
- return x_q, scales
78
+ if cal_sum:
79
+ return x_q, scales, x_sum
80
+ else:
81
+ return x_q, scales
69
82
 
70
83
 
71
84
  @triton.jit
@@ -76,7 +89,7 @@ def _per_token_group_quant_int8(
76
89
  y_s_ptr,
77
90
  # Stride of input
78
91
  y_stride,
79
- # Collums of input
92
+ # Columns of input
80
93
  N,
81
94
  # Avoid to divide zero
82
95
  eps,
@@ -370,7 +383,7 @@ def w8a8_block_int8_matmul(
370
383
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
371
384
  else:
372
385
  # Default config
373
- # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
386
+ # Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
374
387
  config = {
375
388
  "BLOCK_SIZE_M": 64,
376
389
  "BLOCK_SIZE_N": block_size[0],