sglang 0.4.2.post3__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/check_env.py +1 -0
  2. sglang/global_config.py +2 -0
  3. sglang/srt/constrained/outlines_backend.py +4 -1
  4. sglang/srt/entrypoints/engine.py +2 -2
  5. sglang/srt/layers/attention/flashinfer_backend.py +265 -147
  6. sglang/srt/layers/attention/triton_backend.py +358 -72
  7. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  8. sglang/srt/layers/linear.py +12 -5
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +178 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +175 -0
  18. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -5
  19. sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -0
  20. sglang/srt/layers/moe/topk.py +1 -1
  21. sglang/srt/layers/quantization/__init__.py +51 -5
  22. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  25. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  31. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +29 -29
  32. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +33 -33
  35. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  37. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +27 -27
  39. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  41. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +24 -24
  49. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  51. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +42 -42
  53. sglang/srt/layers/quantization/fp8_kernel.py +123 -17
  54. sglang/srt/layers/quantization/fp8_utils.py +33 -4
  55. sglang/srt/lora/backend/__init__.py +25 -5
  56. sglang/srt/lora/backend/base_backend.py +31 -9
  57. sglang/srt/lora/backend/flashinfer_backend.py +41 -4
  58. sglang/srt/lora/backend/triton_backend.py +34 -4
  59. sglang/srt/lora/layers.py +293 -0
  60. sglang/srt/lora/lora.py +101 -326
  61. sglang/srt/lora/lora_manager.py +101 -269
  62. sglang/srt/lora/mem_pool.py +174 -0
  63. sglang/srt/lora/triton_ops/__init__.py +7 -1
  64. sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
  65. sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
  66. sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
  67. sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
  68. sglang/srt/lora/utils.py +141 -0
  69. sglang/srt/managers/detokenizer_manager.py +1 -0
  70. sglang/srt/managers/io_struct.py +4 -0
  71. sglang/srt/managers/schedule_batch.py +16 -3
  72. sglang/srt/managers/scheduler.py +29 -0
  73. sglang/srt/managers/tokenizer_manager.py +6 -0
  74. sglang/srt/managers/tp_worker_overlap_thread.py +4 -0
  75. sglang/srt/model_executor/cuda_graph_runner.py +16 -1
  76. sglang/srt/model_executor/model_runner.py +12 -2
  77. sglang/srt/models/deepseek_v2.py +17 -7
  78. sglang/srt/server_args.py +20 -1
  79. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  80. sglang/srt/speculative/eagle_utils.py +64 -21
  81. sglang/srt/speculative/eagle_worker.py +29 -8
  82. sglang/srt/utils.py +7 -0
  83. sglang/version.py +1 -1
  84. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/METADATA +6 -5
  85. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/RECORD +88 -55
  86. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/LICENSE +0 -0
  87. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/WHEEL +0 -0
  88. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,175 @@
1
+ {
2
+ "4": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 0,
9
+ "waves_per_eu": 4,
10
+ "matrix_instr_nonkdim": 16,
11
+ "kpack": 2
12
+ },
13
+ "8": {
14
+ "BLOCK_SIZE_M": 16,
15
+ "BLOCK_SIZE_N": 64,
16
+ "BLOCK_SIZE_K": 256,
17
+ "GROUP_SIZE_M": 1,
18
+ "num_warps": 4,
19
+ "num_stages": 0,
20
+ "waves_per_eu": 1,
21
+ "matrix_instr_nonkdim": 16,
22
+ "kpack": 1
23
+ },
24
+ "16": {
25
+ "BLOCK_SIZE_M": 32,
26
+ "BLOCK_SIZE_N": 64,
27
+ "BLOCK_SIZE_K": 256,
28
+ "GROUP_SIZE_M": 1,
29
+ "num_warps": 8,
30
+ "num_stages": 0,
31
+ "waves_per_eu": 2,
32
+ "matrix_instr_nonkdim": 16,
33
+ "kpack": 2
34
+ },
35
+ "32": {
36
+ "BLOCK_SIZE_M": 32,
37
+ "BLOCK_SIZE_N": 64,
38
+ "BLOCK_SIZE_K": 256,
39
+ "GROUP_SIZE_M": 1,
40
+ "num_warps": 8,
41
+ "num_stages": 0,
42
+ "waves_per_eu": 1,
43
+ "matrix_instr_nonkdim": 16,
44
+ "kpack": 2
45
+ },
46
+ "64": {
47
+ "BLOCK_SIZE_M": 32,
48
+ "BLOCK_SIZE_N": 64,
49
+ "BLOCK_SIZE_K": 256,
50
+ "GROUP_SIZE_M": 1,
51
+ "num_warps": 4,
52
+ "num_stages": 0,
53
+ "waves_per_eu": 2,
54
+ "matrix_instr_nonkdim": 16,
55
+ "kpack": 2
56
+ },
57
+ "128": {
58
+ "BLOCK_SIZE_M": 16,
59
+ "BLOCK_SIZE_N": 64,
60
+ "BLOCK_SIZE_K": 256,
61
+ "GROUP_SIZE_M": 1,
62
+ "num_warps": 4,
63
+ "num_stages": 0,
64
+ "waves_per_eu": 1,
65
+ "matrix_instr_nonkdim": 16,
66
+ "kpack": 1
67
+ },
68
+ "256": {
69
+ "BLOCK_SIZE_M": 128,
70
+ "BLOCK_SIZE_N": 256,
71
+ "BLOCK_SIZE_K": 128,
72
+ "GROUP_SIZE_M": 32,
73
+ "num_warps": 8,
74
+ "num_stages": 4
75
+ },
76
+ "512": {
77
+ "BLOCK_SIZE_M": 64,
78
+ "BLOCK_SIZE_N": 64,
79
+ "BLOCK_SIZE_K": 256,
80
+ "GROUP_SIZE_M": 1,
81
+ "num_warps": 4,
82
+ "num_stages": 0,
83
+ "waves_per_eu": 2,
84
+ "matrix_instr_nonkdim": 16,
85
+ "kpack": 2
86
+ },
87
+ "1024": {
88
+ "BLOCK_SIZE_M": 128,
89
+ "BLOCK_SIZE_N": 128,
90
+ "BLOCK_SIZE_K": 128,
91
+ "GROUP_SIZE_M": 1,
92
+ "num_warps": 8,
93
+ "num_stages": 0,
94
+ "waves_per_eu": 4,
95
+ "matrix_instr_nonkdim": 16,
96
+ "kpack": 2
97
+ },
98
+ "2048": {
99
+ "BLOCK_SIZE_M": 128,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 0,
105
+ "waves_per_eu": 2,
106
+ "matrix_instr_nonkdim": 16,
107
+ "kpack": 2
108
+ },
109
+ "4096": {
110
+ "BLOCK_SIZE_M": 128,
111
+ "BLOCK_SIZE_N": 128,
112
+ "BLOCK_SIZE_K": 128,
113
+ "GROUP_SIZE_M": 1,
114
+ "num_warps": 8,
115
+ "num_stages": 0,
116
+ "waves_per_eu": 2,
117
+ "matrix_instr_nonkdim": 16,
118
+ "kpack": 2
119
+ },
120
+ "8192": {
121
+ "BLOCK_SIZE_M": 256,
122
+ "BLOCK_SIZE_N": 256,
123
+ "BLOCK_SIZE_K": 64,
124
+ "GROUP_SIZE_M": 1,
125
+ "num_warps": 8,
126
+ "num_stages": 0,
127
+ "waves_per_eu": 2,
128
+ "matrix_instr_nonkdim": 16,
129
+ "kpack": 1
130
+ },
131
+ "16384": {
132
+ "BLOCK_SIZE_M": 256,
133
+ "BLOCK_SIZE_N": 256,
134
+ "BLOCK_SIZE_K": 64,
135
+ "GROUP_SIZE_M": 1,
136
+ "num_warps": 8,
137
+ "num_stages": 0,
138
+ "waves_per_eu": 1,
139
+ "matrix_instr_nonkdim": 16,
140
+ "kpack": 1
141
+ },
142
+ "32768": {
143
+ "BLOCK_SIZE_M": 256,
144
+ "BLOCK_SIZE_N": 256,
145
+ "BLOCK_SIZE_K": 64,
146
+ "GROUP_SIZE_M": 1,
147
+ "num_warps": 8,
148
+ "num_stages": 0,
149
+ "waves_per_eu": 0,
150
+ "matrix_instr_nonkdim": 16,
151
+ "kpack": 1
152
+ },
153
+ "65536": {
154
+ "BLOCK_SIZE_M": 256,
155
+ "BLOCK_SIZE_N": 256,
156
+ "BLOCK_SIZE_K": 64,
157
+ "GROUP_SIZE_M": 1,
158
+ "num_warps": 8,
159
+ "num_stages": 0,
160
+ "waves_per_eu": 1,
161
+ "matrix_instr_nonkdim": 16,
162
+ "kpack": 1
163
+ },
164
+ "131072": {
165
+ "BLOCK_SIZE_M": 256,
166
+ "BLOCK_SIZE_N": 128,
167
+ "BLOCK_SIZE_K": 64,
168
+ "GROUP_SIZE_M": 1,
169
+ "num_warps": 4,
170
+ "num_stages": 0,
171
+ "waves_per_eu": 2,
172
+ "matrix_instr_nonkdim": 16,
173
+ "kpack": 2
174
+ }
175
+ }
@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
18
18
  from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
19
19
 
20
20
  is_hip_flag = is_hip()
21
- from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
21
+
22
22
 
23
23
  logger = logging.getLogger(__name__)
24
24
  padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
@@ -27,6 +27,19 @@ enable_moe_align_block_size_triton = bool(
27
27
  int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
28
28
  )
29
29
 
30
+ _is_cuda = torch.cuda.is_available() and torch.version.cuda
31
+ _is_rocm = torch.cuda.is_available() and torch.version.hip
32
+
33
+ if _is_cuda:
34
+ from sgl_kernel import gelu_and_mul, silu_and_mul
35
+
36
+ from sglang.srt.layers.quantization.fp8_kernel import (
37
+ sglang_per_token_group_quant_fp8,
38
+ )
39
+
40
+ if _is_cuda or _is_rocm:
41
+ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
42
+
30
43
 
31
44
  @triton.jit
32
45
  def fused_moe_kernel(
@@ -479,7 +492,10 @@ def invoke_fused_moe_kernel(
479
492
  else:
480
493
  assert len(block_shape) == 2
481
494
  block_n, block_k = block_shape[0], block_shape[1]
482
- A, A_scale = per_token_group_quant_fp8(A, block_k)
495
+ if _is_cuda:
496
+ A, A_scale = sglang_per_token_group_quant_fp8(A, block_k)
497
+ else:
498
+ A, A_scale = per_token_group_quant_fp8(A, block_k)
483
499
  assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
484
500
  assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
485
501
  assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
@@ -989,9 +1005,15 @@ def fused_experts_impl(
989
1005
  )
990
1006
 
991
1007
  if activation == "silu":
992
- ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
1008
+ if _is_cuda:
1009
+ silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
1010
+ else:
1011
+ ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
993
1012
  elif activation == "gelu":
994
- ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
1013
+ if _is_cuda:
1014
+ gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
1015
+ else:
1016
+ ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
995
1017
  else:
996
1018
  raise ValueError(f"Unsupported activation: {activation=}")
997
1019
 
@@ -1079,7 +1101,7 @@ def fused_moe(
1079
1101
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
1080
1102
  - topk_group: Optional[int]: additional parameter for grouped_topk
1081
1103
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
1082
- note: Deepseekv2 model uses grouped_topk
1104
+ note: Deepseek V2/V3/R1 series models use grouped_topk
1083
1105
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
1084
1106
  products for w1 and w2. Defaults to False.
1085
1107
  - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
@@ -298,7 +298,9 @@ class FusedMoE(torch.nn.Module):
298
298
  layer=self,
299
299
  num_experts=num_experts,
300
300
  hidden_size=hidden_size,
301
+ # FIXME: figure out which intermediate_size to use
301
302
  intermediate_size=self.intermediate_size_per_partition,
303
+ intermediate_size_per_partition=self.intermediate_size_per_partition,
302
304
  params_dtype=params_dtype,
303
305
  weight_loader=self.weight_loader,
304
306
  )
@@ -75,7 +75,7 @@ def fused_topk(
75
75
  return topk_weights, topk_ids
76
76
 
77
77
 
78
- # This is used by the Deepseek-V2 model
78
+ # This is used by the Deepseek V2/V3/R1 series models
79
79
  @torch.compile(dynamic=True, backend=get_compiler_backend())
80
80
  def grouped_topk(
81
81
  hidden_states: torch.Tensor,
@@ -1,10 +1,13 @@
1
1
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
2
+ from typing import Callable, Dict, Optional, Type
2
3
 
3
- from typing import Dict, Type
4
-
4
+ import torch
5
5
  from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
6
6
  from vllm.model_executor.layers.quantization.awq import AWQConfig
7
- from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
7
+ from vllm.model_executor.layers.quantization.awq_marlin import (
8
+ AWQMarlinConfig,
9
+ AWQMoEMethod,
10
+ )
8
11
  from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
9
12
  from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
10
13
  CompressedTensorsConfig,
@@ -73,21 +76,61 @@ def gptq_get_quant_method(self, layer, prefix):
73
76
 
74
77
 
75
78
  def awq_get_quant_method(self, layer, prefix):
79
+ from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
76
80
  from vllm.model_executor.layers.quantization.awq_marlin import (
77
81
  AWQMarlinLinearMethod,
78
82
  AWQMoEMethod,
79
83
  )
80
84
 
81
- from sglang.srt.layers.linear import LinearBase
85
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
82
86
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
87
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
83
88
 
84
- if isinstance(layer, LinearBase):
89
+ if isinstance(layer, LinearBase) or (
90
+ isinstance(layer, ParallelLMHead) and self.lm_head_quantized
91
+ ):
92
+ if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
93
+ return UnquantizedLinearMethod()
85
94
  return AWQMarlinLinearMethod(self)
86
95
  elif isinstance(layer, FusedMoE):
87
96
  return AWQMoEMethod(self)
88
97
  return None
89
98
 
90
99
 
100
+ original_awq_moe_method_apply = AWQMoEMethod.apply
101
+
102
+
103
+ def awq_moe_method_apply(
104
+ self,
105
+ layer: torch.nn.Module,
106
+ x: torch.Tensor,
107
+ router_logits: torch.Tensor,
108
+ top_k: int,
109
+ renormalize: bool,
110
+ use_grouped_topk: bool = False,
111
+ topk_group: Optional[int] = None,
112
+ num_expert_group: Optional[int] = None,
113
+ custom_routing_function: Optional[Callable] = None,
114
+ scoring_func: str = "softmax",
115
+ e_score_correction_bias: Optional[torch.Tensor] = None,
116
+ **kwargs,
117
+ ):
118
+ return original_awq_moe_method_apply(
119
+ self,
120
+ layer,
121
+ x,
122
+ router_logits,
123
+ top_k,
124
+ renormalize,
125
+ use_grouped_topk,
126
+ topk_group,
127
+ num_expert_group,
128
+ custom_routing_function,
129
+ scoring_func,
130
+ e_score_correction_bias,
131
+ )
132
+
133
+
91
134
  def patch_vllm_linear_base_isinstance():
92
135
  import builtins
93
136
 
@@ -107,8 +150,11 @@ def patch_vllm_linear_base_isinstance():
107
150
 
108
151
  def apply_monkey_patches():
109
152
  """Apply all monkey patches in one place."""
153
+ from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
154
+
110
155
  setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
111
156
  setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
157
+ setattr(AWQMoEMethod, "apply", awq_moe_method_apply)
112
158
 
113
159
 
114
160
  patch_vllm_linear_base_isinstance()
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 8,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 32,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 32,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 32,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 32,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 32,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 32,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 32,
47
+ "num_warps": 4,
48
+ "num_stages": 5
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 32,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 64,
55
+ "num_warps": 8,
56
+ "num_stages": 5
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 32,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 32,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 4
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 128,
76
+ "BLOCK_SIZE_N": 32,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 8,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 5
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 32,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 32,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 64,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 64,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 64,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -0,0 +1,164 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 16,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 4,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0
10
+ },
11
+ "2": {
12
+ "BLOCK_SIZE_M": 64,
13
+ "BLOCK_SIZE_N": 16,
14
+ "BLOCK_SIZE_K": 128,
15
+ "GROUP_SIZE_M": 8,
16
+ "num_warps": 4,
17
+ "num_stages": 2,
18
+ "waves_per_eu": 0
19
+ },
20
+ "4": {
21
+ "BLOCK_SIZE_M": 64,
22
+ "BLOCK_SIZE_N": 16,
23
+ "BLOCK_SIZE_K": 128,
24
+ "GROUP_SIZE_M": 1,
25
+ "num_warps": 4,
26
+ "num_stages": 2,
27
+ "waves_per_eu": 0
28
+ },
29
+ "8": {
30
+ "BLOCK_SIZE_M": 64,
31
+ "BLOCK_SIZE_N": 16,
32
+ "BLOCK_SIZE_K": 128,
33
+ "GROUP_SIZE_M": 1,
34
+ "num_warps": 4,
35
+ "num_stages": 2,
36
+ "waves_per_eu": 0
37
+ },
38
+ "16": {
39
+ "BLOCK_SIZE_M": 64,
40
+ "BLOCK_SIZE_N": 16,
41
+ "BLOCK_SIZE_K": 128,
42
+ "GROUP_SIZE_M": 1,
43
+ "num_warps": 4,
44
+ "num_stages": 2,
45
+ "waves_per_eu": 0
46
+ },
47
+ "24": {
48
+ "BLOCK_SIZE_M": 64,
49
+ "BLOCK_SIZE_N": 16,
50
+ "BLOCK_SIZE_K": 128,
51
+ "GROUP_SIZE_M": 32,
52
+ "num_warps": 4,
53
+ "num_stages": 2,
54
+ "waves_per_eu": 0
55
+ },
56
+ "32": {
57
+ "BLOCK_SIZE_M": 64,
58
+ "BLOCK_SIZE_N": 16,
59
+ "BLOCK_SIZE_K": 128,
60
+ "GROUP_SIZE_M": 1,
61
+ "num_warps": 4,
62
+ "num_stages": 2,
63
+ "waves_per_eu": 0
64
+ },
65
+ "48": {
66
+ "BLOCK_SIZE_M": 64,
67
+ "BLOCK_SIZE_N": 16,
68
+ "BLOCK_SIZE_K": 128,
69
+ "GROUP_SIZE_M": 1,
70
+ "num_warps": 4,
71
+ "num_stages": 2,
72
+ "waves_per_eu": 0
73
+ },
74
+ "64": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 16,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 4,
80
+ "num_stages": 2,
81
+ "waves_per_eu": 0
82
+ },
83
+ "96": {
84
+ "BLOCK_SIZE_M": 64,
85
+ "BLOCK_SIZE_N": 16,
86
+ "BLOCK_SIZE_K": 128,
87
+ "GROUP_SIZE_M": 1,
88
+ "num_warps": 4,
89
+ "num_stages": 2,
90
+ "waves_per_eu": 0
91
+ },
92
+ "128": {
93
+ "BLOCK_SIZE_M": 64,
94
+ "BLOCK_SIZE_N": 16,
95
+ "BLOCK_SIZE_K": 128,
96
+ "GROUP_SIZE_M": 1,
97
+ "num_warps": 4,
98
+ "num_stages": 2,
99
+ "waves_per_eu": 0
100
+ },
101
+ "256": {
102
+ "BLOCK_SIZE_M": 64,
103
+ "BLOCK_SIZE_N": 16,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 4,
106
+ "num_warps": 4,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0
109
+ },
110
+ "512": {
111
+ "BLOCK_SIZE_M": 64,
112
+ "BLOCK_SIZE_N": 16,
113
+ "BLOCK_SIZE_K": 128,
114
+ "GROUP_SIZE_M": 4,
115
+ "num_warps": 4,
116
+ "num_stages": 2,
117
+ "waves_per_eu": 0
118
+ },
119
+ "1024": {
120
+ "BLOCK_SIZE_M": 64,
121
+ "BLOCK_SIZE_N": 16,
122
+ "BLOCK_SIZE_K": 128,
123
+ "GROUP_SIZE_M": 4,
124
+ "num_warps": 4,
125
+ "num_stages": 2,
126
+ "waves_per_eu": 0
127
+ },
128
+ "1536": {
129
+ "BLOCK_SIZE_M": 64,
130
+ "BLOCK_SIZE_N": 64,
131
+ "BLOCK_SIZE_K": 128,
132
+ "GROUP_SIZE_M": 8,
133
+ "num_warps": 4,
134
+ "num_stages": 2,
135
+ "waves_per_eu": 0
136
+ },
137
+ "2048": {
138
+ "BLOCK_SIZE_M": 128,
139
+ "BLOCK_SIZE_N": 32,
140
+ "BLOCK_SIZE_K": 128,
141
+ "GROUP_SIZE_M": 8,
142
+ "num_warps": 4,
143
+ "num_stages": 2,
144
+ "waves_per_eu": 0
145
+ },
146
+ "3072": {
147
+ "BLOCK_SIZE_M": 64,
148
+ "BLOCK_SIZE_N": 128,
149
+ "BLOCK_SIZE_K": 128,
150
+ "GROUP_SIZE_M": 16,
151
+ "num_warps": 4,
152
+ "num_stages": 2,
153
+ "waves_per_eu": 0
154
+ },
155
+ "4096": {
156
+ "BLOCK_SIZE_M": 64,
157
+ "BLOCK_SIZE_N": 64,
158
+ "BLOCK_SIZE_K": 128,
159
+ "GROUP_SIZE_M": 16,
160
+ "num_warps": 4,
161
+ "num_stages": 2,
162
+ "waves_per_eu": 0
163
+ }
164
+ }